Avoid interactions through test artifacts (#19821)

This commit is contained in:
Adrian Wälchli 2024-04-28 17:56:40 +02:00 committed by GitHub
parent 5e0e02b79e
commit 29136332d6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
40 changed files with 154 additions and 98 deletions

View File

@ -14,6 +14,7 @@
import os
import sys
import threading
from pathlib import Path
from typing import List
from unittest.mock import Mock
@ -185,6 +186,17 @@ def caplog(caplog):
lightning_logger.propagate = propagate
@pytest.fixture(autouse=True)
def leave_no_artifacts_behind():
tests_root = Path(__file__).parent.parent
files_before = {p for p in tests_root.rglob("*") if "__pycache__" not in p.parts}
yield
files_after = {p for p in tests_root.rglob("*") if "__pycache__" not in p.parts}
difference = files_after - files_before
difference = {str(f.relative_to(tests_root)) for f in difference}
assert not difference, f"Test left artifacts behind: {difference}"
def pytest_collection_modifyitems(items: List[pytest.Function], config: pytest.Config) -> None:
"""An adaptation of `tests/tests_pytorch/conftest.py::pytest_collection_modifyitems`"""
initial_size = len(items)

View File

@ -13,24 +13,19 @@
# limitations under the License.
import os
import warnings
from pathlib import Path
import pytest
_TEST_ROOT = os.path.dirname(__file__)
_PROJECT_ROOT = os.path.dirname(_TEST_ROOT)
_TEMP_PATH = os.path.join(_PROJECT_ROOT, "test_temp")
_PATH_DATASETS = os.path.join(_PROJECT_ROOT, "Datasets")
_PATH_LEGACY = os.path.join(_PROJECT_ROOT, "legacy")
_TEST_ROOT = Path(__file__).parent.parent
_PROJECT_ROOT = _TEST_ROOT.parent
_PATH_DATASETS = _PROJECT_ROOT / "Datasets"
_PATH_LEGACY = _TEST_ROOT / "legacy"
# todo: this setting `PYTHONPATH` may not be used by other evns like Conda for import packages
if _PROJECT_ROOT not in os.getenv("PYTHONPATH", ""):
if str(_PROJECT_ROOT) not in os.getenv("PYTHONPATH", ""):
splitter = ":" if os.environ.get("PYTHONPATH", "") else ""
os.environ["PYTHONPATH"] = f'{_PROJECT_ROOT}{splitter}{os.environ.get("PYTHONPATH", "")}'
if not os.path.isdir(_TEMP_PATH):
os.mkdir(_TEMP_PATH)
# Ignore cleanup warnings from pytest (rarely happens due to a race condition when executing pytest in parallel)
warnings.filterwarnings("ignore", category=pytest.PytestWarning, message=r".*\(rm_rf\) error removing.*")

View File

@ -447,9 +447,10 @@ def test_rich_progress_bar_padding():
@RunIf(rich=True)
def test_rich_progress_bar_can_be_pickled():
def test_rich_progress_bar_can_be_pickled(tmp_path):
bar = RichProgressBar()
trainer = Trainer(
default_root_dir=tmp_path,
callbacks=[bar],
max_epochs=1,
limit_train_batches=1,

View File

@ -550,9 +550,10 @@ def test_tqdm_progress_bar_print_disabled(tqdm_write, mock_print, tmp_path):
tqdm_write.assert_not_called()
def test_tqdm_progress_bar_can_be_pickled():
def test_tqdm_progress_bar_can_be_pickled(tmp_path):
bar = TQDMProgressBar()
trainer = Trainer(
default_root_dir=tmp_path,
callbacks=[bar],
max_epochs=1,
limit_train_batches=1,

View File

@ -162,7 +162,7 @@ def test_device_stats_monitor_warning_when_psutil_not_available(monkeypatch, tmp
monkeypatch.setattr(imports, "_PSUTIL_AVAILABLE", False)
monitor = DeviceStatsMonitor()
trainer = Trainer(logger=CSVLogger(tmp_path))
trainer = Trainer(accelerator="cpu", logger=CSVLogger(tmp_path))
assert trainer.strategy.root_device == torch.device("cpu")
with pytest.raises(ModuleNotFoundError, match="psutil` is not installed"):
monitor.setup(trainer, Mock(), "fit")

View File

@ -113,7 +113,7 @@ def test_finetuning_callback_warning(tmp_path):
trainer.fit(model)
assert model.backbone.has_been_used
trainer = Trainer(max_epochs=3)
trainer = Trainer(default_root_dir=tmp_path, max_epochs=3)
trainer.fit(model, ckpt_path=chk.last_model_path)
@ -245,7 +245,7 @@ def test_base_finetuning_internal_optimizer_metadata(tmp_path):
model = FreezeModel()
cb = OnEpochLayerFinetuning()
trainer = Trainer(max_epochs=10, callbacks=[cb])
trainer = Trainer(default_root_dir=tmp_path, max_epochs=10, callbacks=[cb])
with pytest.raises(IndexError, match="index 6 is out of range"):
trainer.fit(model, ckpt_path=chk.last_model_path)

View File

@ -35,7 +35,7 @@ def test_prediction_writer_invalid_write_interval():
DummyPredictionWriter("something")
def test_prediction_writer_hook_call_intervals():
def test_prediction_writer_hook_call_intervals(tmp_path):
"""Test that the `write_on_batch_end` and `write_on_epoch_end` hooks get invoked based on the defined interval."""
DummyPredictionWriter.write_on_batch_end = Mock()
DummyPredictionWriter.write_on_epoch_end = Mock()
@ -44,7 +44,7 @@ def test_prediction_writer_hook_call_intervals():
model = BoringModel()
cb = DummyPredictionWriter("batch_and_epoch")
trainer = Trainer(limit_predict_batches=4, callbacks=cb)
trainer = Trainer(default_root_dir=tmp_path, logger=False, limit_predict_batches=4, callbacks=cb)
results = trainer.predict(model, dataloaders=dataloader)
assert len(results) == 4
assert cb.write_on_batch_end.call_count == 4
@ -54,7 +54,7 @@ def test_prediction_writer_hook_call_intervals():
DummyPredictionWriter.write_on_epoch_end.reset_mock()
cb = DummyPredictionWriter("batch_and_epoch")
trainer = Trainer(limit_predict_batches=4, callbacks=cb)
trainer = Trainer(default_root_dir=tmp_path, logger=False, limit_predict_batches=4, callbacks=cb)
trainer.predict(model, dataloaders=dataloader, return_predictions=False)
assert cb.write_on_batch_end.call_count == 4
assert cb.write_on_epoch_end.call_count == 1
@ -63,7 +63,7 @@ def test_prediction_writer_hook_call_intervals():
DummyPredictionWriter.write_on_epoch_end.reset_mock()
cb = DummyPredictionWriter("batch")
trainer = Trainer(limit_predict_batches=4, callbacks=cb)
trainer = Trainer(default_root_dir=tmp_path, logger=False, limit_predict_batches=4, callbacks=cb)
trainer.predict(model, dataloaders=dataloader, return_predictions=False)
assert cb.write_on_batch_end.call_count == 4
assert cb.write_on_epoch_end.call_count == 0
@ -72,21 +72,21 @@ def test_prediction_writer_hook_call_intervals():
DummyPredictionWriter.write_on_epoch_end.reset_mock()
cb = DummyPredictionWriter("epoch")
trainer = Trainer(limit_predict_batches=4, callbacks=cb)
trainer = Trainer(default_root_dir=tmp_path, logger=False, limit_predict_batches=4, callbacks=cb)
trainer.predict(model, dataloaders=dataloader, return_predictions=False)
assert cb.write_on_batch_end.call_count == 0
assert cb.write_on_epoch_end.call_count == 1
@pytest.mark.parametrize("num_workers", [0, 2])
def test_prediction_writer_batch_indices(num_workers):
def test_prediction_writer_batch_indices(num_workers, tmp_path):
DummyPredictionWriter.write_on_batch_end = Mock()
DummyPredictionWriter.write_on_epoch_end = Mock()
dataloader = DataLoader(RandomDataset(32, 64), batch_size=4, num_workers=num_workers)
model = BoringModel()
writer = DummyPredictionWriter("batch_and_epoch")
trainer = Trainer(limit_predict_batches=4, callbacks=writer)
trainer = Trainer(default_root_dir=tmp_path, logger=False, limit_predict_batches=4, callbacks=writer)
trainer.predict(model, dataloaders=dataloader)
writer.write_on_batch_end.assert_has_calls([
@ -101,7 +101,7 @@ def test_prediction_writer_batch_indices(num_workers):
])
def test_batch_level_batch_indices():
def test_batch_level_batch_indices(tmp_path):
"""Test that batch_indices are returned when `return_predictions=False`."""
DummyPredictionWriter.write_on_batch_end = Mock()
@ -112,7 +112,7 @@ def test_batch_level_batch_indices():
writer = DummyPredictionWriter("batch")
model = CustomBoringModel()
dataloader = DataLoader(RandomDataset(32, 64), batch_size=4)
trainer = Trainer(limit_predict_batches=4, callbacks=writer)
trainer = Trainer(default_root_dir=tmp_path, logger=False, limit_predict_batches=4, callbacks=writer)
trainer.predict(model, dataloaders=dataloader, return_predictions=False)
writer.write_on_batch_end.assert_has_calls([

View File

@ -190,7 +190,7 @@ def test_pruning_callback_ddp_cpu(tmp_path):
@pytest.mark.parametrize("resample_parameters", [False, True])
def test_pruning_lth_callable(tmp_path, resample_parameters: bool):
def test_pruning_lth_callable(tmp_path, resample_parameters):
model = TestModel()
class ModelPruningTestCallback(ModelPruning):
@ -206,7 +206,7 @@ def test_pruning_lth_callable(tmp_path, resample_parameters: bool):
curr, curr_name = self._parameters_to_prune[i]
assert name == curr_name
actual, expected = getattr(curr, name).data, getattr(copy, name).data
allclose = torch.allclose(actual, expected)
allclose = torch.allclose(actual.cpu(), expected)
assert not allclose if self._resample_parameters else allclose
pruning = ModelPruningTestCallback(
@ -310,7 +310,13 @@ def test_permanent_when_model_is_saved_multiple_times(
ckpt_callback = ModelCheckpoint(
monitor="test", save_top_k=2, save_last=True, save_on_train_epoch_end=save_on_train_epoch_end
)
trainer = Trainer(callbacks=[pruning_callback, ckpt_callback], max_epochs=3, enable_progress_bar=False)
trainer = Trainer(
default_root_dir=tmp_path,
logger=False,
callbacks=[pruning_callback, ckpt_callback],
max_epochs=3,
enable_progress_bar=False,
)
with caplog.at_level(INFO):
trainer.fit(model)

View File

@ -213,6 +213,8 @@ def test_trainer_spike_detection_integration(tmp_path, global_rank_spike, num_de
cb.should_raise = spike_value is None or finite_only or spike_value == float("inf")
trainer = Trainer(
default_root_dir=tmp_path,
logger=False,
callbacks=[cb],
accelerator="cpu",
devices=num_devices,

View File

@ -26,24 +26,24 @@ from lightning.pytorch.utilities.exceptions import MisconfigurationException
from tests_pytorch.helpers.runif import RunIf
def test_trainer_flag(caplog):
def test_trainer_flag(caplog, tmp_path):
class TestModel(BoringModel):
def on_fit_start(self):
raise SystemExit()
trainer = Trainer(max_time={"seconds": 1337})
trainer = Trainer(default_root_dir=tmp_path, logger=False, max_time={"seconds": 1337})
with pytest.raises(SystemExit):
trainer.fit(TestModel())
timer = [c for c in trainer.callbacks if isinstance(c, Timer)][0]
assert timer._duration == 1337
trainer = Trainer(max_time={"seconds": 1337}, callbacks=[Timer()])
trainer = Trainer(default_root_dir=tmp_path, logger=False, max_time={"seconds": 1337}, callbacks=[Timer()])
with pytest.raises(SystemExit), caplog.at_level(level=logging.INFO):
trainer.fit(TestModel())
assert "callbacks list already contains a Timer" in caplog.text
# Make sure max_time still honored even if max_epochs == -1
trainer = Trainer(max_time={"seconds": 1}, max_epochs=-1)
trainer = Trainer(default_root_dir=tmp_path, logger=False, max_time={"seconds": 1}, max_epochs=-1)
with pytest.raises(SystemExit):
trainer.fit(TestModel())
timer = [c for c in trainer.callbacks if isinstance(c, Timer)][0]

View File

@ -24,7 +24,7 @@ from tests_pytorch.helpers.runif import RunIf
def test_disabled_checkpointing():
# no callback
trainer = Trainer(max_epochs=3, enable_checkpointing=False)
trainer = Trainer(logger=False, max_epochs=3, enable_checkpointing=False)
assert not trainer.checkpoint_callbacks
trainer.fit(BoringModel())
assert not trainer.checkpoint_callbacks

View File

@ -308,6 +308,17 @@ def single_process_pg():
os.environ.update(orig_environ)
@pytest.fixture(autouse=True)
def leave_no_artifacts_behind():
tests_root = Path(__file__).parent.parent
files_before = {p for p in tests_root.rglob("*") if "__pycache__" not in p.parts}
yield
files_after = {p for p in tests_root.rglob("*") if "__pycache__" not in p.parts}
difference = files_after - files_before
difference = {str(f.relative_to(tests_root)) for f in difference}
assert not difference, f"Test left artifacts behind: {difference}"
def pytest_collection_modifyitems(items: List[pytest.Function], config: pytest.Config) -> None:
initial_size = len(items)
conditions = []

View File

@ -452,11 +452,12 @@ def test_define_as_dataclass():
@RunIf(skip_windows=True) # TODO: all durations are 0 on Windows
def test_datamodule_hooks_are_profiled():
def test_datamodule_hooks_are_profiled(tmp_path):
"""Test that `LightningDataModule` hooks are profiled."""
def get_trainer():
return Trainer(
default_root_dir=tmp_path,
max_steps=1,
limit_val_batches=0,
profiler="simple",

View File

@ -23,6 +23,8 @@ from lightning.pytorch.loops.optimization.automatic import Closure
from lightning.pytorch.tuner.tuning import Tuner
from torch.optim import SGD, Adam, Optimizer
from tests_pytorch.helpers.runif import RunIf
@pytest.mark.parametrize("auto", [True, False])
def test_lightning_optimizer(tmp_path, auto):
@ -232,6 +234,7 @@ def test_lightning_optimizer_automatic_optimization_optimizer_step(tmp_path):
assert sgd["zero_grad"].call_count == limit_train_batches
@RunIf(mps=False) # mps does not support LBFGS
def test_lightning_optimizer_automatic_optimization_lbfgs_zero_grad(tmp_path):
"""Test zero_grad is called the same number of times as LBFGS requires for reevaluation of the loss in
automatic_optimization."""

View File

@ -395,7 +395,7 @@ def result_collection_reload(default_root_dir, accelerator="auto", devices=1, **
@pytest.mark.parametrize(
"kwargs",
[
{},
pytest.param({}, marks=RunIf(mps=False)),
pytest.param({"strategy": "ddp", "accelerator": "gpu", "devices": 1}, marks=RunIf(min_cuda_gpus=1)),
pytest.param(
{"strategy": "ddp", "accelerator": "gpu", "devices": 2}, marks=RunIf(min_cuda_gpus=2, standalone=True)

View File

@ -13,6 +13,8 @@ from tests_pytorch.helpers.runif import RunIf
def create_boring_checkpoint(tmp_path, model, accelerator="cuda"):
checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="checkpoint")
trainer = pl.Trainer(
default_root_dir=tmp_path,
logger=False,
devices=1,
accelerator=accelerator,
max_epochs=1,

View File

@ -39,14 +39,6 @@ class MNIST(Dataset):
download: If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
Examples:
>>> dataset = MNIST(".", download=True)
>>> len(dataset)
60000
>>> torch.bincount(dataset.targets)
tensor([5923, 6742, 5958, 6131, 5842, 5421, 5918, 6265, 5851, 5949])
"""
RESOURCES = (
@ -141,15 +133,6 @@ class TrialMNIST(MNIST):
digits: list selected MNIST digits/classes
kwargs: Same as MNIST
Examples:
>>> dataset = TrialMNIST(".", download=True)
>>> len(dataset)
300
>>> sorted(set([d.item() for d in dataset.targets]))
[0, 1, 2]
>>> torch.bincount(dataset.targets)
tensor([100, 100, 100])
"""
def __init__(self, root: str, num_samples: int = 100, digits: Optional[Sequence] = (0, 1, 2), **kwargs):

View File

@ -15,11 +15,25 @@ import pickle
import cloudpickle
import pytest
import torch
from tests_pytorch import _PATH_DATASETS
from tests_pytorch.helpers.datasets import MNIST, AverageDataset, TrialMNIST
def test_mnist(tmp_path):
dataset = MNIST(tmp_path, download=True)
assert len(dataset) == 60000
assert torch.bincount(dataset.targets).tolist() == [5923, 6742, 5958, 6131, 5842, 5421, 5918, 6265, 5851, 5949]
def test_trial_mnist(tmp_path):
dataset = TrialMNIST(tmp_path, download=True)
assert len(dataset) == 300
assert set(dataset.targets.tolist()) == {0, 1, 2}
assert torch.bincount(dataset.targets).tolist() == [100, 100, 100]
@pytest.mark.parametrize(
("dataset_cls", "args"),
[(MNIST, {"root": _PATH_DATASETS}), (TrialMNIST, {"root": _PATH_DATASETS}), (AverageDataset, {})],

View File

@ -18,23 +18,19 @@ from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.loggers import TensorBoardLogger
from tests_pytorch import _TEMP_PATH
def get_default_logger(save_dir, version=None):
# set up logger object without actually saving logs
return TensorBoardLogger(save_dir, name="lightning_logs", version=version)
def get_data_path(expt_logger, path_dir=None):
def get_data_path(expt_logger, path_dir):
# some calls contain only experiment not complete logger
# each logger has to have these attributes
name, version = expt_logger.name, expt_logger.version
# the other experiments...
if not path_dir:
path_dir = expt_logger.save_dir if hasattr(expt_logger, "save_dir") and expt_logger.save_dir else _TEMP_PATH
path_expt = os.path.join(path_dir, name, "version_%s" % version)
# try if the new sub-folder exists, typical case for test-tube

View File

@ -70,8 +70,9 @@ def _instantiate_logger(logger_class, save_dir, **override_kwargs):
@mock.patch.dict(os.environ, {})
@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock())
@pytest.mark.parametrize("logger_class", ALL_LOGGER_CLASSES)
def test_loggers_fit_test_all(logger_class, mlflow_mock, wandb_mock, comet_mock, neptune_mock, tmp_path):
def test_loggers_fit_test_all(logger_class, mlflow_mock, wandb_mock, comet_mock, neptune_mock, tmp_path, monkeypatch):
"""Verify that basic functionality of all loggers."""
monkeypatch.chdir(tmp_path)
class CustomModel(BoringModel):
def training_step(self, batch, batch_idx):
@ -116,12 +117,12 @@ def test_loggers_fit_test_all(logger_class, mlflow_mock, wandb_mock, comet_mock,
model = CustomModel()
trainer = Trainer(
default_root_dir=tmp_path,
max_epochs=1,
logger=logger,
limit_train_batches=1,
limit_val_batches=1,
log_every_n_steps=1,
default_root_dir=tmp_path,
)
trainer.fit(model)
trainer.test()

View File

@ -168,7 +168,7 @@ def test_metrics_reset_after_save(tmp_path):
# Mock the existance check, so we can simulate appending to the metrics file
"lightning.fabric.loggers.csv_logs._ExperimentWriter._check_log_dir_exists"
)
def test_append_metrics_file(tmp_path):
def test_append_metrics_file(_, tmp_path):
"""Test that the logger appends to the file instead of rewriting it on every save."""
logger = CSVLogger(tmp_path, name="test", version=0, flush_logs_every_n_steps=1)

View File

@ -149,15 +149,17 @@ def test_neptune_additional_methods(neptune_mock):
run_instance_mock.__getitem__().log.assert_called_once_with(torch.ones(1))
def test_neptune_leave_open_experiment_after_fit(neptune_mock, tmp_path):
def test_neptune_leave_open_experiment_after_fit(neptune_mock, tmp_path, monkeypatch):
"""Verify that neptune experiment was NOT closed after training."""
monkeypatch.chdir(tmp_path)
logger, run_instance_mock, _ = _get_logger_with_mocks(api_key="test", project="project")
_fit_and_test(logger=logger, model=BoringModel(), tmp_path=tmp_path)
assert run_instance_mock.stop.call_count == 0
def test_neptune_log_metrics_on_trained_model(neptune_mock, tmp_path):
def test_neptune_log_metrics_on_trained_model(neptune_mock, tmp_path, monkeypatch):
"""Verify that trained models do log data."""
monkeypatch.chdir(tmp_path)
class LoggingModel(BoringModel):
def on_validation_epoch_end(self):
@ -305,9 +307,10 @@ def test_get_full_model_names_from_exp_structure():
assert NeptuneLogger._get_full_model_names_from_exp_structure(input_dict, "foo/bar") == expected_keys
def test_inactive_run(neptune_mock, tmp_path):
def test_inactive_run(neptune_mock, tmp_path, monkeypatch):
from neptune.exceptions import InactiveRunException
monkeypatch.chdir(tmp_path)
logger, run_instance_mock, _ = _get_logger_with_mocks(api_key="test", project="project")
run_instance_mock.__setitem__.side_effect = InactiveRunException

View File

@ -30,6 +30,7 @@ def test_no_val_on_train_epoch_loop_restart(tmp_path):
"limit_train_batches": 1,
"limit_val_batches": 1,
"num_sanity_val_steps": 0,
"logger": False,
"enable_checkpointing": False,
}
trainer = Trainer(**trainer_kwargs)

View File

@ -258,7 +258,7 @@ def test_fit_twice(tmp_path):
def test_try_resume_from_non_existing_checkpoint(tmp_path):
"""Test that trying to resume from non-existing `ckpt_path` fails with an error."""
model = BoringModel()
trainer = Trainer()
trainer = Trainer(logger=False)
with pytest.raises(FileNotFoundError, match="Checkpoint file not found"):
trainer.fit(model, ckpt_path=str(tmp_path / "non_existing.ckpt"))

View File

@ -134,6 +134,7 @@ class DoublePrecisionBoringModelComplexBuffer(BoringModel):
return super().training_step(batch, batch_idx)
@RunIf(mps=False) # mps does not support float64
@pytest.mark.parametrize(
"boring_model",
[

View File

@ -37,7 +37,7 @@ def test_servable_module_validator():
@pytest.mark.flaky(reruns=3)
def test_servable_module_validator_with_trainer(tmp_path):
def test_servable_module_validator_with_trainer(tmp_path, mps_count_0):
callback = ServableModuleValidator()
trainer = Trainer(
default_root_dir=tmp_path,

View File

@ -194,14 +194,16 @@ class SimpleModel(BoringModel):
assert torch.equal(self.layer.weight.data, self.tied_layer.weight.data)
def test_memory_sharing_disabled():
def test_memory_sharing_disabled(tmp_path):
"""Test that the multiprocessing launcher disables memory sharing on model parameters and buffers to avoid race
conditions on model updates."""
model = SimpleModel()
assert not model.layer.weight.is_shared()
assert model.layer.weight.data_ptr() == model.tied_layer.weight.data_ptr()
trainer = Trainer(accelerator="cpu", devices=2, strategy="ddp_spawn", max_steps=0)
trainer = Trainer(
default_root_dir=tmp_path, logger=False, accelerator="cpu", devices=2, strategy="ddp_spawn", max_steps=0
)
trainer.fit(model)
@ -214,7 +216,7 @@ def test_check_for_missing_main_guard():
launcher.launch(function=Mock())
def test_fit_twice_raises():
def test_fit_twice_raises(mps_count_0):
model = BoringModel()
trainer = Trainer(
limit_train_batches=1,

View File

@ -284,10 +284,10 @@ class BoringZeroRedundancyOptimizerModel(BoringModel):
@RunIf(min_cuda_gpus=2, skip_windows=True)
@pytest.mark.parametrize("strategy", [pytest.param("ddp", marks=RunIf(standalone=True)), "ddp_spawn"])
def test_ddp_strategy_checkpoint_zero_redundancy_optimizer(tmp_path, strategy):
def test_ddp_strategy_checkpoint_zero_redundancy_optimizer(strategy, tmp_path):
"""Test to ensure that checkpoint is saved correctly when using zero redundancy optimizer."""
model = BoringZeroRedundancyOptimizerModel()
trainer = Trainer(accelerator="gpu", devices=2, strategy=strategy, max_steps=1)
trainer = Trainer(default_root_dir=tmp_path, accelerator="gpu", devices=2, strategy=strategy, max_steps=1)
trainer.fit(model)

View File

@ -630,7 +630,7 @@ def test_fsdp_strategy_save_optimizer_states(tmp_path, wrap_min_params):
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True)
@pytest.mark.parametrize("wrap_min_params", [2, 1024, 100000000])
def test_fsdp_strategy_load_optimizer_states(tmp_path, wrap_min_params):
def test_fsdp_strategy_load_optimizer_states(wrap_min_params, tmp_path):
"""Test to ensure that the full state dict and optimizer states can be load when using FSDP strategy.
Based on `wrap_min_params`, the model will be fully wrapped, half wrapped, and not wrapped at all. If the DDP model
@ -694,14 +694,17 @@ def test_fsdp_strategy_load_optimizer_states(tmp_path, wrap_min_params):
("32-true", torch.float32),
],
)
def test_configure_model(precision, expected_dtype):
def test_configure_model(precision, expected_dtype, tmp_path):
"""Test that the module under configure_model gets moved to the right device and dtype."""
trainer = Trainer(
default_root_dir=tmp_path,
accelerator="cuda",
devices=2,
strategy=FSDPStrategy(auto_wrap_policy=always_wrap_policy),
precision=precision,
max_epochs=1,
enable_checkpointing=False,
logger=False,
)
class MyModel(BoringModel):
@ -899,7 +902,7 @@ def test_fsdp_lazy_load_full_state_dict(_, lazy_load_mock, torch_load_mock, tmp_
pytest.param("bf16-true", torch.bfloat16, marks=RunIf(bf16_cuda=True)),
],
)
def test_module_init_context(precision, expected_dtype):
def test_module_init_context(precision, expected_dtype, tmp_path):
"""Test that the module under the init-context gets moved to the right device and dtype."""
class Model(BoringModel):
@ -915,12 +918,15 @@ def test_module_init_context(precision, expected_dtype):
def _run_setup_assertions(empty_init, expected_device):
trainer = Trainer(
default_root_dir=tmp_path,
accelerator="cuda",
devices=2,
strategy=FSDPStrategy(auto_wrap_policy={torch.nn.Linear}),
precision=precision,
max_steps=1,
barebones=True,
enable_checkpointing=False,
logger=False,
)
with trainer.init_module(empty_init=empty_init):
model = Model()

View File

@ -40,7 +40,7 @@ def test_strategy_registry_with_deepspeed_strategies(strategy_name, init_params)
@RunIf(deepspeed=True)
@pytest.mark.parametrize("strategy", ["deepspeed", "deepspeed_stage_2_offload", "deepspeed_stage_3"])
def test_deepspeed_strategy_registry_with_trainer(tmp_path, strategy):
def test_deepspeed_strategy_registry_with_trainer(tmp_path, strategy, mps_count_0):
trainer = Trainer(default_root_dir=tmp_path, strategy=strategy, precision="16-mixed")
assert isinstance(trainer.strategy, DeepSpeedStrategy)

View File

@ -25,7 +25,7 @@ def test_passing_no_env_variables():
assert trainer.logger is not None
assert trainer.max_steps == -1
assert trainer.max_epochs is None
trainer = Trainer(logger=False, max_steps=1)
trainer = Trainer(max_steps=1, logger=False, enable_checkpointing=False)
trainer.fit(model)
assert trainer.logger is None
assert trainer.max_steps == 1
@ -49,7 +49,7 @@ def test_passing_env_variables_defaults():
@mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1", "PL_TRAINER_DEVICES": "2"})
def test_passing_env_variables_devices(cuda_count_2):
def test_passing_env_variables_devices(cuda_count_2, mps_count_0):
"""Testing overwriting trainer arguments."""
trainer = Trainer()
assert trainer.num_devices == 2

View File

@ -36,7 +36,7 @@ def test_min_max_steps_epochs(tmp_path, min_epochs, max_epochs, min_steps, max_s
assert trainer.global_step == trainer.max_steps
def test_max_epochs_not_set_warning():
def test_max_epochs_not_set_warning(tmp_path):
"""Test that a warning is only emitted when `max_epochs` was not set by the user."""
class CustomModel(BoringModel):
@ -46,7 +46,7 @@ def test_max_epochs_not_set_warning():
match = "`max_epochs` was not set. Setting it to 1000 epochs."
model = CustomModel()
trainer = Trainer(max_epochs=None, limit_train_batches=1)
trainer = Trainer(logger=False, enable_checkpointing=False, max_epochs=None, limit_train_batches=1)
with pytest.warns(PossibleUserWarning, match=match):
trainer.fit(model)

View File

@ -37,7 +37,13 @@ def test_val_check_interval(tmp_path, max_epochs, denominator):
self.val_epoch_calls += 1
model = TestModel()
trainer = Trainer(max_epochs=max_epochs, val_check_interval=1 / denominator, logger=False)
trainer = Trainer(
default_root_dir=tmp_path,
enable_checkpointing=False,
logger=False,
max_epochs=max_epochs,
val_check_interval=1 / denominator,
)
trainer.fit(model)
assert model.train_epoch_calls == max_epochs
@ -107,6 +113,8 @@ def test_validation_check_interval_exceed_data_length_wrong():
trainer = Trainer(
limit_train_batches=10,
val_check_interval=100,
logger=False,
enable_checkpointing=False,
)
model = BoringModel()

View File

@ -11,7 +11,7 @@ from lightning.pytorch.demos.boring_classes import BoringModel
def test_backward_count_simple(torch_backward, num_steps):
"""Test that backward is called exactly once per step."""
model = BoringModel()
trainer = Trainer(max_steps=num_steps)
trainer = Trainer(max_steps=num_steps, logger=False, enable_checkpointing=False)
trainer.fit(model)
assert torch_backward.call_count == num_steps
@ -25,19 +25,21 @@ def test_backward_count_simple(torch_backward, num_steps):
def test_backward_count_with_grad_accumulation(torch_backward):
"""Test that backward is called the correct number of times when accumulating gradients."""
model = BoringModel()
trainer = Trainer(max_epochs=1, limit_train_batches=6, accumulate_grad_batches=2)
trainer = Trainer(
max_epochs=1, limit_train_batches=6, accumulate_grad_batches=2, logger=False, enable_checkpointing=False
)
trainer.fit(model)
assert torch_backward.call_count == 6
torch_backward.reset_mock()
trainer = Trainer(max_steps=6, accumulate_grad_batches=2)
trainer = Trainer(max_steps=6, accumulate_grad_batches=2, logger=False, enable_checkpointing=False)
trainer.fit(model)
assert torch_backward.call_count == 12
@patch("torch.Tensor.backward")
def test_backward_count_with_closure(torch_backward):
def test_backward_count_with_closure(torch_backward, tmp_path):
"""Using a closure (e.g. with LBFGS) should lead to no extra backward calls."""
class TestModel(BoringModel):
@ -45,12 +47,12 @@ def test_backward_count_with_closure(torch_backward):
return torch.optim.LBFGS(self.parameters(), lr=0.1)
model = TestModel()
trainer = Trainer(max_steps=5)
trainer = Trainer(max_steps=5, logger=False, enable_checkpointing=False)
trainer.fit(model)
assert torch_backward.call_count == 5
torch_backward.reset_mock()
trainer = Trainer(max_steps=5, accumulate_grad_batches=2)
trainer = Trainer(max_steps=5, accumulate_grad_batches=2, logger=False, enable_checkpointing=False)
trainer.fit(model)
assert torch_backward.call_count == 10

View File

@ -910,7 +910,7 @@ def test_manual_optimization_with_non_pytorch_scheduler(automatic_optimization):
return [optimizer], [scheduler]
model = Model()
trainer = Trainer(accelerator="cpu", max_epochs=0)
trainer = Trainer(accelerator="cpu", max_epochs=0, logger=False, enable_checkpointing=False)
if automatic_optimization:
with pytest.raises(MisconfigurationException, match="doesn't follow PyTorch's LRScheduler"):
trainer.fit(model)

View File

@ -36,7 +36,7 @@ def test_multiple_optimizers_automatic_optimization_raises():
model = TestModel()
model.automatic_optimization = True
trainer = pl.Trainer()
trainer = pl.Trainer(logger=False, enable_checkpointing=False)
with pytest.raises(RuntimeError, match="Remove the `optimizer_idx` argument from `training_step`"):
trainer.fit(model)
@ -47,7 +47,7 @@ def test_multiple_optimizers_automatic_optimization_raises():
model = TestModel()
model.automatic_optimization = True
trainer = pl.Trainer()
trainer = pl.Trainer(logger=False, enable_checkpointing=False)
with pytest.raises(RuntimeError, match="multiple optimizers is only supported with manual optimization"):
trainer.fit(model)

View File

@ -86,9 +86,9 @@ def test_num_stepping_batches_infinite_training():
@pytest.mark.parametrize("max_steps", [2, 100])
def test_num_stepping_batches_with_max_steps(max_steps):
def test_num_stepping_batches_with_max_steps(max_steps, tmp_path):
"""Test stepping batches with `max_steps`."""
trainer = Trainer(max_steps=max_steps)
trainer = Trainer(max_steps=max_steps, default_root_dir=tmp_path, logger=False, enable_checkpointing=False)
model = BoringModel()
trainer.fit(model)
assert trainer.estimated_stepping_batches == max_steps

View File

@ -679,7 +679,11 @@ def test_warning_with_small_dataloader_and_logging_interval(tmp_path):
with pytest.warns(UserWarning, match=r"The number of training batches \(1\) is smaller than the logging interval"):
trainer = Trainer(
default_root_dir=tmp_path, max_epochs=1, log_every_n_steps=2, limit_train_batches=1, logger=CSVLogger(".")
default_root_dir=tmp_path,
max_epochs=1,
log_every_n_steps=2,
limit_train_batches=1,
logger=CSVLogger(tmp_path),
)
trainer.fit(model)
@ -727,7 +731,7 @@ def test_warning_with_iterable_dataset_and_len(tmp_path):
@pytest.mark.parametrize("yield_at_all", [False, True])
def test_iterable_dataset_stop_iteration_at_epoch_beginning(yield_at_all):
def test_iterable_dataset_stop_iteration_at_epoch_beginning(yield_at_all, tmp_path):
"""Test that the training loop skips execution if the iterator is empty from the start."""
class TestDataset(IterableDataset):
@ -748,7 +752,8 @@ def test_iterable_dataset_stop_iteration_at_epoch_beginning(yield_at_all):
model = TestModel()
train_dataloader = DataLoader(TestDataset(model.gen), batch_size=2)
trainer = Trainer(
default_root_dir=os.getcwd(),
default_root_dir=tmp_path,
logger=False,
max_epochs=2,
enable_model_summary=False,
)

View File

@ -2032,7 +2032,7 @@ def test_trainer_calls_logger_finalize_on_exception(tmp_path):
@pytest.mark.parametrize("exception_type", [KeyboardInterrupt, RuntimeError])
def test_trainer_calls_strategy_on_exception(exception_type):
def test_trainer_calls_strategy_on_exception(exception_type, tmp_path):
"""Test that when an exception occurs, the Trainer lets the strategy process it."""
exception = exception_type("Test exception")
@ -2040,7 +2040,7 @@ def test_trainer_calls_strategy_on_exception(exception_type):
def on_fit_start(self):
raise exception
trainer = Trainer()
trainer = Trainer(default_root_dir=tmp_path)
with mock.patch("lightning.pytorch.strategies.strategy.Strategy.on_exception") as on_exception_mock, suppress(
Exception
):
@ -2049,7 +2049,7 @@ def test_trainer_calls_strategy_on_exception(exception_type):
@pytest.mark.parametrize("exception_type", [KeyboardInterrupt, RuntimeError])
def test_trainer_calls_datamodule_on_exception(exception_type):
def test_trainer_calls_datamodule_on_exception(exception_type, tmp_path):
"""Test that when an exception occurs, the Trainer lets the data module process it."""
exception = exception_type("Test exception")
@ -2059,7 +2059,7 @@ def test_trainer_calls_datamodule_on_exception(exception_type):
datamodule = BoringDataModule()
datamodule.on_exception = Mock()
trainer = Trainer()
trainer = Trainer(default_root_dir=tmp_path)
with suppress(Exception):
trainer.fit(ExceptionModel(), datamodule=datamodule)

View File

@ -438,7 +438,7 @@ def test_batch_size_finder_with_multiple_eval_dataloaders(tmp_path):
def val_dataloader(self):
return [super().val_dataloader(), super().val_dataloader()]
trainer = Trainer()
trainer = Trainer(logger=False, enable_checkpointing=False)
tuner = Tuner(trainer)
model = CustomModel()