Always run validation inside the training loop epoch (#7357)
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
parent
037a71b156
commit
311d9fe67e
|
@ -62,6 +62,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
||||||
|
|
||||||
- Changed `clip_grad_norm` to use `torch.nn.utils.clip_grad_norm_` ([#7025](https://github.com/PyTorchLightning/pytorch-lightning/pull/7025))
|
- Changed `clip_grad_norm` to use `torch.nn.utils.clip_grad_norm_` ([#7025](https://github.com/PyTorchLightning/pytorch-lightning/pull/7025))
|
||||||
|
|
||||||
|
|
||||||
|
- Validation is now always run inside the training epoch scope ([#7357](https://github.com/PyTorchLightning/pytorch-lightning/pull/7357))
|
||||||
|
|
||||||
|
|
||||||
- Refactored Loops
|
- Refactored Loops
|
||||||
* Moved attributes `global_step`, `current_epoch`, `max/min_steps`, `max/min_epochs`, `batch_idx`, and `total_batch_idx` to TrainLoop ([#7437](https://github.com/PyTorchLightning/pytorch-lightning/pull/7025))
|
* Moved attributes `global_step`, `current_epoch`, `max/min_steps`, `max/min_epochs`, `batch_idx`, and `total_batch_idx` to TrainLoop ([#7437](https://github.com/PyTorchLightning/pytorch-lightning/pull/7025))
|
||||||
* Refactored result handling in training loop ([#7506](https://github.com/PyTorchLightning/pytorch-lightning/pull/7506))
|
* Refactored result handling in training loop ([#7506](https://github.com/PyTorchLightning/pytorch-lightning/pull/7506))
|
||||||
|
|
|
@ -928,7 +928,7 @@ class Trainer(
|
||||||
self.state.stage = None
|
self.state.stage = None
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def _run_evaluation(self, on_epoch: bool = False) -> _EVALUATE_OUTPUT:
|
def _run_evaluation(self) -> _EVALUATE_OUTPUT:
|
||||||
if not (self.evaluating or self.sanity_checking):
|
if not (self.evaluating or self.sanity_checking):
|
||||||
rank_zero_warn(
|
rank_zero_warn(
|
||||||
f"`trainer._run_evaluation()` was called but the running stage is set to {self.state.stage}."
|
f"`trainer._run_evaluation()` was called but the running stage is set to {self.state.stage}."
|
||||||
|
@ -1010,17 +1010,6 @@ class Trainer(
|
||||||
# hook
|
# hook
|
||||||
self.evaluation_loop.on_evaluation_epoch_end()
|
self.evaluation_loop.on_evaluation_epoch_end()
|
||||||
|
|
||||||
# update epoch-level lr_schedulers
|
|
||||||
if on_epoch:
|
|
||||||
self.optimizer_connector.update_learning_rates(
|
|
||||||
interval='epoch',
|
|
||||||
opt_indices=[
|
|
||||||
opt_idx for opt_idx, _ in self.train_loop.get_active_optimizers(
|
|
||||||
batch_idx=(self.train_loop.total_batch_idx - 1)
|
|
||||||
) # Select the optimizers which were used in the last batch of the epoch
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
# log epoch metrics
|
# log epoch metrics
|
||||||
eval_loop_results = self.logger_connector.get_evaluate_epoch_results()
|
eval_loop_results = self.logger_connector.get_evaluate_epoch_results()
|
||||||
|
|
||||||
|
|
|
@ -479,7 +479,6 @@ class TrainLoop:
|
||||||
train_dataloader = self.trainer.data_connector.get_profiled_train_dataloader(train_dataloader)
|
train_dataloader = self.trainer.data_connector.get_profiled_train_dataloader(train_dataloader)
|
||||||
dataloader_idx = 0
|
dataloader_idx = 0
|
||||||
batch_idx = None
|
batch_idx = None
|
||||||
is_last_batch = None
|
|
||||||
|
|
||||||
for batch_idx, (batch, is_last_batch) in train_dataloader:
|
for batch_idx, (batch, is_last_batch) in train_dataloader:
|
||||||
self.batch_idx = batch_idx
|
self.batch_idx = batch_idx
|
||||||
|
@ -529,16 +528,13 @@ class TrainLoop:
|
||||||
|
|
||||||
self.total_batch_idx += 1
|
self.total_batch_idx += 1
|
||||||
|
|
||||||
max_steps_reached = (
|
|
||||||
self.max_steps is not None and self.max_steps <= self.global_step + 1
|
|
||||||
and self._accumulated_batches_reached()
|
|
||||||
)
|
|
||||||
if max_steps_reached or self.trainer.should_stop or self._num_training_batches_reached(is_last_batch):
|
|
||||||
break
|
|
||||||
|
|
||||||
# progress global step according to grads progress
|
# progress global step according to grads progress
|
||||||
self.increment_accumulated_grad_global_step()
|
self.increment_accumulated_grad_global_step()
|
||||||
|
|
||||||
|
max_steps_reached = (self.max_steps is not None and self.max_steps <= self.global_step)
|
||||||
|
if max_steps_reached or self.trainer.should_stop or self._num_training_batches_reached(is_last_batch):
|
||||||
|
break
|
||||||
|
|
||||||
if batch_idx is None:
|
if batch_idx is None:
|
||||||
# dataloader/iterator did not produce a batch
|
# dataloader/iterator did not produce a batch
|
||||||
return
|
return
|
||||||
|
@ -546,27 +542,24 @@ class TrainLoop:
|
||||||
# handle epoch_output on epoch end
|
# handle epoch_output on epoch end
|
||||||
self.on_train_epoch_end(epoch_output)
|
self.on_train_epoch_end(epoch_output)
|
||||||
|
|
||||||
|
# the global step is manually decreased here due to backwards compatibility with existing loggers
|
||||||
|
# as they expect that the same step is used when logging epoch end metrics even when the batch loop has
|
||||||
|
# finished. this means the attribute does not exactly track the number of optimizer steps applied.
|
||||||
|
# TODO(@carmocca): deprecate and rename so users don't get confused
|
||||||
|
self.global_step -= 1
|
||||||
# log epoch metrics
|
# log epoch metrics
|
||||||
self.trainer.logger_connector.log_train_epoch_end_metrics(epoch_output)
|
self.trainer.logger_connector.log_train_epoch_end_metrics(epoch_output)
|
||||||
|
self.global_step += 1
|
||||||
|
|
||||||
should_check_val = self._should_check_val_fx(batch_idx, is_last_batch, on_epoch=True)
|
|
||||||
should_skip_eval = self.trainer.evaluation_loop.should_skip_evaluation(self.trainer.num_val_batches)
|
|
||||||
should_train_only = self.trainer.disable_validation or should_skip_eval
|
|
||||||
|
|
||||||
# update epoch level lr_schedulers if no val loop outside train loop is triggered
|
|
||||||
if not should_check_val or should_train_only:
|
|
||||||
self.update_lr_schedulers('epoch')
|
self.update_lr_schedulers('epoch')
|
||||||
|
|
||||||
if should_train_only:
|
did_train_only = self.trainer.disable_validation or self.trainer.evaluation_loop.should_skip_evaluation(
|
||||||
|
self.trainer.num_val_batches
|
||||||
|
)
|
||||||
|
if did_train_only:
|
||||||
|
self.global_step -= 1
|
||||||
self.check_checkpoint_callback(True)
|
self.check_checkpoint_callback(True)
|
||||||
|
self.global_step += 1
|
||||||
if should_check_val:
|
|
||||||
self.trainer.validating = True
|
|
||||||
self.trainer._run_evaluation(on_epoch=True)
|
|
||||||
self.trainer.training = True
|
|
||||||
|
|
||||||
if batch_output.signal != -1:
|
|
||||||
self.increment_accumulated_grad_global_step()
|
|
||||||
|
|
||||||
def on_train_epoch_end(self, epoch_output: List[List[List[Result]]]) -> None:
|
def on_train_epoch_end(self, epoch_output: List[List[List[Result]]]) -> None:
|
||||||
# inform logger the batch loop has finished
|
# inform logger the batch loop has finished
|
||||||
|
@ -882,7 +875,7 @@ class TrainLoop:
|
||||||
is_final_batch = self._num_training_batches_reached()
|
is_final_batch = self._num_training_batches_reached()
|
||||||
return not (accumulation_done or is_final_batch)
|
return not (accumulation_done or is_final_batch)
|
||||||
|
|
||||||
def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool, on_epoch: bool = False) -> bool:
|
def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool) -> bool:
|
||||||
""" Decide if we should run validation. """
|
""" Decide if we should run validation. """
|
||||||
if not self.trainer.enable_validation:
|
if not self.trainer.enable_validation:
|
||||||
return False
|
return False
|
||||||
|
@ -893,26 +886,19 @@ class TrainLoop:
|
||||||
|
|
||||||
# val_check_batch is inf for iterable datasets with no length defined
|
# val_check_batch is inf for iterable datasets with no length defined
|
||||||
is_infinite_dataset = self.trainer.val_check_batch == float('inf')
|
is_infinite_dataset = self.trainer.val_check_batch == float('inf')
|
||||||
if on_epoch and is_last_batch and is_infinite_dataset:
|
if is_last_batch and is_infinite_dataset:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if self.trainer.should_stop:
|
if self.trainer.should_stop:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# TODO: let training/eval loop handle logic around limit_*_batches and val_check_batch
|
# TODO: let training/eval loop handle logic around limit_*_batches and val_check_batch
|
||||||
is_val_check_batch = False
|
is_val_check_batch = is_last_batch
|
||||||
if isinstance(self.trainer.limit_train_batches, int) and self.trainer.val_check_batch == float('inf'):
|
if isinstance(self.trainer.limit_train_batches, int) and is_infinite_dataset:
|
||||||
is_val_check_batch = (batch_idx + 1) % self.trainer.limit_train_batches == 0
|
is_val_check_batch = (batch_idx + 1) % self.trainer.limit_train_batches == 0
|
||||||
elif self.trainer.val_check_batch != float('inf'):
|
elif self.trainer.val_check_batch != float('inf'):
|
||||||
is_val_check_batch = (batch_idx + 1) % self.trainer.val_check_batch == 0
|
is_val_check_batch = (batch_idx + 1) % self.trainer.val_check_batch == 0
|
||||||
|
return is_val_check_batch
|
||||||
# Note: num_training_batches is also inf for iterable datasets with no length defined
|
|
||||||
epoch_end_val_check = (batch_idx + 1) % self.trainer.num_training_batches == 0
|
|
||||||
|
|
||||||
if on_epoch:
|
|
||||||
return is_val_check_batch and epoch_end_val_check
|
|
||||||
else:
|
|
||||||
return is_val_check_batch and not epoch_end_val_check
|
|
||||||
|
|
||||||
def _build_kwargs(self, batch, batch_idx, opt_idx, hiddens):
|
def _build_kwargs(self, batch, batch_idx, opt_idx, hiddens):
|
||||||
# enable not needing to add opt_idx to training_step
|
# enable not needing to add opt_idx to training_step
|
||||||
|
|
|
@ -83,8 +83,6 @@ def test_trainer_callback_hook_system_fit(_, tmpdir):
|
||||||
call.on_after_backward(trainer, model),
|
call.on_after_backward(trainer, model),
|
||||||
call.on_train_batch_end(trainer, model, ANY, ANY, 2, 0),
|
call.on_train_batch_end(trainer, model, ANY, ANY, 2, 0),
|
||||||
call.on_batch_end(trainer, model),
|
call.on_batch_end(trainer, model),
|
||||||
call.on_train_epoch_end(trainer, model, ANY),
|
|
||||||
call.on_epoch_end(trainer, model),
|
|
||||||
call.on_validation_start(trainer, model),
|
call.on_validation_start(trainer, model),
|
||||||
call.on_epoch_start(trainer, model),
|
call.on_epoch_start(trainer, model),
|
||||||
call.on_validation_epoch_start(trainer, model),
|
call.on_validation_epoch_start(trainer, model),
|
||||||
|
@ -94,6 +92,8 @@ def test_trainer_callback_hook_system_fit(_, tmpdir):
|
||||||
call.on_epoch_end(trainer, model),
|
call.on_epoch_end(trainer, model),
|
||||||
call.on_validation_end(trainer, model),
|
call.on_validation_end(trainer, model),
|
||||||
call.on_save_checkpoint(trainer, model), # should take ANY but we are inspecting signature for BC
|
call.on_save_checkpoint(trainer, model), # should take ANY but we are inspecting signature for BC
|
||||||
|
call.on_train_epoch_end(trainer, model, ANY),
|
||||||
|
call.on_epoch_end(trainer, model),
|
||||||
call.on_train_end(trainer, model),
|
call.on_train_end(trainer, model),
|
||||||
call.on_fit_end(trainer, model),
|
call.on_fit_end(trainer, model),
|
||||||
call.teardown(trainer, model, 'fit'),
|
call.teardown(trainer, model, 'fit'),
|
||||||
|
|
|
@ -169,7 +169,7 @@ def test_early_stopping_patience_train(
|
||||||
model.validation_step = None
|
model.validation_step = None
|
||||||
|
|
||||||
early_stop_callback = EarlyStopping(
|
early_stop_callback = EarlyStopping(
|
||||||
monitor="train_loss", patience=patience, verbose=True, check_on_train_epoch_end=validation_step_none
|
monitor="train_loss", patience=patience, verbose=True, check_on_train_epoch_end=True
|
||||||
)
|
)
|
||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
default_root_dir=tmpdir,
|
default_root_dir=tmpdir,
|
||||||
|
|
|
@ -60,9 +60,8 @@ class LogInTwoMethods(BoringModel):
|
||||||
"validation_step_none,val_dataloaders_none,monitor",
|
"validation_step_none,val_dataloaders_none,monitor",
|
||||||
[
|
[
|
||||||
(False, False, 'val_log'),
|
(False, False, 'val_log'),
|
||||||
(False, False, 'train_log_epoch'),
|
|
||||||
(True, False, 'train_log_epoch'),
|
(True, False, 'train_log_epoch'),
|
||||||
(False, True, 'train_log_epoch'),
|
(False, True, 'val_log'),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@pytest.mark.parametrize('reduce_lr_on_plateau', [False, True])
|
@pytest.mark.parametrize('reduce_lr_on_plateau', [False, True])
|
||||||
|
@ -76,7 +75,7 @@ def test_model_checkpoint_score_and_ckpt(
|
||||||
max_epochs = 3
|
max_epochs = 3
|
||||||
limit_train_batches = 5
|
limit_train_batches = 5
|
||||||
limit_val_batches = 7
|
limit_val_batches = 7
|
||||||
lr = 1e-1
|
lr, gamma = 1e-1, 2
|
||||||
|
|
||||||
class CustomBoringModel(BoringModel):
|
class CustomBoringModel(BoringModel):
|
||||||
|
|
||||||
|
@ -106,7 +105,7 @@ def test_model_checkpoint_score_and_ckpt(
|
||||||
'strict': True,
|
'strict': True,
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1)
|
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=gamma)
|
||||||
|
|
||||||
return [optimizer], [lr_scheduler]
|
return [optimizer], [lr_scheduler]
|
||||||
|
|
||||||
|
@ -153,9 +152,12 @@ def test_model_checkpoint_score_and_ckpt(
|
||||||
assert mc_specific_data['current_score'] == score
|
assert mc_specific_data['current_score'] == score
|
||||||
|
|
||||||
if not reduce_lr_on_plateau:
|
if not reduce_lr_on_plateau:
|
||||||
lr_scheduler_specific_data = chk['lr_schedulers'][0]
|
actual_step_count = chk['lr_schedulers'][0]['_step_count']
|
||||||
assert lr_scheduler_specific_data['_step_count'] == epoch + 2
|
actual_lr = chk['lr_schedulers'][0]['_last_lr'][0]
|
||||||
assert lr_scheduler_specific_data['_last_lr'][0] == lr * (lr**(epoch + 1))
|
# if validation_step_none, the checkpoint gets saved after the learning rate update
|
||||||
|
# so we need to increase the count by one
|
||||||
|
assert actual_step_count == epoch + 1 + validation_step_none
|
||||||
|
assert actual_lr == lr * gamma**(epoch + validation_step_none)
|
||||||
|
|
||||||
assert lr_scheduler_debug[epoch]['monitor_val'] == (score if reduce_lr_on_plateau else None)
|
assert lr_scheduler_debug[epoch]['monitor_val'] == (score if reduce_lr_on_plateau else None)
|
||||||
assert lr_scheduler_debug[epoch]['monitor_key'] == (monitor if reduce_lr_on_plateau else None)
|
assert lr_scheduler_debug[epoch]['monitor_key'] == (monitor if reduce_lr_on_plateau else None)
|
||||||
|
@ -180,23 +182,21 @@ def test_model_checkpoint_score_and_ckpt_val_check_interval(
|
||||||
max_epochs = 3
|
max_epochs = 3
|
||||||
limit_train_batches = 12
|
limit_train_batches = 12
|
||||||
limit_val_batches = 7
|
limit_val_batches = 7
|
||||||
lr = 1e-1
|
lr, gamma = 1e-1, 2
|
||||||
monitor = 'val_log'
|
monitor = 'val_log'
|
||||||
per_epoch_steps = int(limit_train_batches * val_check_interval)
|
per_val_train_batches = int(limit_train_batches * val_check_interval)
|
||||||
per_epoch_call_count = limit_train_batches // per_epoch_steps
|
per_epoch_val_checks, leftover_train_batches = divmod(limit_train_batches, per_val_train_batches)
|
||||||
left_over_steps = limit_train_batches % per_epoch_steps
|
|
||||||
|
|
||||||
class CustomBoringModel(BoringModel):
|
class CustomBoringModel(BoringModel):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.val_logs = torch.randn(per_epoch_call_count * max_epochs, limit_val_batches)
|
self.val_logs = torch.randn(per_epoch_val_checks * max_epochs, limit_val_batches)
|
||||||
self.val_loop_count = 0
|
self.val_loop_count = 0
|
||||||
|
|
||||||
def validation_step(self, batch, batch_idx):
|
def validation_step(self, batch, batch_idx):
|
||||||
log_value = self.val_logs[self.val_loop_count, batch_idx]
|
log_value = self.val_logs[self.val_loop_count, batch_idx]
|
||||||
self.log('val_log', log_value)
|
self.log('val_log', log_value)
|
||||||
self.log('epoch', self.current_epoch, on_epoch=True)
|
|
||||||
return super().validation_step(batch, batch_idx)
|
return super().validation_step(batch, batch_idx)
|
||||||
|
|
||||||
def validation_epoch_end(self, outputs):
|
def validation_epoch_end(self, outputs):
|
||||||
|
@ -213,7 +213,7 @@ def test_model_checkpoint_score_and_ckpt_val_check_interval(
|
||||||
'strict': True,
|
'strict': True,
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1)
|
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=gamma)
|
||||||
|
|
||||||
return [optimizer], [lr_scheduler]
|
return [optimizer], [lr_scheduler]
|
||||||
|
|
||||||
|
@ -241,26 +241,27 @@ def test_model_checkpoint_score_and_ckpt_val_check_interval(
|
||||||
|
|
||||||
# on_train_end ckpt callback is called which creates an additional ckpt in case no ckpt is created at the
|
# on_train_end ckpt callback is called which creates an additional ckpt in case no ckpt is created at the
|
||||||
# end of epoch, thus if val_check_interval doesn't align with the training steps we create an additional ckpt
|
# end of epoch, thus if val_check_interval doesn't align with the training steps we create an additional ckpt
|
||||||
additional_ckpt, additional_ckpt_path = 0, None
|
additional_ckpt, additional_ckpt_path = False, None
|
||||||
if not epoch_aligned:
|
if not epoch_aligned:
|
||||||
additional_ckpt_path = [f for f in ckpt_files if 'v1' in f.stem][0]
|
additional_ckpt_path = [f for f in ckpt_files if 'v1' in f.stem][0]
|
||||||
additional_ckpt = 1
|
additional_ckpt = True
|
||||||
|
|
||||||
additional_ckpt = 1 if not epoch_aligned else 0
|
assert len(ckpt_files) == len(scores) + additional_ckpt == per_epoch_val_checks * max_epochs + additional_ckpt
|
||||||
assert len(ckpt_files) == len(scores) + additional_ckpt == per_epoch_call_count * max_epochs + additional_ckpt
|
|
||||||
assert len(lr_scheduler_debug) == max_epochs
|
assert len(lr_scheduler_debug) == max_epochs
|
||||||
|
|
||||||
def _make_assertions(epoch, ix, add=''):
|
def _make_assertions(epoch, ix, version=''):
|
||||||
global_ix = ix + per_epoch_call_count * epoch
|
global_ix = ix + per_epoch_val_checks * epoch
|
||||||
|
duplicated = bool(version)
|
||||||
|
|
||||||
score = scores[global_ix]
|
score = scores[global_ix]
|
||||||
expected_score = getattr(model, f'{monitor}s')[global_ix].mean().item()
|
expected_score = getattr(model, f'{monitor}s')[global_ix].mean().item()
|
||||||
expected_filename = f'{monitor}={score:.4f}-epoch={epoch}{add}.ckpt'
|
expected_filename = f'{monitor}={score:.4f}-epoch={epoch}{version}.ckpt'
|
||||||
assert math.isclose(score, expected_score, rel_tol=1e-4)
|
assert math.isclose(score, expected_score, rel_tol=1e-4)
|
||||||
|
|
||||||
chk = pl_load(os.path.join(checkpoint.dirpath, expected_filename))
|
chk = pl_load(os.path.join(checkpoint.dirpath, expected_filename))
|
||||||
assert chk['epoch'] == epoch + 1
|
assert chk['epoch'] == epoch + 1
|
||||||
epoch_num = epoch + (1 if add else 0)
|
epoch_num = epoch + duplicated
|
||||||
expected_global_step = per_epoch_steps * (global_ix + 1) + (left_over_steps * epoch_num)
|
expected_global_step = per_val_train_batches * (global_ix + 1) + (leftover_train_batches * epoch_num)
|
||||||
assert chk['global_step'] == expected_global_step
|
assert chk['global_step'] == expected_global_step
|
||||||
|
|
||||||
mc_specific_data = chk['callbacks'][type(checkpoint)]
|
mc_specific_data = chk['callbacks'][type(checkpoint)]
|
||||||
|
@ -269,15 +270,15 @@ def test_model_checkpoint_score_and_ckpt_val_check_interval(
|
||||||
assert mc_specific_data['current_score'] == score
|
assert mc_specific_data['current_score'] == score
|
||||||
|
|
||||||
if not reduce_lr_on_plateau:
|
if not reduce_lr_on_plateau:
|
||||||
lr_scheduler_specific_data = chk['lr_schedulers'][0]
|
actual_step_count = chk['lr_schedulers'][0]['_step_count']
|
||||||
did_update = 1 if (ix + 1 == per_epoch_call_count) and (epoch_aligned or add) else 0
|
actual_lr = chk['lr_schedulers'][0]['_last_lr'][0]
|
||||||
assert lr_scheduler_specific_data['_step_count'] == epoch + 1 + did_update
|
assert actual_step_count == epoch + 1 + duplicated
|
||||||
assert lr_scheduler_specific_data['_last_lr'][0] == lr * (lr**(epoch + did_update))
|
assert actual_lr == lr * gamma**(epoch + duplicated)
|
||||||
|
|
||||||
return score
|
return score
|
||||||
|
|
||||||
for epoch in range(max_epochs):
|
for epoch in range(max_epochs):
|
||||||
for i in range(per_epoch_call_count):
|
for i in range(per_epoch_val_checks):
|
||||||
score = _make_assertions(epoch, i)
|
score = _make_assertions(epoch, i)
|
||||||
|
|
||||||
assert lr_scheduler_debug[epoch]['monitor_val'] == (score if reduce_lr_on_plateau else None)
|
assert lr_scheduler_debug[epoch]['monitor_val'] == (score if reduce_lr_on_plateau else None)
|
||||||
|
@ -285,9 +286,7 @@ def test_model_checkpoint_score_and_ckpt_val_check_interval(
|
||||||
|
|
||||||
# check the ckpt file saved on_train_end
|
# check the ckpt file saved on_train_end
|
||||||
if additional_ckpt_path:
|
if additional_ckpt_path:
|
||||||
epoch = max_epochs - 1
|
_make_assertions(max_epochs - 1, per_epoch_val_checks - 1, version='-v1')
|
||||||
i = per_epoch_call_count - 1
|
|
||||||
_make_assertions(epoch, i, add='-v1')
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("save_top_k", [-1, 0, 1, 2])
|
@pytest.mark.parametrize("save_top_k", [-1, 0, 1, 2])
|
||||||
|
|
|
@ -469,9 +469,6 @@ def test_trainer_model_hook_system_fit(tmpdir):
|
||||||
'on_epoch_start',
|
'on_epoch_start',
|
||||||
'on_train_epoch_start',
|
'on_train_epoch_start',
|
||||||
*(model.train_batch * train_batches),
|
*(model.train_batch * train_batches),
|
||||||
'training_epoch_end',
|
|
||||||
'on_train_epoch_end',
|
|
||||||
'on_epoch_end',
|
|
||||||
'on_validation_model_eval',
|
'on_validation_model_eval',
|
||||||
'on_validation_start',
|
'on_validation_start',
|
||||||
'on_epoch_start',
|
'on_epoch_start',
|
||||||
|
@ -483,6 +480,9 @@ def test_trainer_model_hook_system_fit(tmpdir):
|
||||||
'on_save_checkpoint',
|
'on_save_checkpoint',
|
||||||
'on_validation_end',
|
'on_validation_end',
|
||||||
'on_validation_model_train',
|
'on_validation_model_train',
|
||||||
|
'training_epoch_end',
|
||||||
|
'on_train_epoch_end',
|
||||||
|
'on_epoch_end',
|
||||||
'on_train_end',
|
'on_train_end',
|
||||||
'on_fit_end',
|
'on_fit_end',
|
||||||
'teardown_fit',
|
'teardown_fit',
|
||||||
|
|
|
@ -141,4 +141,4 @@ def test_should_stop_mid_epoch(tmpdir):
|
||||||
|
|
||||||
assert trainer.current_epoch == 0
|
assert trainer.current_epoch == 0
|
||||||
assert trainer.global_step == 5
|
assert trainer.global_step == 5
|
||||||
assert model.validation_called_at == (0, 4) # TODO(@carmocca): should be 5 - will be fixed in next PR
|
assert model.validation_called_at == (0, 4)
|
||||||
|
|
Loading…
Reference in New Issue