From 1754774ed33520dd7571e9e8acbdb5d8837f7c36 Mon Sep 17 00:00:00 2001
From: Victor Joos <victor.joos@uclouvain.be>
Date: Tue, 28 Jun 2022 16:01:32 +0200
Subject: [PATCH] cleanup visualizer

---
 hydra_flash/train.py | 22 +++-------------------
 hydra_flash/utils.py | 23 +++++++++++++++++++++++
 2 files changed, 26 insertions(+), 19 deletions(-)

diff --git a/hydra_flash/train.py b/hydra_flash/train.py
index 612d811..6214831 100644
--- a/hydra_flash/train.py
+++ b/hydra_flash/train.py
@@ -1,11 +1,11 @@
-from attr import has
 import hydra
 from hydra.utils import call, instantiate as hydra_instantiate
 from functools import partial
-import matplotlib.pyplot as plt
 from omegaconf import OmegaConf
 import logging
 
+from hydra_flash.utils import show_predictions
+
 log = logging.getLogger(__name__)
 
 instantiate = partial(hydra_instantiate, _convert_="all")
@@ -30,23 +30,7 @@ def main(cfg):
     predict_datamodule = instantiate(cfg.predict_datamodule)
     predictions = trainer.predict(model, datamodule=predict_datamodule, output="labels")
 
-    for i, (image, pred) in enumerate(
-        zip(predict_datamodule.predict_dataset, predictions[0])
-    ):
-        image = image["input"]
-        if hasattr(image, "shape"):
-            image = image.permute(1, 2, 0)
-        fig = plt.figure()
-        plt.imshow(image)
-        if isinstance(pred, str) or isinstance(pred, list):
-            plt.title(f"{pred}")
-        else:
-            plt.imshow(pred, cmap="tab20", alpha=0.5)
-
-        if cfg.save:
-            plt.savefig(f"{i}.png")
-        if cfg.show:
-            plt.show()
+    show_predictions(predict_datamodule.predict_dataset, predictions[0])
 
 
 if __name__ == "__main__":
diff --git a/hydra_flash/utils.py b/hydra_flash/utils.py
index 2a6b34c..1cb8953 100644
--- a/hydra_flash/utils.py
+++ b/hydra_flash/utils.py
@@ -1,5 +1,28 @@
 from pathlib import Path
+import matplotlib.pyplot as plt
+import numpy as np
 
 
 def resolver(root, file_id):
     return Path(root) / f"{file_id}.jpg"
+
+
+def show_predictions(images, predictions, show, save):
+    for i, (image, pred) in enumerate(zip(images, predictions)):
+        image = image["input"]
+        if hasattr(image, "shape"):
+            image = image.permute(1, 2, 0)
+        fig = plt.figure()
+        plt.imshow(image)
+        if isinstance(pred, str):
+            plt.title(f"{pred}")
+        elif isinstance(pred, list):
+            if np.array(pred).ndim == 2:
+                plt.imshow(pred, cmap="tab20", alpha=0.5)
+            else:
+                plt.title(f"{pred}")
+
+        if save:
+            plt.savefig(f"{i}.png")
+        if show:
+            plt.show()
-- 
GitLab