diff --git a/pytorch_lightning/loggers/tensorboard.py b/pytorch_lightning/loggers/tensorboard.py index 613262dd4b..36169e0acb 100644 --- a/pytorch_lightning/loggers/tensorboard.py +++ b/pytorch_lightning/loggers/tensorboard.py @@ -101,7 +101,8 @@ class TensorBoardLogger(LightningLoggerBase): return self._experiment @rank_zero_only - def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: + def log_hyperparams(self, params: Union[Dict[str, Any], Namespace], + metrics: Optional[Dict[str, Any]] = None) -> None: params = self._convert_params(params) params = self._flatten_dict(params) sanitized_params = self._sanitize_params(params) @@ -114,7 +115,9 @@ class TensorBoardLogger(LightningLoggerBase): ) else: from torch.utils.tensorboard.summary import hparams - exp, ssi, sei = hparams(sanitized_params, {}) + if metrics is None: + metrics = {} + exp, ssi, sei = hparams(sanitized_params, metrics) writer = self.experiment._get_file_writer() writer.add_summary(exp) writer.add_summary(ssi)