From b0cd9daf25d8422f4fb4406dbbf3543b47c27945 Mon Sep 17 00:00:00 2001 From: Mauricio Villegas Date: Thu, 29 Apr 2021 01:52:28 +0200 Subject: [PATCH] Simplify backbone_image_classifier example (#7246) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos Mocholí Co-authored-by: Adrian Wälchli Co-authored-by: ananthsub --- .../backbone_image_classifier.py | 19 ++++++------------- requirements/extra.txt | 2 +- 2 files changed, 7 insertions(+), 14 deletions(-) diff --git a/pl_examples/basic_examples/backbone_image_classifier.py b/pl_examples/basic_examples/backbone_image_classifier.py index b20ffffcdd..53a24dfdb2 100644 --- a/pl_examples/basic_examples/backbone_image_classifier.py +++ b/pl_examples/basic_examples/backbone_image_classifier.py @@ -17,6 +17,7 @@ MNIST backbone image classifier example. To run: python backbone_image_classifier.py --trainer.max_epochs=50 """ +from typing import Optional import torch from torch.nn import functional as F @@ -66,11 +67,13 @@ class LitClassifier(pl.LightningModule): def __init__( self, - backbone, + backbone: Optional[Backbone] = None, learning_rate: float = 0.0001, ): super().__init__() - self.save_hyperparameters() + self.save_hyperparameters(ignore=['backbone']) + if backbone is None: + backbone = Backbone() self.backbone = backbone def forward(self, x): @@ -124,18 +127,8 @@ class MyDataModule(pl.LightningDataModule): return DataLoader(self.mnist_test, batch_size=self.batch_size) -class MyLightningCLI(LightningCLI): - - def add_arguments_to_parser(self, parser): - parser.add_class_arguments(Backbone, 'model.backbone') - - def instantiate_model(self): - self.config_init['model']['backbone'] = Backbone(**self.config['model']['backbone']) - super().instantiate_model() - - def cli_main(): - cli = MyLightningCLI(LitClassifier, MyDataModule, seed_everything_default=1234) + cli = LightningCLI(LitClassifier, MyDataModule, seed_everything_default=1234) result = cli.trainer.test(cli.model, datamodule=cli.datamodule) print(result) diff --git a/requirements/extra.txt b/requirements/extra.txt index 98c4948125..89b16b1095 100644 --- a/requirements/extra.txt +++ b/requirements/extra.txt @@ -7,4 +7,4 @@ torchtext>=0.5 # onnx>=1.7.0 onnxruntime>=1.3.0 hydra-core>=1.0 -jsonargparse[signatures]>=3.10.1 +jsonargparse[signatures]>=3.11.0