Fix support for torch Module type hints in LightningCLI (#7807)

* Fixed support for torch Module type hints in LightningCLI

* - Fix issue with serializing values when type hint is Any.
- Run unit test only on newer torchvision versions in which the base class is Module.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Minor change

* Update CHANGELOG.md

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
Mauricio Villegas 2021-06-04 07:43:43 +02:00 committed by GitHub
parent 36770b22fd
commit f34584001c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 58 additions and 1 deletions

View File

@ -185,6 +185,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed a bug where `precision=64` with `accelerator='ddp_spawn'` would throw a pickle error ([#6924](https://github.com/PyTorchLightning/pytorch-lightning/pull/6924))
- Fixed support for `torch.nn.Module` type hints in `LightningCLI` ([#7807](https://github.com/PyTorchLightning/pytorch-lightning/pull/7807))
## [1.3.2] - 2021-05-18
### Changed

View File

@ -7,4 +7,4 @@ torchtext>=0.5
# onnx>=1.7.0
onnxruntime>=1.3.0
hydra-core>=1.0
jsonargparse[signatures]>=3.12.0
jsonargparse[signatures]>=3.13.1

View File

@ -20,18 +20,26 @@ import sys
from argparse import Namespace
from contextlib import redirect_stdout
from io import StringIO
from typing import List, Optional
from unittest import mock
import pytest
import torch
import yaml
from packaging import version
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.plugins.environments import SLURMEnvironment
from pytorch_lightning.utilities import _TPU_AVAILABLE
from pytorch_lightning.utilities.cli import LightningArgumentParser, LightningCLI, SaveConfigCallback
from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE
from tests.helpers import BoringDataModule, BoringModel
torchvision_version = version.parse('0')
if _TORCHVISION_AVAILABLE:
torchvision_version = version.parse(__import__('torchvision').__version__)
@mock.patch('argparse.ArgumentParser.parse_args')
def test_default_args(mock_argparse, tmpdir):
@ -443,3 +451,49 @@ def test_lightning_cli_submodules(tmpdir):
assert cli.model.submodule2 == cli.config_init['model']['submodule2']
assert isinstance(cli.config_init['model']['submodule1'], BoringModel)
assert isinstance(cli.config_init['model']['submodule2'], BoringModel)
@pytest.mark.skipif(torchvision_version < version.parse('0.8.0'), reason='torchvision>=0.8.0 is required')
def test_lightning_cli_torch_modules(tmpdir):
class TestModule(BoringModel):
def __init__(
self,
activation: torch.nn.Module = None,
transform: Optional[List[torch.nn.Module]] = None,
):
super().__init__()
self.activation = activation
self.transform = transform
config = """model:
activation:
class_path: torch.nn.LeakyReLU
init_args:
negative_slope: 0.2
transform:
- class_path: torchvision.transforms.Resize
init_args:
size: 64
- class_path: torchvision.transforms.CenterCrop
init_args:
size: 64
"""
config_path = tmpdir / 'config.yaml'
with open(config_path, 'w') as f:
f.write(config)
cli_args = [
f'--trainer.default_root_dir={tmpdir}',
'--trainer.max_epochs=1',
f'--config={str(config_path)}',
]
with mock.patch('sys.argv', ['any.py'] + cli_args):
cli = LightningCLI(TestModule)
assert isinstance(cli.model.activation, torch.nn.LeakyReLU)
assert cli.model.activation.negative_slope == 0.2
assert len(cli.model.transform) == 2
assert all(isinstance(v, torch.nn.Module) for v in cli.model.transform)