Update test_pruning.py to use `devices` instead of `gpus` or `ipus` (#11341)
This commit is contained in:
parent
b56d8677ad
commit
3e0569fccc
|
@ -64,7 +64,8 @@ def train_with_pruning_callback(
|
|||
pruning_fn="l1_unstructured",
|
||||
use_lottery_ticket_hypothesis=False,
|
||||
strategy=None,
|
||||
gpus=None,
|
||||
accelerator="cpu",
|
||||
devices=1,
|
||||
num_processes=1,
|
||||
):
|
||||
model = TestModel()
|
||||
|
@ -113,7 +114,8 @@ def train_with_pruning_callback(
|
|||
limit_val_batches=2,
|
||||
max_epochs=10,
|
||||
strategy=strategy,
|
||||
gpus=gpus,
|
||||
accelerator=accelerator,
|
||||
devices=devices,
|
||||
num_processes=num_processes,
|
||||
callbacks=pruning,
|
||||
)
|
||||
|
@ -169,13 +171,16 @@ def test_pruning_callback_ddp(tmpdir, parameters_to_prune, use_global_unstructur
|
|||
parameters_to_prune=parameters_to_prune,
|
||||
use_global_unstructured=use_global_unstructured,
|
||||
strategy="ddp",
|
||||
gpus=2,
|
||||
accelerator="gpu",
|
||||
devices=2,
|
||||
)
|
||||
|
||||
|
||||
@RunIf(min_gpus=2, skip_windows=True)
|
||||
def test_pruning_callback_ddp_spawn(tmpdir):
|
||||
train_with_pruning_callback(tmpdir, use_global_unstructured=True, strategy="ddp_spawn", gpus=2)
|
||||
train_with_pruning_callback(
|
||||
tmpdir, use_global_unstructured=True, strategy="ddp_spawn", accelerator="gpu", devices=2
|
||||
)
|
||||
|
||||
|
||||
@RunIf(skip_windows=True, skip_49370=True)
|
||||
|
|
Loading…
Reference in New Issue