update tests for v2 (#11487)
This commit is contained in:
parent
5dc8002d46
commit
3dee2759ee
|
@ -82,7 +82,8 @@ def test_callback_batch_on_device(tmpdir):
|
|||
limit_val_batches=1,
|
||||
limit_test_batches=1,
|
||||
limit_predict_batches=1,
|
||||
gpus=1,
|
||||
accelerator="gpu",
|
||||
devices=1,
|
||||
callbacks=[batch_callback],
|
||||
)
|
||||
trainer.fit(model)
|
||||
|
|
|
@ -100,7 +100,12 @@ def test_memory_consumption_validation(tmpdir):
|
|||
|
||||
torch.cuda.empty_cache()
|
||||
trainer = Trainer(
|
||||
gpus=1, default_root_dir=tmpdir, fast_dev_run=2, move_metrics_to_cpu=True, enable_model_summary=False
|
||||
accelerator="gpu",
|
||||
devices=1,
|
||||
default_root_dir=tmpdir,
|
||||
fast_dev_run=2,
|
||||
move_metrics_to_cpu=True,
|
||||
enable_model_summary=False,
|
||||
)
|
||||
trainer.fit(BoringLargeBatchModel())
|
||||
|
||||
|
|
Loading…
Reference in New Issue