2022-06-30 20:43:04 +00:00
import os
import shutil
2022-12-19 23:12:55 +00:00
import signal
2022-10-28 13:26:08 +00:00
import threading
2022-06-30 20:43:04 +00:00
from datetime import datetime
from pathlib import Path
2022-12-19 23:12:55 +00:00
from threading import Thread
2022-06-30 20:43:04 +00:00
import psutil
import py
import pytest
2022-10-28 13:57:35 +00:00
from lightning_app.storage.path import _storage_root_dir
2022-12-19 23:12:55 +00:00
from lightning_app.utilities.app_helpers import _collect_child_process_pids
2022-06-30 20:43:04 +00:00
from lightning_app.utilities.component import _set_context
2022-10-04 19:46:44 +00:00
from lightning_app.utilities.packaging import cloud_compute
2022-06-30 20:43:04 +00:00
from lightning_app.utilities.packaging.app_config import _APP_CONFIG_FILENAME
from lightning_app.utilities.state import AppState
2022-11-09 20:46:31 +00:00
os.environ["LIGHTNING_DISPATCHED"] = "1"
2022-12-19 23:12:55 +00:00
original_method = Thread._wait_for_tstate_lock
def fn(self, *args, timeout=None, **kwargs):
original_method(self, *args, timeout=1, **kwargs)
Thread._wait_for_tstate_lock = fn
2022-06-30 20:43:04 +00:00
def pytest_sessionfinish(session, exitstatus):
"""Pytest hook that get called after whole test run finished, right before returning the exit status to the
# kill all the processes and threads created by parent
# TODO this isn't great. We should have each tests doing it's own cleanup
current_process = psutil.Process()
for child in current_process.children(recursive=True):
2022-11-10 11:21:51 +00:00
params = child.as_dict() or {}
cmd_lines = params.get("cmdline", [])
# we shouldn't kill the resource tracker from multiprocessing. If we do,
# `atexit` will throw as it uses resource tracker to try to clean up
if cmd_lines and "resource_tracker" in cmd_lines[-1]:
except psutil.NoSuchProcess:
2022-06-30 20:43:04 +00:00
2022-10-28 13:26:08 +00:00
main_thread = threading.current_thread()
for t in threading.enumerate():
if t is not main_thread:
2022-12-19 23:12:55 +00:00
for child_pid in _collect_child_process_pids(os.getpid()):
os.kill(child_pid, signal.SIGTERM)
2022-06-30 20:43:04 +00:00
@pytest.fixture(scope="function", autouse=True)
def cleanup():
from lightning_app.utilities.app_helpers import _LightningAppRef
_LightningAppRef._app_instance = None
shutil.rmtree("./storage", ignore_errors=True)
2022-10-28 13:57:35 +00:00
shutil.rmtree(_storage_root_dir(), ignore_errors=True)
2022-06-30 20:43:04 +00:00
shutil.rmtree("./.shared", ignore_errors=True)
if os.path.isfile(_APP_CONFIG_FILENAME):
@pytest.fixture(scope="function", autouse=True)
def clear_app_state_state_variables():
"""Resets global variables in order to prevent interference between tests."""
import lightning_app.utilities.state
lightning_app.utilities.state._STATE = None
lightning_app.utilities.state._LAST_STATE = None
2022-10-20 14:18:06 +00:00
if hasattr(cloud_compute, "_CLOUD_COMPUTE_STORE"):
2022-06-30 20:43:04 +00:00
def another_tmpdir(tmp_path: Path) -> py.path.local:
random_dir = datetime.now().strftime("%m-%d-%Y-%H-%M-%S")
tmp_path = os.path.join(tmp_path, random_dir)
return py.path.local(tmp_path)
2022-11-24 14:09:25 +00:00
def caplog(caplog):
"""Workaround for https://github.com/pytest-dev/pytest/issues/3697.
Setting ``filterwarnings`` with pytest breaks ``caplog`` when ``not logger.propagate``.
import logging
root_logger = logging.getLogger()
root_propagate = root_logger.propagate
root_logger.propagate = True
propagation_dict = {
name: logging.getLogger(name).propagate
for name in logging.root.manager.loggerDict
if name.startswith("lightning_app")
for name in propagation_dict.keys():
logging.getLogger(name).propagate = True
yield caplog
root_logger.propagate = root_propagate
for name, propagate in propagation_dict.items():
logging.getLogger(name).propagate = propagate