debug
This commit is contained in:
parent
27e88fde31
commit
1e17bf76aa
|
@ -42,7 +42,7 @@ def test_cpu_restore_training():
|
|||
exp.save()
|
||||
|
||||
trainer_options = dict(
|
||||
max_nb_epochs=1,
|
||||
max_nb_epochs=2,
|
||||
val_check_interval=0.50,
|
||||
val_percent_check=0.2,
|
||||
train_percent_check=0.2,
|
||||
|
@ -53,7 +53,7 @@ def test_cpu_restore_training():
|
|||
# fit model
|
||||
trainer = Trainer(**trainer_options)
|
||||
result = trainer.fit(model)
|
||||
real_global_step = trainer.global_step
|
||||
real_global_epoch = trainer.current_epoch
|
||||
|
||||
# traning complete
|
||||
assert result == 1, 'amp + ddp model failed to complete'
|
||||
|
@ -86,7 +86,7 @@ def test_cpu_restore_training():
|
|||
|
||||
# set the epoch start hook so we can predict before the model does the full training
|
||||
def assert_pred_same():
|
||||
assert trainer.global_step == real_global_step and trainer.global_step > 0
|
||||
assert trainer.current_epoch == real_global_epoch and trainer.real_global_epoch > 0
|
||||
|
||||
# predict with loaded model to make sure answers are the same
|
||||
trainer.model.eval()
|
||||
|
|
Loading…
Reference in New Issue