Fix wandb test writing artifacts to cwd (#15551)

This commit is contained in:
Adrian Wälchli 2022-11-08 20:13:49 +01:00 committed by GitHub
parent e33d09a1a8
commit f2449ac5ab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 5 additions and 5 deletions

View File

@ -198,7 +198,7 @@ def test_wandb_log_model(wandb, monkeypatch, tmpdir):
model = BoringModel()
# test log_model=True
logger = WandbLogger(log_model=True)
logger = WandbLogger(save_dir=tmpdir, log_model=True)
logger.experiment.id = "1"
logger.experiment.name = "run_name"
trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=2, limit_train_batches=3, limit_val_batches=3)
@ -208,7 +208,7 @@ def test_wandb_log_model(wandb, monkeypatch, tmpdir):
# test log_model='all'
wandb.init().log_artifact.reset_mock()
wandb.init.reset_mock()
logger = WandbLogger(log_model="all")
logger = WandbLogger(save_dir=tmpdir, log_model="all")
logger.experiment.id = "1"
logger.experiment.name = "run_name"
trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=2, limit_train_batches=3, limit_val_batches=3)
@ -218,7 +218,7 @@ def test_wandb_log_model(wandb, monkeypatch, tmpdir):
# test log_model=False
wandb.init().log_artifact.reset_mock()
wandb.init.reset_mock()
logger = WandbLogger(log_model=False)
logger = WandbLogger(save_dir=tmpdir, log_model=False)
logger.experiment.id = "1"
logger.experiment.name = "run_name"
trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=2, limit_train_batches=3, limit_val_batches=3)
@ -229,7 +229,7 @@ def test_wandb_log_model(wandb, monkeypatch, tmpdir):
wandb.init().log_artifact.reset_mock()
wandb.init.reset_mock()
wandb.Artifact.reset_mock()
logger = WandbLogger(log_model=True)
logger = WandbLogger(save_dir=tmpdir, log_model=True)
logger.experiment.id = "1"
logger.experiment.name = "run_name"
trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=2, limit_train_batches=3, limit_val_batches=3)
@ -265,7 +265,7 @@ def test_wandb_log_model_with_score(wandb, monkeypatch, tmpdir):
wandb.init().log_artifact.reset_mock()
wandb.init.reset_mock()
wandb.Artifact.reset_mock()
logger = WandbLogger(log_model=True)
logger = WandbLogger(save_dir=tmpdir, log_model=True)
logger.experiment.id = "1"
logger.experiment.name = "run_name"
checkpoint_callback = ModelCheckpoint(monitor="step")