Re-enable Lite CLI on Windows + PyTorch 1.13 (#15645)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
This commit is contained in:
Adrian Wälchli 2022-12-19 11:50:08 +01:00 committed by GitHub
parent 39d27f6370
commit 8a727c6243
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 1 additions and 31 deletions

View File

@ -20,7 +20,6 @@ from lightning_utilities.core.imports import RequirementCache
from lightning_lite.accelerators import CPUAccelerator, CUDAAccelerator, MPSAccelerator
from lightning_lite.utilities.device_parser import _parse_gpu_ids
from lightning_lite.utilities.imports import _IS_WINDOWS, _TORCH_GREATER_EQUAL_1_13
_log = logging.getLogger(__name__)
@ -148,15 +147,6 @@ def _get_num_processes(accelerator: str, devices: str) -> int:
def _torchrun_launch(args: Namespace, script_args: List[str]) -> None:
"""This will invoke `torchrun` programmatically to launch the given script in new processes."""
if _IS_WINDOWS and _TORCH_GREATER_EQUAL_1_13: # pragma: no cover
# TODO: remove once import issue is resolved: https://github.com/pytorch/pytorch/issues/85427
_log.error(
"On the Windows platform, this launcher is currently only supported on torch < 1.13 due to a bug"
" upstream: https://github.com/pytorch/pytorch/issues/85427"
)
raise SystemExit(1)
import torch.distributed.run as torchrun
if args.strategy == "dp":

View File

@ -16,21 +16,10 @@ from unittest import mock
from unittest.mock import Mock
import pytest
import torch.distributed.run
from tests_lite.helpers.runif import RunIf
from lightning_lite.cli import _run_model
from lightning_lite.utilities.imports import _IS_WINDOWS, _TORCH_GREATER_EQUAL_1_13
if not (_IS_WINDOWS and _TORCH_GREATER_EQUAL_1_13):
import torch.distributed.run
def skip_windows_pt_1_13():
# https://github.com/pytorch/pytorch/issues/85427
return pytest.mark.skipif(
condition=(_IS_WINDOWS and _TORCH_GREATER_EQUAL_1_13),
reason="Torchelastic import bug in 1.13 affecting Windows",
)
@pytest.fixture
@ -40,7 +29,6 @@ def fake_script(tmp_path):
return str(script)
@skip_windows_pt_1_13()
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
def test_cli_env_vars_defaults(monkeypatch, fake_script):
monkeypatch.setattr(torch.distributed, "run", Mock())
@ -55,7 +43,6 @@ def test_cli_env_vars_defaults(monkeypatch, fake_script):
assert os.environ["LT_PRECISION"] == "32"
@skip_windows_pt_1_13()
@pytest.mark.parametrize("accelerator", ["cpu", "gpu", "cuda", pytest.param("mps", marks=RunIf(mps=True))])
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
@mock.patch("lightning_lite.accelerators.cuda.num_cuda_devices", return_value=2)
@ -67,7 +54,6 @@ def test_cli_env_vars_accelerator(_, accelerator, monkeypatch, fake_script):
assert os.environ["LT_ACCELERATOR"] == accelerator
@skip_windows_pt_1_13()
@pytest.mark.parametrize("strategy", ["dp", "ddp", "deepspeed"])
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
@mock.patch("lightning_lite.accelerators.cuda.num_cuda_devices", return_value=2)
@ -79,7 +65,6 @@ def test_cli_env_vars_strategy(_, strategy, monkeypatch, fake_script):
assert os.environ["LT_STRATEGY"] == strategy
@skip_windows_pt_1_13()
@pytest.mark.parametrize("devices", ["1", "2", "0,", "1,0", "-1"])
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
@mock.patch("lightning_lite.accelerators.cuda.num_cuda_devices", return_value=2)
@ -92,7 +77,6 @@ def test_cli_env_vars_devices_cuda(_, devices, monkeypatch, fake_script):
@RunIf(mps=True)
@skip_windows_pt_1_13()
@pytest.mark.parametrize("accelerator", ["mps", "gpu"])
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
def test_cli_env_vars_devices_mps(accelerator, monkeypatch, fake_script):
@ -103,7 +87,6 @@ def test_cli_env_vars_devices_mps(accelerator, monkeypatch, fake_script):
assert os.environ["LT_DEVICES"] == "1"
@skip_windows_pt_1_13()
@pytest.mark.parametrize("num_nodes", ["1", "2", "3"])
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
def test_cli_env_vars_num_nodes(num_nodes, monkeypatch, fake_script):
@ -114,7 +97,6 @@ def test_cli_env_vars_num_nodes(num_nodes, monkeypatch, fake_script):
assert os.environ["LT_NUM_NODES"] == num_nodes
@skip_windows_pt_1_13()
@pytest.mark.parametrize("precision", ["64", "32", "16", "bf16"])
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
def test_cli_env_vars_precision(precision, monkeypatch, fake_script):
@ -125,7 +107,6 @@ def test_cli_env_vars_precision(precision, monkeypatch, fake_script):
assert os.environ["LT_PRECISION"] == precision
@skip_windows_pt_1_13()
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
def test_cli_torchrun_defaults(monkeypatch, fake_script):
torchrun_mock = Mock()
@ -145,7 +126,6 @@ def test_cli_torchrun_defaults(monkeypatch, fake_script):
)
@skip_windows_pt_1_13()
@pytest.mark.parametrize(
"devices,expected",
[