diff --git a/hydra_flash/conf/__init__.py b/hydra_flash/conf/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/hydra_flash/conf/__pycache__/__init__.cpython-39.pyc b/hydra_flash/conf/__pycache__/__init__.cpython-39.pyc deleted file mode 100644 index ca90e872b9d96c841278572014626f0fe998d336..0000000000000000000000000000000000000000 Binary files a/hydra_flash/conf/__pycache__/__init__.cpython-39.pyc and /dev/null differ diff --git a/hydra_flash/conf/config.yaml b/hydra_flash/conf/config.yaml deleted file mode 100644 index 7fc55fbcb9f4bd7b28423fb0612ab409c20f3b04..0000000000000000000000000000000000000000 --- a/hydra_flash/conf/config.yaml +++ /dev/null @@ -1,2 +0,0 @@ -defaults: - - experiment: hymenoptera diff --git a/hydra_flash/conf/config1.yaml b/hydra_flash/conf/config1.yaml deleted file mode 100644 index c83aafc41051b28ab978fd5fd603a1583584a90c..0000000000000000000000000000000000000000 --- a/hydra_flash/conf/config1.yaml +++ /dev/null @@ -1,35 +0,0 @@ -preprocess: - _target_: flash.core.data.utils.download_data - url: "https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip" - path: "/tmp" - -datamodule: - _target_: flash.image.ImageClassificationData.from_folders - train_folder: ${preprocess.path}/hymenoptera_data/train/ - val_folder: ${preprocess.path}/hymenoptera_data/val/ - batch_size: 4 - num_workers: 12 - transform_kwargs: - image_size: [196, 196] - mean: [0.485, 0.456, 0.406] - std: [0.229, 0.224, 0.225] - -model: - _target_: flash.image.ImageClassifier - backbone: resnet18 - -trainer: - _target_: flash.Trainer - max_epochs: 3 - gpus: 1 - -predict_datamodule: - _target_: flash.image.ImageClassificationData.from_files - predict_files: - - "${preprocess.path}/hymenoptera_data/val/bees/65038344_52a45d090d.jpg" - - "${preprocess.path}/hymenoptera_data/val/bees/590318879_68cf112861.jpg" - - "${preprocess.path}/hymenoptera_data/val/ants/540543309_ddbb193ee5.jpg" - batch_size: 3 - -show: true -save: false diff --git a/hydra_flash/conf/config2.yaml b/hydra_flash/conf/config2.yaml deleted file mode 100644 index 7de4f6293f0d2728c981ecf24e62483d1eecafb7..0000000000000000000000000000000000000000 --- a/hydra_flash/conf/config2.yaml +++ /dev/null @@ -1,12 +0,0 @@ -defaults: - - _self_ - - server: local - - preprocess: hymenoptera - - datamodule: hymenoptera - - model: resnet18 - - datamodule@predict_datamodule: hymenoptera_test - -trainer: - _target_: flash.Trainer - max_epochs: 3 - gpus: 1 diff --git a/hydra_flash/conf/datamodule/__init__.py b/hydra_flash/conf/datamodule/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/hydra_flash/conf/datamodule/hymenoptera.yaml b/hydra_flash/conf/datamodule/hymenoptera.yaml deleted file mode 100644 index a834992d8a6afe2272f90dfb25b92ec8ed233613..0000000000000000000000000000000000000000 --- a/hydra_flash/conf/datamodule/hymenoptera.yaml +++ /dev/null @@ -1,9 +0,0 @@ -_target_: flash.image.ImageClassificationData.from_folders -train_folder: ${preprocess.path}/hymenoptera_data/train/ -val_folder: ${preprocess.path}/hymenoptera_data/val/ -batch_size: 4 -num_workers: 12 -transform_kwargs: - image_size: [196, 196] - mean: [0.485, 0.456, 0.406] - std: [0.229, 0.224, 0.225] diff --git a/hydra_flash/conf/datamodule/hymenoptera_test.yaml b/hydra_flash/conf/datamodule/hymenoptera_test.yaml deleted file mode 100644 index 39a2c6fcfbab51e6bf07085ff6dfb5a89ce52260..0000000000000000000000000000000000000000 --- a/hydra_flash/conf/datamodule/hymenoptera_test.yaml +++ /dev/null @@ -1,6 +0,0 @@ -_target_: flash.image.ImageClassificationData.from_files -predict_files: - - "${preprocess.path}/hymenoptera_data/val/bees/65038344_52a45d090d.jpg" - - "${preprocess.path}/hymenoptera_data/val/bees/590318879_68cf112861.jpg" - - "${preprocess.path}/hymenoptera_data/val/ants/540543309_ddbb193ee5.jpg" -batch_size: 3 diff --git a/hydra_flash/conf/model/__init__.py b/hydra_flash/conf/model/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/hydra_flash/conf/model/gernet.yaml b/hydra_flash/conf/model/gernet.yaml deleted file mode 100644 index 400426c35f22ba6ebdbfc4995bf96b7e85e8d212..0000000000000000000000000000000000000000 --- a/hydra_flash/conf/model/gernet.yaml +++ /dev/null @@ -1,3 +0,0 @@ -_target_: flash.image.ImageClassifier -backbone: gernet_s -learning_rate: 1.e-3 diff --git a/hydra_flash/conf/model/resnet18.yaml b/hydra_flash/conf/model/resnet18.yaml deleted file mode 100644 index 14f88f4f1d6d19dfd50ee06d22c4e64bfa1113c1..0000000000000000000000000000000000000000 --- a/hydra_flash/conf/model/resnet18.yaml +++ /dev/null @@ -1,3 +0,0 @@ -_target_: flash.image.ImageClassifier -backbone: resnet18 -learning_rate: 1.e-3 diff --git a/hydra_flash/conf/preprocess/__init__.py b/hydra_flash/conf/preprocess/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/hydra_flash/conf/preprocess/download.yaml b/hydra_flash/conf/preprocess/download.yaml deleted file mode 100644 index d0692a41d80e34216bd59f378b9c6b99e4c08a14..0000000000000000000000000000000000000000 --- a/hydra_flash/conf/preprocess/download.yaml +++ /dev/null @@ -1,2 +0,0 @@ -_target_: flash.core.data.utils.download_data -path: ${path} diff --git a/hydra_flash/conf/preprocess/hymenoptera.yaml b/hydra_flash/conf/preprocess/hymenoptera.yaml deleted file mode 100644 index b5d48bef8b9368b6598a9719986f089c21f953f3..0000000000000000000000000000000000000000 --- a/hydra_flash/conf/preprocess/hymenoptera.yaml +++ /dev/null @@ -1,4 +0,0 @@ -defaults: - - download - -url: "https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip" diff --git a/hydra_flash/conf/server/__init__.py b/hydra_flash/conf/server/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/hydra_flash/conf/server/local.yaml b/hydra_flash/conf/server/local.yaml deleted file mode 100644 index 12a894c982e33b5b3559c8a1f70a669e31bbc345..0000000000000000000000000000000000000000 --- a/hydra_flash/conf/server/local.yaml +++ /dev/null @@ -1,4 +0,0 @@ -# @package _global_ -show: true -save: true -path: "/tmp" diff --git a/hydra_flash/conf/server/mb.yaml b/hydra_flash/conf/server/mb.yaml deleted file mode 100644 index be4a2f456f4012d91a846a1cbf0292e76e2d043f..0000000000000000000000000000000000000000 --- a/hydra_flash/conf/server/mb.yaml +++ /dev/null @@ -1,23 +0,0 @@ -# @package _global_ -defaults: - - override /hydra/launcher: submitit_slurm - -show: false -save: true -preprocess: - path: ${oc.env:TEMPDIR} - -gpus: [0] -hydra: - launcher: - name: hydra_test - timeout_min: 360 - cpus_per_task: 12 - tasks_per_node: 1 - signal_delay_s: 120 - max_num_timeout: 20 - mem_gb: 32 - nodes: 1 - array_parallelism: 4 - partition: gpu - gres: "gpu:1g.10gb:1" diff --git a/hydra_flash/train.py b/hydra_flash/train.py index 5770a5fe37c3029e5294a6bce282fb1b1c642ff0..2ba922b96e1c66cc0e29c98295aa30830a6520ce 100644 --- a/hydra_flash/train.py +++ b/hydra_flash/train.py @@ -1,40 +1,48 @@ -import hydra -from hydra.utils import call, instantiate as hydra_instantiate -from functools import partial -from omegaconf import OmegaConf import logging +import torch +import flash +from flash.core.data.utils import download_data +from flash.image import ImageClassificationData, ImageClassifier from hydra_flash.utils import show_predictions -log = logging.getLogger(__name__) -instantiate = partial(hydra_instantiate, _convert_="all") - - -@hydra.main(version_base="1.1", config_path="conf", config_name="config2") -def main(cfg): - # Preprocess : download and/or unzip data - call(cfg.preprocess) - - # Create datamodule from downloaded files - datamodule = instantiate(cfg.datamodule) +def main(): + # 1. Create the DataModule + download_data( + "https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "./data" + ) - # Create model - model = instantiate( - cfg.model, labels=datamodule.labels, multi_label=datamodule.multi_label + datamodule = ImageClassificationData.from_folders( + train_folder="data/hymenoptera_data/train/", + val_folder="data/hymenoptera_data/val/", + batch_size=4, + transform_kwargs={ + "image_size": (196, 196), + "mean": (0.485, 0.456, 0.406), + "std": (0.229, 0.224, 0.225), + }, ) - # Finetune model - trainer = instantiate(cfg.trainer) - trainer.finetune(model, datamodule=datamodule, strategy="freeze") + # 2. Build the task + model = ImageClassifier(backbone="resnet18", labels=datamodule.labels) - # Predict and show predictions - predict_datamodule = instantiate(cfg.predict_datamodule) - predictions = trainer.predict(model, datamodule=predict_datamodule, output="labels") + # 3. Create the trainer and finetune the model + trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count()) + trainer.finetune(model, datamodule=datamodule, strategy="freeze") - show_predictions( - predict_datamodule.predict_dataset, predictions[0], cfg.show, cfg.save + # 4. Predict what's on a few images! ants or bees? + datamodule = ImageClassificationData.from_files( + predict_files=[ + "data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg", + "data/hymenoptera_data/val/bees/590318879_68cf112861.jpg", + "data/hymenoptera_data/val/ants/540543309_ddbb193ee5.jpg", + ], + batch_size=3, ) + predictions = trainer.predict(model, datamodule=datamodule, output="labels") + + show_predictions(datamodule.predict_dataset, predictions[0], show=True, save=False) if __name__ == "__main__":