# Copyright The PyTorch Lightning team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import subprocess import sys from pathlib import Path from subprocess import TimeoutExpired import pytorch_lightning def call_training_script(module_file, cli_args, method, tmpdir, timeout=60, as_module=False): file = Path(module_file.__file__).absolute() cli_args = cli_args.split(" ") if cli_args else [] cli_args += ["--tmpdir", str(tmpdir)] cli_args += ["--trainer_method", method] file_args = ["-m", module_file.__spec__.name] if as_module else [str(file)] command = [sys.executable] + file_args + cli_args # need to set the PYTHONPATH in case pytorch_lightning was not installed into the environment env = os.environ.copy() env["PYTHONPATH"] = env.get("PYTHONPATH", "") + f"{pytorch_lightning.__file__}:" # for running in ddp mode, we need to lauch it's own process or pytest will get stuck p = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env) try: std, err = p.communicate(timeout=timeout) err = str(err.decode("utf-8")) if "Exception" in err: raise Exception(err) except TimeoutExpired: p.kill() std, err = p.communicate() return std, err