Squeeze tensor while logging (#14489)
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
parent
6773df9387
commit
ce702fd40e
|
@ -136,6 +136,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Fixed an issue to keep downscaling the batch size in case there hasn't been even a single successful optimal batch size with `mode="power"` ([#14372](https://github.com/Lightning-AI/lightning/pull/14372))
|
||||
|
||||
|
||||
- Squeezed tensor values when logging with `LightningModule.log` ([#14489](https://github.com/Lightning-AI/lightning/pull/14489))
|
||||
|
||||
|
||||
- Fixed `WandbLogger` `save_dir` is not set after creation ([#14326](https://github.com/Lightning-AI/lightning/pull/14326))
|
||||
|
||||
|
||||
|
|
|
@ -423,8 +423,7 @@ class LightningModule(
|
|||
" but it should not contain information about `dataloader_idx`"
|
||||
)
|
||||
|
||||
value = apply_to_collection(value, numbers.Number, self.__to_tensor)
|
||||
apply_to_collection(value, torch.Tensor, self.__check_numel_1, name)
|
||||
value = apply_to_collection(value, (torch.Tensor, numbers.Number), self.__to_tensor, name)
|
||||
|
||||
if self.trainer._logger_connector.should_reset_tensors(self._current_fx_name):
|
||||
# if we started a new epoch (running its first batch) the hook name has changed
|
||||
|
@ -556,16 +555,15 @@ class LightningModule(
|
|||
def __check_allowed(v: Any, name: str, value: Any) -> None:
|
||||
raise ValueError(f"`self.log({name}, {value})` was called, but `{type(v).__name__}` values cannot be logged")
|
||||
|
||||
def __to_tensor(self, value: numbers.Number) -> Tensor:
|
||||
return torch.tensor(value, device=self.device)
|
||||
|
||||
@staticmethod
|
||||
def __check_numel_1(value: Tensor, name: str) -> None:
|
||||
def __to_tensor(self, value: Union[torch.Tensor, numbers.Number], name: str) -> Tensor:
|
||||
value = torch.tensor(value, device=self.device)
|
||||
if not torch.numel(value) == 1:
|
||||
raise ValueError(
|
||||
f"`self.log({name}, {value})` was called, but the tensor must have a single element."
|
||||
f" You can try doing `self.log({name}, {value}.mean())`"
|
||||
)
|
||||
value = value.squeeze()
|
||||
return value
|
||||
|
||||
def log_grad_norm(self, grad_norm_dict: Dict[str, float]) -> None:
|
||||
"""Override this method to change the default behaviour of ``log_grad_norm``.
|
||||
|
|
|
@ -29,6 +29,7 @@ from pytorch_lightning import callbacks, Trainer
|
|||
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, TQDMProgressBar
|
||||
from pytorch_lightning.core.module import LightningModule
|
||||
from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset, RandomDictDataset
|
||||
from pytorch_lightning.trainer.states import RunningStage
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from tests_pytorch.helpers.runif import RunIf
|
||||
|
||||
|
@ -836,3 +837,13 @@ def test_log_on_train_start(mock_log_metrics, tmpdir):
|
|||
|
||||
assert mock_log_metrics.mock_calls == [call(metrics={"foo": 123.0, "epoch": 0}, step=0)]
|
||||
assert trainer.max_epochs > 1
|
||||
|
||||
|
||||
def test_unsqueezed_tensor_logging():
|
||||
model = BoringModel()
|
||||
trainer = Trainer()
|
||||
trainer.state.stage = RunningStage.TRAINING
|
||||
model._current_fx_name = "training_step"
|
||||
model.trainer = trainer
|
||||
model.log("foo", torch.Tensor([1.2]))
|
||||
assert trainer.callback_metrics["foo"].ndim == 0
|
||||
|
|
Loading…
Reference in New Issue