diff --git a/hydra_flash/train.py b/hydra_flash/train.py index 612d811dd7ebaba7421e061bb882b5f07f4fd24c..62148319a2f31c147a347d5360a43fab010ffc34 100644 --- a/hydra_flash/train.py +++ b/hydra_flash/train.py @@ -1,11 +1,11 @@ -from attr import has import hydra from hydra.utils import call, instantiate as hydra_instantiate from functools import partial -import matplotlib.pyplot as plt from omegaconf import OmegaConf import logging +from hydra_flash.utils import show_predictions + log = logging.getLogger(__name__) instantiate = partial(hydra_instantiate, _convert_="all") @@ -30,23 +30,7 @@ def main(cfg): predict_datamodule = instantiate(cfg.predict_datamodule) predictions = trainer.predict(model, datamodule=predict_datamodule, output="labels") - for i, (image, pred) in enumerate( - zip(predict_datamodule.predict_dataset, predictions[0]) - ): - image = image["input"] - if hasattr(image, "shape"): - image = image.permute(1, 2, 0) - fig = plt.figure() - plt.imshow(image) - if isinstance(pred, str) or isinstance(pred, list): - plt.title(f"{pred}") - else: - plt.imshow(pred, cmap="tab20", alpha=0.5) - - if cfg.save: - plt.savefig(f"{i}.png") - if cfg.show: - plt.show() + show_predictions(predict_datamodule.predict_dataset, predictions[0]) if __name__ == "__main__": diff --git a/hydra_flash/utils.py b/hydra_flash/utils.py index 2a6b34c30425bc7147e05d77451541190c51b89f..1cb8953b998ac93dff563ec749950094145c5bfa 100644 --- a/hydra_flash/utils.py +++ b/hydra_flash/utils.py @@ -1,5 +1,28 @@ from pathlib import Path +import matplotlib.pyplot as plt +import numpy as np def resolver(root, file_id): return Path(root) / f"{file_id}.jpg" + + +def show_predictions(images, predictions, show, save): + for i, (image, pred) in enumerate(zip(images, predictions)): + image = image["input"] + if hasattr(image, "shape"): + image = image.permute(1, 2, 0) + fig = plt.figure() + plt.imshow(image) + if isinstance(pred, str): + plt.title(f"{pred}") + elif isinstance(pred, list): + if np.array(pred).ndim == 2: + plt.imshow(pred, cmap="tab20", alpha=0.5) + else: + plt.title(f"{pred}") + + if save: + plt.savefig(f"{i}.png") + if show: + plt.show()