From 08fba96b6c1c0aa8528ec7caff05c50635f61ba1 Mon Sep 17 00:00:00 2001 From: Elad Segal Date: Wed, 4 Aug 2021 01:05:34 +0300 Subject: [PATCH] Add `batch_size`, `rank_zero_only` arguments for `log_dict` to match `log` (#8628) --- CHANGELOG.md | 3 +++ pytorch_lightning/core/lightning.py | 8 ++++++++ 2 files changed, 11 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index ff4a53ed8e..4c42eeb239 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `state_id` property to the `Callback` base class ([#6886](https://github.com/PyTorchLightning/pytorch-lightning/pull/6886)) +- Added `batch_size` and `rank_zero_only` arguments for `log_dict` to match `log` ([#8628](https://github.com/PyTorchLightning/pytorch-lightning/pull/8628)) + + - Added check for unique GPU ids ([#8666](https://github.com/PyTorchLightning/pytorch-lightning/pull/8666)) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 0d77761761..2bb15c9cfa 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -479,6 +479,8 @@ class LightningModule( sync_dist_op: Optional[Any] = None, # todo: Remove in 1.6 sync_dist_group: Optional[Any] = None, add_dataloader_idx: bool = True, + batch_size: Optional[int] = None, + rank_zero_only: Optional[bool] = None, ) -> None: """ Log a dictionary of values at once. @@ -502,6 +504,10 @@ class LightningModule( add_dataloader_idx: if True, appends the index of the current dataloader to the name (when using multiple). If False, user needs to give unique names for each dataloader to not mix values + batch_size: Current batch_size. This will be directly inferred from the loaded batch, + but some data structures might need to explicitly provide it. + rank_zero_only: Whether the value will be logged only on rank 0. This will prevent synchronization which + would produce a deadlock as not all processes would perform this log call. """ for k, v in dictionary.items(): self.log( @@ -519,6 +525,8 @@ class LightningModule( tbptt_pad_token=tbptt_pad_token, tbptt_reduce_fx=tbptt_reduce_fx, add_dataloader_idx=add_dataloader_idx, + batch_size=batch_size, + rank_zero_only=rank_zero_only, ) @staticmethod