ref: fixes logging for eval steps (#3763)
* fixes logging for eval steps
This commit is contained in:
parent
5ec00ccd28
commit
7c61fc7c27
|
@ -54,13 +54,14 @@ class LitClassifier(pl.LightningModule):
|
||||||
x, y = batch
|
x, y = batch
|
||||||
y_hat = self.backbone(x)
|
y_hat = self.backbone(x)
|
||||||
loss = F.cross_entropy(y_hat, y)
|
loss = F.cross_entropy(y_hat, y)
|
||||||
|
self.log('train_loss', loss, on_epoch=True)
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
def validation_step(self, batch, batch_idx):
|
def validation_step(self, batch, batch_idx):
|
||||||
x, y = batch
|
x, y = batch
|
||||||
y_hat = self.backbone(x)
|
y_hat = self.backbone(x)
|
||||||
loss = F.cross_entropy(y_hat, y)
|
loss = F.cross_entropy(y_hat, y)
|
||||||
self.log('valid_loss', loss)
|
self.log('valid_loss', loss, on_step=True)
|
||||||
|
|
||||||
def test_step(self, batch, batch_idx):
|
def test_step(self, batch, batch_idx):
|
||||||
x, y = batch
|
x, y = batch
|
||||||
|
|
|
@ -296,6 +296,9 @@ class ModelCheckpoint(Callback):
|
||||||
raise ValueError(".save_function() not set")
|
raise ValueError(".save_function() not set")
|
||||||
|
|
||||||
def check_monitor_top_k(self, current) -> bool:
|
def check_monitor_top_k(self, current) -> bool:
|
||||||
|
if current is None:
|
||||||
|
return False
|
||||||
|
|
||||||
if self.save_top_k == -1:
|
if self.save_top_k == -1:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@ -421,7 +424,7 @@ class ModelCheckpoint(Callback):
|
||||||
if self.monitor is None and 'checkpoint_on' in metrics:
|
if self.monitor is None and 'checkpoint_on' in metrics:
|
||||||
self.monitor = 'checkpoint_on'
|
self.monitor = 'checkpoint_on'
|
||||||
|
|
||||||
if self.save_top_k is None:
|
if self.save_top_k is None and self.monitor is not None:
|
||||||
self.save_top_k = 1
|
self.save_top_k = 1
|
||||||
|
|
||||||
def _validate_monitor_key(self, trainer):
|
def _validate_monitor_key(self, trainer):
|
||||||
|
@ -486,15 +489,7 @@ class ModelCheckpoint(Callback):
|
||||||
if not isinstance(current, torch.Tensor) and current is not None:
|
if not isinstance(current, torch.Tensor) and current is not None:
|
||||||
current = torch.tensor(current, device=pl_module.device)
|
current = torch.tensor(current, device=pl_module.device)
|
||||||
|
|
||||||
if current is None:
|
if self.check_monitor_top_k(current):
|
||||||
m = f"Can save best model only with {self.monitor} available, skipping."
|
|
||||||
if self.monitor == 'checkpoint_on':
|
|
||||||
m = (
|
|
||||||
'No checkpoint_on found. HINT: Did you set it in '
|
|
||||||
'EvalResult(checkpoint_on=tensor) or TrainResult(checkpoint_on=tensor)?'
|
|
||||||
)
|
|
||||||
rank_zero_warn(m, RuntimeWarning)
|
|
||||||
elif self.check_monitor_top_k(current):
|
|
||||||
self._update_best_and_save(filepath, current, epoch, trainer, pl_module)
|
self._update_best_and_save(filepath, current, epoch, trainer, pl_module)
|
||||||
elif self.verbose:
|
elif self.verbose:
|
||||||
rank_zero_info(
|
rank_zero_info(
|
||||||
|
|
|
@ -157,6 +157,9 @@ class LoggerConnector:
|
||||||
# track the final results for the dataloader
|
# track the final results for the dataloader
|
||||||
self.eval_loop_results.append(deepcopy(self.callback_metrics))
|
self.eval_loop_results.append(deepcopy(self.callback_metrics))
|
||||||
|
|
||||||
|
# actually log
|
||||||
|
self.log_metrics(logger_metrics, {}, step=self.trainer.global_step)
|
||||||
|
|
||||||
def __rename_keys_by_dataloader_idx(self, metrics, dataloader_idx, num_loaders):
|
def __rename_keys_by_dataloader_idx(self, metrics, dataloader_idx, num_loaders):
|
||||||
if num_loaders == 1:
|
if num_loaders == 1:
|
||||||
return metrics
|
return metrics
|
||||||
|
|
Loading…
Reference in New Issue