diff --git a/matrix_opencl.cpp b/matrix_opencl.cpp
index 47c6e778ef169b7a36473a8fa3eb7287cc76c7bc..a2689fa3c557f375c7deedd75f1d7a252ae07d61 100644
--- a/matrix_opencl.cpp
+++ b/matrix_opencl.cpp
@@ -86,19 +86,43 @@ const std::string kernel_source_transpose = R"(
     }
 )";*/
 const std::string kernel_source_matrix_mul = R"(
-    __kernel void matrix_mul(__global const float* A,__global const float* B, __global float* C, int M, int K, int N) {
-        int i = get_global_id(0);
-        float Awrk[1024];
-        for (int k = 0; k < K; ++k) {
-            Awrk[k] = A[i * K + k];
-        }
+     __kernel void matrix_mul(__global const float* A, __global const float* B, __global float* C, int A_rows, int A_cols, int B_cols) {
+        int row = get_global_id(0);
+        int col = get_global_id(1);
 
-        for (int j = 0; j < N; ++j) {
-            float tmp = 0.0f;
-            for (int k = 0; k < K; ++k) {
-                tmp += Awrk[k] * B[k * N + j];
-            }
-            C[i * N + j] = tmp;
+        int local_row = get_local_id(0);
+        int local_col = get_local_id(1);
+
+        const int TILE_SIZE = 16; 
+
+        __local float As[TILE_SIZE][TILE_SIZE];
+        __local float Bs[TILE_SIZE][TILE_SIZE];
+
+        float sum = 0.0f;
+        
+        for (int t = 0; t < (A_cols + TILE_SIZE - 1) / TILE_SIZE; t++) {
+            int tile_row = t * TILE_SIZE + local_row;
+            int tile_col = t * TILE_SIZE + local_col;
+
+            if (row < A_rows && tile_col < A_cols)
+                As[local_row][local_col] = A[row * A_cols + tile_col];
+            else
+                As[local_row][local_col] = 0.0f;
+            
+            if (tile_row < A_cols && col < B_cols)
+                Bs[local_row][local_col] = B[tile_row * B_cols + col];
+            else
+                Bs[local_row][local_col] = 0.0f;
+            
+            barrier(CLK_LOCAL_MEM_FENCE);
+            
+            for (int k = 0; k < TILE_SIZE; ++k)
+                sum += As[local_row][k] * Bs[k][local_col];
+            
+            barrier(CLK_LOCAL_MEM_FENCE);
+        }
+        if (row < A_rows && col < B_cols) {
+            C[row * B_cols + col] = sum;
         }
     }
 )";