2020-10-13 11:18:07 +00:00
|
|
|
# Copyright The PyTorch Lightning team.
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
2020-10-01 12:15:23 +00:00
|
|
|
import os
|
|
|
|
import subprocess
|
|
|
|
import sys
|
|
|
|
from pathlib import Path
|
2020-11-20 10:17:46 +00:00
|
|
|
from subprocess import TimeoutExpired
|
2020-10-01 12:15:23 +00:00
|
|
|
|
|
|
|
import pytorch_lightning
|
|
|
|
|
|
|
|
|
|
|
|
def call_training_script(module_file, cli_args, method, tmpdir, timeout=60):
|
|
|
|
file = Path(module_file.__file__).absolute()
|
|
|
|
cli_args = cli_args.split(' ') if cli_args else []
|
|
|
|
cli_args += ['--tmpdir', str(tmpdir)]
|
|
|
|
cli_args += ['--trainer_method', method]
|
|
|
|
command = [sys.executable, str(file)] + cli_args
|
|
|
|
|
|
|
|
# need to set the PYTHONPATH in case pytorch_lightning was not installed into the environment
|
|
|
|
env = os.environ.copy()
|
2020-12-05 20:09:47 +00:00
|
|
|
env['PYTHONPATH'] = env.get('PYTHONPATH', '') + f'{pytorch_lightning.__file__}:'
|
2020-10-01 12:15:23 +00:00
|
|
|
|
|
|
|
# for running in ddp mode, we need to lauch it's own process or pytest will get stuck
|
|
|
|
p = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env)
|
|
|
|
try:
|
|
|
|
std, err = p.communicate(timeout=timeout)
|
|
|
|
err = str(err.decode("utf-8"))
|
|
|
|
if 'Exception' in err:
|
|
|
|
raise Exception(err)
|
|
|
|
except TimeoutExpired:
|
|
|
|
p.kill()
|
|
|
|
std, err = p.communicate()
|
|
|
|
return std, err
|