Fix support for torch Module type hints in LightningCLI (#7807)
* Fixed support for torch Module type hints in LightningCLI * - Fix issue with serializing values when type hint is Any. - Run unit test only on newer torchvision versions in which the base class is Module. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Minor change * Update CHANGELOG.md Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
parent
36770b22fd
commit
f34584001c
|
@ -185,6 +185,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Fixed a bug where `precision=64` with `accelerator='ddp_spawn'` would throw a pickle error ([#6924](https://github.com/PyTorchLightning/pytorch-lightning/pull/6924))
|
||||
|
||||
|
||||
- Fixed support for `torch.nn.Module` type hints in `LightningCLI` ([#7807](https://github.com/PyTorchLightning/pytorch-lightning/pull/7807))
|
||||
|
||||
|
||||
## [1.3.2] - 2021-05-18
|
||||
|
||||
### Changed
|
||||
|
|
|
@ -7,4 +7,4 @@ torchtext>=0.5
|
|||
# onnx>=1.7.0
|
||||
onnxruntime>=1.3.0
|
||||
hydra-core>=1.0
|
||||
jsonargparse[signatures]>=3.12.0
|
||||
jsonargparse[signatures]>=3.13.1
|
||||
|
|
|
@ -20,18 +20,26 @@ import sys
|
|||
from argparse import Namespace
|
||||
from contextlib import redirect_stdout
|
||||
from io import StringIO
|
||||
from typing import List, Optional
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import yaml
|
||||
from packaging import version
|
||||
|
||||
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
|
||||
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
|
||||
from pytorch_lightning.plugins.environments import SLURMEnvironment
|
||||
from pytorch_lightning.utilities import _TPU_AVAILABLE
|
||||
from pytorch_lightning.utilities.cli import LightningArgumentParser, LightningCLI, SaveConfigCallback
|
||||
from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE
|
||||
from tests.helpers import BoringDataModule, BoringModel
|
||||
|
||||
torchvision_version = version.parse('0')
|
||||
if _TORCHVISION_AVAILABLE:
|
||||
torchvision_version = version.parse(__import__('torchvision').__version__)
|
||||
|
||||
|
||||
@mock.patch('argparse.ArgumentParser.parse_args')
|
||||
def test_default_args(mock_argparse, tmpdir):
|
||||
|
@ -443,3 +451,49 @@ def test_lightning_cli_submodules(tmpdir):
|
|||
assert cli.model.submodule2 == cli.config_init['model']['submodule2']
|
||||
assert isinstance(cli.config_init['model']['submodule1'], BoringModel)
|
||||
assert isinstance(cli.config_init['model']['submodule2'], BoringModel)
|
||||
|
||||
|
||||
@pytest.mark.skipif(torchvision_version < version.parse('0.8.0'), reason='torchvision>=0.8.0 is required')
|
||||
def test_lightning_cli_torch_modules(tmpdir):
|
||||
|
||||
class TestModule(BoringModel):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
activation: torch.nn.Module = None,
|
||||
transform: Optional[List[torch.nn.Module]] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.activation = activation
|
||||
self.transform = transform
|
||||
|
||||
config = """model:
|
||||
activation:
|
||||
class_path: torch.nn.LeakyReLU
|
||||
init_args:
|
||||
negative_slope: 0.2
|
||||
transform:
|
||||
- class_path: torchvision.transforms.Resize
|
||||
init_args:
|
||||
size: 64
|
||||
- class_path: torchvision.transforms.CenterCrop
|
||||
init_args:
|
||||
size: 64
|
||||
"""
|
||||
config_path = tmpdir / 'config.yaml'
|
||||
with open(config_path, 'w') as f:
|
||||
f.write(config)
|
||||
|
||||
cli_args = [
|
||||
f'--trainer.default_root_dir={tmpdir}',
|
||||
'--trainer.max_epochs=1',
|
||||
f'--config={str(config_path)}',
|
||||
]
|
||||
|
||||
with mock.patch('sys.argv', ['any.py'] + cli_args):
|
||||
cli = LightningCLI(TestModule)
|
||||
|
||||
assert isinstance(cli.model.activation, torch.nn.LeakyReLU)
|
||||
assert cli.model.activation.negative_slope == 0.2
|
||||
assert len(cli.model.transform) == 2
|
||||
assert all(isinstance(v, torch.nn.Module) for v in cli.model.transform)
|
||||
|
|
Loading…
Reference in New Issue