Avoid torchscript export for Metric forward (#4428)
* Update metric.py * add test * Update CHANGELOG.md * Update test_metric_lightning.py * Update test_metric_lightning.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
This commit is contained in:
parent
ee414d25be
commit
5d08559c03
|
@ -19,7 +19,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
- Added timeout for `tpu_device_exists` to ensure process does not hang indefinitely ([#4340](https://github.com/PyTorchLightning/pytorch-lightning/pull/4340))
|
||||
|
||||
- Added global step indexing to the checkpoint name for a better sub-epoch checkpointing experience ([#3807](https://github.com/PyTorchLightning/pytorch-lightning/pull/3807))
|
||||
- Added global step indexing to the checkpoint name for a better sub-epoch checkpointing experience ([#3807](https://github.com/PyTorchLightning/pytorch-lightning/pull/3807))
|
||||
|
||||
### Changed
|
||||
|
||||
|
@ -47,6 +47,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
- Fixed AMP unscale for `on_after_backward` ([#4439](https://github.com/PyTorchLightning/pytorch-lightning/pull/4439))
|
||||
|
||||
- Fixed TorchScript export when module includes Metrics ([#4428](https://github.com/PyTorchLightning/pytorch-lightning/pull/4428))
|
||||
|
||||
## [1.0.4] - 2020-10-27
|
||||
|
||||
|
|
|
@ -145,6 +145,7 @@ class Metric(nn.Module, ABC):
|
|||
self._defaults[name] = deepcopy(default)
|
||||
self._reductions[name] = dist_reduce_fx
|
||||
|
||||
@torch.jit.unused
|
||||
def forward(self, *args, **kwargs):
|
||||
"""
|
||||
Automatically calls ``update()``. Returns the metric value over inputs if ``compute_on_step`` is True.
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import torch
|
||||
import os
|
||||
|
||||
import torch
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.metrics import Metric
|
||||
from tests.base.boring_model import BoringModel
|
||||
|
@ -78,3 +79,41 @@ def test_metric_lightning_log(tmpdir):
|
|||
|
||||
logged = trainer.logged_metrics
|
||||
assert torch.allclose(torch.tensor(logged["sum"]), model.sum)
|
||||
|
||||
|
||||
def test_scriptable(tmpdir):
|
||||
class TestModel(BoringModel):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# the metric is not used in the module's `forward`
|
||||
# so the module should be exportable to TorchScript
|
||||
self.metric = SumMetric()
|
||||
self.sum = 0.0
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
x = batch
|
||||
self.metric(x.sum())
|
||||
self.sum += x.sum()
|
||||
self.log("sum", self.metric, on_epoch=True, on_step=False)
|
||||
return self.step(x)
|
||||
|
||||
model = TestModel()
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
limit_train_batches=2,
|
||||
limit_val_batches=2,
|
||||
max_epochs=1,
|
||||
log_every_n_steps=1,
|
||||
weights_summary=None,
|
||||
logger=False,
|
||||
checkpoint_callback=False,
|
||||
)
|
||||
trainer.fit(model)
|
||||
rand_input = torch.randn(10, 32)
|
||||
|
||||
script_model = model.to_torchscript()
|
||||
|
||||
# test that we can still do inference
|
||||
output = model(rand_input)
|
||||
script_output = script_model(rand_input)
|
||||
assert torch.allclose(output, script_output)
|
||||
|
|
Loading…
Reference in New Issue