Raise environment variable collision errors only when Fabric CLI is used (#17679)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Adrian Wälchli 2023-05-23 01:12:26 +02:00 committed by GitHub
parent e6b7f1383c
commit 00909ba3ff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 19 additions and 13 deletions

View File

@ -70,6 +70,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Enable precision autocast for LightningModule step methods in Fabric ([#17439](https://github.com/Lightning-AI/lightning/pull/17439))
- Fabric argument validation now only raises an error if conflicting settings are set through the CLI ([#17679](https://github.com/Lightning-AI/lightning/pull/17679))
### Deprecated
- Deprecated the `DDPStrategy.is_distributed` property. This strategy is distributed by definition ([#17381](https://github.com/Lightning-AI/lightning/pull/17381))

View File

@ -533,14 +533,12 @@ class _Connector:
if env_value is None:
return current
if env_value is not None and env_value != str(current) and str(current) != str(default):
if env_value is not None and env_value != str(current) and str(current) != str(default) and _is_using_cli():
raise ValueError(
f"Your code has `Fabric({name}={current!r}, ...)` but it conflicts with the value "
f"`--{name}={env_value}` set through the CLI. "
" Remove it either from the CLI or from the Lightning Fabric object."
)
if env_value is None:
return current
return env_value
@ -561,3 +559,7 @@ def _convert_precision_to_unified_args(precision: _PRECISION_INPUT) -> _PRECISIO
)
precision = _PRECISION_INPUT_STR_ALIAS_CONVERSION[precision]
return cast(_PRECISION_INPUT_STR, precision)
def _is_using_cli() -> bool:
return bool(int(os.environ.get("LT_CLI_USED", "0")))

View File

@ -32,7 +32,7 @@ from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
from lightning.fabric.plugins import Precision # avoid circular imports: # isort: split
from lightning.fabric.accelerators.accelerator import Accelerator
from lightning.fabric.connector import _Connector, _PLUGIN_INPUT, _PRECISION_INPUT
from lightning.fabric.connector import _Connector, _is_using_cli, _PLUGIN_INPUT, _PRECISION_INPUT
from lightning.fabric.strategies import DeepSpeedStrategy, FSDPStrategy, SingleDeviceStrategy, Strategy, XLAStrategy
from lightning.fabric.strategies.launchers import _MultiProcessingLauncher, _XLALauncher
from lightning.fabric.strategies.strategy import _Sharded, TBroadcast
@ -888,7 +888,3 @@ class Fabric:
if any(not isinstance(dl, DataLoader) for dl in dataloaders):
raise TypeError("Only PyTorch DataLoader are currently supported in `setup_dataloaders`.")
def _is_using_cli() -> bool:
return bool(int(os.environ.get("LT_CLI_USED", "0")))

View File

@ -876,27 +876,32 @@ def test_devices_from_environment(*_):
def test_arguments_from_environment_collision():
"""Test that the connector raises an error when the CLI settings conflict with settings in the code."""
with mock.patch.dict(os.environ, {"LT_ACCELERATOR": "cpu"}), pytest.raises(
# Do not raise an error about collisions unless the CLI was used
with mock.patch.dict(os.environ, {"LT_ACCELERATOR": "cpu"}):
_Connector(accelerator="cuda")
with mock.patch.dict(os.environ, {"LT_ACCELERATOR": "cpu", "LT_CLI_USED": "1"}), pytest.raises(
ValueError, match="`Fabric\\(accelerator='cuda', ...\\)` but .* `--accelerator=cpu`"
):
_Connector(accelerator="cuda")
with mock.patch.dict(os.environ, {"LT_STRATEGY": "ddp"}), pytest.raises(
with mock.patch.dict(os.environ, {"LT_STRATEGY": "ddp", "LT_CLI_USED": "1"}), pytest.raises(
ValueError, match="`Fabric\\(strategy='ddp_spawn', ...\\)` but .* `--strategy=ddp`"
):
_Connector(strategy="ddp_spawn")
with mock.patch.dict(os.environ, {"LT_DEVICES": "2"}), pytest.raises(
with mock.patch.dict(os.environ, {"LT_DEVICES": "2", "LT_CLI_USED": "1"}), pytest.raises(
ValueError, match="`Fabric\\(devices=3, ...\\)` but .* `--devices=2`"
):
_Connector(devices=3)
with mock.patch.dict(os.environ, {"LT_NUM_NODES": "3"}), pytest.raises(
with mock.patch.dict(os.environ, {"LT_NUM_NODES": "3", "LT_CLI_USED": "1"}), pytest.raises(
ValueError, match="`Fabric\\(num_nodes=2, ...\\)` but .* `--num_nodes=3`"
):
_Connector(num_nodes=2)
with mock.patch.dict(os.environ, {"LT_PRECISION": "16-mixed"}), pytest.raises(
with mock.patch.dict(os.environ, {"LT_PRECISION": "16-mixed", "LT_CLI_USED": "1"}), pytest.raises(
ValueError, match="`Fabric\\(precision='64-true', ...\\)` but .* `--precision=16-mixed`"
):
_Connector(precision="64-true")