diff --git a/hydra_flash/conf/config.yaml b/hydra_flash/conf/config.yaml index 815a2e7cc98ff5afb792fe53264431e78f078ce3..7fc55fbcb9f4bd7b28423fb0612ab409c20f3b04 100644 --- a/hydra_flash/conf/config.yaml +++ b/hydra_flash/conf/config.yaml @@ -1,16 +1,2 @@ defaults: - - _self_ - - server: local - - datamodule: hymenoptera - - model: resnet18 - - datamodule@predict_datamodule: hymenoptera_test - -trainer: - _target_: flash.Trainer - max_epochs: 3 - gpus: 1 - -preprocess: - _target_: flash.core.data.utils.download_data - url: "https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip" - path: "./data" + - experiment: hymenoptera diff --git a/hydra_flash/conf/datamodule/hymenoptera_test.yaml b/hydra_flash/conf/datamodule/hymenoptera_test.yaml index 0213bdc36400f2abb201e7f5c92a802071fb0916..39a2c6fcfbab51e6bf07085ff6dfb5a89ce52260 100644 --- a/hydra_flash/conf/datamodule/hymenoptera_test.yaml +++ b/hydra_flash/conf/datamodule/hymenoptera_test.yaml @@ -1,6 +1,6 @@ _target_: flash.image.ImageClassificationData.from_files predict_files: - - "data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg" - - "data/hymenoptera_data/val/bees/590318879_68cf112861.jpg" - - "data/hymenoptera_data/val/ants/540543309_ddbb193ee5.jpg" + - "${preprocess.path}/hymenoptera_data/val/bees/65038344_52a45d090d.jpg" + - "${preprocess.path}/hymenoptera_data/val/bees/590318879_68cf112861.jpg" + - "${preprocess.path}/hymenoptera_data/val/ants/540543309_ddbb193ee5.jpg" batch_size: 3 diff --git a/hydra_flash/conf/model/resnet18.yaml b/hydra_flash/conf/model/resnet18.yaml index 14f88f4f1d6d19dfd50ee06d22c4e64bfa1113c1..7c0f8576c8f1545946dd9d58b423254645692a36 100644 --- a/hydra_flash/conf/model/resnet18.yaml +++ b/hydra_flash/conf/model/resnet18.yaml @@ -1,3 +1,6 @@ _target_: flash.image.ImageClassifier backbone: resnet18 learning_rate: 1.e-3 +complete: + - labels + - multi_label diff --git a/hydra_flash/conf/server/local.yaml b/hydra_flash/conf/server/local.yaml index c6439317590b7cf5a348f4588f9c0e4bb1539cf3..12a894c982e33b5b3559c8a1f70a669e31bbc345 100644 --- a/hydra_flash/conf/server/local.yaml +++ b/hydra_flash/conf/server/local.yaml @@ -1 +1,4 @@ -name: local +# @package _global_ +show: true +save: true +path: "/tmp" diff --git a/hydra_flash/conf/server/mb.yaml b/hydra_flash/conf/server/mb.yaml index 66f02e78c96f11c9e24571ad1be93c04017587fe..be4a2f456f4012d91a846a1cbf0292e76e2d043f 100644 --- a/hydra_flash/conf/server/mb.yaml +++ b/hydra_flash/conf/server/mb.yaml @@ -2,7 +2,11 @@ defaults: - override /hydra/launcher: submitit_slurm -name: manneback +show: false +save: true +preprocess: + path: ${oc.env:TEMPDIR} + gpus: [0] hydra: launcher: diff --git a/hydra_flash/train.py b/hydra_flash/train.py index fa5aa5faa13d7b59efbd505bc009b016e8fb0291..612d811dd7ebaba7421e061bb882b5f07f4fd24c 100644 --- a/hydra_flash/train.py +++ b/hydra_flash/train.py @@ -1,6 +1,9 @@ +from attr import has import hydra from hydra.utils import call, instantiate as hydra_instantiate from functools import partial +import matplotlib.pyplot as plt +from omegaconf import OmegaConf import logging log = logging.getLogger(__name__) @@ -8,21 +11,42 @@ log = logging.getLogger(__name__) instantiate = partial(hydra_instantiate, _convert_="all") -@hydra.main(version_base=None, config_path="conf", config_name="config") +@hydra.main(version_base="1.1", config_path="conf", config_name="config") def main(cfg): # Preprocess : download and/or unzip data call(cfg.preprocess) - + print(OmegaConf.to_container(cfg.datamodule, resolve=True)) datamodule = instantiate(cfg.datamodule) - model = instantiate(cfg.model, labels=datamodule.labels) + model_kwargs = {} + for kw in cfg.model.complete: + model_kwargs[kw] = getattr(datamodule, kw) + del cfg.model.complete + model = instantiate(cfg.model, **model_kwargs) trainer = instantiate(cfg.trainer) trainer.finetune(model, datamodule=datamodule, strategy="freeze") predict_datamodule = instantiate(cfg.predict_datamodule) predictions = trainer.predict(model, datamodule=predict_datamodule, output="labels") - log.info(f"predictions : {predictions}") + + for i, (image, pred) in enumerate( + zip(predict_datamodule.predict_dataset, predictions[0]) + ): + image = image["input"] + if hasattr(image, "shape"): + image = image.permute(1, 2, 0) + fig = plt.figure() + plt.imshow(image) + if isinstance(pred, str) or isinstance(pred, list): + plt.title(f"{pred}") + else: + plt.imshow(pred, cmap="tab20", alpha=0.5) + + if cfg.save: + plt.savefig(f"{i}.png") + if cfg.show: + plt.show() if __name__ == "__main__":