diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index b77d8d64c6..ae9f27776c 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -70,6 +70,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Enable precision autocast for LightningModule step methods in Fabric ([#17439](https://github.com/Lightning-AI/lightning/pull/17439)) +- Fabric argument validation now only raises an error if conflicting settings are set through the CLI ([#17679](https://github.com/Lightning-AI/lightning/pull/17679)) + + ### Deprecated - Deprecated the `DDPStrategy.is_distributed` property. This strategy is distributed by definition ([#17381](https://github.com/Lightning-AI/lightning/pull/17381)) diff --git a/src/lightning/fabric/connector.py b/src/lightning/fabric/connector.py index 63e2cc4f9f..f994181857 100644 --- a/src/lightning/fabric/connector.py +++ b/src/lightning/fabric/connector.py @@ -533,14 +533,12 @@ class _Connector: if env_value is None: return current - if env_value is not None and env_value != str(current) and str(current) != str(default): + if env_value is not None and env_value != str(current) and str(current) != str(default) and _is_using_cli(): raise ValueError( f"Your code has `Fabric({name}={current!r}, ...)` but it conflicts with the value " f"`--{name}={env_value}` set through the CLI. " " Remove it either from the CLI or from the Lightning Fabric object." ) - if env_value is None: - return current return env_value @@ -561,3 +559,7 @@ def _convert_precision_to_unified_args(precision: _PRECISION_INPUT) -> _PRECISIO ) precision = _PRECISION_INPUT_STR_ALIAS_CONVERSION[precision] return cast(_PRECISION_INPUT_STR, precision) + + +def _is_using_cli() -> bool: + return bool(int(os.environ.get("LT_CLI_USED", "0"))) diff --git a/src/lightning/fabric/fabric.py b/src/lightning/fabric/fabric.py index 2c01a8c7fb..fca2b7274e 100644 --- a/src/lightning/fabric/fabric.py +++ b/src/lightning/fabric/fabric.py @@ -32,7 +32,7 @@ from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0 from lightning.fabric.plugins import Precision # avoid circular imports: # isort: split from lightning.fabric.accelerators.accelerator import Accelerator -from lightning.fabric.connector import _Connector, _PLUGIN_INPUT, _PRECISION_INPUT +from lightning.fabric.connector import _Connector, _is_using_cli, _PLUGIN_INPUT, _PRECISION_INPUT from lightning.fabric.strategies import DeepSpeedStrategy, FSDPStrategy, SingleDeviceStrategy, Strategy, XLAStrategy from lightning.fabric.strategies.launchers import _MultiProcessingLauncher, _XLALauncher from lightning.fabric.strategies.strategy import _Sharded, TBroadcast @@ -888,7 +888,3 @@ class Fabric: if any(not isinstance(dl, DataLoader) for dl in dataloaders): raise TypeError("Only PyTorch DataLoader are currently supported in `setup_dataloaders`.") - - -def _is_using_cli() -> bool: - return bool(int(os.environ.get("LT_CLI_USED", "0"))) diff --git a/tests/tests_fabric/test_connector.py b/tests/tests_fabric/test_connector.py index f6e1290c8e..d55fe1e70f 100644 --- a/tests/tests_fabric/test_connector.py +++ b/tests/tests_fabric/test_connector.py @@ -876,27 +876,32 @@ def test_devices_from_environment(*_): def test_arguments_from_environment_collision(): """Test that the connector raises an error when the CLI settings conflict with settings in the code.""" - with mock.patch.dict(os.environ, {"LT_ACCELERATOR": "cpu"}), pytest.raises( + + # Do not raise an error about collisions unless the CLI was used + with mock.patch.dict(os.environ, {"LT_ACCELERATOR": "cpu"}): + _Connector(accelerator="cuda") + + with mock.patch.dict(os.environ, {"LT_ACCELERATOR": "cpu", "LT_CLI_USED": "1"}), pytest.raises( ValueError, match="`Fabric\\(accelerator='cuda', ...\\)` but .* `--accelerator=cpu`" ): _Connector(accelerator="cuda") - with mock.patch.dict(os.environ, {"LT_STRATEGY": "ddp"}), pytest.raises( + with mock.patch.dict(os.environ, {"LT_STRATEGY": "ddp", "LT_CLI_USED": "1"}), pytest.raises( ValueError, match="`Fabric\\(strategy='ddp_spawn', ...\\)` but .* `--strategy=ddp`" ): _Connector(strategy="ddp_spawn") - with mock.patch.dict(os.environ, {"LT_DEVICES": "2"}), pytest.raises( + with mock.patch.dict(os.environ, {"LT_DEVICES": "2", "LT_CLI_USED": "1"}), pytest.raises( ValueError, match="`Fabric\\(devices=3, ...\\)` but .* `--devices=2`" ): _Connector(devices=3) - with mock.patch.dict(os.environ, {"LT_NUM_NODES": "3"}), pytest.raises( + with mock.patch.dict(os.environ, {"LT_NUM_NODES": "3", "LT_CLI_USED": "1"}), pytest.raises( ValueError, match="`Fabric\\(num_nodes=2, ...\\)` but .* `--num_nodes=3`" ): _Connector(num_nodes=2) - with mock.patch.dict(os.environ, {"LT_PRECISION": "16-mixed"}), pytest.raises( + with mock.patch.dict(os.environ, {"LT_PRECISION": "16-mixed", "LT_CLI_USED": "1"}), pytest.raises( ValueError, match="`Fabric\\(precision='64-true', ...\\)` but .* `--precision=16-mixed`" ): _Connector(precision="64-true")