diff --git a/main.cpp b/main.cpp index 4b6db4434a91a7c16b4abcab3d45c2dc7745cc47..bd4711b7e56fd3b800852d1c3bb90b99c1116f84 100644 --- a/main.cpp +++ b/main.cpp @@ -9,11 +9,62 @@ #include <limits> #include <chrono> +// Helper function to print a matrix (copies to host first) +void printMatrix(const std::string& label, const MatrixCL& mat) { + std::cout << label << " (" << mat.numRows() << "x" << mat.numCols() << "):\n"; + try { + std::vector<float> host_data = mat.copyToHost(); + for (int i = 0; i < mat.numRows(); ++i) { + std::cout << " ["; + for (int j = 0; j < mat.numCols(); ++j) { + std::cout << " " << host_data[i * mat.numCols() + j]; + } + std::cout << " ]\n"; + } + std::cout << std::endl; + } catch (const std::runtime_error& e) { + std::cerr << "Error printing matrix: " << e.what() << std::endl; + } +} + // Helper function for approximate float comparison bool approxEqual(float a, float b, float epsilon = 1e-5f) { return std::abs(a - b) < epsilon; } +// Helper function to verify matrix contents +bool verifyMatrix(const std::string& label, const MatrixCL& mat, const std::vector<float>& expected, float epsilon = 1e-5f) { + std::cout << "Verifying " << label << "..." << std::endl; + if (static_cast<size_t>(mat.numRows() * mat.numCols()) != expected.size()) { + std::cerr << "Verification failed: Dimension mismatch for " << label << ". Got " + << mat.numRows() << "x" << mat.numCols() << ", expected " << expected.size() << " elements." << std::endl; + return false; + } + try { + std::vector<float> actual = mat.copyToHost(); + bool match = true; + for (size_t i = 0; i < actual.size(); ++i) { + if (!approxEqual(actual[i], expected[i], epsilon)) { + std::cerr << "Verification failed for " << label << " at index " << i + << ". Got " << actual[i] << ", expected " << expected[i] << std::endl; + match = false; + // Don't break, report all mismatches if desired, or break here for efficiency + break; + } + } + if (match) { + std::cout << label << " verified successfully." << std::endl; + } else { + std::cout << label << " verification failed." << std::endl; + } + return match; + } catch (const std::runtime_error& e) { + std::cerr << "Error verifying matrix " << label << ": " << e.what() << std::endl; + return false; + } +} + + cl_ulong getElapsedTime(const cl::Event& event) { cl_ulong start = event.getProfilingInfo<CL_PROFILING_COMMAND_START>(); cl_ulong end = event.getProfilingInfo<CL_PROFILING_COMMAND_END>();