Skip to content
Extraits de code Groupes Projets
Valider a101ebdd rédigé par JordanHanotiaux's avatar JordanHanotiaux
Parcourir les fichiers

update

parent 84e5bc64
Aucune branche associée trouvée
Aucune étiquette associée trouvée
Aucune requête de fusion associée trouvée
......@@ -216,20 +216,9 @@ DistributedMatrix multiply(const Matrix& left, const DistributedMatrix& right) {
Matrix DistributedMatrix::multiplyTransposed(const DistributedMatrix &other) const {
Matrix result = (*this).localData * other.getLocalData().transpose();
int localSize = result.numRows() * result.numCols();
std::vector<int> counts(numProcesses);
std::vector<int> displacements(numProcesses);
MPI_Allgather(&localSize, 1, MPI_INT, counts.data(), 1, MPI_INT, MPI_COMM_WORLD);
displacements[0] = 0;
for (int i = 1; i < numProcesses; ++i)
displacements[i] = displacements[i - 1] + counts[i - 1];
std::vector<double> buffer(this->globalRows * other.globalRows);
MPI_Allreduce(
result.getData().data(),
buffer.data(),
......@@ -250,6 +239,13 @@ Matrix DistributedMatrix::multiplyTransposed(const DistributedMatrix &other) con
return fullMatrix;
}
void sync_matrix(Matrix *matrix, int rank, int src) {
// Diffuse la matrice de manière synchrone du processus `src` vers tous les autres processus.
MPI_Bcast(matrix->getData().data(), matrix->numRows() * matrix->numCols(), MPI_DOUBLE, src, MPI_COMM_WORLD);
}
......
......@@ -430,6 +430,63 @@ void testMultiplyTransposed() {
}
}
void test_distributed_mlp_training()
{
int rank, size;
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
MPI_Comm_size(MPI_COMM_WORLD, &size);
// Print info about the MPI environment
if (rank == 0) {
std::cout << "Running with " << size << " MPI processes." << std::endl;
}
// Create a simple XOR dataset
// Create data
Matrix X(3, 4);
Matrix Y(1, 4);
// 0
X.set(0, 0, 0.0);
X.set(1, 0, 0.0);
X.set(2, 0, 1.0);
Y.set(0, 0, 0.0);
// 1
X.set(0, 1, 0.0);
X.set(1, 1, 1.0);
X.set(2, 1, 1.0);
Y.set(0, 1, 1.0);
// 2
X.set(0, 2, 1.0);
X.set(1, 2, 0.0);
X.set(2, 2, 1.0);
Y.set(0, 2, 1.0);
// 3
X.set(0, 3, 1.0);
X.set(1, 3, 1.0);
X.set(2, 3, 1.0);
Y.set(0, 3, 0.0);
// Distribute the data
Dataset data = Dataset(DistributedMatrix(X, size), DistributedMatrix(Y, size));
// Create and train the model
MLP model(3, 128, 1, 0.1);
if (rank == 0) {
std::cout << "Training distributed MLP for XOR problem..." << std::endl;
}
model.train(data, 5000);
if (rank==0) {
std::cout << "Distributed MLP training test completed." << std::endl;
}
}
int main(int argc, char** argv) {
// Initialize MPI
int initialized;
......@@ -457,7 +514,7 @@ int main(int argc, char** argv) {
testGather();
testGetAndSet();
testCopyConstructor();
//test_distributed_mlp_training();
test_distributed_mlp_training();
if (rank == 0) {
std::cout << "All tests passed successfully!" << std::endl;
......
0% Chargement en cours ou .
You are about to add 0 people to the discussion. Proceed with caution.
Terminez d'abord l'édition de ce message.
Veuillez vous inscrire ou vous pour commenter