added clean slurm save load test
This commit is contained in:
parent
f5a01edfb8
commit
8e3a0443c7
|
@ -182,6 +182,7 @@ def test_cpu_slurm_save_load():
|
||||||
assert trainer.global_step == real_global_step and trainer.global_step != 20000000
|
assert trainer.global_step == real_global_step and trainer.global_step != 20000000
|
||||||
|
|
||||||
# predict with loaded model to make sure answers are the same
|
# predict with loaded model to make sure answers are the same
|
||||||
|
trainer.model.eval()
|
||||||
new_pred = trainer.model(x)
|
new_pred = trainer.model(x)
|
||||||
assert torch.all(torch.eq(pred_before_saving, new_pred)).item() == 1
|
assert torch.all(torch.eq(pred_before_saving, new_pred)).item() == 1
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue