lightning/tests/utilities/distributed.py

46 lines
1.7 KiB
Python

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import subprocess
import sys
from pathlib import Path
from subprocess import TimeoutExpired
import pytorch_lightning
def call_training_script(module_file, cli_args, method, tmpdir, timeout=60, as_module=False):
file = Path(module_file.__file__).absolute()
cli_args = cli_args.split(" ") if cli_args else []
cli_args += ["--tmpdir", str(tmpdir)]
cli_args += ["--trainer_method", method]
file_args = ["-m", module_file.__spec__.name] if as_module else [str(file)]
command = [sys.executable] + file_args + cli_args
# need to set the PYTHONPATH in case pytorch_lightning was not installed into the environment
env = os.environ.copy()
env["PYTHONPATH"] = env.get("PYTHONPATH", "") + f"{pytorch_lightning.__file__}:"
# for running in ddp mode, we need to lauch it's own process or pytest will get stuck
p = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env)
try:
std, err = p.communicate(timeout=timeout)
err = str(err.decode("utf-8"))
if "Exception" in err:
raise Exception(err)
except TimeoutExpired:
p.kill()
std, err = p.communicate()
return std, err