Simplify backbone_image_classifier example (#7246)
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: ananthsub <ananth.subramaniam@gmail.com>
This commit is contained in:
parent
7a48db591e
commit
b0cd9daf25
|
@ -17,6 +17,7 @@ MNIST backbone image classifier example.
|
|||
To run:
|
||||
python backbone_image_classifier.py --trainer.max_epochs=50
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
@ -66,11 +67,13 @@ class LitClassifier(pl.LightningModule):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
backbone,
|
||||
backbone: Optional[Backbone] = None,
|
||||
learning_rate: float = 0.0001,
|
||||
):
|
||||
super().__init__()
|
||||
self.save_hyperparameters()
|
||||
self.save_hyperparameters(ignore=['backbone'])
|
||||
if backbone is None:
|
||||
backbone = Backbone()
|
||||
self.backbone = backbone
|
||||
|
||||
def forward(self, x):
|
||||
|
@ -124,18 +127,8 @@ class MyDataModule(pl.LightningDataModule):
|
|||
return DataLoader(self.mnist_test, batch_size=self.batch_size)
|
||||
|
||||
|
||||
class MyLightningCLI(LightningCLI):
|
||||
|
||||
def add_arguments_to_parser(self, parser):
|
||||
parser.add_class_arguments(Backbone, 'model.backbone')
|
||||
|
||||
def instantiate_model(self):
|
||||
self.config_init['model']['backbone'] = Backbone(**self.config['model']['backbone'])
|
||||
super().instantiate_model()
|
||||
|
||||
|
||||
def cli_main():
|
||||
cli = MyLightningCLI(LitClassifier, MyDataModule, seed_everything_default=1234)
|
||||
cli = LightningCLI(LitClassifier, MyDataModule, seed_everything_default=1234)
|
||||
result = cli.trainer.test(cli.model, datamodule=cli.datamodule)
|
||||
print(result)
|
||||
|
||||
|
|
|
@ -7,4 +7,4 @@ torchtext>=0.5
|
|||
# onnx>=1.7.0
|
||||
onnxruntime>=1.3.0
|
||||
hydra-core>=1.0
|
||||
jsonargparse[signatures]>=3.10.1
|
||||
jsonargparse[signatures]>=3.11.0
|
||||
|
|
Loading…
Reference in New Issue