Re-enable Lite CLI on Windows + PyTorch 1.13 (#15645)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
This commit is contained in:
parent
39d27f6370
commit
8a727c6243
|
@ -20,7 +20,6 @@ from lightning_utilities.core.imports import RequirementCache
|
|||
|
||||
from lightning_lite.accelerators import CPUAccelerator, CUDAAccelerator, MPSAccelerator
|
||||
from lightning_lite.utilities.device_parser import _parse_gpu_ids
|
||||
from lightning_lite.utilities.imports import _IS_WINDOWS, _TORCH_GREATER_EQUAL_1_13
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
|
@ -148,15 +147,6 @@ def _get_num_processes(accelerator: str, devices: str) -> int:
|
|||
|
||||
def _torchrun_launch(args: Namespace, script_args: List[str]) -> None:
|
||||
"""This will invoke `torchrun` programmatically to launch the given script in new processes."""
|
||||
|
||||
if _IS_WINDOWS and _TORCH_GREATER_EQUAL_1_13: # pragma: no cover
|
||||
# TODO: remove once import issue is resolved: https://github.com/pytorch/pytorch/issues/85427
|
||||
_log.error(
|
||||
"On the Windows platform, this launcher is currently only supported on torch < 1.13 due to a bug"
|
||||
" upstream: https://github.com/pytorch/pytorch/issues/85427"
|
||||
)
|
||||
raise SystemExit(1)
|
||||
|
||||
import torch.distributed.run as torchrun
|
||||
|
||||
if args.strategy == "dp":
|
||||
|
|
|
@ -16,21 +16,10 @@ from unittest import mock
|
|||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
import torch.distributed.run
|
||||
from tests_lite.helpers.runif import RunIf
|
||||
|
||||
from lightning_lite.cli import _run_model
|
||||
from lightning_lite.utilities.imports import _IS_WINDOWS, _TORCH_GREATER_EQUAL_1_13
|
||||
|
||||
if not (_IS_WINDOWS and _TORCH_GREATER_EQUAL_1_13):
|
||||
import torch.distributed.run
|
||||
|
||||
|
||||
def skip_windows_pt_1_13():
|
||||
# https://github.com/pytorch/pytorch/issues/85427
|
||||
return pytest.mark.skipif(
|
||||
condition=(_IS_WINDOWS and _TORCH_GREATER_EQUAL_1_13),
|
||||
reason="Torchelastic import bug in 1.13 affecting Windows",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -40,7 +29,6 @@ def fake_script(tmp_path):
|
|||
return str(script)
|
||||
|
||||
|
||||
@skip_windows_pt_1_13()
|
||||
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
|
||||
def test_cli_env_vars_defaults(monkeypatch, fake_script):
|
||||
monkeypatch.setattr(torch.distributed, "run", Mock())
|
||||
|
@ -55,7 +43,6 @@ def test_cli_env_vars_defaults(monkeypatch, fake_script):
|
|||
assert os.environ["LT_PRECISION"] == "32"
|
||||
|
||||
|
||||
@skip_windows_pt_1_13()
|
||||
@pytest.mark.parametrize("accelerator", ["cpu", "gpu", "cuda", pytest.param("mps", marks=RunIf(mps=True))])
|
||||
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
|
||||
@mock.patch("lightning_lite.accelerators.cuda.num_cuda_devices", return_value=2)
|
||||
|
@ -67,7 +54,6 @@ def test_cli_env_vars_accelerator(_, accelerator, monkeypatch, fake_script):
|
|||
assert os.environ["LT_ACCELERATOR"] == accelerator
|
||||
|
||||
|
||||
@skip_windows_pt_1_13()
|
||||
@pytest.mark.parametrize("strategy", ["dp", "ddp", "deepspeed"])
|
||||
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
|
||||
@mock.patch("lightning_lite.accelerators.cuda.num_cuda_devices", return_value=2)
|
||||
|
@ -79,7 +65,6 @@ def test_cli_env_vars_strategy(_, strategy, monkeypatch, fake_script):
|
|||
assert os.environ["LT_STRATEGY"] == strategy
|
||||
|
||||
|
||||
@skip_windows_pt_1_13()
|
||||
@pytest.mark.parametrize("devices", ["1", "2", "0,", "1,0", "-1"])
|
||||
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
|
||||
@mock.patch("lightning_lite.accelerators.cuda.num_cuda_devices", return_value=2)
|
||||
|
@ -92,7 +77,6 @@ def test_cli_env_vars_devices_cuda(_, devices, monkeypatch, fake_script):
|
|||
|
||||
|
||||
@RunIf(mps=True)
|
||||
@skip_windows_pt_1_13()
|
||||
@pytest.mark.parametrize("accelerator", ["mps", "gpu"])
|
||||
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
|
||||
def test_cli_env_vars_devices_mps(accelerator, monkeypatch, fake_script):
|
||||
|
@ -103,7 +87,6 @@ def test_cli_env_vars_devices_mps(accelerator, monkeypatch, fake_script):
|
|||
assert os.environ["LT_DEVICES"] == "1"
|
||||
|
||||
|
||||
@skip_windows_pt_1_13()
|
||||
@pytest.mark.parametrize("num_nodes", ["1", "2", "3"])
|
||||
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
|
||||
def test_cli_env_vars_num_nodes(num_nodes, monkeypatch, fake_script):
|
||||
|
@ -114,7 +97,6 @@ def test_cli_env_vars_num_nodes(num_nodes, monkeypatch, fake_script):
|
|||
assert os.environ["LT_NUM_NODES"] == num_nodes
|
||||
|
||||
|
||||
@skip_windows_pt_1_13()
|
||||
@pytest.mark.parametrize("precision", ["64", "32", "16", "bf16"])
|
||||
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
|
||||
def test_cli_env_vars_precision(precision, monkeypatch, fake_script):
|
||||
|
@ -125,7 +107,6 @@ def test_cli_env_vars_precision(precision, monkeypatch, fake_script):
|
|||
assert os.environ["LT_PRECISION"] == precision
|
||||
|
||||
|
||||
@skip_windows_pt_1_13()
|
||||
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
|
||||
def test_cli_torchrun_defaults(monkeypatch, fake_script):
|
||||
torchrun_mock = Mock()
|
||||
|
@ -145,7 +126,6 @@ def test_cli_torchrun_defaults(monkeypatch, fake_script):
|
|||
)
|
||||
|
||||
|
||||
@skip_windows_pt_1_13()
|
||||
@pytest.mark.parametrize(
|
||||
"devices,expected",
|
||||
[
|
||||
|
|
Loading…
Reference in New Issue