diff --git a/matrix_opencl.cpp b/matrix_opencl.cpp index c367b4539fb36ede1797719e7229fb6b2ae6fd8d..190932ea5c7fe7c9581d1ba64457cfb6b00f6430 100644 --- a/matrix_opencl.cpp +++ b/matrix_opencl.cpp @@ -85,7 +85,7 @@ const std::string kernel_source_matrix_mul = R"( } } )"; -const std::string kernel_source_matrix_mul_V2 = R"( +const std::string kernel_source_fast_matrix_mul = R"( __kernel void fast_matrix_mul(__global const float* A, __global const float* B, __global float* C, int A_rows, int A_cols, int B_cols) { int row = get_global_id(0); int col = get_global_id(1); @@ -191,8 +191,8 @@ void KernelCache::compileKernels(cl::Context context, const std::vector<cl::Devi cl::Program prog_matrix_mul = loadAndBuildProgram(context, devices, kernel_source_matrix_mul, "matrix_mul"); kernel_matrix_mul = cl::Kernel(prog_matrix_mul, "matrix_mul"); - cl::Program prog_matrix_mul_V2 = loadAndBuildProgram(context, devices, kernel_source_matrix_mul_V2, "matrix_mul_V2"); - kernel_matrix_mul_v2 = cl::Kernel(prog_matrix_mul_V2, "matrix_mul_V2"); + cl::Program prog_fast_matrix_mul = loadAndBuildProgram(context, devices, kernel_source_fast_matrix_mul, "fast_matrix_mul"); + kernel_fast_matrix_mul = cl::Kernel(prog_fast_matrix_mul, "fast_matrix_mul"); cl::Program prog_sigmoid = loadAndBuildProgram(context, devices, kernel_source_sigmoid, "sigmoid"); kernel_sigmoid = cl::Kernel(prog_sigmoid, "sigmoid"); @@ -377,10 +377,10 @@ MatrixCL MatrixCL::operator*(const MatrixCL& other) const { return result; } -MatrixCL MatrixCL::matrix_mul_V2(const MatrixCL& other) const { +MatrixCL MatrixCL::fast_matrix_mul(const MatrixCL& other) const { MatrixCL result(rows_, other.numCols(), context_, queue_); - cl::Kernel kernel = kernels_->kernel_matrix_mul_v2; + cl::Kernel kernel = kernels_->kernel_fast_matrix_mul; kernel.setArg(0, buffer_); kernel.setArg(1, other.getBuffer()); kernel.setArg(2, result.getBuffer()); diff --git a/matrix_opencl.hpp b/matrix_opencl.hpp index 6e9d5dfa9929aa5d59374d4833c3334aaced9d68..07c1bab5bef2f8315cd3b98dea33d0d17bc85f24 100644 --- a/matrix_opencl.hpp +++ b/matrix_opencl.hpp @@ -25,7 +25,7 @@ struct KernelCache { cl::Kernel kernel_sub_mul; cl::Kernel kernel_transpose; cl::Kernel kernel_matrix_mul; - cl::Kernel kernel_matrix_mul_v2; + cl::Kernel kernel_fast_matrix_mul; cl::Kernel kernel_sigmoid; cl::Kernel kernel_sigmoid_backward; cl::Kernel kernel_bce_elementwise; @@ -94,7 +94,8 @@ public: // Matrix multiplication: C = A * B MatrixCL operator*(const MatrixCL& other) const; - MatrixCL matrix_mul_V2(const MatrixCL& other) const; + // Fast matrix multiplication: C = A * B (optimized for large matrices) + MatrixCL fast_matrix_mul(const MatrixCL& other) const; // Transpose: returns a new Matrix that is the transpose (B = A^T) MatrixCL transpose() const;