ref: fixes logging for eval steps (#3763)

* fixes logging for eval steps
This commit is contained in:
William Falcon 2020-10-01 02:31:11 -04:00 committed by GitHub
parent 5ec00ccd28
commit 7c61fc7c27
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 10 additions and 11 deletions

View File

@ -54,13 +54,14 @@ class LitClassifier(pl.LightningModule):
x, y = batch
y_hat = self.backbone(x)
loss = F.cross_entropy(y_hat, y)
self.log('train_loss', loss, on_epoch=True)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self.backbone(x)
loss = F.cross_entropy(y_hat, y)
self.log('valid_loss', loss)
self.log('valid_loss', loss, on_step=True)
def test_step(self, batch, batch_idx):
x, y = batch

View File

@ -296,6 +296,9 @@ class ModelCheckpoint(Callback):
raise ValueError(".save_function() not set")
def check_monitor_top_k(self, current) -> bool:
if current is None:
return False
if self.save_top_k == -1:
return True
@ -421,7 +424,7 @@ class ModelCheckpoint(Callback):
if self.monitor is None and 'checkpoint_on' in metrics:
self.monitor = 'checkpoint_on'
if self.save_top_k is None:
if self.save_top_k is None and self.monitor is not None:
self.save_top_k = 1
def _validate_monitor_key(self, trainer):
@ -486,15 +489,7 @@ class ModelCheckpoint(Callback):
if not isinstance(current, torch.Tensor) and current is not None:
current = torch.tensor(current, device=pl_module.device)
if current is None:
m = f"Can save best model only with {self.monitor} available, skipping."
if self.monitor == 'checkpoint_on':
m = (
'No checkpoint_on found. HINT: Did you set it in '
'EvalResult(checkpoint_on=tensor) or TrainResult(checkpoint_on=tensor)?'
)
rank_zero_warn(m, RuntimeWarning)
elif self.check_monitor_top_k(current):
if self.check_monitor_top_k(current):
self._update_best_and_save(filepath, current, epoch, trainer, pl_module)
elif self.verbose:
rank_zero_info(

View File

@ -157,6 +157,9 @@ class LoggerConnector:
# track the final results for the dataloader
self.eval_loop_results.append(deepcopy(self.callback_metrics))
# actually log
self.log_metrics(logger_metrics, {}, step=self.trainer.global_step)
def __rename_keys_by_dataloader_idx(self, metrics, dataloader_idx, num_loaders):
if num_loaders == 1:
return metrics