diff --git a/hydra_flash/train.py b/hydra_flash/train.py
index 612d811dd7ebaba7421e061bb882b5f07f4fd24c..62148319a2f31c147a347d5360a43fab010ffc34 100644
--- a/hydra_flash/train.py
+++ b/hydra_flash/train.py
@@ -1,11 +1,11 @@
-from attr import has
 import hydra
 from hydra.utils import call, instantiate as hydra_instantiate
 from functools import partial
-import matplotlib.pyplot as plt
 from omegaconf import OmegaConf
 import logging
 
+from hydra_flash.utils import show_predictions
+
 log = logging.getLogger(__name__)
 
 instantiate = partial(hydra_instantiate, _convert_="all")
@@ -30,23 +30,7 @@ def main(cfg):
     predict_datamodule = instantiate(cfg.predict_datamodule)
     predictions = trainer.predict(model, datamodule=predict_datamodule, output="labels")
 
-    for i, (image, pred) in enumerate(
-        zip(predict_datamodule.predict_dataset, predictions[0])
-    ):
-        image = image["input"]
-        if hasattr(image, "shape"):
-            image = image.permute(1, 2, 0)
-        fig = plt.figure()
-        plt.imshow(image)
-        if isinstance(pred, str) or isinstance(pred, list):
-            plt.title(f"{pred}")
-        else:
-            plt.imshow(pred, cmap="tab20", alpha=0.5)
-
-        if cfg.save:
-            plt.savefig(f"{i}.png")
-        if cfg.show:
-            plt.show()
+    show_predictions(predict_datamodule.predict_dataset, predictions[0])
 
 
 if __name__ == "__main__":
diff --git a/hydra_flash/utils.py b/hydra_flash/utils.py
index 2a6b34c30425bc7147e05d77451541190c51b89f..1cb8953b998ac93dff563ec749950094145c5bfa 100644
--- a/hydra_flash/utils.py
+++ b/hydra_flash/utils.py
@@ -1,5 +1,28 @@
 from pathlib import Path
+import matplotlib.pyplot as plt
+import numpy as np
 
 
 def resolver(root, file_id):
     return Path(root) / f"{file_id}.jpg"
+
+
+def show_predictions(images, predictions, show, save):
+    for i, (image, pred) in enumerate(zip(images, predictions)):
+        image = image["input"]
+        if hasattr(image, "shape"):
+            image = image.permute(1, 2, 0)
+        fig = plt.figure()
+        plt.imshow(image)
+        if isinstance(pred, str):
+            plt.title(f"{pred}")
+        elif isinstance(pred, list):
+            if np.array(pred).ndim == 2:
+                plt.imshow(pred, cmap="tab20", alpha=0.5)
+            else:
+                plt.title(f"{pred}")
+
+        if save:
+            plt.savefig(f"{i}.png")
+        if show:
+            plt.show()