diff --git a/tests/tests_pytorch/loggers/test_wandb.py b/tests/tests_pytorch/loggers/test_wandb.py index 3ce1401efa..69ce210108 100644 --- a/tests/tests_pytorch/loggers/test_wandb.py +++ b/tests/tests_pytorch/loggers/test_wandb.py @@ -198,7 +198,7 @@ def test_wandb_log_model(wandb, monkeypatch, tmpdir): model = BoringModel() # test log_model=True - logger = WandbLogger(log_model=True) + logger = WandbLogger(save_dir=tmpdir, log_model=True) logger.experiment.id = "1" logger.experiment.name = "run_name" trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=2, limit_train_batches=3, limit_val_batches=3) @@ -208,7 +208,7 @@ def test_wandb_log_model(wandb, monkeypatch, tmpdir): # test log_model='all' wandb.init().log_artifact.reset_mock() wandb.init.reset_mock() - logger = WandbLogger(log_model="all") + logger = WandbLogger(save_dir=tmpdir, log_model="all") logger.experiment.id = "1" logger.experiment.name = "run_name" trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=2, limit_train_batches=3, limit_val_batches=3) @@ -218,7 +218,7 @@ def test_wandb_log_model(wandb, monkeypatch, tmpdir): # test log_model=False wandb.init().log_artifact.reset_mock() wandb.init.reset_mock() - logger = WandbLogger(log_model=False) + logger = WandbLogger(save_dir=tmpdir, log_model=False) logger.experiment.id = "1" logger.experiment.name = "run_name" trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=2, limit_train_batches=3, limit_val_batches=3) @@ -229,7 +229,7 @@ def test_wandb_log_model(wandb, monkeypatch, tmpdir): wandb.init().log_artifact.reset_mock() wandb.init.reset_mock() wandb.Artifact.reset_mock() - logger = WandbLogger(log_model=True) + logger = WandbLogger(save_dir=tmpdir, log_model=True) logger.experiment.id = "1" logger.experiment.name = "run_name" trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=2, limit_train_batches=3, limit_val_batches=3) @@ -265,7 +265,7 @@ def test_wandb_log_model_with_score(wandb, monkeypatch, tmpdir): wandb.init().log_artifact.reset_mock() wandb.init.reset_mock() wandb.Artifact.reset_mock() - logger = WandbLogger(log_model=True) + logger = WandbLogger(save_dir=tmpdir, log_model=True) logger.experiment.id = "1" logger.experiment.name = "run_name" checkpoint_callback = ModelCheckpoint(monitor="step")