From e65619c79d12af3d85944332fed942a183f5cf27 Mon Sep 17 00:00:00 2001 From: Victor Joos <victor.joos@uclouvain.be> Date: Tue, 28 Jun 2022 16:05:06 +0200 Subject: [PATCH] comments --- hydra_flash/train.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/hydra_flash/train.py b/hydra_flash/train.py index b384567..7fe5e08 100644 --- a/hydra_flash/train.py +++ b/hydra_flash/train.py @@ -16,17 +16,21 @@ def main(cfg): # Preprocess : download and/or unzip data call(cfg.preprocess) + # Create datamodule from downloaded files datamodule = instantiate(cfg.datamodule) + # Create model model_kwargs = {} for kw in cfg.model.complete: model_kwargs[kw] = getattr(datamodule, kw) del cfg.model.complete model = instantiate(cfg.model, **model_kwargs) + # Finetune model trainer = instantiate(cfg.trainer) trainer.finetune(model, datamodule=datamodule, strategy="freeze") + # Predict and show predictions predict_datamodule = instantiate(cfg.predict_datamodule) predictions = trainer.predict(model, datamodule=predict_datamodule, output="labels") -- GitLab