Skip to content
Extraits de code Groupes Projets
distributedtests.cpp 2,8 ko
Newer Older
  • Learn to ignore specific revisions
  • JordanHanotiaux's avatar
    JordanHanotiaux a validé
    #include "distributedmatrix.hpp"
    #include "matrix.hpp"
    #include "mlp_sgd_distributed.cpp"
    #include <mpi.h>
    #include <iostream>
    #include <cassert>
    #include <cmath>
    #include <functional>
    
    
    JordanHanotiaux's avatar
    JordanHanotiaux a validé
    
    
    
    JordanHanotiaux's avatar
    JordanHanotiaux a validé
    // Helper function to test if two doubles are approximately equal
    bool approxEqual(double a, double b, double epsilon = 1e-10) {
        return std::abs(a - b) < epsilon;
    }
    
    // Helper function to test if two matrices are approximately equal
    bool matricesEqual(const Matrix& a, const Matrix& b, double epsilon = 1e-10) {
        if (a.numRows() != b.numRows() || a.numCols() != b.numCols()) {
            return false;
        }
        
        for (int i = 0; i < a.numRows(); i++) {
            for (int j = 0; j < a.numCols(); j++) {
                if (!approxEqual(a.get(i, j), b.get(i, j), epsilon)) {
                    return false;
                }
            }
        }
        
        return true;
    }
    
    
    JordanHanotiaux's avatar
    JordanHanotiaux a validé
    // Test multiplyTransposed
    void testMultiplyTransposed() {
        int rank;
        MPI_Comm_rank(MPI_COMM_WORLD, &rank);
        
        // Create test matrices
    
    JordanHanotiaux's avatar
    JordanHanotiaux a validé
        Matrix matrix1Full(256, 256);
        Matrix matrix2Full(256, 256);
        for (int i = 0; i < 256; i++) {
            for (int j = 0; j < 256; j++) {
    
    JordanHanotiaux's avatar
    JordanHanotiaux a validé
                matrix1Full.set(i, j, i * 5 + j + 1);
            }
        }
    
    JordanHanotiaux's avatar
    JordanHanotiaux a validé
        for (int i = 0; i < 256; i++) {
            for (int j = 0; j < 256; j++) {
    
    JordanHanotiaux's avatar
    JordanHanotiaux a validé
                matrix2Full.set(i, j, i * 5 + j + 2);
            }
        }
        
        // Create distributed matrices
        int numProcs;
        MPI_Comm_size(MPI_COMM_WORLD, &numProcs);
        DistributedMatrix matrix1(matrix1Full, numProcs);
        DistributedMatrix matrix2(matrix2Full, numProcs);
        
        // Compute expected result
        Matrix expectedMatrix = matrix1Full * matrix2Full.transpose();
    
    
        // Compute A * B^T
        Matrix result = matrix1.multiplyTransposed(matrix2);
    
    JordanHanotiaux's avatar
    JordanHanotiaux a validé
        
        // Check
        assert(matricesEqual(result, expectedMatrix, 1e-8));
        
        if (rank == 0) {
            std::cout << "MultiplyTransposed test passed!" << std::endl;
        }
    }
    
    
    JordanHanotiaux's avatar
    JordanHanotiaux a validé
    int main(int argc, char** argv) {
        // Initialize MPI
        int initialized;
        MPI_Initialized(&initialized);
        if (!initialized) {
            MPI_Init(&argc, &argv);
        }
        
        int rank;
        MPI_Comm_rank(MPI_COMM_WORLD, &rank);
        
        if (rank == 0) {
            std::cout << "Starting DistributedMatrix tests..." << std::endl;
        }
        
        try {
            // Run tests
    
    JordanHanotiaux's avatar
    JordanHanotiaux a validé
            testMultiplyTransposed();
    
    JordanHanotiaux's avatar
    JordanHanotiaux a validé
            
            if (rank == 0) {
                std::cout << "All tests passed successfully!" << std::endl;
            }
        } 
        catch (std::exception& e) {
            if (rank == 0) {
                std::cerr << "Test failed with exception: " << e.what() << std::endl;
            }
            MPI_Abort(MPI_COMM_WORLD, 1);
        }
        
        // Finalize MPI if we initialized it
        // int finalized;
        // MPI_Finalized(&finalized);
        // if (!finalized && initialized) {
        //     MPI_Finalize();
        // }
    
        MPI_Finalize();
        return 0;
    }