diff --git a/hydra_flash/train.py b/hydra_flash/train.py
index 62148319a2f31c147a347d5360a43fab010ffc34..90fc12415313ba54abf2653d890b0f26d5ddeaad 100644
--- a/hydra_flash/train.py
+++ b/hydra_flash/train.py
@@ -30,7 +30,9 @@ def main(cfg):
     predict_datamodule = instantiate(cfg.predict_datamodule)
     predictions = trainer.predict(model, datamodule=predict_datamodule, output="labels")
 
-    show_predictions(predict_datamodule.predict_dataset, predictions[0])
+    show_predictions(
+        predict_datamodule.predict_dataset, predictions[0], cfg.show, cfg.save
+    )
 
 
 if __name__ == "__main__":