From 5d08559c0331ae700c3de4f4df9ecfd2f8782c83 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Tue, 3 Nov 2020 14:02:02 -0800 Subject: [PATCH] Avoid torchscript export for Metric forward (#4428) * Update metric.py * add test * Update CHANGELOG.md * Update test_metric_lightning.py * Update test_metric_lightning.py Co-authored-by: Jirka Borovec --- CHANGELOG.md | 3 +- pytorch_lightning/metrics/metric.py | 1 + tests/metrics/test_metric_lightning.py | 41 +++++++++++++++++++++++++- 3 files changed, 43 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8d9c3d8a1f..194d50ef37 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,7 +19,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added timeout for `tpu_device_exists` to ensure process does not hang indefinitely ([#4340](https://github.com/PyTorchLightning/pytorch-lightning/pull/4340)) -- Added global step indexing to the checkpoint name for a better sub-epoch checkpointing experience ([#3807](https://github.com/PyTorchLightning/pytorch-lightning/pull/3807)) +- Added global step indexing to the checkpoint name for a better sub-epoch checkpointing experience ([#3807](https://github.com/PyTorchLightning/pytorch-lightning/pull/3807)) ### Changed @@ -47,6 +47,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed AMP unscale for `on_after_backward` ([#4439](https://github.com/PyTorchLightning/pytorch-lightning/pull/4439)) +- Fixed TorchScript export when module includes Metrics ([#4428](https://github.com/PyTorchLightning/pytorch-lightning/pull/4428)) ## [1.0.4] - 2020-10-27 diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index f003e0d3da..b716817427 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -145,6 +145,7 @@ class Metric(nn.Module, ABC): self._defaults[name] = deepcopy(default) self._reductions[name] = dist_reduce_fx + @torch.jit.unused def forward(self, *args, **kwargs): """ Automatically calls ``update()``. Returns the metric value over inputs if ``compute_on_step`` is True. diff --git a/tests/metrics/test_metric_lightning.py b/tests/metrics/test_metric_lightning.py index 7a860ea6c1..3c6938734b 100644 --- a/tests/metrics/test_metric_lightning.py +++ b/tests/metrics/test_metric_lightning.py @@ -1,5 +1,6 @@ -import torch +import os +import torch from pytorch_lightning import Trainer from pytorch_lightning.metrics import Metric from tests.base.boring_model import BoringModel @@ -78,3 +79,41 @@ def test_metric_lightning_log(tmpdir): logged = trainer.logged_metrics assert torch.allclose(torch.tensor(logged["sum"]), model.sum) + + +def test_scriptable(tmpdir): + class TestModel(BoringModel): + def __init__(self): + super().__init__() + # the metric is not used in the module's `forward` + # so the module should be exportable to TorchScript + self.metric = SumMetric() + self.sum = 0.0 + + def training_step(self, batch, batch_idx): + x = batch + self.metric(x.sum()) + self.sum += x.sum() + self.log("sum", self.metric, on_epoch=True, on_step=False) + return self.step(x) + + model = TestModel() + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=1, + log_every_n_steps=1, + weights_summary=None, + logger=False, + checkpoint_callback=False, + ) + trainer.fit(model) + rand_input = torch.randn(10, 32) + + script_model = model.to_torchscript() + + # test that we can still do inference + output = model(rand_input) + script_output = script_model(rand_input) + assert torch.allclose(output, script_output)