lightning/tests/conftest.py

109 lines
3.7 KiB
Python

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import threading
from functools import partial, wraps
from http.server import SimpleHTTPRequestHandler
from pathlib import Path
import pytest
import torch.distributed
import torch.multiprocessing as mp
from tests import _PATH_DATASETS
@pytest.fixture(scope="session")
def datadir():
return Path(_PATH_DATASETS)
@pytest.fixture(scope="function", autouse=True)
def preserve_global_rank_variable():
"""Ensures that the rank_zero_only.rank global variable gets reset in each test."""
from pytorch_lightning.utilities.distributed import rank_zero_only
rank = getattr(rank_zero_only, "rank", None)
yield
if rank is not None:
setattr(rank_zero_only, "rank", rank)
@pytest.fixture(scope="function", autouse=True)
def restore_env_variables():
"""Ensures that environment variables set during the test do not leak out."""
env_backup = os.environ.copy()
yield
# restore environment as it was before running the test
os.environ.clear()
os.environ.update(env_backup)
@pytest.fixture(scope="function", autouse=True)
def teardown_process_group():
"""Ensures that the distributed process group gets closed before the next test runs."""
yield
if torch.distributed.is_available() and torch.distributed.is_initialized():
torch.distributed.destroy_process_group()
def pytest_configure(config):
config.addinivalue_line("markers", "spawn: spawn test in a separate process using torch.multiprocessing.spawn")
@pytest.mark.tryfirst
def pytest_pyfunc_call(pyfuncitem):
if pyfuncitem.get_closest_marker("spawn"):
testfunction = pyfuncitem.obj
funcargs = pyfuncitem.funcargs
testargs = tuple(funcargs[arg] for arg in pyfuncitem._fixtureinfo.argnames)
mp.spawn(wraps, (testfunction, testargs))
return True
@pytest.fixture
def tmpdir_server(tmpdir):
if sys.version_info >= (3, 7):
Handler = partial(SimpleHTTPRequestHandler, directory=str(tmpdir))
from http.server import ThreadingHTTPServer
else:
# unfortunately SimpleHTTPRequestHandler doesn't accept the directory arg in python3.6
# so we have to hack it like this
class Handler(SimpleHTTPRequestHandler):
def translate_path(self, path):
# get the path from cwd
path = super().translate_path(path)
# get the relative path
relpath = os.path.relpath(path, os.getcwd())
# return the full path from root_dir
return os.path.join(str(tmpdir), relpath)
# ThreadingHTTPServer was added in 3.7, so we need to define it ourselves
from http.server import HTTPServer
from socketserver import ThreadingMixIn
class ThreadingHTTPServer(ThreadingMixIn, HTTPServer):
daemon_threads = True
with ThreadingHTTPServer(("localhost", 0), Handler) as server:
server_thread = threading.Thread(target=server.serve_forever)
# Exit the server thread when the main thread terminates
server_thread.daemon = True
server_thread.start()
yield server.server_address
server.shutdown()