From 013155420c89f623f5d6510287e57eeb74d6a133 Mon Sep 17 00:00:00 2001 From: JordanHanotiaux <103147288+JordanHanotiaux@users.noreply.github.com> Date: Mon, 19 May 2025 10:06:55 +0200 Subject: [PATCH] up --- matrix_opencl.cpp | 10 +++++----- matrix_opencl.hpp | 5 +++-- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/matrix_opencl.cpp b/matrix_opencl.cpp index c367b45..190932e 100644 --- a/matrix_opencl.cpp +++ b/matrix_opencl.cpp @@ -85,7 +85,7 @@ const std::string kernel_source_matrix_mul = R"( } } )"; -const std::string kernel_source_matrix_mul_V2 = R"( +const std::string kernel_source_fast_matrix_mul = R"( __kernel void fast_matrix_mul(__global const float* A, __global const float* B, __global float* C, int A_rows, int A_cols, int B_cols) { int row = get_global_id(0); int col = get_global_id(1); @@ -191,8 +191,8 @@ void KernelCache::compileKernels(cl::Context context, const std::vector<cl::Devi cl::Program prog_matrix_mul = loadAndBuildProgram(context, devices, kernel_source_matrix_mul, "matrix_mul"); kernel_matrix_mul = cl::Kernel(prog_matrix_mul, "matrix_mul"); - cl::Program prog_matrix_mul_V2 = loadAndBuildProgram(context, devices, kernel_source_matrix_mul_V2, "matrix_mul_V2"); - kernel_matrix_mul_v2 = cl::Kernel(prog_matrix_mul_V2, "matrix_mul_V2"); + cl::Program prog_fast_matrix_mul = loadAndBuildProgram(context, devices, kernel_source_fast_matrix_mul, "fast_matrix_mul"); + kernel_fast_matrix_mul = cl::Kernel(prog_fast_matrix_mul, "fast_matrix_mul"); cl::Program prog_sigmoid = loadAndBuildProgram(context, devices, kernel_source_sigmoid, "sigmoid"); kernel_sigmoid = cl::Kernel(prog_sigmoid, "sigmoid"); @@ -377,10 +377,10 @@ MatrixCL MatrixCL::operator*(const MatrixCL& other) const { return result; } -MatrixCL MatrixCL::matrix_mul_V2(const MatrixCL& other) const { +MatrixCL MatrixCL::fast_matrix_mul(const MatrixCL& other) const { MatrixCL result(rows_, other.numCols(), context_, queue_); - cl::Kernel kernel = kernels_->kernel_matrix_mul_v2; + cl::Kernel kernel = kernels_->kernel_fast_matrix_mul; kernel.setArg(0, buffer_); kernel.setArg(1, other.getBuffer()); kernel.setArg(2, result.getBuffer()); diff --git a/matrix_opencl.hpp b/matrix_opencl.hpp index 6e9d5df..07c1bab 100644 --- a/matrix_opencl.hpp +++ b/matrix_opencl.hpp @@ -25,7 +25,7 @@ struct KernelCache { cl::Kernel kernel_sub_mul; cl::Kernel kernel_transpose; cl::Kernel kernel_matrix_mul; - cl::Kernel kernel_matrix_mul_v2; + cl::Kernel kernel_fast_matrix_mul; cl::Kernel kernel_sigmoid; cl::Kernel kernel_sigmoid_backward; cl::Kernel kernel_bce_elementwise; @@ -94,7 +94,8 @@ public: // Matrix multiplication: C = A * B MatrixCL operator*(const MatrixCL& other) const; - MatrixCL matrix_mul_V2(const MatrixCL& other) const; + // Fast matrix multiplication: C = A * B (optimized for large matrices) + MatrixCL fast_matrix_mul(const MatrixCL& other) const; // Transpose: returns a new Matrix that is the transpose (B = A^T) MatrixCL transpose() const; -- GitLab