From ce702fd40e5301cdb27a17bc9c22aac1b8299a41 Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Mon, 5 Sep 2022 19:31:51 +0530 Subject: [PATCH] Squeeze tensor while logging (#14489) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- src/pytorch_lightning/CHANGELOG.md | 3 +++ src/pytorch_lightning/core/module.py | 12 +++++------- .../trainer/logging_/test_train_loop_logging.py | 11 +++++++++++ 3 files changed, 19 insertions(+), 7 deletions(-) diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index ba8bf05f49..cae69427ac 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -136,6 +136,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed an issue to keep downscaling the batch size in case there hasn't been even a single successful optimal batch size with `mode="power"` ([#14372](https://github.com/Lightning-AI/lightning/pull/14372)) +- Squeezed tensor values when logging with `LightningModule.log` ([#14489](https://github.com/Lightning-AI/lightning/pull/14489)) + + - Fixed `WandbLogger` `save_dir` is not set after creation ([#14326](https://github.com/Lightning-AI/lightning/pull/14326)) diff --git a/src/pytorch_lightning/core/module.py b/src/pytorch_lightning/core/module.py index a479beadc7..a8fea8c210 100644 --- a/src/pytorch_lightning/core/module.py +++ b/src/pytorch_lightning/core/module.py @@ -423,8 +423,7 @@ class LightningModule( " but it should not contain information about `dataloader_idx`" ) - value = apply_to_collection(value, numbers.Number, self.__to_tensor) - apply_to_collection(value, torch.Tensor, self.__check_numel_1, name) + value = apply_to_collection(value, (torch.Tensor, numbers.Number), self.__to_tensor, name) if self.trainer._logger_connector.should_reset_tensors(self._current_fx_name): # if we started a new epoch (running its first batch) the hook name has changed @@ -556,16 +555,15 @@ class LightningModule( def __check_allowed(v: Any, name: str, value: Any) -> None: raise ValueError(f"`self.log({name}, {value})` was called, but `{type(v).__name__}` values cannot be logged") - def __to_tensor(self, value: numbers.Number) -> Tensor: - return torch.tensor(value, device=self.device) - - @staticmethod - def __check_numel_1(value: Tensor, name: str) -> None: + def __to_tensor(self, value: Union[torch.Tensor, numbers.Number], name: str) -> Tensor: + value = torch.tensor(value, device=self.device) if not torch.numel(value) == 1: raise ValueError( f"`self.log({name}, {value})` was called, but the tensor must have a single element." f" You can try doing `self.log({name}, {value}.mean())`" ) + value = value.squeeze() + return value def log_grad_norm(self, grad_norm_dict: Dict[str, float]) -> None: """Override this method to change the default behaviour of ``log_grad_norm``. diff --git a/tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py b/tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py index 85ed3d8e34..cd7f83ddc7 100644 --- a/tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py +++ b/tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py @@ -29,6 +29,7 @@ from pytorch_lightning import callbacks, Trainer from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, TQDMProgressBar from pytorch_lightning.core.module import LightningModule from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset, RandomDictDataset +from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests_pytorch.helpers.runif import RunIf @@ -836,3 +837,13 @@ def test_log_on_train_start(mock_log_metrics, tmpdir): assert mock_log_metrics.mock_calls == [call(metrics={"foo": 123.0, "epoch": 0}, step=0)] assert trainer.max_epochs > 1 + + +def test_unsqueezed_tensor_logging(): + model = BoringModel() + trainer = Trainer() + trainer.state.stage = RunningStage.TRAINING + model._current_fx_name = "training_step" + model.trainer = trainer + model.log("foo", torch.Tensor([1.2])) + assert trainer.callback_metrics["foo"].ndim == 0