update tests for v2 (#11486)

This commit is contained in:
Jv Kyle Eclarin 2022-01-16 11:13:07 -05:00 committed by GitHub
parent f97359a8c2
commit 5dc8002d46
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 12 additions and 5 deletions

View File

@ -168,7 +168,13 @@ def test_simple_profiler_distributed_files(tmpdir):
profiler = SimpleProfiler(dirpath=tmpdir, filename="profiler")
model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir, fast_dev_run=2, strategy="ddp_spawn", num_processes=2, profiler=profiler, logger=False
default_root_dir=tmpdir,
fast_dev_run=2,
strategy="ddp_spawn",
accelerator="cpu",
devices=2,
profiler=profiler,
logger=False,
)
trainer.fit(model)
trainer.validate(model)
@ -307,7 +313,8 @@ def test_pytorch_profiler_trainer_ddp(tmpdir, pytorch_profiler):
limit_val_batches=5,
profiler=pytorch_profiler,
strategy="ddp",
gpus=2,
accelerator="gpu",
devices=2,
)
trainer.fit(model)
expected = {"[Strategy]DDPStrategy.validation_step"}
@ -429,7 +436,7 @@ def test_pytorch_profiler_nested_emit_nvtx(tmpdir):
profiler = PyTorchProfiler(use_cuda=True, emit_nvtx=True)
model = BoringModel()
trainer = Trainer(fast_dev_run=True, profiler=profiler, gpus=1)
trainer = Trainer(fast_dev_run=True, profiler=profiler, accelerator="gpu", devices=1)
trainer.fit(model)

View File

@ -33,7 +33,7 @@ if _TPU_AVAILABLE:
def test_xla_profiler_instance(tmpdir):
model = BoringModel()
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, profiler="xla", tpu_cores=8)
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, profiler="xla", accelerator="tpu", devices=8)
assert isinstance(trainer.profiler, XLAProfiler)
trainer.fit(model)
@ -48,7 +48,7 @@ def test_xla_profiler_prog_capture(tmpdir):
def train_worker():
model = BoringModel()
trainer = Trainer(default_root_dir=tmpdir, max_epochs=4, profiler="xla", tpu_cores=8)
trainer = Trainer(default_root_dir=tmpdir, max_epochs=4, profiler="xla", accelerator="tpu", devices=8)
trainer.fit(model)