diff --git a/hydra_flash/train.py b/hydra_flash/train.py
index b384567e0823b1152c624d1b0303c394f2466b1b..7fe5e085950e98b7491fb777226b8488bc64dac1 100644
--- a/hydra_flash/train.py
+++ b/hydra_flash/train.py
@@ -16,17 +16,21 @@ def main(cfg):
     # Preprocess : download and/or unzip data
     call(cfg.preprocess)
 
+    # Create datamodule from downloaded files
     datamodule = instantiate(cfg.datamodule)
 
+    # Create model
     model_kwargs = {}
     for kw in cfg.model.complete:
         model_kwargs[kw] = getattr(datamodule, kw)
     del cfg.model.complete
     model = instantiate(cfg.model, **model_kwargs)
 
+    # Finetune model
     trainer = instantiate(cfg.trainer)
     trainer.finetune(model, datamodule=datamodule, strategy="freeze")
 
+    # Predict and show predictions
     predict_datamodule = instantiate(cfg.predict_datamodule)
     predictions = trainer.predict(model, datamodule=predict_datamodule, output="labels")