[FEAT] DDP: Create DDPLauncher (#4515)
* test * poc * add simpler test for ddp * typo * resolve pep8 * try coverage testing * trying to add coverage inside ddp * resolve flake8 * update * forgot coverage * move .coveragerc * update rcfile path * update * test * update * adding description * add DDPLauncher decorator * add undecorated * push update * update ddp testing * Update tests/backends/launcher.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update tests/backends/launcher.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * update on comments * update on comments * resolve comments * resolve isort Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
This commit is contained in:
parent
6e788d2dc6
commit
4803f681b0
|
@ -0,0 +1,218 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import functools
|
||||
import itertools
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from argparse import ArgumentParser
|
||||
from inspect import isclass, isfunction, ismethod
|
||||
from pathlib import Path
|
||||
from subprocess import TimeoutExpired
|
||||
from time import time
|
||||
from typing import Callable, Dict, Optional
|
||||
|
||||
import coverage
|
||||
import torch
|
||||
|
||||
import pytorch_lightning
|
||||
from pytorch_lightning import Trainer, seed_everything
|
||||
from tests.base import EvalModelTemplate
|
||||
|
||||
|
||||
def import_from(module, name):
|
||||
module = __import__(module, fromlist=[name])
|
||||
return getattr(module, name)
|
||||
|
||||
|
||||
def call_training_script(cli_args: str, tmpdir: str, env: Dict, timeout: int = 20):
|
||||
file = Path(__file__).absolute()
|
||||
cli_args = cli_args.split(' ') if cli_args else []
|
||||
cli_args += ['--tmpdir', str(tmpdir)]
|
||||
command = [sys.executable, '-m', 'coverage', 'run', str(file)] + cli_args
|
||||
|
||||
# need to set the PYTHONPATH in case pytorch_lightning was not installed into the environment
|
||||
env['PYTHONPATH'] = f'{pytorch_lightning.__file__}:{env.get("PYTHONPATH", "")}'
|
||||
|
||||
# for running in ddp mode, we need to lauch it's own process or pytest will get stuck
|
||||
p = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env)
|
||||
|
||||
try:
|
||||
std, err = p.communicate(timeout=timeout)
|
||||
err = str(err.decode("utf-8"))
|
||||
if 'Exception' in err or 'Error' in err:
|
||||
raise Exception(err)
|
||||
except TimeoutExpired:
|
||||
p.kill()
|
||||
std, err = p.communicate()
|
||||
return std, err
|
||||
|
||||
|
||||
def create_cmd_lines(cmd_line, **kwargs):
|
||||
keys = sorted(kwargs.keys())
|
||||
values_comb = itertools.product(*(kwargs[k] for k in keys))
|
||||
cmd_lines = []
|
||||
for combi in values_comb:
|
||||
temp_cmd_line = cmd_line[::]
|
||||
for key_idx, v in enumerate(combi):
|
||||
k = keys[key_idx]
|
||||
temp_cmd_line = temp_cmd_line.replace(f"[{k}]", str(v))
|
||||
cmd_lines.append(temp_cmd_line)
|
||||
cmd_lines = list(set(cmd_lines))
|
||||
return cmd_lines
|
||||
|
||||
|
||||
def undecorated(o):
|
||||
"""Remove all decorators from a function, method or class"""
|
||||
# class decorator
|
||||
if type(o) is type:
|
||||
return o
|
||||
|
||||
try:
|
||||
closure = o.__closure__
|
||||
except AttributeError:
|
||||
return
|
||||
|
||||
if closure:
|
||||
for cell in closure:
|
||||
# avoid infinite recursion
|
||||
if cell.cell_contents is o:
|
||||
continue
|
||||
|
||||
# check if the contents looks like a decorator; in that case
|
||||
# we need to go one level down into the dream, otherwise it
|
||||
# might just be a different closed-over variable, which we
|
||||
# can ignore.
|
||||
|
||||
# Note: this favors supporting decorators defined without
|
||||
# @wraps to the detriment of function/method/class closures
|
||||
if looks_like_a_decorator(cell.cell_contents):
|
||||
undecd = undecorated(cell.cell_contents)
|
||||
if undecd:
|
||||
return undecd
|
||||
else:
|
||||
return o
|
||||
else:
|
||||
return o
|
||||
|
||||
|
||||
def looks_like_a_decorator(a):
|
||||
return (
|
||||
isfunction(a) or ismethod(a) or isclass(a)
|
||||
)
|
||||
|
||||
|
||||
class DDPLauncher:
|
||||
"""
|
||||
This script is used to launch DDP related tests.
|
||||
It provides a simple decorator to run your test. See below for explication and example:
|
||||
|
||||
Example:
|
||||
|
||||
# The decorator will read cmd_line + arguments provided as kwargs.
|
||||
|
||||
@DDPLauncher.run("--max_epochs [max_epochs] --gpus 2 --accelerator [accelerator]",
|
||||
max_epochs=["1"],
|
||||
accelerator=["ddp", "ddp_spawn"])
|
||||
def test_cli_to_pass(tmpdir, args=None):
|
||||
|
||||
... do something with args + BoringModel
|
||||
|
||||
return '1'
|
||||
|
||||
|
||||
Explication:
|
||||
|
||||
1 - DDPLauncher.run will recieve a command line to run where tokens are recognized by [].
|
||||
DDPLauncher.run will emulate pytest `parametrize` function and generate as many
|
||||
resolved cmd_lines from your provided cmd_line + product of your kwargs arguments.
|
||||
For the previous example, it will generate 2 cmd_lines and run them:
|
||||
1: --max_epochs 1 --gpus 2 --accelerator ddp
|
||||
2: --max_epochs 1 --gpus 2 --accelerator ddp_spawn
|
||||
|
||||
2 - For each cmd_line, the launcher will save the module and function name of your decorated test
|
||||
in env variable and run the cmd_line on himself, which is located at tests/backend/launcher.py.
|
||||
|
||||
3 - When running {ENV} python tests/backend/launcher.py {your_resolved_cmd_line},
|
||||
the script will start in `__name__ == "__main__"` where argparse
|
||||
is used to parsed your command line.
|
||||
The parsed_args will be provided to `main` function.
|
||||
|
||||
4 - the `main function` will extract module and function name of your decorated test
|
||||
from os.environ variables, import dymically your function, undecorate it,
|
||||
and run result = your_test_func(tmpdir, args=args).
|
||||
|
||||
5 - the `main function` will save a `ddp.result` object which will be read by
|
||||
the launcher to make sure your test run correctly.
|
||||
"""
|
||||
@staticmethod
|
||||
def run_from_cmd_line(cli_args:str = None, func_to_run: Optional[Callable] = None, tmpdir: Optional[str] = None, timeout: int = 20):
|
||||
env = os.environ.copy()
|
||||
env["PL_CURRENT_TEST_MODULE"] = str(func_to_run.__module__)
|
||||
env["PL_CURRENT_TEST_NAME"] = str(func_to_run.__name__)
|
||||
return call_training_script(cli_args, tmpdir, env, timeout=timeout)
|
||||
|
||||
def run(cmd_line, **kwargs):
|
||||
cmd_lines = create_cmd_lines(cmd_line, **kwargs)
|
||||
|
||||
def inner(func):
|
||||
@functools.wraps(func)
|
||||
def func_wrapper(*args, **kwargs):
|
||||
tmpdir = kwargs.get("tmpdir")
|
||||
for cmd_line in cmd_lines:
|
||||
print(f"Launching {func.__name__} with {cmd_line}")
|
||||
std, err = DDPLauncher.run_from_cmd_line(cmd_line, func, tmpdir, timeout=20)
|
||||
print(std)
|
||||
print(err)
|
||||
# Make sure the test run properly
|
||||
result_path = os.path.join(tmpdir, 'ddp.result')
|
||||
result = torch.load(result_path)
|
||||
# verify the file wrote the expected outputs
|
||||
assert result['status'] == 'complete'
|
||||
return func_wrapper
|
||||
return inner
|
||||
|
||||
|
||||
def main(args):
|
||||
# Set PL_IN_LAUNCHER for first use case
|
||||
os.environ["PL_IN_LAUNCHER"] = '1'
|
||||
env = os.environ.copy()
|
||||
|
||||
# Load function based on module and its name
|
||||
func = import_from(env["PL_CURRENT_TEST_MODULE"], env["PL_CURRENT_TEST_NAME"])
|
||||
|
||||
# Undecorate the function
|
||||
func = undecorated(func)
|
||||
|
||||
# Run the function and gather result
|
||||
result = func(args.tmpdir, args=args)
|
||||
|
||||
# Save result
|
||||
result = {'status': 'complete', 'result':result}
|
||||
if len(result) > 0:
|
||||
file_path = os.path.join(args.tmpdir, 'ddp.result')
|
||||
torch.save(result, file_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
seed_everything(1234)
|
||||
|
||||
# Parse arguments
|
||||
parser = ArgumentParser(add_help=False)
|
||||
parser = Trainer.add_argparse_args(parser)
|
||||
parser.add_argument('--tmpdir')
|
||||
parser.set_defaults(gpus=2)
|
||||
|
||||
# Launch main process
|
||||
main(parser.parse_args())
|
|
@ -11,10 +11,14 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
from time import sleep
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import os
|
||||
|
||||
from tests.backends import ddp_model
|
||||
from tests.backends.launcher import DDPLauncher
|
||||
from tests.utilities.dist import call_training_script
|
||||
|
||||
|
||||
|
@ -68,3 +72,36 @@ def test_multi_gpu_model_ddp_fit_test(tmpdir, cli_args):
|
|||
model_outs = result['result']
|
||||
for out in model_outs:
|
||||
assert out['test_acc'] > 0.90
|
||||
|
||||
|
||||
# START: test_cli ddp test
|
||||
@pytest.mark.skipif(os.getenv("PL_IN_LAUNCHER", '0') == '1', reason="test runs only in DDPLauncher")
|
||||
def internal_test_cli(tmpdir, args=None):
|
||||
"""
|
||||
This test verify we can call function using test_cli name
|
||||
"""
|
||||
|
||||
return 1
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
|
||||
def test_cli(tmpdir):
|
||||
DDPLauncher.run_from_cmd_line("--max_epochs 1 --gpus 2 --accelerator ddp", internal_test_cli, tmpdir)
|
||||
# load the results of the script
|
||||
result_path = os.path.join(tmpdir, 'ddp.result')
|
||||
result = torch.load(result_path)
|
||||
# verify the file wrote the expected outputs
|
||||
assert result['status'] == 'complete'
|
||||
assert str(result['result']) == '1'
|
||||
# END: test_cli ddp test
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
|
||||
@DDPLauncher.run("--max_epochs [max_epochs] --gpus 2 --accelerator [accelerator]",
|
||||
max_epochs=["1"],
|
||||
accelerator=["ddp", "ddp_spawn"])
|
||||
def test_cli_to_pass(tmpdir, args=None):
|
||||
"""
|
||||
This test verify we can call function using test_cli name
|
||||
"""
|
||||
return '1'
|
||||
|
|
|
@ -13,9 +13,9 @@
|
|||
# limitations under the License.
|
||||
import os
|
||||
import subprocess
|
||||
from subprocess import TimeoutExpired
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from subprocess import TimeoutExpired
|
||||
|
||||
import pytorch_lightning
|
||||
|
||||
|
|
Loading…
Reference in New Issue