Skip to content
Extraits de code Groupes Projets
Valider c0cd2b9a rédigé par JordanHanotiaux's avatar JordanHanotiaux
Parcourir les fichiers

Update matrix_opencl.cpp

parent 1eb2c355
Aucune branche associée trouvée
Aucune étiquette associée trouvée
Aucune requête de fusion associée trouvée
......@@ -86,46 +86,38 @@ const std::string kernel_source_transpose = R"(
}
)";*/
const std::string kernel_source_matrix_mul = R"(
__kernel void matrix_mul(__global const float* A, __global const float* B, __global float* C, int A_rows, int A_cols, int B_cols) {
int row = get_global_id(0);
int col = get_global_id(1);
__kernel void matrix_mul(__global const float* A,
__global const float* B,
__global float* C,
int M, int K, int N,
__local float* Bwrk) { // <<< nouveau paramètre
int local_row = get_local_id(0);
int local_col = get_local_id(1);
const int TILE_SIZE = 16;
__local float As[TILE_SIZE][TILE_SIZE];
__local float Bs[TILE_SIZE][TILE_SIZE];
float sum = 0.0f;
for (int t = 0; t < (A_cols + TILE_SIZE - 1) / TILE_SIZE; t++) {
int tile_row = t * TILE_SIZE + local_row;
int tile_col = t * TILE_SIZE + local_col;
if (row < A_rows && tile_col < A_cols)
As[local_row][local_col] = A[row * A_cols + tile_col];
else
As[local_row][local_col] = 0.0f;
if (tile_row < A_cols && col < B_cols)
Bs[local_row][local_col] = B[tile_row * B_cols + col];
else
Bs[local_row][local_col] = 0.0f;
barrier(CLK_LOCAL_MEM_FENCE);
for (int k = 0; k < TILE_SIZE; ++k)
sum += As[local_row][k] * Bs[k][local_col];
barrier(CLK_LOCAL_MEM_FENCE);
}
if (row < A_rows && col < B_cols) {
C[row * B_cols + col] = sum;
}
int j, k;
int i = get_global_id(0);
int iloc = get_local_id(0);
int nloc = get_local_size(0);
float tmp;
float Awrk[1024];
for (k = 0; k < K; k++)
Awrk[k] = A[i * K + k];
for (j = 0; j < N; j++) {
for (k = iloc; k < K; k += nloc)
Bwrk[k] = B[k * N + j];
barrier(CLK_LOCAL_MEM_FENCE);
tmp = 0.0f;
for (k = 0; k < K; k++)
tmp += Awrk[k] * Bwrk[k];
C[i * N + j] = tmp;
barrier(CLK_LOCAL_MEM_FENCE);
}
)";
})";
const std::string kernel_source_sigmoid = R"(
__kernel void sigmoid(__global const float* input, __global float* output, int rows, int cols) {
int idx = get_global_id(0);
......@@ -352,7 +344,7 @@ MatrixCL MatrixCL::operator+(const MatrixCL& other) const {
return result;
}
MatrixCL MatrixCL::operator*(const MatrixCL& other) const {
/*MatrixCL MatrixCL::operator*(const MatrixCL& other) const {
if (cols_ != other.rows_)
throw std::runtime_error("Matrix dimension error.");
......@@ -366,18 +358,42 @@ MatrixCL MatrixCL::operator*(const MatrixCL& other) const {
kernel.setArg(4, cols_);
kernel.setArg(5, other.cols_);
const size_t TILE_SIZE = 16;
queue_.enqueueNDRangeKernel(kernel, cl::NullRange, cl::NDRange(rows_, cols_));
} catch (const cl::Error& err) {
throw std::runtime_error("OpenCL error during matrix multiplication: " + std::string(err.what()) + " (" + std::to_string(err.err()) + ")");
}
return result;
}*/
MatrixCL MatrixCL::operator*(const MatrixCL& other) const {
if (cols_ != other.rows_)
throw std::runtime_error("Matrix dimension error.");
// Align global work size to the nearest multiple of TILE_SIZE
size_t global_rows = ((rows_ + TILE_SIZE - 1) / TILE_SIZE) * TILE_SIZE;
size_t global_cols = ((other.numCols() + TILE_SIZE - 1) / TILE_SIZE) * TILE_SIZE;
MatrixCL result(rows_, other.cols_, context_, queue_);
try {
cl::Kernel& kernel = kernels_->kernel_matrix_mul;
int M = rows_;
int K = cols_;
int N = other.cols_;
// 1. Arguments classiques
kernel.setArg(0, buffer_);
kernel.setArg(1, other.buffer_);
kernel.setArg(2, result.buffer_);
kernel.setArg(3, M);
kernel.setArg(4, K);
kernel.setArg(5, N);
cl::NDRange global_work_size(global_rows, global_cols);
cl::NDRange local_work_size(TILE_SIZE, TILE_SIZE);
// 2. Ajoute le buffer local
kernel.setArg(6, cl::Local(sizeof(float) * K)); // Bwrk[K]
queue_.enqueueNDRangeKernel(kernel, cl::NullRange, global_work_size, local_work_size);
// 3. Définis NDRange (1D ici sur M lignes)
cl::NDRange global(M);
cl::NDRange local = cl::NullRange; // Ou un local size calculé
//queue_.enqueueNDRangeKernel(kernel, cl::NullRange, cl::NDRange(rows_, cols_));
queue_.enqueueNDRangeKernel(kernel, cl::NullRange, global, local);
} catch (const cl::Error& err) {
throw std::runtime_error("OpenCL error during matrix multiplication: " + std::string(err.what()) + " (" + std::to_string(err.err()) + ")");
}
......
0% Chargement en cours ou .
You are about to add 0 people to the discussion. Proceed with caution.
Terminez d'abord l'édition de ce message.
Veuillez vous inscrire ou vous pour commenter