Add `batch_size`, `rank_zero_only` arguments for `log_dict` to match `log` (#8628)

This commit is contained in:
Elad Segal 2021-08-04 01:05:34 +03:00 committed by GitHub
parent 98319f83bf
commit 08fba96b6c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 11 additions and 0 deletions

View File

@ -12,6 +12,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `state_id` property to the `Callback` base class ([#6886](https://github.com/PyTorchLightning/pytorch-lightning/pull/6886))
- Added `batch_size` and `rank_zero_only` arguments for `log_dict` to match `log` ([#8628](https://github.com/PyTorchLightning/pytorch-lightning/pull/8628))
- Added check for unique GPU ids ([#8666](https://github.com/PyTorchLightning/pytorch-lightning/pull/8666))

View File

@ -479,6 +479,8 @@ class LightningModule(
sync_dist_op: Optional[Any] = None, # todo: Remove in 1.6
sync_dist_group: Optional[Any] = None,
add_dataloader_idx: bool = True,
batch_size: Optional[int] = None,
rank_zero_only: Optional[bool] = None,
) -> None:
"""
Log a dictionary of values at once.
@ -502,6 +504,10 @@ class LightningModule(
add_dataloader_idx: if True, appends the index of the current dataloader to
the name (when using multiple). If False, user needs to give unique names for
each dataloader to not mix values
batch_size: Current batch_size. This will be directly inferred from the loaded batch,
but some data structures might need to explicitly provide it.
rank_zero_only: Whether the value will be logged only on rank 0. This will prevent synchronization which
would produce a deadlock as not all processes would perform this log call.
"""
for k, v in dictionary.items():
self.log(
@ -519,6 +525,8 @@ class LightningModule(
tbptt_pad_token=tbptt_pad_token,
tbptt_reduce_fx=tbptt_reduce_fx,
add_dataloader_idx=add_dataloader_idx,
batch_size=batch_size,
rank_zero_only=rank_zero_only,
)
@staticmethod