fix dtype conversion of example_input_array in model summary (#2510)

* fix dtype conversion

* changelog
This commit is contained in:
Adrian Wälchli 2020-07-05 13:17:22 +02:00 committed by GitHub
parent b6507daf89
commit 6bfcfa8671
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 34 additions and 8 deletions

View File

@ -23,6 +23,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed using the same DDP python interpreter and actually running ([#2482](https://github.com/PyTorchLightning/pytorch-lightning/pull/2482))
- Fixed model summary input type conversion for models that have input dtype different from model parameters ([#2510](https://github.com/PyTorchLightning/pytorch-lightning/pull/2510))
## [0.8.4] - 2020-07-01

View File

@ -208,7 +208,6 @@ class ModelSummary(object):
input_ = model.example_input_array
input_ = model.transfer_batch_to_device(input_, model.device)
input_ = apply_to_collection(input_, torch.Tensor, lambda x: x.type(model.dtype))
if trainer is not None and trainer.use_amp:
if NATIVE_AMP_AVALAIBLE:

View File

@ -42,6 +42,19 @@ class UnorderedModel(LightningModule):
return out
class MixedDtypeModel(LightningModule):
""" The parameters and inputs of this model have different dtypes. """
def __init__(self):
super().__init__()
self.embed = nn.Embedding(10, 20) # expects dtype long as input
self.reduce = nn.Linear(20, 1) # dtype: float
self.example_input_array = torch.tensor([[0, 2, 1], [3, 5, 3]]) # dtype: long
def forward(self, x):
return self.reduce(self.embed(x))
@pytest.mark.parametrize(['mode'], [
pytest.param(ModelSummary.MODE_FULL),
pytest.param(ModelSummary.MODE_TOP),
@ -59,15 +72,15 @@ def test_empty_model_summary_shapes(mode):
pytest.param(ModelSummary.MODE_FULL),
pytest.param(ModelSummary.MODE_TOP),
])
@pytest.mark.parametrize(['device', 'dtype'], [
pytest.param(torch.device('cpu'), torch.double),
pytest.param(torch.device('cuda', 0), torch.float),
pytest.param(torch.device('cuda', 0), torch.float16),
@pytest.mark.parametrize(['device'], [
pytest.param(torch.device('cpu')),
pytest.param(torch.device('cuda', 0)),
pytest.param(torch.device('cuda', 0)),
])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Test requires GPU.")
def test_linear_model_summary_shapes(device, dtype, mode):
def test_linear_model_summary_shapes(device, mode):
""" Test that the model summary correctly computes the input- and output shapes. """
model = UnorderedModel().type(dtype).to(device)
model = UnorderedModel().to(device)
model.train()
summary = model.summarize(mode=mode)
assert summary.in_sizes == [
@ -85,10 +98,23 @@ def test_linear_model_summary_shapes(device, dtype, mode):
UNKNOWN_SIZE,
]
assert model.training
assert model.dtype == dtype
assert model.device == device
def test_mixed_dtype_model_summary():
""" Test that the model summary works with models that have mixed input- and parameter dtypes. """
model = MixedDtypeModel()
summary = model.summarize()
assert summary.in_sizes == [
[2, 3], # embed
[2, 3, 20], # reduce
]
assert summary.out_sizes == [
[2, 3, 20], # embed
[2, 3, 1], # reduce
]
@pytest.mark.parametrize(['mode'], [
pytest.param(ModelSummary.MODE_FULL),
pytest.param(ModelSummary.MODE_TOP),