diff --git a/hydra_flash/train.py b/hydra_flash/train.py index b384567e0823b1152c624d1b0303c394f2466b1b..7fe5e085950e98b7491fb777226b8488bc64dac1 100644 --- a/hydra_flash/train.py +++ b/hydra_flash/train.py @@ -16,17 +16,21 @@ def main(cfg): # Preprocess : download and/or unzip data call(cfg.preprocess) + # Create datamodule from downloaded files datamodule = instantiate(cfg.datamodule) + # Create model model_kwargs = {} for kw in cfg.model.complete: model_kwargs[kw] = getattr(datamodule, kw) del cfg.model.complete model = instantiate(cfg.model, **model_kwargs) + # Finetune model trainer = instantiate(cfg.trainer) trainer.finetune(model, datamodule=datamodule, strategy="freeze") + # Predict and show predictions predict_datamodule = instantiate(cfg.predict_datamodule) predictions = trainer.predict(model, datamodule=predict_datamodule, output="labels")