Simplify logger connector access (#8318)

This commit is contained in:
Carlos Mocholí 2021-07-07 14:13:30 +02:00 committed by GitHub
parent d73c32ab51
commit 9877265887
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 18 additions and 19 deletions

View File

@ -629,7 +629,7 @@ class ModelCheckpoint(Callback):
self._fs.makedirs(self.dirpath, exist_ok=True)
def _add_backward_monitor_support(self, trainer: 'pl.Trainer') -> None:
metrics = trainer.logger_connector.callback_metrics
metrics = trainer.callback_metrics
deprecation_warning = False
if self.monitor is None and 'val_loss' in metrics:
@ -648,7 +648,7 @@ class ModelCheckpoint(Callback):
)
def _validate_monitor_key(self, trainer: 'pl.Trainer') -> None:
metrics = trainer.logger_connector.callback_metrics
metrics = trainer.callback_metrics
# validate metric
if self.monitor is not None and not self._is_valid_monitor_key(metrics):
@ -678,7 +678,7 @@ class ModelCheckpoint(Callback):
return filepath
def _monitor_candidates(self, trainer: 'pl.Trainer', epoch: int, step: int) -> Dict[str, _METRIC]:
monitor_candidates = deepcopy(trainer.logger_connector.callback_metrics)
monitor_candidates = deepcopy(trainer.callback_metrics)
monitor_candidates.update(epoch=epoch, step=step)
return monitor_candidates

View File

@ -64,10 +64,10 @@ class OptimizerConnector:
monitor_key, monitor_val = None, None
if lr_scheduler['reduce_on_plateau']:
monitor_key = lr_scheduler['monitor']
monitor_val = self.trainer.logger_connector.callback_metrics.get(monitor_key)
monitor_val = self.trainer.callback_metrics.get(monitor_key)
if monitor_val is None:
if lr_scheduler.get('strict', True):
avail_metrics = list(self.trainer.logger_connector.callback_metrics.keys())
avail_metrics = list(self.trainer.callback_metrics)
raise MisconfigurationException(
f'ReduceLROnPlateau conditioned on metric {monitor_key}'
f' which is not available. Available metrics are: {avail_metrics}.'

View File

@ -395,7 +395,7 @@ def test_log_works_in_val_callback(tmpdir):
for fx, attrs in cb.logged_arguments.items():
should_include = attrs["prog_bar"] and attrs["on_step"] ^ attrs["on_epoch"]
is_included = fx in trainer.logger_connector.progress_bar_metrics
is_included = fx in trainer.progress_bar_metrics
assert is_included if should_include else not is_included
@ -529,7 +529,7 @@ def test_log_works_in_test_callback(tmpdir):
for fx, attrs in cb.funcs_attr.items():
should_include = attrs["prog_bar"] and attrs["on_step"] ^ attrs["on_epoch"]
is_included = fx in trainer.logger_connector.progress_bar_metrics
is_included = fx in trainer.progress_bar_metrics
assert is_included if should_include else not is_included

View File

@ -451,7 +451,7 @@ def test_log_works_in_train_callback(tmpdir):
for fx, attrs in cb.logged_arguments.items():
should_include = attrs["prog_bar"] and attrs["on_step"] ^ attrs["on_epoch"]
is_included = fx in trainer.logger_connector.progress_bar_metrics
is_included = fx in trainer.progress_bar_metrics
assert is_included if should_include else not is_included
@ -590,7 +590,6 @@ def test_logging_in_callbacks_with_log_function(tmpdir):
def on_train_epoch_end(self, trainer, pl_module, outputs):
self.log("on_train_epoch_end", 6)
self.callback_metrics = trainer.logger_connector.callback_metrics
model = BoringModel()
trainer = Trainer(
@ -757,9 +756,9 @@ def test_sanity_metrics_are_reset(tmpdir):
def training_step(self, batch, batch_idx):
loss = super().training_step(batch, batch_idx)
if batch_idx == 0:
assert self.trainer.logger_connector._progress_bar_metrics == {}
assert self.trainer.logger_connector._logged_metrics == {}
assert self.trainer.logger_connector._callback_metrics == {}
assert self.trainer.progress_bar_metrics == {}
assert self.trainer.logged_metrics == {}
assert self.trainer.callback_metrics == {}
self.log("train_loss", loss, prog_bar=True, logger=True)
return loss

View File

@ -143,8 +143,8 @@ def test__training_step__epoch_end__flow_scalar(tmpdir):
assert model.training_epoch_end_called
# assert epoch end metrics were added
assert len(trainer.logger_connector.callback_metrics) == 0
assert len(trainer.logger_connector.progress_bar_metrics) == 0
assert len(trainer.callback_metrics) == 0
assert len(trainer.progress_bar_metrics) == 0
trainer.state.stage = RunningStage.TRAINING
# make sure training outputs what is expected
@ -221,8 +221,8 @@ def test__training_step__step_end__epoch_end__flow_scalar(tmpdir):
assert model.training_epoch_end_called
# assert epoch end metrics were added
assert len(trainer.logger_connector.callback_metrics) == 0
assert len(trainer.logger_connector.progress_bar_metrics) == 0
assert len(trainer.callback_metrics) == 0
assert len(trainer.progress_bar_metrics) == 0
trainer.state.stage = RunningStage.TRAINING
# make sure training outputs what is expected

View File

@ -575,8 +575,8 @@ def test_step_with_optimizer_closure(tmpdir):
trainer.fit(model)
assert trainer.dev_debugger.count_events('backward_call') == limit_train_batches * 2
assert trainer.logger_connector.progress_bar_metrics["train_loss_step"] == model._losses[-1]
assert trainer.logger_connector.progress_bar_metrics["train_loss_epoch"] == torch.stack(model._losses).mean()
assert trainer.progress_bar_metrics["train_loss_step"] == model._losses[-1]
assert trainer.progress_bar_metrics["train_loss_epoch"] == torch.stack(model._losses).mean()
@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})

View File

@ -373,7 +373,7 @@ def test_model_checkpoint_options(tmpdir, save_top_k, save_last, expected_files)
for i, loss in enumerate(losses):
trainer.fit_loop.current_epoch = i
trainer.fit_loop.global_step = i
trainer.logger_connector.callback_metrics.update({"checkpoint_on": loss})
trainer.callback_metrics.update({"checkpoint_on": loss})
checkpoint_callback.on_validation_end(trainer, trainer.lightning_module)
file_lists = set(os.listdir(tmpdir))