Skip to content
Extraits de code Groupes Projets
Valider 5f13c54b rédigé par JordanHanotiaux's avatar JordanHanotiaux
Parcourir les fichiers

Update matrix_opencl.cpp

parent 433dfadf
Aucune branche associée trouvée
Aucune étiquette associée trouvée
Aucune requête de fusion associée trouvée
......@@ -76,7 +76,7 @@ const std::string kernel_source_transpose = R"(
B[output_idx] = A[input_idx];
}
)";
const std::string kernel_source_matrix_mul = R"(
/*const std::string kernel_source_matrix_mul = R"(
__kernel void matrix_mul(__global const float* A, __global const float* B, __global float* C, int A_rows, int A_cols, int B_cols) {
int row = get_global_id(0);
int col = get_global_id(1);
......@@ -84,24 +84,48 @@ const std::string kernel_source_matrix_mul = R"(
C[row * B_cols + col] += A[row * A_cols + k] * B[k * B_cols + col];
}
}
)";
/*const std::string kernel_source_matrix_mul = R"(
__kernel void matrix_mul(__global const float* A,__global const float* B, __global float* C, int M, int K, int N) {
int i = get_global_id(0);
float Awrk[1024];
for (int k = 0; k < K; ++k) {
Awrk[k] = A[i * K + k];
}
)";*/
const std::string kernel_source_matrix_mul = R"(
__kernel void fast_matrix_mul(__global const float* A, __global const float* B, __global float* C, int A_rows, int A_cols, int B_cols) {
int row = get_global_id(0);
int col = get_global_id(1);
for (int j = 0; j < N; ++j) {
float tmp = 0.0f;
for (int k = 0; k < K; ++k) {
tmp += Awrk[k] * B[k * N + j];
}
C[i * N + j] = tmp;
int local_row = get_local_id(0);
int local_col = get_local_id(1);
const int TILE_SIZE = 16;
__local float As[TILE_SIZE][TILE_SIZE];
__local float Bs[TILE_SIZE][TILE_SIZE];
float sum = 0.0f;
for (int t = 0; t < (A_cols + TILE_SIZE - 1) / TILE_SIZE; t++) {
int tile_row = t * TILE_SIZE + local_row;
int tile_col = t * TILE_SIZE + local_col;
if (row < A_rows && tile_col < A_cols)
As[local_row][local_col] = A[row * A_cols + tile_col];
else
As[local_row][local_col] = 0.0f;
if (tile_row < A_cols && col < B_cols)
Bs[local_row][local_col] = B[tile_row * B_cols + col];
else
Bs[local_row][local_col] = 0.0f;
barrier(CLK_LOCAL_MEM_FENCE);
for (int k = 0; k < TILE_SIZE; ++k)
sum += As[local_row][k] * Bs[k][local_col];
barrier(CLK_LOCAL_MEM_FENCE);
}
if (row < A_rows && col < B_cols) {
C[row * B_cols + col] = sum;
}
}
)";*/
)";
const std::string kernel_source_sigmoid = R"(
__kernel void sigmoid(__global const float* input, __global float* output, int rows, int cols) {
int idx = get_global_id(0);
......
0% Chargement en cours ou .
You are about to add 0 people to the discussion. Proceed with caution.
Terminez d'abord l'édition de ce message.
Veuillez vous inscrire ou vous pour commenter