Simplify backbone_image_classifier example (#7246)

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: ananthsub <ananth.subramaniam@gmail.com>
This commit is contained in:
Mauricio Villegas 2021-04-29 01:52:28 +02:00 committed by GitHub
parent 7a48db591e
commit b0cd9daf25
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 7 additions and 14 deletions

View File

@ -17,6 +17,7 @@ MNIST backbone image classifier example.
To run:
python backbone_image_classifier.py --trainer.max_epochs=50
"""
from typing import Optional
import torch
from torch.nn import functional as F
@ -66,11 +67,13 @@ class LitClassifier(pl.LightningModule):
def __init__(
self,
backbone,
backbone: Optional[Backbone] = None,
learning_rate: float = 0.0001,
):
super().__init__()
self.save_hyperparameters()
self.save_hyperparameters(ignore=['backbone'])
if backbone is None:
backbone = Backbone()
self.backbone = backbone
def forward(self, x):
@ -124,18 +127,8 @@ class MyDataModule(pl.LightningDataModule):
return DataLoader(self.mnist_test, batch_size=self.batch_size)
class MyLightningCLI(LightningCLI):
def add_arguments_to_parser(self, parser):
parser.add_class_arguments(Backbone, 'model.backbone')
def instantiate_model(self):
self.config_init['model']['backbone'] = Backbone(**self.config['model']['backbone'])
super().instantiate_model()
def cli_main():
cli = MyLightningCLI(LitClassifier, MyDataModule, seed_everything_default=1234)
cli = LightningCLI(LitClassifier, MyDataModule, seed_everything_default=1234)
result = cli.trainer.test(cli.model, datamodule=cli.datamodule)
print(result)

View File

@ -7,4 +7,4 @@ torchtext>=0.5
# onnx>=1.7.0
onnxruntime>=1.3.0
hydra-core>=1.0
jsonargparse[signatures]>=3.10.1
jsonargparse[signatures]>=3.11.0