lightning/examples/app_hpo/pl_script.py

44 lines
1.4 KiB
Python
Raw Normal View History

import argparse
import os
import pandas as pd
import torch
from flash import Trainer
from flash.image import ImageClassificationData, ImageClassifier
# Parse arguments provided by the Work.
parser = argparse.ArgumentParser()
parser.add_argument("--train_data_path", type=str, required=True)
parser.add_argument("--submission_path", type=str, required=True)
parser.add_argument("--test_data_path", type=str, required=True)
parser.add_argument("--best_model_path", type=str, required=True)
# Optional
parser.add_argument("--backbone", type=str, default="resnet18")
parser.add_argument("--learning_rate", type=float, default=0.01)
args = parser.parse_args()
datamodule = ImageClassificationData.from_folders(
train_folder=args.train_data_path,
batch_size=8,
)
model = ImageClassifier(datamodule.num_classes, backbone=args.backbone)
trainer = Trainer(fast_dev_run=True)
trainer.fit(model, datamodule=datamodule)
trainer.save_checkpoint(args.best_model_path)
datamodule = ImageClassificationData.from_folders(
predict_folder=args.test_data_path,
batch_size=8,
)
predictions = Trainer().predict(model, datamodule=datamodule)
submission_data = [
{"filename": os.path.basename(p["metadata"]["filepath"]), "label": torch.argmax(p["preds"]).item()}
for batch in predictions
for p in batch
]
df = pd.DataFrame(submission_data)
df.to_csv(args.submission_path, index=False)