From 3c3bff5e6ed7efe77929be6da6162af81b6e46b5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Wed, 11 Jan 2023 16:29:51 +0100 Subject: [PATCH] Fabric: Remove `_Connector.is_distributed` (#16327) --- src/lightning_fabric/connector.py | 20 ------------------- src/lightning_fabric/fabric.py | 4 ++-- src/lightning_fabric/strategies/ddp.py | 4 ---- src/lightning_fabric/strategies/dp.py | 4 ++++ src/lightning_fabric/strategies/fsdp.py | 6 +----- src/lightning_fabric/strategies/parallel.py | 18 +++++++---------- src/lightning_fabric/strategies/single_tpu.py | 4 ---- src/lightning_fabric/strategies/xla.py | 10 +++------- tests/tests_fabric/test_fabric.py | 4 ++-- 9 files changed, 19 insertions(+), 55 deletions(-) diff --git a/src/lightning_fabric/connector.py b/src/lightning_fabric/connector.py index 1b0d960e84..208fb9f00d 100644 --- a/src/lightning_fabric/connector.py +++ b/src/lightning_fabric/connector.py @@ -43,8 +43,6 @@ from lightning_fabric.plugins.precision.double import DoublePrecision from lightning_fabric.plugins.precision.fsdp import FSDPPrecision from lightning_fabric.plugins.precision.precision import _PRECISION_INPUT, _PRECISION_INPUT_INT, _PRECISION_INPUT_STR from lightning_fabric.strategies import ( - DDPShardedStrategy, - DDPStrategy, DeepSpeedStrategy, SingleDeviceStrategy, SingleTPUStrategy, @@ -547,21 +545,3 @@ class _Connector: if env_value is None: return current return env_value - - @property - def is_distributed(self) -> bool: - # TODO: deprecate this property - # Used for custom plugins. - # Custom plugins should implement is_distributed property. - if hasattr(self.strategy, "is_distributed") and not isinstance(self.accelerator, TPUAccelerator): - return self.strategy.is_distributed - distributed_strategy = ( - DDPStrategy, - DDPShardedStrategy, - DeepSpeedStrategy, - XLAStrategy, - ) - is_distributed = isinstance(self.strategy, distributed_strategy) - if isinstance(self.accelerator, TPUAccelerator): - is_distributed |= self.strategy.is_distributed - return is_distributed diff --git a/src/lightning_fabric/fabric.py b/src/lightning_fabric/fabric.py index 8428549417..de569c93f2 100644 --- a/src/lightning_fabric/fabric.py +++ b/src/lightning_fabric/fabric.py @@ -69,7 +69,7 @@ class Fabric: accelerator: The hardware to run on. Possible choices are: ``"cpu"``, ``"cuda"``, ``"mps"``, ``"gpu"``, ``"tpu"``, ``"auto"``. strategy: Strategy for how to run across multiple devices. Possible choices are: - ``"dp"``, ``"ddp"``, ``"ddp_spawn"``, ``"deepspeed"``, ``"ddp_sharded"``. + ``"dp"``, ``"ddp"``, ``"ddp_spawn"``, ``"deepspeed"``, ``"fsdp"``, ``"ddp_sharded"``. devices: Number of devices to train on (``int``), which GPUs to train on (``list`` or ``str``), or ``"auto"``. The value applies per node. num_nodes: Number of GPU nodes for distributed training. @@ -673,7 +673,7 @@ class Fabric: def _requires_distributed_sampler(self, dataloader: DataLoader) -> bool: return ( - self._connector.is_distributed + getattr(self.strategy, "distributed_sampler_kwargs", None) is not None and not isinstance(dataloader.sampler, DistributedSampler) and not has_iterable_dataset(dataloader) ) diff --git a/src/lightning_fabric/strategies/ddp.py b/src/lightning_fabric/strategies/ddp.py index c27980f0af..b17545bc0b 100644 --- a/src/lightning_fabric/strategies/ddp.py +++ b/src/lightning_fabric/strategies/ddp.py @@ -81,10 +81,6 @@ class DDPStrategy(ParallelStrategy): assert self.parallel_devices is not None return self.parallel_devices[self.local_rank] - @property - def is_distributed(self) -> bool: - return True - @property def num_nodes(self) -> int: return self._num_nodes diff --git a/src/lightning_fabric/strategies/dp.py b/src/lightning_fabric/strategies/dp.py index 8a697b7e2e..1fcc2b4c67 100644 --- a/src/lightning_fabric/strategies/dp.py +++ b/src/lightning_fabric/strategies/dp.py @@ -50,6 +50,10 @@ class DataParallelStrategy(ParallelStrategy): assert self.parallel_devices is not None return self.parallel_devices[0] + @property + def distributed_sampler_kwargs(self) -> None: + return None + def setup_module(self, module: Module) -> DataParallel: """Wraps the given model into a :class:`~torch.nn.parallel.DataParallel` module.""" return DataParallel(module=module, device_ids=self.parallel_devices) diff --git a/src/lightning_fabric/strategies/fsdp.py b/src/lightning_fabric/strategies/fsdp.py index 7fe400179e..748568a16e 100644 --- a/src/lightning_fabric/strategies/fsdp.py +++ b/src/lightning_fabric/strategies/fsdp.py @@ -133,10 +133,6 @@ class FSDPStrategy(ParallelStrategy, _Sharded): assert self.parallel_devices is not None return self.parallel_devices[self.local_rank] - @property - def is_distributed(self) -> bool: - return True - @property def num_nodes(self) -> int: return self._num_nodes @@ -150,7 +146,7 @@ class FSDPStrategy(ParallelStrategy, _Sharded): return len(self.parallel_devices) if self.parallel_devices is not None else 0 @property - def distributed_sampler_kwargs(self) -> Dict: + def distributed_sampler_kwargs(self) -> Dict[str, Any]: return dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank) @property diff --git a/src/lightning_fabric/strategies/parallel.py b/src/lightning_fabric/strategies/parallel.py index c1166e8225..28a5c6fc23 100644 --- a/src/lightning_fabric/strategies/parallel.py +++ b/src/lightning_fabric/strategies/parallel.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from abc import ABC, abstractmethod +from abc import ABC from typing import Any, Dict, List, Optional import torch @@ -41,11 +41,6 @@ class ParallelStrategy(Strategy, ABC): self.parallel_devices = parallel_devices self.cluster_environment: Optional[ClusterEnvironment] = cluster_environment - @property - @abstractmethod - def root_device(self) -> torch.device: - """Return the root device.""" - @property def global_rank(self) -> int: return self.cluster_environment.global_rank() if self.cluster_environment is not None else 0 @@ -75,11 +70,12 @@ class ParallelStrategy(Strategy, ABC): self._parallel_devices = parallel_devices @property - def distributed_sampler_kwargs(self) -> Dict[str, Any]: - return dict( - num_replicas=len(self.parallel_devices) if self.parallel_devices is not None else 0, - rank=self.global_rank, - ) + def distributed_sampler_kwargs(self) -> Optional[Dict[str, Any]]: + """Arguments for the ``DistributedSampler``. + + If this method is not defined, or it returns ``None``, then the ``DistributedSampler`` will not be used. + """ + return {"num_replicas": self.world_size, "rank": self.global_rank} def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor: """Perform a all_gather on all processes.""" diff --git a/src/lightning_fabric/strategies/single_tpu.py b/src/lightning_fabric/strategies/single_tpu.py index 78a3c64c5c..48fb6de79c 100644 --- a/src/lightning_fabric/strategies/single_tpu.py +++ b/src/lightning_fabric/strategies/single_tpu.py @@ -49,10 +49,6 @@ class SingleTPUStrategy(SingleDeviceStrategy): def checkpoint_io(self, io: Optional[CheckpointIO]) -> None: self._checkpoint_io = io - @property - def is_distributed(self) -> bool: - return False - @classmethod def register_strategies(cls, strategy_registry: Dict) -> None: strategy_registry.register("single_tpu", cls, description=f"{cls.__class__.__name__}") diff --git a/src/lightning_fabric/strategies/xla.py b/src/lightning_fabric/strategies/xla.py index cb3acd53c8..08c232e13c 100644 --- a/src/lightning_fabric/strategies/xla.py +++ b/src/lightning_fabric/strategies/xla.py @@ -88,11 +88,7 @@ class XLAStrategy(ParallelStrategy): self._checkpoint_io = io @property - def distributed_sampler_kwargs(self) -> Dict[str, int]: - return dict(num_replicas=self.world_size, rank=self.global_rank) - - @property - def is_distributed(self) -> bool: + def _is_distributed(self) -> bool: import torch_xla.core.xla_env_vars as xenv # HOST_WORLD_SIZE is not set outside the xmp.spawn process @@ -145,13 +141,13 @@ class XLAStrategy(ParallelStrategy): return output def barrier(self, name: Optional[str] = None, *args: Any, **kwargs: Any) -> None: - if self.is_distributed: + if self._is_distributed: import torch_xla.core.xla_model as xm xm.rendezvous(name) def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: - if not self.is_distributed: + if not self._is_distributed: return obj buffer = io.BytesIO() torch.save(obj, buffer) diff --git a/tests/tests_fabric/test_fabric.py b/tests/tests_fabric/test_fabric.py index 467f18e45c..a6f295ab10 100644 --- a/tests/tests_fabric/test_fabric.py +++ b/tests/tests_fabric/test_fabric.py @@ -479,7 +479,7 @@ def test_setup_dataloaders_replace_custom_sampler(strategy): # explicitly asking to replace when a custom sampler is already configured raises an exception fabric = EmptyFabric(accelerator="cpu", strategy=strategy, devices=2) - if fabric._connector.is_distributed: + if hasattr(fabric.strategy, "distributed_sampler_kwargs"): with pytest.raises(TypeError, match="You seem to have configured a sampler in your DataLoader"): fabric.setup_dataloaders(dataloader, replace_sampler=True) @@ -504,7 +504,7 @@ def test_setup_dataloaders_replace_custom_sampler(strategy): def test_setup_dataloaders_replace_standard_sampler(shuffle, strategy): """Test that Fabric replaces the default samplers with DistributedSampler automatically.""" fabric = EmptyFabric(accelerator="cpu", strategy=strategy, devices=2) - is_distributed = fabric._connector.is_distributed + is_distributed = hasattr(fabric.strategy, "distributed_sampler_kwargs") fabric_dataloader = fabric.setup_dataloaders(DataLoader(range(3), shuffle=shuffle)) assert not is_distributed or isinstance(fabric_dataloader.sampler, DistributedSampler)