Newer
Older
import flash
from flash.core.data.utils import download_data
from flash.image import ImageClassificationData, ImageClassifier
from hydra_flash.utils import show_predictions
def main():
# 1. Create the DataModule
download_data(
"https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "./data"
)
datamodule = ImageClassificationData.from_folders(
train_folder="data/hymenoptera_data/train/",
val_folder="data/hymenoptera_data/val/",
batch_size=4,
transform_kwargs={
"image_size": (196, 196),
"mean": (0.485, 0.456, 0.406),
"std": (0.229, 0.224, 0.225),
},
# 2. Build the task
model = ImageClassifier(backbone="resnet18", labels=datamodule.labels)
# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.finetune(model, datamodule=datamodule, strategy="freeze")
# 4. Predict what's on a few images! ants or bees?
datamodule = ImageClassificationData.from_files(
predict_files=[
"data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg",
"data/hymenoptera_data/val/bees/590318879_68cf112861.jpg",
"data/hymenoptera_data/val/ants/540543309_ddbb193ee5.jpg",
],
batch_size=3,
predictions = trainer.predict(model, datamodule=datamodule, output="labels")
show_predictions(datamodule.predict_dataset, predictions[0], show=True, save=False)
if __name__ == "__main__":
main()