Skip to content
Extraits de code Groupes Projets
Valider e65619c7 rédigé par Victor Joos de ter Beerst's avatar Victor Joos de ter Beerst
Parcourir les fichiers

comments

parent e1206f3a
Aucune branche associée trouvée
Aucune étiquette associée trouvée
Aucune requête de fusion associée trouvée
......@@ -16,17 +16,21 @@ def main(cfg):
# Preprocess : download and/or unzip data
call(cfg.preprocess)
# Create datamodule from downloaded files
datamodule = instantiate(cfg.datamodule)
# Create model
model_kwargs = {}
for kw in cfg.model.complete:
model_kwargs[kw] = getattr(datamodule, kw)
del cfg.model.complete
model = instantiate(cfg.model, **model_kwargs)
# Finetune model
trainer = instantiate(cfg.trainer)
trainer.finetune(model, datamodule=datamodule, strategy="freeze")
# Predict and show predictions
predict_datamodule = instantiate(cfg.predict_datamodule)
predictions = trainer.predict(model, datamodule=predict_datamodule, output="labels")
......
0% Chargement en cours ou .
You are about to add 0 people to the discussion. Proceed with caution.
Terminez d'abord l'édition de ce message.
Veuillez vous inscrire ou vous pour commenter