diff --git a/CHANGELOG.md b/CHANGELOG.md index a30ddc6530..3227bd9055 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -185,6 +185,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed a bug where `precision=64` with `accelerator='ddp_spawn'` would throw a pickle error ([#6924](https://github.com/PyTorchLightning/pytorch-lightning/pull/6924)) +- Fixed support for `torch.nn.Module` type hints in `LightningCLI` ([#7807](https://github.com/PyTorchLightning/pytorch-lightning/pull/7807)) + + ## [1.3.2] - 2021-05-18 ### Changed diff --git a/requirements/extra.txt b/requirements/extra.txt index c41f464ef3..cb9515beef 100644 --- a/requirements/extra.txt +++ b/requirements/extra.txt @@ -7,4 +7,4 @@ torchtext>=0.5 # onnx>=1.7.0 onnxruntime>=1.3.0 hydra-core>=1.0 -jsonargparse[signatures]>=3.12.0 +jsonargparse[signatures]>=3.13.1 diff --git a/tests/utilities/test_cli.py b/tests/utilities/test_cli.py index 5780a83e75..c1eabca5d6 100644 --- a/tests/utilities/test_cli.py +++ b/tests/utilities/test_cli.py @@ -20,18 +20,26 @@ import sys from argparse import Namespace from contextlib import redirect_stdout from io import StringIO +from typing import List, Optional from unittest import mock import pytest +import torch import yaml +from packaging import version from pytorch_lightning import LightningDataModule, LightningModule, Trainer from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint from pytorch_lightning.plugins.environments import SLURMEnvironment from pytorch_lightning.utilities import _TPU_AVAILABLE from pytorch_lightning.utilities.cli import LightningArgumentParser, LightningCLI, SaveConfigCallback +from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE from tests.helpers import BoringDataModule, BoringModel +torchvision_version = version.parse('0') +if _TORCHVISION_AVAILABLE: + torchvision_version = version.parse(__import__('torchvision').__version__) + @mock.patch('argparse.ArgumentParser.parse_args') def test_default_args(mock_argparse, tmpdir): @@ -443,3 +451,49 @@ def test_lightning_cli_submodules(tmpdir): assert cli.model.submodule2 == cli.config_init['model']['submodule2'] assert isinstance(cli.config_init['model']['submodule1'], BoringModel) assert isinstance(cli.config_init['model']['submodule2'], BoringModel) + + +@pytest.mark.skipif(torchvision_version < version.parse('0.8.0'), reason='torchvision>=0.8.0 is required') +def test_lightning_cli_torch_modules(tmpdir): + + class TestModule(BoringModel): + + def __init__( + self, + activation: torch.nn.Module = None, + transform: Optional[List[torch.nn.Module]] = None, + ): + super().__init__() + self.activation = activation + self.transform = transform + + config = """model: + activation: + class_path: torch.nn.LeakyReLU + init_args: + negative_slope: 0.2 + transform: + - class_path: torchvision.transforms.Resize + init_args: + size: 64 + - class_path: torchvision.transforms.CenterCrop + init_args: + size: 64 + """ + config_path = tmpdir / 'config.yaml' + with open(config_path, 'w') as f: + f.write(config) + + cli_args = [ + f'--trainer.default_root_dir={tmpdir}', + '--trainer.max_epochs=1', + f'--config={str(config_path)}', + ] + + with mock.patch('sys.argv', ['any.py'] + cli_args): + cli = LightningCLI(TestModule) + + assert isinstance(cli.model.activation, torch.nn.LeakyReLU) + assert cli.model.activation.negative_slope == 0.2 + assert len(cli.model.transform) == 2 + assert all(isinstance(v, torch.nn.Module) for v in cli.model.transform)