diff --git a/matrix_opencl.cpp b/matrix_opencl.cpp index a2689fa3c557f375c7deedd75f1d7a252ae07d61..4a4c1f8ceb7655e8e882f8428ad62de7d2490ffa 100644 --- a/matrix_opencl.cpp +++ b/matrix_opencl.cpp @@ -366,7 +366,18 @@ MatrixCL MatrixCL::operator*(const MatrixCL& other) const { kernel.setArg(4, cols_); kernel.setArg(5, other.cols_); - queue_.enqueueNDRangeKernel(kernel, cl::NullRange, cl::NDRange(rows_, cols_)); + const size_t TILE_SIZE = 16; + + // Align global work size to the nearest multiple of TILE_SIZE + size_t global_rows = ((rows_ + TILE_SIZE - 1) / TILE_SIZE) * TILE_SIZE; + size_t global_cols = ((other.numCols() + TILE_SIZE - 1) / TILE_SIZE) * TILE_SIZE; + + cl::NDRange global_work_size(global_rows, global_cols); + cl::NDRange local_work_size(TILE_SIZE, TILE_SIZE); + + queue_.enqueueNDRangeKernel(kernel, cl::NullRange, global_work_size, local_work_size); + + //queue_.enqueueNDRangeKernel(kernel, cl::NullRange, cl::NDRange(rows_, cols_)); } catch (const cl::Error& err) { throw std::runtime_error("OpenCL error during matrix multiplication: " + std::string(err.what()) + " (" + std::to_string(err.err()) + ")"); }