Support calling fit and test scripts using "python -m" module syntax with DDP (#8073)

Co-authored-by: Nisheeth Lahoti <nisheeth@rephrase.ai>
This commit is contained in:
nisheethlahoti 2021-06-23 08:12:04 +05:30 committed by GitHub
parent b378806b6c
commit 06f8349291
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 33 additions and 19 deletions

View File

@ -108,6 +108,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Add support for overriding `optimizer_zero_grad` and `optimizer_step` when using accumulate_grad_batches ([#7980](https://github.com/PyTorchLightning/pytorch-lightning/pull/7980))
- Add support for calling scripts using the module syntax (`python -m package.script`) ([#8073](https://github.com/PyTorchLightning/pytorch-lightning/pull/8073))
### Changed

View File

@ -18,6 +18,7 @@ import sys
from time import sleep
from typing import Any, Dict, List, Optional, Union
import __main__
import numpy as np
import torch
import torch.distributed as torch_distrib
@ -155,19 +156,25 @@ class DDPPlugin(ParallelPlugin):
os.environ["NODE_RANK"] = str(self.cluster_environment.node_rank())
os.environ["LOCAL_RANK"] = str(self.cluster_environment.local_rank())
# when user is using hydra find the absolute path
path_lib = os.path.abspath if not _HYDRA_AVAILABLE else to_absolute_path
# Check if the current calling command looked like `python a/b/c.py` or `python -m a.b.c`
# See https://docs.python.org/3/reference/import.html#main-spec
if __main__.__spec__ is None: # pragma: no-cover
# Script called as `python a/b/c.py`
# when user is using hydra find the absolute path
path_lib = os.path.abspath if not _HYDRA_AVAILABLE else to_absolute_path
# pull out the commands used to run the script and resolve the abs file path
command = sys.argv
try:
full_path = path_lib(command[0])
except Exception:
full_path = os.path.abspath(command[0])
# pull out the commands used to run the script and resolve the abs file path
command = sys.argv
try:
full_path = path_lib(command[0])
except Exception:
full_path = os.path.abspath(command[0])
command[0] = full_path
# use the same python interpreter and actually running
command = [sys.executable] + command
command[0] = full_path
# use the same python interpreter and actually running
command = [sys.executable] + command
else: # Script called as `python -m a.b.c`
command = [sys.executable, "-m", __main__.__spec__.name] + sys.argv[1:]
# the visible devices tell us how many GPUs we want to use.
# when the trainer script was called the device has already been scoped by the time

View File

@ -32,9 +32,10 @@ CLI_ARGS = '--max_epochs 1 --gpus 2 --accelerator ddp'
@RunIf(min_gpus=2)
def test_multi_gpu_model_ddp_fit_only(tmpdir):
@pytest.mark.parametrize("as_module", [True, False])
def test_multi_gpu_model_ddp_fit_only(tmpdir, as_module):
# call the script
call_training_script(ddp_model, CLI_ARGS, 'fit', tmpdir, timeout=120)
call_training_script(ddp_model, CLI_ARGS, 'fit', tmpdir, timeout=120, as_module=as_module)
# load the results of the script
result_path = os.path.join(tmpdir, 'ddp.result')
@ -45,9 +46,10 @@ def test_multi_gpu_model_ddp_fit_only(tmpdir):
@RunIf(min_gpus=2)
def test_multi_gpu_model_ddp_test_only(tmpdir):
@pytest.mark.parametrize("as_module", [True, False])
def test_multi_gpu_model_ddp_test_only(tmpdir, as_module):
# call the script
call_training_script(ddp_model, CLI_ARGS, 'test', tmpdir)
call_training_script(ddp_model, CLI_ARGS, 'test', tmpdir, as_module=as_module)
# load the results of the script
result_path = os.path.join(tmpdir, 'ddp.result')
@ -58,9 +60,10 @@ def test_multi_gpu_model_ddp_test_only(tmpdir):
@RunIf(min_gpus=2)
def test_multi_gpu_model_ddp_fit_test(tmpdir):
@pytest.mark.parametrize("as_module", [True, False])
def test_multi_gpu_model_ddp_fit_test(tmpdir, as_module):
# call the script
call_training_script(ddp_model, CLI_ARGS, 'fit_test', tmpdir, timeout=20)
call_training_script(ddp_model, CLI_ARGS, 'fit_test', tmpdir, timeout=20, as_module=as_module)
# load the results of the script
result_path = os.path.join(tmpdir, 'ddp.result')

View File

@ -20,12 +20,13 @@ from subprocess import TimeoutExpired
import pytorch_lightning
def call_training_script(module_file, cli_args, method, tmpdir, timeout=60):
def call_training_script(module_file, cli_args, method, tmpdir, timeout=60, as_module=False):
file = Path(module_file.__file__).absolute()
cli_args = cli_args.split(' ') if cli_args else []
cli_args += ['--tmpdir', str(tmpdir)]
cli_args += ['--trainer_method', method]
command = [sys.executable, str(file)] + cli_args
file_args = ["-m", module_file.__spec__.name] if as_module else [str(file)]
command = [sys.executable] + file_args + cli_args
# need to set the PYTHONPATH in case pytorch_lightning was not installed into the environment
env = os.environ.copy()