Set accelerator through CLI only if set explicitly (#16818)
This commit is contained in:
parent
65e66814f8
commit
0e4ca7c286
|
@ -59,6 +59,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
### Fixed
|
||||
|
||||
- Fixed an issue causing a wrong environment plugin to be selected when `accelerator=tpu` and `devices > 1` ([#16806](https://github.com/Lightning-AI/lightning/pull/16806))
|
||||
- Fixed parsing of defaults for `--accelerator` and `--precision` in Fabric CLI when `accelerator` and `precision` are set to non-default values in the code ([#16818](https://github.com/Lightning-AI/lightning/pull/16818))
|
||||
|
||||
|
||||
## [1.9.2] - 2023-02-15
|
||||
|
|
|
@ -56,7 +56,7 @@ if _CLICK_AVAILABLE:
|
|||
@click.option(
|
||||
"--accelerator",
|
||||
type=click.Choice(_SUPPORTED_ACCELERATORS),
|
||||
default="cpu",
|
||||
default=None,
|
||||
help="The hardware accelerator to run on.",
|
||||
)
|
||||
@click.option(
|
||||
|
@ -108,7 +108,7 @@ if _CLICK_AVAILABLE:
|
|||
@click.option(
|
||||
"--precision",
|
||||
type=click.Choice(get_args(_PRECISION_INPUT_STR) + get_args(_PRECISION_INPUT_STR_ALIAS)),
|
||||
default="32-true",
|
||||
default=None,
|
||||
help=(
|
||||
"Double precision (``64-true`` or ``64``), full precision (``32-true`` or ``64``), "
|
||||
"half precision (``16-mixed`` or ``16``) or bfloat16 precision (``bf16-mixed`` or ``bf16``)"
|
||||
|
@ -133,12 +133,14 @@ def _set_env_variables(args: Namespace) -> None:
|
|||
The Fabric connector will parse the arguments set here.
|
||||
"""
|
||||
os.environ["LT_CLI_USED"] = "1"
|
||||
os.environ["LT_ACCELERATOR"] = str(args.accelerator)
|
||||
if args.accelerator is not None:
|
||||
os.environ["LT_ACCELERATOR"] = str(args.accelerator)
|
||||
if args.strategy is not None:
|
||||
os.environ["LT_STRATEGY"] = str(args.strategy)
|
||||
os.environ["LT_DEVICES"] = str(args.devices)
|
||||
os.environ["LT_NUM_NODES"] = str(args.num_nodes)
|
||||
os.environ["LT_PRECISION"] = str(args.precision)
|
||||
if args.precision is not None:
|
||||
os.environ["LT_PRECISION"] = str(args.precision)
|
||||
|
||||
|
||||
def _get_num_processes(accelerator: str, devices: str) -> int:
|
||||
|
|
|
@ -39,11 +39,11 @@ def test_cli_env_vars_defaults(monkeypatch, fake_script):
|
|||
_run_model.main([fake_script])
|
||||
assert e.value.code == 0
|
||||
assert os.environ["LT_CLI_USED"] == "1"
|
||||
assert os.environ["LT_ACCELERATOR"] == "cpu"
|
||||
assert "LT_ACCELERATOR" not in os.environ
|
||||
assert "LT_STRATEGY" not in os.environ
|
||||
assert os.environ["LT_DEVICES"] == "1"
|
||||
assert os.environ["LT_NUM_NODES"] == "1"
|
||||
assert os.environ["LT_PRECISION"] == "32-true"
|
||||
assert "LT_PRECISION" not in os.environ
|
||||
|
||||
|
||||
@pytest.mark.parametrize("accelerator", ["cpu", "gpu", "cuda", pytest.param("mps", marks=RunIf(mps=True))])
|
||||
|
|
|
@ -813,11 +813,14 @@ def test_strategy_str_passed_being_case_insensitive(_, strategy, strategy_cls):
|
|||
assert isinstance(connector.strategy, strategy_cls)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("precision", ["64-true", "32-true", "16-mixed", "bf16-mixed"])
|
||||
@pytest.mark.parametrize("precision", [None, "64-true", "32-true", "16-mixed", "bf16-mixed"])
|
||||
@mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=1)
|
||||
def test_precision_from_environment(_, precision):
|
||||
"""Test that the precision input can be set through the environment variable."""
|
||||
with mock.patch.dict(os.environ, {"LT_PRECISION": precision}):
|
||||
env_vars = {}
|
||||
if precision is not None:
|
||||
env_vars["LT_PRECISION"] = precision
|
||||
with mock.patch.dict(os.environ, env_vars):
|
||||
connector = _Connector(accelerator="cuda") # need to use cuda, because AMP not available on CPU
|
||||
assert isinstance(connector.precision, Precision)
|
||||
|
||||
|
@ -825,6 +828,7 @@ def test_precision_from_environment(_, precision):
|
|||
@pytest.mark.parametrize(
|
||||
"accelerator, strategy, expected_accelerator, expected_strategy",
|
||||
[
|
||||
(None, None, CPUAccelerator, SingleDeviceStrategy),
|
||||
("cpu", None, CPUAccelerator, SingleDeviceStrategy),
|
||||
("cpu", "ddp", CPUAccelerator, DDPStrategy),
|
||||
pytest.param("mps", None, MPSAccelerator, SingleDeviceStrategy, marks=RunIf(mps=True)),
|
||||
|
@ -836,7 +840,9 @@ def test_precision_from_environment(_, precision):
|
|||
)
|
||||
def test_accelerator_strategy_from_environment(accelerator, strategy, expected_accelerator, expected_strategy):
|
||||
"""Test that the accelerator and strategy input can be set through the environment variables."""
|
||||
env_vars = {"LT_ACCELERATOR": accelerator}
|
||||
env_vars = {}
|
||||
if accelerator is not None:
|
||||
env_vars["LT_ACCELERATOR"] = accelerator
|
||||
if strategy is not None:
|
||||
env_vars["LT_STRATEGY"] = strategy
|
||||
|
||||
|
|
Loading…
Reference in New Issue