diff --git a/tests/test_cpu_models.py b/tests/test_cpu_models.py index 62f3040441..1b7a2573d7 100644 --- a/tests/test_cpu_models.py +++ b/tests/test_cpu_models.py @@ -299,6 +299,8 @@ def test_tbptt_cpu_model(): """ testing_utils.reset_seed() + save_dir = testing_utils.init_save_dir() + truncated_bptt_steps = 2 sequence_size = 30 batch_size = 30 @@ -366,6 +368,8 @@ def test_tbptt_cpu_model(): assert result == 1, 'training failed to complete' + testing_utils.clear_save_dir() + def test_single_gpu_model(): """