ref: fixes logging for eval steps (#3763)
* fixes logging for eval steps
This commit is contained in:
parent
5ec00ccd28
commit
7c61fc7c27
|
@ -54,13 +54,14 @@ class LitClassifier(pl.LightningModule):
|
|||
x, y = batch
|
||||
y_hat = self.backbone(x)
|
||||
loss = F.cross_entropy(y_hat, y)
|
||||
self.log('train_loss', loss, on_epoch=True)
|
||||
return loss
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
y_hat = self.backbone(x)
|
||||
loss = F.cross_entropy(y_hat, y)
|
||||
self.log('valid_loss', loss)
|
||||
self.log('valid_loss', loss, on_step=True)
|
||||
|
||||
def test_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
|
|
|
@ -296,6 +296,9 @@ class ModelCheckpoint(Callback):
|
|||
raise ValueError(".save_function() not set")
|
||||
|
||||
def check_monitor_top_k(self, current) -> bool:
|
||||
if current is None:
|
||||
return False
|
||||
|
||||
if self.save_top_k == -1:
|
||||
return True
|
||||
|
||||
|
@ -421,7 +424,7 @@ class ModelCheckpoint(Callback):
|
|||
if self.monitor is None and 'checkpoint_on' in metrics:
|
||||
self.monitor = 'checkpoint_on'
|
||||
|
||||
if self.save_top_k is None:
|
||||
if self.save_top_k is None and self.monitor is not None:
|
||||
self.save_top_k = 1
|
||||
|
||||
def _validate_monitor_key(self, trainer):
|
||||
|
@ -486,15 +489,7 @@ class ModelCheckpoint(Callback):
|
|||
if not isinstance(current, torch.Tensor) and current is not None:
|
||||
current = torch.tensor(current, device=pl_module.device)
|
||||
|
||||
if current is None:
|
||||
m = f"Can save best model only with {self.monitor} available, skipping."
|
||||
if self.monitor == 'checkpoint_on':
|
||||
m = (
|
||||
'No checkpoint_on found. HINT: Did you set it in '
|
||||
'EvalResult(checkpoint_on=tensor) or TrainResult(checkpoint_on=tensor)?'
|
||||
)
|
||||
rank_zero_warn(m, RuntimeWarning)
|
||||
elif self.check_monitor_top_k(current):
|
||||
if self.check_monitor_top_k(current):
|
||||
self._update_best_and_save(filepath, current, epoch, trainer, pl_module)
|
||||
elif self.verbose:
|
||||
rank_zero_info(
|
||||
|
|
|
@ -157,6 +157,9 @@ class LoggerConnector:
|
|||
# track the final results for the dataloader
|
||||
self.eval_loop_results.append(deepcopy(self.callback_metrics))
|
||||
|
||||
# actually log
|
||||
self.log_metrics(logger_metrics, {}, step=self.trainer.global_step)
|
||||
|
||||
def __rename_keys_by_dataloader_idx(self, metrics, dataloader_idx, num_loaders):
|
||||
if num_loaders == 1:
|
||||
return metrics
|
||||
|
|
Loading…
Reference in New Issue