Skip to content
Extraits de code Groupes Projets
Valider 3f71332a rédigé par Victor Joos de ter Beerst's avatar Victor Joos de ter Beerst
Parcourir les fichiers

Back to basics

parent 606aec82
Aucune branche associée trouvée
Aucune étiquette associée trouvée
Aucune requête de fusion associée trouvée
Affichage de
avec 34 ajouts et 129 suppressions
Fichier supprimé
defaults:
- experiment: hymenoptera
preprocess:
_target_: flash.core.data.utils.download_data
url: "https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip"
path: "/tmp"
datamodule:
_target_: flash.image.ImageClassificationData.from_folders
train_folder: ${preprocess.path}/hymenoptera_data/train/
val_folder: ${preprocess.path}/hymenoptera_data/val/
batch_size: 4
num_workers: 12
transform_kwargs:
image_size: [196, 196]
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
model:
_target_: flash.image.ImageClassifier
backbone: resnet18
trainer:
_target_: flash.Trainer
max_epochs: 3
gpus: 1
predict_datamodule:
_target_: flash.image.ImageClassificationData.from_files
predict_files:
- "${preprocess.path}/hymenoptera_data/val/bees/65038344_52a45d090d.jpg"
- "${preprocess.path}/hymenoptera_data/val/bees/590318879_68cf112861.jpg"
- "${preprocess.path}/hymenoptera_data/val/ants/540543309_ddbb193ee5.jpg"
batch_size: 3
show: true
save: false
defaults:
- _self_
- server: local
- preprocess: hymenoptera
- datamodule: hymenoptera
- model: resnet18
- datamodule@predict_datamodule: hymenoptera_test
trainer:
_target_: flash.Trainer
max_epochs: 3
gpus: 1
_target_: flash.image.ImageClassificationData.from_folders
train_folder: ${preprocess.path}/hymenoptera_data/train/
val_folder: ${preprocess.path}/hymenoptera_data/val/
batch_size: 4
num_workers: 12
transform_kwargs:
image_size: [196, 196]
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
_target_: flash.image.ImageClassificationData.from_files
predict_files:
- "${preprocess.path}/hymenoptera_data/val/bees/65038344_52a45d090d.jpg"
- "${preprocess.path}/hymenoptera_data/val/bees/590318879_68cf112861.jpg"
- "${preprocess.path}/hymenoptera_data/val/ants/540543309_ddbb193ee5.jpg"
batch_size: 3
_target_: flash.image.ImageClassifier
backbone: gernet_s
learning_rate: 1.e-3
_target_: flash.image.ImageClassifier
backbone: resnet18
learning_rate: 1.e-3
_target_: flash.core.data.utils.download_data
path: ${path}
defaults:
- download
url: "https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip"
# @package _global_
show: true
save: true
path: "/tmp"
# @package _global_
defaults:
- override /hydra/launcher: submitit_slurm
show: false
save: true
preprocess:
path: ${oc.env:TEMPDIR}
gpus: [0]
hydra:
launcher:
name: hydra_test
timeout_min: 360
cpus_per_task: 12
tasks_per_node: 1
signal_delay_s: 120
max_num_timeout: 20
mem_gb: 32
nodes: 1
array_parallelism: 4
partition: gpu
gres: "gpu:1g.10gb:1"
import hydra
from hydra.utils import call, instantiate as hydra_instantiate
from functools import partial
from omegaconf import OmegaConf
import logging
import torch
import flash
from flash.core.data.utils import download_data
from flash.image import ImageClassificationData, ImageClassifier
from hydra_flash.utils import show_predictions
log = logging.getLogger(__name__)
instantiate = partial(hydra_instantiate, _convert_="all")
@hydra.main(version_base="1.1", config_path="conf", config_name="config2")
def main(cfg):
# Preprocess : download and/or unzip data
call(cfg.preprocess)
# Create datamodule from downloaded files
datamodule = instantiate(cfg.datamodule)
def main():
# 1. Create the DataModule
download_data(
"https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "./data"
)
# Create model
model = instantiate(
cfg.model, labels=datamodule.labels, multi_label=datamodule.multi_label
datamodule = ImageClassificationData.from_folders(
train_folder="data/hymenoptera_data/train/",
val_folder="data/hymenoptera_data/val/",
batch_size=4,
transform_kwargs={
"image_size": (196, 196),
"mean": (0.485, 0.456, 0.406),
"std": (0.229, 0.224, 0.225),
},
)
# Finetune model
trainer = instantiate(cfg.trainer)
trainer.finetune(model, datamodule=datamodule, strategy="freeze")
# 2. Build the task
model = ImageClassifier(backbone="resnet18", labels=datamodule.labels)
# Predict and show predictions
predict_datamodule = instantiate(cfg.predict_datamodule)
predictions = trainer.predict(model, datamodule=predict_datamodule, output="labels")
# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.finetune(model, datamodule=datamodule, strategy="freeze")
show_predictions(
predict_datamodule.predict_dataset, predictions[0], cfg.show, cfg.save
# 4. Predict what's on a few images! ants or bees?
datamodule = ImageClassificationData.from_files(
predict_files=[
"data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg",
"data/hymenoptera_data/val/bees/590318879_68cf112861.jpg",
"data/hymenoptera_data/val/ants/540543309_ddbb193ee5.jpg",
],
batch_size=3,
)
predictions = trainer.predict(model, datamodule=datamodule, output="labels")
show_predictions(datamodule.predict_dataset, predictions[0], show=True, save=False)
if __name__ == "__main__":
......
0% Chargement en cours ou .
You are about to add 0 people to the discussion. Proceed with caution.
Terminez d'abord l'édition de ce message.
Veuillez vous inscrire ou vous pour commenter