diff --git a/hydra_flash/train.py b/hydra_flash/train.py index 62148319a2f31c147a347d5360a43fab010ffc34..90fc12415313ba54abf2653d890b0f26d5ddeaad 100644 --- a/hydra_flash/train.py +++ b/hydra_flash/train.py @@ -30,7 +30,9 @@ def main(cfg): predict_datamodule = instantiate(cfg.predict_datamodule) predictions = trainer.predict(model, datamodule=predict_datamodule, output="labels") - show_predictions(predict_datamodule.predict_dataset, predictions[0]) + show_predictions( + predict_datamodule.predict_dataset, predictions[0], cfg.show, cfg.save + ) if __name__ == "__main__":