Simplify logger connector access (#8318)
This commit is contained in:
parent
d73c32ab51
commit
9877265887
|
@ -629,7 +629,7 @@ class ModelCheckpoint(Callback):
|
|||
self._fs.makedirs(self.dirpath, exist_ok=True)
|
||||
|
||||
def _add_backward_monitor_support(self, trainer: 'pl.Trainer') -> None:
|
||||
metrics = trainer.logger_connector.callback_metrics
|
||||
metrics = trainer.callback_metrics
|
||||
deprecation_warning = False
|
||||
|
||||
if self.monitor is None and 'val_loss' in metrics:
|
||||
|
@ -648,7 +648,7 @@ class ModelCheckpoint(Callback):
|
|||
)
|
||||
|
||||
def _validate_monitor_key(self, trainer: 'pl.Trainer') -> None:
|
||||
metrics = trainer.logger_connector.callback_metrics
|
||||
metrics = trainer.callback_metrics
|
||||
|
||||
# validate metric
|
||||
if self.monitor is not None and not self._is_valid_monitor_key(metrics):
|
||||
|
@ -678,7 +678,7 @@ class ModelCheckpoint(Callback):
|
|||
return filepath
|
||||
|
||||
def _monitor_candidates(self, trainer: 'pl.Trainer', epoch: int, step: int) -> Dict[str, _METRIC]:
|
||||
monitor_candidates = deepcopy(trainer.logger_connector.callback_metrics)
|
||||
monitor_candidates = deepcopy(trainer.callback_metrics)
|
||||
monitor_candidates.update(epoch=epoch, step=step)
|
||||
return monitor_candidates
|
||||
|
||||
|
|
|
@ -64,10 +64,10 @@ class OptimizerConnector:
|
|||
monitor_key, monitor_val = None, None
|
||||
if lr_scheduler['reduce_on_plateau']:
|
||||
monitor_key = lr_scheduler['monitor']
|
||||
monitor_val = self.trainer.logger_connector.callback_metrics.get(monitor_key)
|
||||
monitor_val = self.trainer.callback_metrics.get(monitor_key)
|
||||
if monitor_val is None:
|
||||
if lr_scheduler.get('strict', True):
|
||||
avail_metrics = list(self.trainer.logger_connector.callback_metrics.keys())
|
||||
avail_metrics = list(self.trainer.callback_metrics)
|
||||
raise MisconfigurationException(
|
||||
f'ReduceLROnPlateau conditioned on metric {monitor_key}'
|
||||
f' which is not available. Available metrics are: {avail_metrics}.'
|
||||
|
|
|
@ -395,7 +395,7 @@ def test_log_works_in_val_callback(tmpdir):
|
|||
|
||||
for fx, attrs in cb.logged_arguments.items():
|
||||
should_include = attrs["prog_bar"] and attrs["on_step"] ^ attrs["on_epoch"]
|
||||
is_included = fx in trainer.logger_connector.progress_bar_metrics
|
||||
is_included = fx in trainer.progress_bar_metrics
|
||||
assert is_included if should_include else not is_included
|
||||
|
||||
|
||||
|
@ -529,7 +529,7 @@ def test_log_works_in_test_callback(tmpdir):
|
|||
|
||||
for fx, attrs in cb.funcs_attr.items():
|
||||
should_include = attrs["prog_bar"] and attrs["on_step"] ^ attrs["on_epoch"]
|
||||
is_included = fx in trainer.logger_connector.progress_bar_metrics
|
||||
is_included = fx in trainer.progress_bar_metrics
|
||||
assert is_included if should_include else not is_included
|
||||
|
||||
|
||||
|
|
|
@ -451,7 +451,7 @@ def test_log_works_in_train_callback(tmpdir):
|
|||
|
||||
for fx, attrs in cb.logged_arguments.items():
|
||||
should_include = attrs["prog_bar"] and attrs["on_step"] ^ attrs["on_epoch"]
|
||||
is_included = fx in trainer.logger_connector.progress_bar_metrics
|
||||
is_included = fx in trainer.progress_bar_metrics
|
||||
assert is_included if should_include else not is_included
|
||||
|
||||
|
||||
|
@ -590,7 +590,6 @@ def test_logging_in_callbacks_with_log_function(tmpdir):
|
|||
|
||||
def on_train_epoch_end(self, trainer, pl_module, outputs):
|
||||
self.log("on_train_epoch_end", 6)
|
||||
self.callback_metrics = trainer.logger_connector.callback_metrics
|
||||
|
||||
model = BoringModel()
|
||||
trainer = Trainer(
|
||||
|
@ -757,9 +756,9 @@ def test_sanity_metrics_are_reset(tmpdir):
|
|||
def training_step(self, batch, batch_idx):
|
||||
loss = super().training_step(batch, batch_idx)
|
||||
if batch_idx == 0:
|
||||
assert self.trainer.logger_connector._progress_bar_metrics == {}
|
||||
assert self.trainer.logger_connector._logged_metrics == {}
|
||||
assert self.trainer.logger_connector._callback_metrics == {}
|
||||
assert self.trainer.progress_bar_metrics == {}
|
||||
assert self.trainer.logged_metrics == {}
|
||||
assert self.trainer.callback_metrics == {}
|
||||
self.log("train_loss", loss, prog_bar=True, logger=True)
|
||||
return loss
|
||||
|
||||
|
|
|
@ -143,8 +143,8 @@ def test__training_step__epoch_end__flow_scalar(tmpdir):
|
|||
assert model.training_epoch_end_called
|
||||
|
||||
# assert epoch end metrics were added
|
||||
assert len(trainer.logger_connector.callback_metrics) == 0
|
||||
assert len(trainer.logger_connector.progress_bar_metrics) == 0
|
||||
assert len(trainer.callback_metrics) == 0
|
||||
assert len(trainer.progress_bar_metrics) == 0
|
||||
|
||||
trainer.state.stage = RunningStage.TRAINING
|
||||
# make sure training outputs what is expected
|
||||
|
@ -221,8 +221,8 @@ def test__training_step__step_end__epoch_end__flow_scalar(tmpdir):
|
|||
assert model.training_epoch_end_called
|
||||
|
||||
# assert epoch end metrics were added
|
||||
assert len(trainer.logger_connector.callback_metrics) == 0
|
||||
assert len(trainer.logger_connector.progress_bar_metrics) == 0
|
||||
assert len(trainer.callback_metrics) == 0
|
||||
assert len(trainer.progress_bar_metrics) == 0
|
||||
|
||||
trainer.state.stage = RunningStage.TRAINING
|
||||
# make sure training outputs what is expected
|
||||
|
|
|
@ -575,8 +575,8 @@ def test_step_with_optimizer_closure(tmpdir):
|
|||
|
||||
trainer.fit(model)
|
||||
assert trainer.dev_debugger.count_events('backward_call') == limit_train_batches * 2
|
||||
assert trainer.logger_connector.progress_bar_metrics["train_loss_step"] == model._losses[-1]
|
||||
assert trainer.logger_connector.progress_bar_metrics["train_loss_epoch"] == torch.stack(model._losses).mean()
|
||||
assert trainer.progress_bar_metrics["train_loss_step"] == model._losses[-1]
|
||||
assert trainer.progress_bar_metrics["train_loss_epoch"] == torch.stack(model._losses).mean()
|
||||
|
||||
|
||||
@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
|
||||
|
|
|
@ -373,7 +373,7 @@ def test_model_checkpoint_options(tmpdir, save_top_k, save_last, expected_files)
|
|||
for i, loss in enumerate(losses):
|
||||
trainer.fit_loop.current_epoch = i
|
||||
trainer.fit_loop.global_step = i
|
||||
trainer.logger_connector.callback_metrics.update({"checkpoint_on": loss})
|
||||
trainer.callback_metrics.update({"checkpoint_on": loss})
|
||||
checkpoint_callback.on_validation_end(trainer, trainer.lightning_module)
|
||||
|
||||
file_lists = set(os.listdir(tmpdir))
|
||||
|
|
Loading…
Reference in New Issue