Newer
Older
#include "distributedmatrix.hpp"
#include "matrix.hpp"
#include "mlp_sgd_distributed.cpp"
#include <mpi.h>
#include <iostream>
#include <cassert>
#include <cmath>
#include <functional>
// Helper function to test if two doubles are approximately equal
bool approxEqual(double a, double b, double epsilon = 1e-10) {
return std::abs(a - b) < epsilon;
}
// Helper function to test if two matrices are approximately equal
bool matricesEqual(const Matrix& a, const Matrix& b, double epsilon = 1e-10) {
if (a.numRows() != b.numRows() || a.numCols() != b.numCols()) {
return false;
}
for (int i = 0; i < a.numRows(); i++) {
for (int j = 0; j < a.numCols(); j++) {
if (!approxEqual(a.get(i, j), b.get(i, j), epsilon)) {
return false;
}
}
}
return true;
}
// Test multiplyTransposed
void testMultiplyTransposed() {
int rank;
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
// Create test matrices
Matrix matrix1Full(256, 256);
Matrix matrix2Full(256, 256);
for (int i = 0; i < 256; i++) {
for (int j = 0; j < 256; j++) {
for (int i = 0; i < 256; i++) {
for (int j = 0; j < 256; j++) {
matrix2Full.set(i, j, i * 5 + j + 2);
}
}
// Create distributed matrices
int numProcs;
MPI_Comm_size(MPI_COMM_WORLD, &numProcs);
DistributedMatrix matrix1(matrix1Full, numProcs);
DistributedMatrix matrix2(matrix2Full, numProcs);
// Compute expected result
Matrix expectedMatrix = matrix1Full * matrix2Full.transpose();
// Compute A * B^T
Matrix result = matrix1.multiplyTransposed(matrix2);
// Check
assert(matricesEqual(result, expectedMatrix, 1e-8));
if (rank == 0) {
std::cout << "MultiplyTransposed test passed!" << std::endl;
}
}
int main(int argc, char** argv) {
// Initialize MPI
int initialized;
MPI_Initialized(&initialized);
if (!initialized) {
MPI_Init(&argc, &argv);
}
int rank;
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
if (rank == 0) {
std::cout << "Starting DistributedMatrix tests..." << std::endl;
}
try {
// Run tests
if (rank == 0) {
std::cout << "All tests passed successfully!" << std::endl;
}
}
catch (std::exception& e) {
if (rank == 0) {
std::cerr << "Test failed with exception: " << e.what() << std::endl;
}
MPI_Abort(MPI_COMM_WORLD, 1);
}
// Finalize MPI if we initialized it
// int finalized;
// MPI_Finalized(&finalized);
// if (!finalized && initialized) {
// MPI_Finalize();
// }
MPI_Finalize();
return 0;
}