Squeeze tensor while logging (#14489)

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
Rohit Gupta 2022-09-05 19:31:51 +05:30 committed by GitHub
parent 6773df9387
commit ce702fd40e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 19 additions and 7 deletions

View File

@ -136,6 +136,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed an issue to keep downscaling the batch size in case there hasn't been even a single successful optimal batch size with `mode="power"` ([#14372](https://github.com/Lightning-AI/lightning/pull/14372))
- Squeezed tensor values when logging with `LightningModule.log` ([#14489](https://github.com/Lightning-AI/lightning/pull/14489))
- Fixed `WandbLogger` `save_dir` is not set after creation ([#14326](https://github.com/Lightning-AI/lightning/pull/14326))

View File

@ -423,8 +423,7 @@ class LightningModule(
" but it should not contain information about `dataloader_idx`"
)
value = apply_to_collection(value, numbers.Number, self.__to_tensor)
apply_to_collection(value, torch.Tensor, self.__check_numel_1, name)
value = apply_to_collection(value, (torch.Tensor, numbers.Number), self.__to_tensor, name)
if self.trainer._logger_connector.should_reset_tensors(self._current_fx_name):
# if we started a new epoch (running its first batch) the hook name has changed
@ -556,16 +555,15 @@ class LightningModule(
def __check_allowed(v: Any, name: str, value: Any) -> None:
raise ValueError(f"`self.log({name}, {value})` was called, but `{type(v).__name__}` values cannot be logged")
def __to_tensor(self, value: numbers.Number) -> Tensor:
return torch.tensor(value, device=self.device)
@staticmethod
def __check_numel_1(value: Tensor, name: str) -> None:
def __to_tensor(self, value: Union[torch.Tensor, numbers.Number], name: str) -> Tensor:
value = torch.tensor(value, device=self.device)
if not torch.numel(value) == 1:
raise ValueError(
f"`self.log({name}, {value})` was called, but the tensor must have a single element."
f" You can try doing `self.log({name}, {value}.mean())`"
)
value = value.squeeze()
return value
def log_grad_norm(self, grad_norm_dict: Dict[str, float]) -> None:
"""Override this method to change the default behaviour of ``log_grad_norm``.

View File

@ -29,6 +29,7 @@ from pytorch_lightning import callbacks, Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, TQDMProgressBar
from pytorch_lightning.core.module import LightningModule
from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset, RandomDictDataset
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests_pytorch.helpers.runif import RunIf
@ -836,3 +837,13 @@ def test_log_on_train_start(mock_log_metrics, tmpdir):
assert mock_log_metrics.mock_calls == [call(metrics={"foo": 123.0, "epoch": 0}, step=0)]
assert trainer.max_epochs > 1
def test_unsqueezed_tensor_logging():
model = BoringModel()
trainer = Trainer()
trainer.state.stage = RunningStage.TRAINING
model._current_fx_name = "training_step"
model.trainer = trainer
model.log("foo", torch.Tensor([1.2]))
assert trainer.callback_metrics["foo"].ndim == 0