diff --git a/tests/backends/ddp_model.py b/tests/backends/ddp_model.py index b625d8cc98..32b30c0553 100644 --- a/tests/backends/ddp_model.py +++ b/tests/backends/ddp_model.py @@ -14,20 +14,24 @@ """ Runs either `.fit()` or `.test()` on a single node across multiple gpus. """ +import os from argparse import ArgumentParser +import tests as pl_tests from pytorch_lightning import Trainer, seed_everything from tests.base import EvalModelTemplate -import os + import torch def main(): seed_everything(1234) + parser = ArgumentParser(add_help=False) parser = Trainer.add_argparse_args(parser) parser.add_argument('--trainer_method', default='fit') parser.add_argument('--tmpdir') + parser.add_argument('--workdir') parser.set_defaults(gpus=2) parser.set_defaults(distributed_backend="ddp") args = parser.parse_args() @@ -38,14 +42,26 @@ def main(): result = {} if args.trainer_method == 'fit': trainer.fit(model) - result = {'status': 'complete', 'method': args.trainer_method, 'result': None} + result = { + 'status': 'complete', + 'method': args.trainer_method, + 'result': None + } if args.trainer_method == 'test': result = trainer.test(model) - result = {'status': 'complete', 'method': args.trainer_method, 'result': result} + result = { + 'status': 'complete', + 'method': args.trainer_method, + 'result': result + } if args.trainer_method == 'fit_test': trainer.fit(model) result = trainer.test(model) - result = {'status': 'complete', 'method': args.trainer_method, 'result': result} + result = { + 'status': 'complete', + 'method': args.trainer_method, + 'result': result + } if len(result) > 0: file_path = os.path.join(args.tmpdir, 'ddp.result') diff --git a/tests/utilities/distributed.py b/tests/utilities/distributed.py index f6b9a686b2..80c0246ce6 100644 --- a/tests/utilities/distributed.py +++ b/tests/utilities/distributed.py @@ -29,11 +29,10 @@ def call_training_script(module_file, cli_args, method, tmpdir, timeout=60): # need to set the PYTHONPATH in case pytorch_lightning was not installed into the environment env = os.environ.copy() - env['PYTHONPATH'] = f'{pytorch_lightning.__file__}:' + env.get('PYTHONPATH', '') + env['PYTHONPATH'] = env.get('PYTHONPATH', '') + f'{pytorch_lightning.__file__}:' # for running in ddp mode, we need to lauch it's own process or pytest will get stuck p = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env) - try: std, err = p.communicate(timeout=timeout) err = str(err.decode("utf-8")) @@ -42,5 +41,4 @@ def call_training_script(module_file, cli_args, method, tmpdir, timeout=60): except TimeoutExpired: p.kill() std, err = p.communicate() - return std, err