Fixed PYTHONPATH for ddp test model (#4528)
* Fixed PYTHONPATH for ddp test model * Removed debug calls * Apply suggestions from code review Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Sean Naren <sean.narenthiran@gmail.com>
This commit is contained in:
parent
4bb3a080c5
commit
16fa4ed1e5
|
@ -14,20 +14,24 @@
|
|||
"""
|
||||
Runs either `.fit()` or `.test()` on a single node across multiple gpus.
|
||||
"""
|
||||
import os
|
||||
from argparse import ArgumentParser
|
||||
|
||||
import tests as pl_tests
|
||||
from pytorch_lightning import Trainer, seed_everything
|
||||
from tests.base import EvalModelTemplate
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def main():
|
||||
seed_everything(1234)
|
||||
|
||||
parser = ArgumentParser(add_help=False)
|
||||
parser = Trainer.add_argparse_args(parser)
|
||||
parser.add_argument('--trainer_method', default='fit')
|
||||
parser.add_argument('--tmpdir')
|
||||
parser.add_argument('--workdir')
|
||||
parser.set_defaults(gpus=2)
|
||||
parser.set_defaults(distributed_backend="ddp")
|
||||
args = parser.parse_args()
|
||||
|
@ -38,14 +42,26 @@ def main():
|
|||
result = {}
|
||||
if args.trainer_method == 'fit':
|
||||
trainer.fit(model)
|
||||
result = {'status': 'complete', 'method': args.trainer_method, 'result': None}
|
||||
result = {
|
||||
'status': 'complete',
|
||||
'method': args.trainer_method,
|
||||
'result': None
|
||||
}
|
||||
if args.trainer_method == 'test':
|
||||
result = trainer.test(model)
|
||||
result = {'status': 'complete', 'method': args.trainer_method, 'result': result}
|
||||
result = {
|
||||
'status': 'complete',
|
||||
'method': args.trainer_method,
|
||||
'result': result
|
||||
}
|
||||
if args.trainer_method == 'fit_test':
|
||||
trainer.fit(model)
|
||||
result = trainer.test(model)
|
||||
result = {'status': 'complete', 'method': args.trainer_method, 'result': result}
|
||||
result = {
|
||||
'status': 'complete',
|
||||
'method': args.trainer_method,
|
||||
'result': result
|
||||
}
|
||||
|
||||
if len(result) > 0:
|
||||
file_path = os.path.join(args.tmpdir, 'ddp.result')
|
||||
|
|
|
@ -29,11 +29,10 @@ def call_training_script(module_file, cli_args, method, tmpdir, timeout=60):
|
|||
|
||||
# need to set the PYTHONPATH in case pytorch_lightning was not installed into the environment
|
||||
env = os.environ.copy()
|
||||
env['PYTHONPATH'] = f'{pytorch_lightning.__file__}:' + env.get('PYTHONPATH', '')
|
||||
env['PYTHONPATH'] = env.get('PYTHONPATH', '') + f'{pytorch_lightning.__file__}:'
|
||||
|
||||
# for running in ddp mode, we need to lauch it's own process or pytest will get stuck
|
||||
p = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env)
|
||||
|
||||
try:
|
||||
std, err = p.communicate(timeout=timeout)
|
||||
err = str(err.decode("utf-8"))
|
||||
|
@ -42,5 +41,4 @@ def call_training_script(module_file, cli_args, method, tmpdir, timeout=60):
|
|||
except TimeoutExpired:
|
||||
p.kill()
|
||||
std, err = p.communicate()
|
||||
|
||||
return std, err
|
||||
|
|
Loading…
Reference in New Issue