Raise environment variable collision errors only when Fabric CLI is used (#17679)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
e6b7f1383c
commit
00909ba3ff
|
@ -70,6 +70,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Enable precision autocast for LightningModule step methods in Fabric ([#17439](https://github.com/Lightning-AI/lightning/pull/17439))
|
||||
|
||||
|
||||
- Fabric argument validation now only raises an error if conflicting settings are set through the CLI ([#17679](https://github.com/Lightning-AI/lightning/pull/17679))
|
||||
|
||||
|
||||
### Deprecated
|
||||
|
||||
- Deprecated the `DDPStrategy.is_distributed` property. This strategy is distributed by definition ([#17381](https://github.com/Lightning-AI/lightning/pull/17381))
|
||||
|
|
|
@ -533,14 +533,12 @@ class _Connector:
|
|||
if env_value is None:
|
||||
return current
|
||||
|
||||
if env_value is not None and env_value != str(current) and str(current) != str(default):
|
||||
if env_value is not None and env_value != str(current) and str(current) != str(default) and _is_using_cli():
|
||||
raise ValueError(
|
||||
f"Your code has `Fabric({name}={current!r}, ...)` but it conflicts with the value "
|
||||
f"`--{name}={env_value}` set through the CLI. "
|
||||
" Remove it either from the CLI or from the Lightning Fabric object."
|
||||
)
|
||||
if env_value is None:
|
||||
return current
|
||||
return env_value
|
||||
|
||||
|
||||
|
@ -561,3 +559,7 @@ def _convert_precision_to_unified_args(precision: _PRECISION_INPUT) -> _PRECISIO
|
|||
)
|
||||
precision = _PRECISION_INPUT_STR_ALIAS_CONVERSION[precision]
|
||||
return cast(_PRECISION_INPUT_STR, precision)
|
||||
|
||||
|
||||
def _is_using_cli() -> bool:
|
||||
return bool(int(os.environ.get("LT_CLI_USED", "0")))
|
||||
|
|
|
@ -32,7 +32,7 @@ from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
|
|||
|
||||
from lightning.fabric.plugins import Precision # avoid circular imports: # isort: split
|
||||
from lightning.fabric.accelerators.accelerator import Accelerator
|
||||
from lightning.fabric.connector import _Connector, _PLUGIN_INPUT, _PRECISION_INPUT
|
||||
from lightning.fabric.connector import _Connector, _is_using_cli, _PLUGIN_INPUT, _PRECISION_INPUT
|
||||
from lightning.fabric.strategies import DeepSpeedStrategy, FSDPStrategy, SingleDeviceStrategy, Strategy, XLAStrategy
|
||||
from lightning.fabric.strategies.launchers import _MultiProcessingLauncher, _XLALauncher
|
||||
from lightning.fabric.strategies.strategy import _Sharded, TBroadcast
|
||||
|
@ -888,7 +888,3 @@ class Fabric:
|
|||
|
||||
if any(not isinstance(dl, DataLoader) for dl in dataloaders):
|
||||
raise TypeError("Only PyTorch DataLoader are currently supported in `setup_dataloaders`.")
|
||||
|
||||
|
||||
def _is_using_cli() -> bool:
|
||||
return bool(int(os.environ.get("LT_CLI_USED", "0")))
|
||||
|
|
|
@ -876,27 +876,32 @@ def test_devices_from_environment(*_):
|
|||
|
||||
def test_arguments_from_environment_collision():
|
||||
"""Test that the connector raises an error when the CLI settings conflict with settings in the code."""
|
||||
with mock.patch.dict(os.environ, {"LT_ACCELERATOR": "cpu"}), pytest.raises(
|
||||
|
||||
# Do not raise an error about collisions unless the CLI was used
|
||||
with mock.patch.dict(os.environ, {"LT_ACCELERATOR": "cpu"}):
|
||||
_Connector(accelerator="cuda")
|
||||
|
||||
with mock.patch.dict(os.environ, {"LT_ACCELERATOR": "cpu", "LT_CLI_USED": "1"}), pytest.raises(
|
||||
ValueError, match="`Fabric\\(accelerator='cuda', ...\\)` but .* `--accelerator=cpu`"
|
||||
):
|
||||
_Connector(accelerator="cuda")
|
||||
|
||||
with mock.patch.dict(os.environ, {"LT_STRATEGY": "ddp"}), pytest.raises(
|
||||
with mock.patch.dict(os.environ, {"LT_STRATEGY": "ddp", "LT_CLI_USED": "1"}), pytest.raises(
|
||||
ValueError, match="`Fabric\\(strategy='ddp_spawn', ...\\)` but .* `--strategy=ddp`"
|
||||
):
|
||||
_Connector(strategy="ddp_spawn")
|
||||
|
||||
with mock.patch.dict(os.environ, {"LT_DEVICES": "2"}), pytest.raises(
|
||||
with mock.patch.dict(os.environ, {"LT_DEVICES": "2", "LT_CLI_USED": "1"}), pytest.raises(
|
||||
ValueError, match="`Fabric\\(devices=3, ...\\)` but .* `--devices=2`"
|
||||
):
|
||||
_Connector(devices=3)
|
||||
|
||||
with mock.patch.dict(os.environ, {"LT_NUM_NODES": "3"}), pytest.raises(
|
||||
with mock.patch.dict(os.environ, {"LT_NUM_NODES": "3", "LT_CLI_USED": "1"}), pytest.raises(
|
||||
ValueError, match="`Fabric\\(num_nodes=2, ...\\)` but .* `--num_nodes=3`"
|
||||
):
|
||||
_Connector(num_nodes=2)
|
||||
|
||||
with mock.patch.dict(os.environ, {"LT_PRECISION": "16-mixed"}), pytest.raises(
|
||||
with mock.patch.dict(os.environ, {"LT_PRECISION": "16-mixed", "LT_CLI_USED": "1"}), pytest.raises(
|
||||
ValueError, match="`Fabric\\(precision='64-true', ...\\)` but .* `--precision=16-mixed`"
|
||||
):
|
||||
_Connector(precision="64-true")
|
||||
|
|
Loading…
Reference in New Issue