Set accelerator through CLI only if set explicitly (#16818)

This commit is contained in:
Adrian Wälchli 2023-02-20 14:45:06 +01:00 committed by GitHub
parent 65e66814f8
commit 0e4ca7c286
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 18 additions and 9 deletions

View File

@ -59,6 +59,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Fixed
- Fixed an issue causing a wrong environment plugin to be selected when `accelerator=tpu` and `devices > 1` ([#16806](https://github.com/Lightning-AI/lightning/pull/16806))
- Fixed parsing of defaults for `--accelerator` and `--precision` in Fabric CLI when `accelerator` and `precision` are set to non-default values in the code ([#16818](https://github.com/Lightning-AI/lightning/pull/16818))
## [1.9.2] - 2023-02-15

View File

@ -56,7 +56,7 @@ if _CLICK_AVAILABLE:
@click.option(
"--accelerator",
type=click.Choice(_SUPPORTED_ACCELERATORS),
default="cpu",
default=None,
help="The hardware accelerator to run on.",
)
@click.option(
@ -108,7 +108,7 @@ if _CLICK_AVAILABLE:
@click.option(
"--precision",
type=click.Choice(get_args(_PRECISION_INPUT_STR) + get_args(_PRECISION_INPUT_STR_ALIAS)),
default="32-true",
default=None,
help=(
"Double precision (``64-true`` or ``64``), full precision (``32-true`` or ``64``), "
"half precision (``16-mixed`` or ``16``) or bfloat16 precision (``bf16-mixed`` or ``bf16``)"
@ -133,12 +133,14 @@ def _set_env_variables(args: Namespace) -> None:
The Fabric connector will parse the arguments set here.
"""
os.environ["LT_CLI_USED"] = "1"
os.environ["LT_ACCELERATOR"] = str(args.accelerator)
if args.accelerator is not None:
os.environ["LT_ACCELERATOR"] = str(args.accelerator)
if args.strategy is not None:
os.environ["LT_STRATEGY"] = str(args.strategy)
os.environ["LT_DEVICES"] = str(args.devices)
os.environ["LT_NUM_NODES"] = str(args.num_nodes)
os.environ["LT_PRECISION"] = str(args.precision)
if args.precision is not None:
os.environ["LT_PRECISION"] = str(args.precision)
def _get_num_processes(accelerator: str, devices: str) -> int:

View File

@ -39,11 +39,11 @@ def test_cli_env_vars_defaults(monkeypatch, fake_script):
_run_model.main([fake_script])
assert e.value.code == 0
assert os.environ["LT_CLI_USED"] == "1"
assert os.environ["LT_ACCELERATOR"] == "cpu"
assert "LT_ACCELERATOR" not in os.environ
assert "LT_STRATEGY" not in os.environ
assert os.environ["LT_DEVICES"] == "1"
assert os.environ["LT_NUM_NODES"] == "1"
assert os.environ["LT_PRECISION"] == "32-true"
assert "LT_PRECISION" not in os.environ
@pytest.mark.parametrize("accelerator", ["cpu", "gpu", "cuda", pytest.param("mps", marks=RunIf(mps=True))])

View File

@ -813,11 +813,14 @@ def test_strategy_str_passed_being_case_insensitive(_, strategy, strategy_cls):
assert isinstance(connector.strategy, strategy_cls)
@pytest.mark.parametrize("precision", ["64-true", "32-true", "16-mixed", "bf16-mixed"])
@pytest.mark.parametrize("precision", [None, "64-true", "32-true", "16-mixed", "bf16-mixed"])
@mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=1)
def test_precision_from_environment(_, precision):
"""Test that the precision input can be set through the environment variable."""
with mock.patch.dict(os.environ, {"LT_PRECISION": precision}):
env_vars = {}
if precision is not None:
env_vars["LT_PRECISION"] = precision
with mock.patch.dict(os.environ, env_vars):
connector = _Connector(accelerator="cuda") # need to use cuda, because AMP not available on CPU
assert isinstance(connector.precision, Precision)
@ -825,6 +828,7 @@ def test_precision_from_environment(_, precision):
@pytest.mark.parametrize(
"accelerator, strategy, expected_accelerator, expected_strategy",
[
(None, None, CPUAccelerator, SingleDeviceStrategy),
("cpu", None, CPUAccelerator, SingleDeviceStrategy),
("cpu", "ddp", CPUAccelerator, DDPStrategy),
pytest.param("mps", None, MPSAccelerator, SingleDeviceStrategy, marks=RunIf(mps=True)),
@ -836,7 +840,9 @@ def test_precision_from_environment(_, precision):
)
def test_accelerator_strategy_from_environment(accelerator, strategy, expected_accelerator, expected_strategy):
"""Test that the accelerator and strategy input can be set through the environment variables."""
env_vars = {"LT_ACCELERATOR": accelerator}
env_vars = {}
if accelerator is not None:
env_vars["LT_ACCELERATOR"] = accelerator
if strategy is not None:
env_vars["LT_STRATEGY"] = strategy