Avoid torchscript export for Metric forward (#4428)

* Update metric.py

* add test

* Update CHANGELOG.md

* Update test_metric_lightning.py

* Update test_metric_lightning.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
This commit is contained in:
ananthsub 2020-11-03 14:02:02 -08:00 committed by GitHub
parent ee414d25be
commit 5d08559c03
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 43 additions and 2 deletions

View File

@ -19,7 +19,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added timeout for `tpu_device_exists` to ensure process does not hang indefinitely ([#4340](https://github.com/PyTorchLightning/pytorch-lightning/pull/4340))
- Added global step indexing to the checkpoint name for a better sub-epoch checkpointing experience ([#3807](https://github.com/PyTorchLightning/pytorch-lightning/pull/3807))
- Added global step indexing to the checkpoint name for a better sub-epoch checkpointing experience ([#3807](https://github.com/PyTorchLightning/pytorch-lightning/pull/3807))
### Changed
@ -47,6 +47,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed AMP unscale for `on_after_backward` ([#4439](https://github.com/PyTorchLightning/pytorch-lightning/pull/4439))
- Fixed TorchScript export when module includes Metrics ([#4428](https://github.com/PyTorchLightning/pytorch-lightning/pull/4428))
## [1.0.4] - 2020-10-27

View File

@ -145,6 +145,7 @@ class Metric(nn.Module, ABC):
self._defaults[name] = deepcopy(default)
self._reductions[name] = dist_reduce_fx
@torch.jit.unused
def forward(self, *args, **kwargs):
"""
Automatically calls ``update()``. Returns the metric value over inputs if ``compute_on_step`` is True.

View File

@ -1,5 +1,6 @@
import torch
import os
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.metrics import Metric
from tests.base.boring_model import BoringModel
@ -78,3 +79,41 @@ def test_metric_lightning_log(tmpdir):
logged = trainer.logged_metrics
assert torch.allclose(torch.tensor(logged["sum"]), model.sum)
def test_scriptable(tmpdir):
class TestModel(BoringModel):
def __init__(self):
super().__init__()
# the metric is not used in the module's `forward`
# so the module should be exportable to TorchScript
self.metric = SumMetric()
self.sum = 0.0
def training_step(self, batch, batch_idx):
x = batch
self.metric(x.sum())
self.sum += x.sum()
self.log("sum", self.metric, on_epoch=True, on_step=False)
return self.step(x)
model = TestModel()
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=2,
limit_val_batches=2,
max_epochs=1,
log_every_n_steps=1,
weights_summary=None,
logger=False,
checkpoint_callback=False,
)
trainer.fit(model)
rand_input = torch.randn(10, 32)
script_model = model.to_torchscript()
# test that we can still do inference
output = model(rand_input)
script_output = script_model(rand_input)
assert torch.allclose(output, script_output)