diff --git a/hydra_flash/conf/config.yaml b/hydra_flash/conf/config.yaml
index 815a2e7cc98ff5afb792fe53264431e78f078ce3..7fc55fbcb9f4bd7b28423fb0612ab409c20f3b04 100644
--- a/hydra_flash/conf/config.yaml
+++ b/hydra_flash/conf/config.yaml
@@ -1,16 +1,2 @@
 defaults:
-  - _self_
-  - server: local
-  - datamodule: hymenoptera
-  - model: resnet18
-  - datamodule@predict_datamodule: hymenoptera_test
-
-trainer:
-  _target_: flash.Trainer
-  max_epochs: 3
-  gpus: 1
-
-preprocess:
-  _target_: flash.core.data.utils.download_data
-  url: "https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip"
-  path: "./data"
+  - experiment: hymenoptera
diff --git a/hydra_flash/conf/datamodule/hymenoptera_test.yaml b/hydra_flash/conf/datamodule/hymenoptera_test.yaml
index 0213bdc36400f2abb201e7f5c92a802071fb0916..39a2c6fcfbab51e6bf07085ff6dfb5a89ce52260 100644
--- a/hydra_flash/conf/datamodule/hymenoptera_test.yaml
+++ b/hydra_flash/conf/datamodule/hymenoptera_test.yaml
@@ -1,6 +1,6 @@
 _target_: flash.image.ImageClassificationData.from_files
 predict_files:
-  - "data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg"
-  - "data/hymenoptera_data/val/bees/590318879_68cf112861.jpg"
-  - "data/hymenoptera_data/val/ants/540543309_ddbb193ee5.jpg"
+  - "${preprocess.path}/hymenoptera_data/val/bees/65038344_52a45d090d.jpg"
+  - "${preprocess.path}/hymenoptera_data/val/bees/590318879_68cf112861.jpg"
+  - "${preprocess.path}/hymenoptera_data/val/ants/540543309_ddbb193ee5.jpg"
 batch_size: 3
diff --git a/hydra_flash/conf/model/resnet18.yaml b/hydra_flash/conf/model/resnet18.yaml
index 14f88f4f1d6d19dfd50ee06d22c4e64bfa1113c1..7c0f8576c8f1545946dd9d58b423254645692a36 100644
--- a/hydra_flash/conf/model/resnet18.yaml
+++ b/hydra_flash/conf/model/resnet18.yaml
@@ -1,3 +1,6 @@
 _target_: flash.image.ImageClassifier
 backbone: resnet18
 learning_rate: 1.e-3
+complete:
+  - labels
+  - multi_label
diff --git a/hydra_flash/conf/server/local.yaml b/hydra_flash/conf/server/local.yaml
index c6439317590b7cf5a348f4588f9c0e4bb1539cf3..12a894c982e33b5b3559c8a1f70a669e31bbc345 100644
--- a/hydra_flash/conf/server/local.yaml
+++ b/hydra_flash/conf/server/local.yaml
@@ -1 +1,4 @@
-name: local
+# @package _global_
+show: true
+save: true
+path: "/tmp"
diff --git a/hydra_flash/conf/server/mb.yaml b/hydra_flash/conf/server/mb.yaml
index 66f02e78c96f11c9e24571ad1be93c04017587fe..be4a2f456f4012d91a846a1cbf0292e76e2d043f 100644
--- a/hydra_flash/conf/server/mb.yaml
+++ b/hydra_flash/conf/server/mb.yaml
@@ -2,7 +2,11 @@
 defaults:
   - override /hydra/launcher: submitit_slurm
 
-name: manneback
+show: false
+save: true
+preprocess:
+  path: ${oc.env:TEMPDIR}
+
 gpus: [0]
 hydra:
   launcher:
diff --git a/hydra_flash/train.py b/hydra_flash/train.py
index fa5aa5faa13d7b59efbd505bc009b016e8fb0291..612d811dd7ebaba7421e061bb882b5f07f4fd24c 100644
--- a/hydra_flash/train.py
+++ b/hydra_flash/train.py
@@ -1,6 +1,9 @@
+from attr import has
 import hydra
 from hydra.utils import call, instantiate as hydra_instantiate
 from functools import partial
+import matplotlib.pyplot as plt
+from omegaconf import OmegaConf
 import logging
 
 log = logging.getLogger(__name__)
@@ -8,21 +11,42 @@ log = logging.getLogger(__name__)
 instantiate = partial(hydra_instantiate, _convert_="all")
 
 
-@hydra.main(version_base=None, config_path="conf", config_name="config")
+@hydra.main(version_base="1.1", config_path="conf", config_name="config")
 def main(cfg):
     # Preprocess : download and/or unzip data
     call(cfg.preprocess)
-
+    print(OmegaConf.to_container(cfg.datamodule, resolve=True))
     datamodule = instantiate(cfg.datamodule)
 
-    model = instantiate(cfg.model, labels=datamodule.labels)
+    model_kwargs = {}
+    for kw in cfg.model.complete:
+        model_kwargs[kw] = getattr(datamodule, kw)
+    del cfg.model.complete
+    model = instantiate(cfg.model, **model_kwargs)
 
     trainer = instantiate(cfg.trainer)
     trainer.finetune(model, datamodule=datamodule, strategy="freeze")
 
     predict_datamodule = instantiate(cfg.predict_datamodule)
     predictions = trainer.predict(model, datamodule=predict_datamodule, output="labels")
-    log.info(f"predictions : {predictions}")
+
+    for i, (image, pred) in enumerate(
+        zip(predict_datamodule.predict_dataset, predictions[0])
+    ):
+        image = image["input"]
+        if hasattr(image, "shape"):
+            image = image.permute(1, 2, 0)
+        fig = plt.figure()
+        plt.imshow(image)
+        if isinstance(pred, str) or isinstance(pred, list):
+            plt.title(f"{pred}")
+        else:
+            plt.imshow(pred, cmap="tab20", alpha=0.5)
+
+        if cfg.save:
+            plt.savefig(f"{i}.png")
+        if cfg.show:
+            plt.show()
 
 
 if __name__ == "__main__":