From 1eb2c35586419d37bc0fd98898457ef12dff1a36 Mon Sep 17 00:00:00 2001 From: JordanHanotiaux <103147288+JordanHanotiaux@users.noreply.github.com> Date: Mon, 19 May 2025 10:32:52 +0200 Subject: [PATCH] Update matrix_opencl.cpp --- matrix_opencl.cpp | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/matrix_opencl.cpp b/matrix_opencl.cpp index a2689fa..4a4c1f8 100644 --- a/matrix_opencl.cpp +++ b/matrix_opencl.cpp @@ -366,7 +366,18 @@ MatrixCL MatrixCL::operator*(const MatrixCL& other) const { kernel.setArg(4, cols_); kernel.setArg(5, other.cols_); - queue_.enqueueNDRangeKernel(kernel, cl::NullRange, cl::NDRange(rows_, cols_)); + const size_t TILE_SIZE = 16; + + // Align global work size to the nearest multiple of TILE_SIZE + size_t global_rows = ((rows_ + TILE_SIZE - 1) / TILE_SIZE) * TILE_SIZE; + size_t global_cols = ((other.numCols() + TILE_SIZE - 1) / TILE_SIZE) * TILE_SIZE; + + cl::NDRange global_work_size(global_rows, global_cols); + cl::NDRange local_work_size(TILE_SIZE, TILE_SIZE); + + queue_.enqueueNDRangeKernel(kernel, cl::NullRange, global_work_size, local_work_size); + + //queue_.enqueueNDRangeKernel(kernel, cl::NullRange, cl::NDRange(rows_, cols_)); } catch (const cl::Error& err) { throw std::runtime_error("OpenCL error during matrix multiplication: " + std::string(err.what()) + " (" + std::to_string(err.err()) + ")"); } -- GitLab