Fixed PYTHONPATH for ddp test model (#4528)

* Fixed PYTHONPATH for ddp test model

* Removed debug calls

* Apply suggestions from code review

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: chaton <thomas@grid.ai>
Co-authored-by: Sean Naren <sean.narenthiran@gmail.com>
This commit is contained in:
Gianluca Scarpellini 2020-12-05 21:09:47 +01:00 committed by GitHub
parent 4bb3a080c5
commit 16fa4ed1e5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 21 additions and 7 deletions

View File

@ -14,20 +14,24 @@
"""
Runs either `.fit()` or `.test()` on a single node across multiple gpus.
"""
import os
from argparse import ArgumentParser
import tests as pl_tests
from pytorch_lightning import Trainer, seed_everything
from tests.base import EvalModelTemplate
import os
import torch
def main():
seed_everything(1234)
parser = ArgumentParser(add_help=False)
parser = Trainer.add_argparse_args(parser)
parser.add_argument('--trainer_method', default='fit')
parser.add_argument('--tmpdir')
parser.add_argument('--workdir')
parser.set_defaults(gpus=2)
parser.set_defaults(distributed_backend="ddp")
args = parser.parse_args()
@ -38,14 +42,26 @@ def main():
result = {}
if args.trainer_method == 'fit':
trainer.fit(model)
result = {'status': 'complete', 'method': args.trainer_method, 'result': None}
result = {
'status': 'complete',
'method': args.trainer_method,
'result': None
}
if args.trainer_method == 'test':
result = trainer.test(model)
result = {'status': 'complete', 'method': args.trainer_method, 'result': result}
result = {
'status': 'complete',
'method': args.trainer_method,
'result': result
}
if args.trainer_method == 'fit_test':
trainer.fit(model)
result = trainer.test(model)
result = {'status': 'complete', 'method': args.trainer_method, 'result': result}
result = {
'status': 'complete',
'method': args.trainer_method,
'result': result
}
if len(result) > 0:
file_path = os.path.join(args.tmpdir, 'ddp.result')

View File

@ -29,11 +29,10 @@ def call_training_script(module_file, cli_args, method, tmpdir, timeout=60):
# need to set the PYTHONPATH in case pytorch_lightning was not installed into the environment
env = os.environ.copy()
env['PYTHONPATH'] = f'{pytorch_lightning.__file__}:' + env.get('PYTHONPATH', '')
env['PYTHONPATH'] = env.get('PYTHONPATH', '') + f'{pytorch_lightning.__file__}:'
# for running in ddp mode, we need to lauch it's own process or pytest will get stuck
p = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env)
try:
std, err = p.communicate(timeout=timeout)
err = str(err.decode("utf-8"))
@ -42,5 +41,4 @@ def call_training_script(module_file, cli_args, method, tmpdir, timeout=60):
except TimeoutExpired:
p.kill()
std, err = p.communicate()
return std, err