diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index f0a29b8e15..919704929b 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -59,6 +59,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed - Fixed an issue causing a wrong environment plugin to be selected when `accelerator=tpu` and `devices > 1` ([#16806](https://github.com/Lightning-AI/lightning/pull/16806)) +- Fixed parsing of defaults for `--accelerator` and `--precision` in Fabric CLI when `accelerator` and `precision` are set to non-default values in the code ([#16818](https://github.com/Lightning-AI/lightning/pull/16818)) ## [1.9.2] - 2023-02-15 diff --git a/src/lightning/fabric/cli.py b/src/lightning/fabric/cli.py index 4671a75da7..cd715b79cf 100644 --- a/src/lightning/fabric/cli.py +++ b/src/lightning/fabric/cli.py @@ -56,7 +56,7 @@ if _CLICK_AVAILABLE: @click.option( "--accelerator", type=click.Choice(_SUPPORTED_ACCELERATORS), - default="cpu", + default=None, help="The hardware accelerator to run on.", ) @click.option( @@ -108,7 +108,7 @@ if _CLICK_AVAILABLE: @click.option( "--precision", type=click.Choice(get_args(_PRECISION_INPUT_STR) + get_args(_PRECISION_INPUT_STR_ALIAS)), - default="32-true", + default=None, help=( "Double precision (``64-true`` or ``64``), full precision (``32-true`` or ``64``), " "half precision (``16-mixed`` or ``16``) or bfloat16 precision (``bf16-mixed`` or ``bf16``)" @@ -133,12 +133,14 @@ def _set_env_variables(args: Namespace) -> None: The Fabric connector will parse the arguments set here. """ os.environ["LT_CLI_USED"] = "1" - os.environ["LT_ACCELERATOR"] = str(args.accelerator) + if args.accelerator is not None: + os.environ["LT_ACCELERATOR"] = str(args.accelerator) if args.strategy is not None: os.environ["LT_STRATEGY"] = str(args.strategy) os.environ["LT_DEVICES"] = str(args.devices) os.environ["LT_NUM_NODES"] = str(args.num_nodes) - os.environ["LT_PRECISION"] = str(args.precision) + if args.precision is not None: + os.environ["LT_PRECISION"] = str(args.precision) def _get_num_processes(accelerator: str, devices: str) -> int: diff --git a/tests/tests_fabric/test_cli.py b/tests/tests_fabric/test_cli.py index 051df16528..7f249a0c95 100644 --- a/tests/tests_fabric/test_cli.py +++ b/tests/tests_fabric/test_cli.py @@ -39,11 +39,11 @@ def test_cli_env_vars_defaults(monkeypatch, fake_script): _run_model.main([fake_script]) assert e.value.code == 0 assert os.environ["LT_CLI_USED"] == "1" - assert os.environ["LT_ACCELERATOR"] == "cpu" + assert "LT_ACCELERATOR" not in os.environ assert "LT_STRATEGY" not in os.environ assert os.environ["LT_DEVICES"] == "1" assert os.environ["LT_NUM_NODES"] == "1" - assert os.environ["LT_PRECISION"] == "32-true" + assert "LT_PRECISION" not in os.environ @pytest.mark.parametrize("accelerator", ["cpu", "gpu", "cuda", pytest.param("mps", marks=RunIf(mps=True))]) diff --git a/tests/tests_fabric/test_connector.py b/tests/tests_fabric/test_connector.py index b29fb92376..c31094406c 100644 --- a/tests/tests_fabric/test_connector.py +++ b/tests/tests_fabric/test_connector.py @@ -813,11 +813,14 @@ def test_strategy_str_passed_being_case_insensitive(_, strategy, strategy_cls): assert isinstance(connector.strategy, strategy_cls) -@pytest.mark.parametrize("precision", ["64-true", "32-true", "16-mixed", "bf16-mixed"]) +@pytest.mark.parametrize("precision", [None, "64-true", "32-true", "16-mixed", "bf16-mixed"]) @mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=1) def test_precision_from_environment(_, precision): """Test that the precision input can be set through the environment variable.""" - with mock.patch.dict(os.environ, {"LT_PRECISION": precision}): + env_vars = {} + if precision is not None: + env_vars["LT_PRECISION"] = precision + with mock.patch.dict(os.environ, env_vars): connector = _Connector(accelerator="cuda") # need to use cuda, because AMP not available on CPU assert isinstance(connector.precision, Precision) @@ -825,6 +828,7 @@ def test_precision_from_environment(_, precision): @pytest.mark.parametrize( "accelerator, strategy, expected_accelerator, expected_strategy", [ + (None, None, CPUAccelerator, SingleDeviceStrategy), ("cpu", None, CPUAccelerator, SingleDeviceStrategy), ("cpu", "ddp", CPUAccelerator, DDPStrategy), pytest.param("mps", None, MPSAccelerator, SingleDeviceStrategy, marks=RunIf(mps=True)), @@ -836,7 +840,9 @@ def test_precision_from_environment(_, precision): ) def test_accelerator_strategy_from_environment(accelerator, strategy, expected_accelerator, expected_strategy): """Test that the accelerator and strategy input can be set through the environment variables.""" - env_vars = {"LT_ACCELERATOR": accelerator} + env_vars = {} + if accelerator is not None: + env_vars["LT_ACCELERATOR"] = accelerator if strategy is not None: env_vars["LT_STRATEGY"] = strategy