# Copyright The Lightning AI team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import sys import threading from typing import List from unittest.mock import Mock import lightning.fabric import pytest import torch.distributed from lightning.fabric.accelerators import XLAAccelerator from lightning.fabric.strategies.launchers.subprocess_script import _ChildProcessObserver from lightning.fabric.utilities.distributed import _distributed_is_initialized if sys.version_info >= (3, 9): from concurrent.futures.process import _ExecutorManagerThread @pytest.fixture(autouse=True) def preserve_global_rank_variable(): """Ensures that the rank_zero_only.rank global variable gets reset in each test.""" from lightning.fabric.utilities.rank_zero import rank_zero_only as rank_zero_only_fabric from lightning_utilities.core.rank_zero import rank_zero_only as rank_zero_only_utilities functions = (rank_zero_only_fabric, rank_zero_only_utilities) ranks = [getattr(fn, "rank", None) for fn in functions] yield for fn, rank in zip(functions, ranks): if rank is not None: setattr(fn, "rank", rank) @pytest.fixture(autouse=True) def restore_env_variables(): """Ensures that environment variables set during the test do not leak out.""" env_backup = os.environ.copy() yield leaked_vars = os.environ.keys() - env_backup.keys() # restore environment as it was before running the test os.environ.clear() os.environ.update(env_backup) # these are currently known leakers - ideally these would not be allowed # TODO(fabric): this list can be trimmed, maybe PL's too after moving tests allowlist = { "CUDA_DEVICE_ORDER", "LOCAL_RANK", "NODE_RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "PL_GLOBAL_SEED", "PL_SEED_WORKERS", "RANK", # set by DeepSpeed "CUDA_MODULE_LOADING", # leaked by PyTorch "CRC32C_SW_MODE", # set by tensorboardX "OMP_NUM_THREADS", # set by our launchers # set by torchdynamo "TRITON_CACHE_DIR", } leaked_vars.difference_update(allowlist) assert not leaked_vars, f"test is leaking environment variable(s): {set(leaked_vars)}" @pytest.fixture(autouse=True) def teardown_process_group(): """Ensures that the distributed process group gets closed before the next test runs.""" yield if _distributed_is_initialized(): torch.distributed.destroy_process_group() @pytest.fixture(autouse=True) def thread_police_duuu_daaa_duuu_daaa(): """Attempts to stop left-over threads to avoid test interactions.""" active_threads_before = set(threading.enumerate()) yield active_threads_after = set(threading.enumerate()) if XLAAccelerator.is_available(): # Ignore the check when running XLA tests for now return for thread in active_threads_after - active_threads_before: stop = getattr(thread, "stop", None) or getattr(thread, "exit", None) if thread.daemon and callable(stop): # A daemon thread would anyway be stopped at the end of a program # We do it preemptively here to reduce the risk of interactions with other tests that run after stop() assert not thread.is_alive() elif isinstance(thread, _ChildProcessObserver): thread.join(timeout=10) elif thread.name == "QueueFeederThread": # tensorboardX thread.join(timeout=20) elif ( sys.version_info >= (3, 9) and isinstance(thread, _ExecutorManagerThread) or "ThreadPoolExecutor-" in thread.name ): # probably `torch.compile`, can't narrow it down further continue else: raise AssertionError(f"Test left zombie thread: {thread}") @pytest.fixture() def reset_deterministic_algorithm(): """Ensures that torch determinism settings are reset before the next test runs.""" yield os.environ.pop("CUBLAS_WORKSPACE_CONFIG", None) torch.use_deterministic_algorithms(False) @pytest.fixture() def reset_cudnn_benchmark(): """Ensures that the `torch.backends.cudnn.benchmark` setting gets reset before the next test runs.""" yield torch.backends.cudnn.benchmark = False def mock_xla_available(monkeypatch: pytest.MonkeyPatch, value: bool = True) -> None: monkeypatch.setattr(lightning.fabric.accelerators.xla, "_XLA_AVAILABLE", value) monkeypatch.setattr(lightning.fabric.plugins.environments.xla, "_XLA_AVAILABLE", value) monkeypatch.setattr(lightning.fabric.plugins.precision.xla, "_XLA_AVAILABLE", value) monkeypatch.setattr(lightning.fabric.plugins.io.xla, "_XLA_AVAILABLE", value) monkeypatch.setattr(lightning.fabric.strategies.single_xla, "_XLA_AVAILABLE", value) monkeypatch.setattr(lightning.fabric.strategies.xla_fsdp, "_XLA_AVAILABLE", value) monkeypatch.setattr(lightning.fabric.strategies.launchers.xla, "_XLA_AVAILABLE", value) monkeypatch.setitem(sys.modules, "torch_xla", Mock()) monkeypatch.setitem(sys.modules, "torch_xla.core.xla_model", Mock()) monkeypatch.setitem(sys.modules, "torch_xla.experimental", Mock()) monkeypatch.setitem(sys.modules, "torch_xla.distributed.fsdp.wrap", Mock()) @pytest.fixture() def xla_available(monkeypatch: pytest.MonkeyPatch) -> None: mock_xla_available(monkeypatch) def mock_tpu_available(monkeypatch: pytest.MonkeyPatch, value: bool = True) -> None: mock_xla_available(monkeypatch, value) monkeypatch.setattr(lightning.fabric.accelerators.xla.XLAAccelerator, "is_available", lambda: value) monkeypatch.setattr(lightning.fabric.accelerators.xla.XLAAccelerator, "auto_device_count", lambda *_: 8) @pytest.fixture() def tpu_available(monkeypatch: pytest.MonkeyPatch) -> None: mock_tpu_available(monkeypatch) @pytest.fixture() def caplog(caplog): """Workaround for https://github.com/pytest-dev/pytest/issues/3697. Setting ``filterwarnings`` with pytest breaks ``caplog`` when ``not logger.propagate``. """ import logging lightning_logger = logging.getLogger("lightning.fabric") propagate = lightning_logger.propagate lightning_logger.propagate = True yield caplog lightning_logger.propagate = propagate def pytest_collection_modifyitems(items: List[pytest.Function], config: pytest.Config) -> None: """An adaptation of `tests/tests_pytorch/conftest.py::pytest_collection_modifyitems`""" initial_size = len(items) conditions = [] filtered, skipped = 0, 0 options = { "standalone": "PL_RUN_STANDALONE_TESTS", "min_cuda_gpus": "PL_RUN_CUDA_TESTS", "tpu": "PL_RUN_TPU_TESTS", } if os.getenv(options["standalone"], "0") == "1" and os.getenv(options["min_cuda_gpus"], "0") == "1": # special case: we don't have a CPU job for standalone tests, so we shouldn't run only cuda tests. # by deleting the key, we avoid filtering out the CPU tests del options["min_cuda_gpus"] for kwarg, env_var in options.items(): # this will compute the intersection of all tests selected per environment variable if os.getenv(env_var, "0") != "1": continue conditions.append(env_var) for i, test in reversed(list(enumerate(items))): # loop in reverse, since we are going to pop items already_skipped = any(marker.name == "skip" for marker in test.own_markers) if already_skipped: # the test was going to be skipped anyway, filter it out items.pop(i) skipped += 1 continue has_runif_with_kwarg = any( marker.name == "skipif" and marker.kwargs.get(kwarg) for marker in test.own_markers ) if not has_runif_with_kwarg: # the test has `@RunIf(kwarg=True)`, filter it out items.pop(i) filtered += 1 if config.option.verbose >= 0 and (filtered or skipped): writer = config.get_terminal_writer() writer.write( f"\nThe number of tests has been filtered from {initial_size} to {initial_size - filtered} after the" f" filters {conditions}.\n{skipped} tests are marked as unconditional skips.\nIn total, {len(items)} tests" " will run.\n", flush=True, bold=True, purple=True, # oh yeah, branded pytest messages ) # error out on our deprecation warnings - ensures the code and tests are kept up-to-date deprecation_error = pytest.mark.filterwarnings( "error::lightning.fabric.utilities.rank_zero.LightningDeprecationWarning", ) for item in items: item.add_marker(deprecation_error)