lightning/tests/tests_app/conftest.py

103 lines
3.3 KiB
Python
Raw Normal View History

import os
import shutil
import threading
from datetime import datetime
from pathlib import Path
import psutil
import py
import pytest
from lightning_app.storage.path import _storage_root_dir
from lightning_app.utilities.component import _set_context
from lightning_app.utilities.packaging import cloud_compute
from lightning_app.utilities.packaging.app_config import _APP_CONFIG_FILENAME
from lightning_app.utilities.state import AppState
os.environ["LIGHTNING_DISPATCHED"] = "1"
def pytest_sessionfinish(session, exitstatus):
"""Pytest hook that get called after whole test run finished, right before returning the exit status to the
system."""
# kill all the processes and threads created by parent
# TODO this isn't great. We should have each tests doing it's own cleanup
current_process = psutil.Process()
for child in current_process.children(recursive=True):
try:
params = child.as_dict() or {}
cmd_lines = params.get("cmdline", [])
# we shouldn't kill the resource tracker from multiprocessing. If we do,
# `atexit` will throw as it uses resource tracker to try to clean up
if cmd_lines and "resource_tracker" in cmd_lines[-1]:
continue
child.kill()
except psutil.NoSuchProcess:
pass
main_thread = threading.current_thread()
for t in threading.enumerate():
if t is not main_thread:
t.join(0)
@pytest.fixture(scope="function", autouse=True)
def cleanup():
from lightning_app.utilities.app_helpers import _LightningAppRef
yield
_LightningAppRef._app_instance = None
shutil.rmtree("./storage", ignore_errors=True)
shutil.rmtree(_storage_root_dir(), ignore_errors=True)
shutil.rmtree("./.shared", ignore_errors=True)
if os.path.isfile(_APP_CONFIG_FILENAME):
os.remove(_APP_CONFIG_FILENAME)
_set_context(None)
@pytest.fixture(scope="function", autouse=True)
def clear_app_state_state_variables():
"""Resets global variables in order to prevent interference between tests."""
yield
import lightning_app.utilities.state
lightning_app.utilities.state._STATE = None
lightning_app.utilities.state._LAST_STATE = None
AppState._MY_AFFILIATION = ()
if hasattr(cloud_compute, "_CLOUD_COMPUTE_STORE"):
cloud_compute._CLOUD_COMPUTE_STORE.clear()
@pytest.fixture
def another_tmpdir(tmp_path: Path) -> py.path.local:
random_dir = datetime.now().strftime("%m-%d-%Y-%H-%M-%S")
tmp_path = os.path.join(tmp_path, random_dir)
return py.path.local(tmp_path)
@pytest.fixture
def caplog(caplog):
"""Workaround for https://github.com/pytest-dev/pytest/issues/3697.
Setting ``filterwarnings`` with pytest breaks ``caplog`` when ``not logger.propagate``.
"""
import logging
root_logger = logging.getLogger()
root_propagate = root_logger.propagate
root_logger.propagate = True
propagation_dict = {
name: logging.getLogger(name).propagate
for name in logging.root.manager.loggerDict
if name.startswith("lightning_app")
}
for name in propagation_dict.keys():
logging.getLogger(name).propagate = True
yield caplog
root_logger.propagate = root_propagate
for name, propagate in propagation_dict.items():
logging.getLogger(name).propagate = propagate