Add `batch_size`, `rank_zero_only` arguments for `log_dict` to match `log` (#8628)

This commit is contained in:
Elad Segal 2021-08-04 01:05:34 +03:00 committed by GitHub
parent 98319f83bf
commit 08fba96b6c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 11 additions and 0 deletions

View File

@ -12,6 +12,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `state_id` property to the `Callback` base class ([#6886](https://github.com/PyTorchLightning/pytorch-lightning/pull/6886)) - Added `state_id` property to the `Callback` base class ([#6886](https://github.com/PyTorchLightning/pytorch-lightning/pull/6886))
- Added `batch_size` and `rank_zero_only` arguments for `log_dict` to match `log` ([#8628](https://github.com/PyTorchLightning/pytorch-lightning/pull/8628))
- Added check for unique GPU ids ([#8666](https://github.com/PyTorchLightning/pytorch-lightning/pull/8666)) - Added check for unique GPU ids ([#8666](https://github.com/PyTorchLightning/pytorch-lightning/pull/8666))

View File

@ -479,6 +479,8 @@ class LightningModule(
sync_dist_op: Optional[Any] = None, # todo: Remove in 1.6 sync_dist_op: Optional[Any] = None, # todo: Remove in 1.6
sync_dist_group: Optional[Any] = None, sync_dist_group: Optional[Any] = None,
add_dataloader_idx: bool = True, add_dataloader_idx: bool = True,
batch_size: Optional[int] = None,
rank_zero_only: Optional[bool] = None,
) -> None: ) -> None:
""" """
Log a dictionary of values at once. Log a dictionary of values at once.
@ -502,6 +504,10 @@ class LightningModule(
add_dataloader_idx: if True, appends the index of the current dataloader to add_dataloader_idx: if True, appends the index of the current dataloader to
the name (when using multiple). If False, user needs to give unique names for the name (when using multiple). If False, user needs to give unique names for
each dataloader to not mix values each dataloader to not mix values
batch_size: Current batch_size. This will be directly inferred from the loaded batch,
but some data structures might need to explicitly provide it.
rank_zero_only: Whether the value will be logged only on rank 0. This will prevent synchronization which
would produce a deadlock as not all processes would perform this log call.
""" """
for k, v in dictionary.items(): for k, v in dictionary.items():
self.log( self.log(
@ -519,6 +525,8 @@ class LightningModule(
tbptt_pad_token=tbptt_pad_token, tbptt_pad_token=tbptt_pad_token,
tbptt_reduce_fx=tbptt_reduce_fx, tbptt_reduce_fx=tbptt_reduce_fx,
add_dataloader_idx=add_dataloader_idx, add_dataloader_idx=add_dataloader_idx,
batch_size=batch_size,
rank_zero_only=rank_zero_only,
) )
@staticmethod @staticmethod