From 1754774ed33520dd7571e9e8acbdb5d8837f7c36 Mon Sep 17 00:00:00 2001 From: Victor Joos <victor.joos@uclouvain.be> Date: Tue, 28 Jun 2022 16:01:32 +0200 Subject: [PATCH] cleanup visualizer --- hydra_flash/train.py | 22 +++------------------- hydra_flash/utils.py | 23 +++++++++++++++++++++++ 2 files changed, 26 insertions(+), 19 deletions(-) diff --git a/hydra_flash/train.py b/hydra_flash/train.py index 612d811..6214831 100644 --- a/hydra_flash/train.py +++ b/hydra_flash/train.py @@ -1,11 +1,11 @@ -from attr import has import hydra from hydra.utils import call, instantiate as hydra_instantiate from functools import partial -import matplotlib.pyplot as plt from omegaconf import OmegaConf import logging +from hydra_flash.utils import show_predictions + log = logging.getLogger(__name__) instantiate = partial(hydra_instantiate, _convert_="all") @@ -30,23 +30,7 @@ def main(cfg): predict_datamodule = instantiate(cfg.predict_datamodule) predictions = trainer.predict(model, datamodule=predict_datamodule, output="labels") - for i, (image, pred) in enumerate( - zip(predict_datamodule.predict_dataset, predictions[0]) - ): - image = image["input"] - if hasattr(image, "shape"): - image = image.permute(1, 2, 0) - fig = plt.figure() - plt.imshow(image) - if isinstance(pred, str) or isinstance(pred, list): - plt.title(f"{pred}") - else: - plt.imshow(pred, cmap="tab20", alpha=0.5) - - if cfg.save: - plt.savefig(f"{i}.png") - if cfg.show: - plt.show() + show_predictions(predict_datamodule.predict_dataset, predictions[0]) if __name__ == "__main__": diff --git a/hydra_flash/utils.py b/hydra_flash/utils.py index 2a6b34c..1cb8953 100644 --- a/hydra_flash/utils.py +++ b/hydra_flash/utils.py @@ -1,5 +1,28 @@ from pathlib import Path +import matplotlib.pyplot as plt +import numpy as np def resolver(root, file_id): return Path(root) / f"{file_id}.jpg" + + +def show_predictions(images, predictions, show, save): + for i, (image, pred) in enumerate(zip(images, predictions)): + image = image["input"] + if hasattr(image, "shape"): + image = image.permute(1, 2, 0) + fig = plt.figure() + plt.imshow(image) + if isinstance(pred, str): + plt.title(f"{pred}") + elif isinstance(pred, list): + if np.array(pred).ndim == 2: + plt.imshow(pred, cmap="tab20", alpha=0.5) + else: + plt.title(f"{pred}") + + if save: + plt.savefig(f"{i}.png") + if show: + plt.show() -- GitLab