Skip to content
Extraits de code Groupes Projets
Valider 998db5d1 rédigé par JordanHanotiaux's avatar JordanHanotiaux
Parcourir les fichiers

Update matrix_opencl.cpp

parent 22bca3f8
Aucune branche associée trouvée
Aucune étiquette associée trouvée
Aucune requête de fusion associée trouvée
...@@ -76,6 +76,7 @@ const std::string kernel_source_transpose = R"( ...@@ -76,6 +76,7 @@ const std::string kernel_source_transpose = R"(
B[output_idx] = A[input_idx]; B[output_idx] = A[input_idx];
} }
)"; )";
// NAIVE
/*const std::string kernel_source_matrix_mul = R"( /*const std::string kernel_source_matrix_mul = R"(
__kernel void matrix_mul(__global const float* A, __global const float* B, __global float* C, int A_rows, int A_cols, int B_cols) { __kernel void matrix_mul(__global const float* A, __global const float* B, __global float* C, int A_rows, int A_cols, int B_cols) {
int row = get_global_id(0); int row = get_global_id(0);
...@@ -85,6 +86,8 @@ const std::string kernel_source_transpose = R"( ...@@ -85,6 +86,8 @@ const std::string kernel_source_transpose = R"(
} }
} }
)";*/ )";*/
// FASTER
const std::string kernel_source_matrix_mul = R"( const std::string kernel_source_matrix_mul = R"(
__kernel void matrix_mul(__global const float* A, __kernel void matrix_mul(__global const float* A,
__global const float* B, __global const float* B,
...@@ -98,7 +101,7 @@ const std::string kernel_source_matrix_mul = R"( ...@@ -98,7 +101,7 @@ const std::string kernel_source_matrix_mul = R"(
int nloc = get_local_size(0); int nloc = get_local_size(0);
float tmp; float tmp;
float Awrk[10000]; float Awrk[4096];
for (k = 0; k < K; k++) for (k = 0; k < K; k++)
Awrk[k] = A[i * K + k]; Awrk[k] = A[i * K + k];
...@@ -344,6 +347,7 @@ MatrixCL MatrixCL::operator+(const MatrixCL& other) const { ...@@ -344,6 +347,7 @@ MatrixCL MatrixCL::operator+(const MatrixCL& other) const {
return result; return result;
} }
// NAIVE VERSION
/*MatrixCL MatrixCL::operator*(const MatrixCL& other) const { /*MatrixCL MatrixCL::operator*(const MatrixCL& other) const {
if (cols_ != other.rows_) if (cols_ != other.rows_)
throw std::runtime_error("Matrix dimension error."); throw std::runtime_error("Matrix dimension error.");
...@@ -366,6 +370,7 @@ MatrixCL MatrixCL::operator+(const MatrixCL& other) const { ...@@ -366,6 +370,7 @@ MatrixCL MatrixCL::operator+(const MatrixCL& other) const {
return result; return result;
}*/ }*/
// FASTER VERSION
MatrixCL MatrixCL::operator*(const MatrixCL& other) const { MatrixCL MatrixCL::operator*(const MatrixCL& other) const {
if (cols_ != other.rows_) if (cols_ != other.rows_)
throw std::runtime_error("Matrix dimension error."); throw std::runtime_error("Matrix dimension error.");
......
0% Chargement en cours ou .
You are about to add 0 people to the discussion. Proceed with caution.
Terminez d'abord l'édition de ce message.
Veuillez vous inscrire ou vous pour commenter