Update test_pruning.py to use `devices` instead of `gpus` or `ipus` (#11339)
This commit is contained in:
parent
4710a8128b
commit
b56d8677ad
|
@ -44,7 +44,8 @@ def test_device_stats_gpu_from_torch(tmpdir):
|
|||
max_epochs=2,
|
||||
limit_train_batches=7,
|
||||
log_every_n_steps=1,
|
||||
gpus=1,
|
||||
accelerator="gpu",
|
||||
devices=1,
|
||||
callbacks=[device_stats],
|
||||
logger=DebugLogger(tmpdir),
|
||||
enable_checkpointing=False,
|
||||
|
@ -73,7 +74,8 @@ def test_device_stats_gpu_from_nvidia(tmpdir):
|
|||
max_epochs=2,
|
||||
limit_train_batches=7,
|
||||
log_every_n_steps=1,
|
||||
gpus=1,
|
||||
accelerator="gpu",
|
||||
devices=1,
|
||||
callbacks=[device_stats],
|
||||
logger=DebugLogger(tmpdir),
|
||||
enable_checkpointing=False,
|
||||
|
@ -101,7 +103,8 @@ def test_device_stats_monitor_tpu(tmpdir):
|
|||
default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
limit_train_batches=1,
|
||||
tpu_cores=8,
|
||||
accelerator="tpu",
|
||||
devices=8,
|
||||
log_every_n_steps=1,
|
||||
callbacks=[device_stats],
|
||||
logger=DebugLogger(tmpdir),
|
||||
|
|
Loading…
Reference in New Issue