Avoid in-place ops during logging result updates (#11401)
Co-authored-by: rohitgr7 <rohitgr1998@gmail.com>
This commit is contained in:
parent
221091afc4
commit
f5bbc2cf17
|
@ -414,6 +414,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Fixed wrong typehint for `Trainer.lightning_optimizers` ([#11155](https://github.com/PyTorchLightning/pytorch-lightning/pull/11155))
|
||||
|
||||
|
||||
- Fixed type promotion when tensors of higher category than float are logged ([#11401](https://github.com/PyTorchLightning/pytorch-lightning/pull/11401))
|
||||
|
||||
|
||||
- Fixed the lr-scheduler state not being dumped to checkpoint when using the deepspeed strategy ([#11307](https://github.com/PyTorchLightning/pytorch-lightning/pull/11307))
|
||||
|
||||
|
||||
|
|
|
@ -216,6 +216,7 @@ class _ResultMetric(Metric, DeviceDtypeModuleMixin):
|
|||
# do not set a dtype in case the default dtype was changed
|
||||
self.add_state("value", torch.tensor(default), dist_reduce_fx=torch.sum)
|
||||
if self.meta.is_mean_reduction:
|
||||
self.cumulated_batch_size: torch.Tensor
|
||||
self.add_state("cumulated_batch_size", torch.tensor(0), dist_reduce_fx=torch.sum)
|
||||
# this is defined here only because upstream is missing the type annotation
|
||||
self._forward_cache: Optional[Any] = None
|
||||
|
@ -241,14 +242,13 @@ class _ResultMetric(Metric, DeviceDtypeModuleMixin):
|
|||
|
||||
# perform accumulation with reduction
|
||||
if self.meta.is_mean_reduction:
|
||||
self.value += value.mean() * batch_size
|
||||
# `Metric.add_state` does not work well with mypy, mypy doesn't know this is a `Tensor`
|
||||
# we could add an assertion, but this is a hot code path
|
||||
self.cumulated_batch_size += batch_size # type: ignore[operator]
|
||||
# do not use `+=` as it doesn't do type promotion
|
||||
self.value = self.value + value.mean() * batch_size
|
||||
self.cumulated_batch_size = self.cumulated_batch_size + batch_size
|
||||
elif self.meta.is_max_reduction or self.meta.is_min_reduction:
|
||||
self.value = self.meta.reduce_fx(self.value, value.mean())
|
||||
elif self.meta.is_sum_reduction:
|
||||
self.value += value.mean()
|
||||
self.value = self.value + value.mean()
|
||||
else:
|
||||
value = cast(Metric, value)
|
||||
self.value = value
|
||||
|
|
|
@ -590,6 +590,26 @@ def test_metric_result_respects_dtype(floating_dtype):
|
|||
torch.set_default_dtype(torch.float)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("reduce_fx", ("mean", sum))
|
||||
def test_metric_result_dtype_promotion(reduce_fx):
|
||||
metadata = _Metadata("foo", "bar", reduce_fx=reduce_fx)
|
||||
metadata.sync = _Sync()
|
||||
rm = _ResultMetric(metadata, is_tensor=True)
|
||||
assert rm.value.dtype == torch.float
|
||||
|
||||
# log a double
|
||||
rm.update(torch.tensor(0, dtype=torch.double), 1)
|
||||
# `rm.value.dtype` is promoted
|
||||
assert rm.value.dtype == torch.double
|
||||
# log a float
|
||||
rm.update(torch.tensor(0, dtype=torch.float), 1)
|
||||
# the previous dtype stays
|
||||
assert rm.value.dtype == torch.double
|
||||
|
||||
total = rm.compute()
|
||||
assert total.dtype == torch.double
|
||||
|
||||
|
||||
@pytest.mark.parametrize(["reduce_fx", "expected"], [(max, -2), (min, 2)])
|
||||
def test_result_metric_max_min(reduce_fx, expected):
|
||||
metadata = _Metadata("foo", "bar", reduce_fx=reduce_fx)
|
||||
|
|
Loading…
Reference in New Issue