Skip to content
Extraits de code Groupes Projets
train.py 1,55 ko
Newer Older
  • Learn to ignore specific revisions
  • import logging
    
    import torch
    
    import flash
    from flash.core.data.utils import download_data
    from flash.image import ImageClassificationData, ImageClassifier
    
    from hydra_flash.utils import show_predictions
    
    
    def main():
        # 1. Create the DataModule
        download_data(
            "https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "./data"
        )
    
        datamodule = ImageClassificationData.from_folders(
            train_folder="data/hymenoptera_data/train/",
            val_folder="data/hymenoptera_data/val/",
            batch_size=4,
            transform_kwargs={
                "image_size": (196, 196),
                "mean": (0.485, 0.456, 0.406),
                "std": (0.229, 0.224, 0.225),
            },
    
        # 2. Build the task
        model = ImageClassifier(backbone="resnet18", labels=datamodule.labels)
    
        # 3. Create the trainer and finetune the model
        trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
        trainer.finetune(model, datamodule=datamodule, strategy="freeze")
    
        # 4. Predict what's on a few images! ants or bees?
        datamodule = ImageClassificationData.from_files(
            predict_files=[
                "data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg",
                "data/hymenoptera_data/val/bees/590318879_68cf112861.jpg",
                "data/hymenoptera_data/val/ants/540543309_ddbb193ee5.jpg",
            ],
            batch_size=3,
    
    Victor Joos de ter Beerst's avatar
    Victor Joos de ter Beerst a validé
        )
    
        predictions = trainer.predict(model, datamodule=datamodule, output="labels")
    
        show_predictions(datamodule.predict_dataset, predictions[0], show=True, save=False)