This commit is contained in:
William Falcon 2019-08-07 08:01:33 -04:00
parent 27e88fde31
commit 1e17bf76aa
1 changed files with 3 additions and 3 deletions

View File

@ -42,7 +42,7 @@ def test_cpu_restore_training():
exp.save()
trainer_options = dict(
max_nb_epochs=1,
max_nb_epochs=2,
val_check_interval=0.50,
val_percent_check=0.2,
train_percent_check=0.2,
@ -53,7 +53,7 @@ def test_cpu_restore_training():
# fit model
trainer = Trainer(**trainer_options)
result = trainer.fit(model)
real_global_step = trainer.global_step
real_global_epoch = trainer.current_epoch
# traning complete
assert result == 1, 'amp + ddp model failed to complete'
@ -86,7 +86,7 @@ def test_cpu_restore_training():
# set the epoch start hook so we can predict before the model does the full training
def assert_pred_same():
assert trainer.global_step == real_global_step and trainer.global_step > 0
assert trainer.current_epoch == real_global_epoch and trainer.real_global_epoch > 0
# predict with loaded model to make sure answers are the same
trainer.model.eval()