diff --git a/matrix_opencl.cpp b/matrix_opencl.cpp index 47c6e778ef169b7a36473a8fa3eb7287cc76c7bc..a2689fa3c557f375c7deedd75f1d7a252ae07d61 100644 --- a/matrix_opencl.cpp +++ b/matrix_opencl.cpp @@ -86,19 +86,43 @@ const std::string kernel_source_transpose = R"( } )";*/ const std::string kernel_source_matrix_mul = R"( - __kernel void matrix_mul(__global const float* A,__global const float* B, __global float* C, int M, int K, int N) { - int i = get_global_id(0); - float Awrk[1024]; - for (int k = 0; k < K; ++k) { - Awrk[k] = A[i * K + k]; - } + __kernel void matrix_mul(__global const float* A, __global const float* B, __global float* C, int A_rows, int A_cols, int B_cols) { + int row = get_global_id(0); + int col = get_global_id(1); - for (int j = 0; j < N; ++j) { - float tmp = 0.0f; - for (int k = 0; k < K; ++k) { - tmp += Awrk[k] * B[k * N + j]; - } - C[i * N + j] = tmp; + int local_row = get_local_id(0); + int local_col = get_local_id(1); + + const int TILE_SIZE = 16; + + __local float As[TILE_SIZE][TILE_SIZE]; + __local float Bs[TILE_SIZE][TILE_SIZE]; + + float sum = 0.0f; + + for (int t = 0; t < (A_cols + TILE_SIZE - 1) / TILE_SIZE; t++) { + int tile_row = t * TILE_SIZE + local_row; + int tile_col = t * TILE_SIZE + local_col; + + if (row < A_rows && tile_col < A_cols) + As[local_row][local_col] = A[row * A_cols + tile_col]; + else + As[local_row][local_col] = 0.0f; + + if (tile_row < A_cols && col < B_cols) + Bs[local_row][local_col] = B[tile_row * B_cols + col]; + else + Bs[local_row][local_col] = 0.0f; + + barrier(CLK_LOCAL_MEM_FENCE); + + for (int k = 0; k < TILE_SIZE; ++k) + sum += As[local_row][k] * Bs[k][local_col]; + + barrier(CLK_LOCAL_MEM_FENCE); + } + if (row < A_rows && col < B_cols) { + C[row * B_cols + col] = sum; } } )";