Fabric: Remove `_Connector.is_distributed` (#16327)

This commit is contained in:
Carlos Mocholí 2023-01-11 16:29:51 +01:00 committed by GitHub
parent 4bc2080c71
commit 3c3bff5e6e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 19 additions and 55 deletions

View File

@ -43,8 +43,6 @@ from lightning_fabric.plugins.precision.double import DoublePrecision
from lightning_fabric.plugins.precision.fsdp import FSDPPrecision
from lightning_fabric.plugins.precision.precision import _PRECISION_INPUT, _PRECISION_INPUT_INT, _PRECISION_INPUT_STR
from lightning_fabric.strategies import (
DDPShardedStrategy,
DDPStrategy,
DeepSpeedStrategy,
SingleDeviceStrategy,
SingleTPUStrategy,
@ -547,21 +545,3 @@ class _Connector:
if env_value is None:
return current
return env_value
@property
def is_distributed(self) -> bool:
# TODO: deprecate this property
# Used for custom plugins.
# Custom plugins should implement is_distributed property.
if hasattr(self.strategy, "is_distributed") and not isinstance(self.accelerator, TPUAccelerator):
return self.strategy.is_distributed
distributed_strategy = (
DDPStrategy,
DDPShardedStrategy,
DeepSpeedStrategy,
XLAStrategy,
)
is_distributed = isinstance(self.strategy, distributed_strategy)
if isinstance(self.accelerator, TPUAccelerator):
is_distributed |= self.strategy.is_distributed
return is_distributed

View File

@ -69,7 +69,7 @@ class Fabric:
accelerator: The hardware to run on. Possible choices are:
``"cpu"``, ``"cuda"``, ``"mps"``, ``"gpu"``, ``"tpu"``, ``"auto"``.
strategy: Strategy for how to run across multiple devices. Possible choices are:
``"dp"``, ``"ddp"``, ``"ddp_spawn"``, ``"deepspeed"``, ``"ddp_sharded"``.
``"dp"``, ``"ddp"``, ``"ddp_spawn"``, ``"deepspeed"``, ``"fsdp"``, ``"ddp_sharded"``.
devices: Number of devices to train on (``int``), which GPUs to train on (``list`` or ``str``), or ``"auto"``.
The value applies per node.
num_nodes: Number of GPU nodes for distributed training.
@ -673,7 +673,7 @@ class Fabric:
def _requires_distributed_sampler(self, dataloader: DataLoader) -> bool:
return (
self._connector.is_distributed
getattr(self.strategy, "distributed_sampler_kwargs", None) is not None
and not isinstance(dataloader.sampler, DistributedSampler)
and not has_iterable_dataset(dataloader)
)

View File

@ -81,10 +81,6 @@ class DDPStrategy(ParallelStrategy):
assert self.parallel_devices is not None
return self.parallel_devices[self.local_rank]
@property
def is_distributed(self) -> bool:
return True
@property
def num_nodes(self) -> int:
return self._num_nodes

View File

@ -50,6 +50,10 @@ class DataParallelStrategy(ParallelStrategy):
assert self.parallel_devices is not None
return self.parallel_devices[0]
@property
def distributed_sampler_kwargs(self) -> None:
return None
def setup_module(self, module: Module) -> DataParallel:
"""Wraps the given model into a :class:`~torch.nn.parallel.DataParallel` module."""
return DataParallel(module=module, device_ids=self.parallel_devices)

View File

@ -133,10 +133,6 @@ class FSDPStrategy(ParallelStrategy, _Sharded):
assert self.parallel_devices is not None
return self.parallel_devices[self.local_rank]
@property
def is_distributed(self) -> bool:
return True
@property
def num_nodes(self) -> int:
return self._num_nodes
@ -150,7 +146,7 @@ class FSDPStrategy(ParallelStrategy, _Sharded):
return len(self.parallel_devices) if self.parallel_devices is not None else 0
@property
def distributed_sampler_kwargs(self) -> Dict:
def distributed_sampler_kwargs(self) -> Dict[str, Any]:
return dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank)
@property

View File

@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from abc import ABC
from typing import Any, Dict, List, Optional
import torch
@ -41,11 +41,6 @@ class ParallelStrategy(Strategy, ABC):
self.parallel_devices = parallel_devices
self.cluster_environment: Optional[ClusterEnvironment] = cluster_environment
@property
@abstractmethod
def root_device(self) -> torch.device:
"""Return the root device."""
@property
def global_rank(self) -> int:
return self.cluster_environment.global_rank() if self.cluster_environment is not None else 0
@ -75,11 +70,12 @@ class ParallelStrategy(Strategy, ABC):
self._parallel_devices = parallel_devices
@property
def distributed_sampler_kwargs(self) -> Dict[str, Any]:
return dict(
num_replicas=len(self.parallel_devices) if self.parallel_devices is not None else 0,
rank=self.global_rank,
)
def distributed_sampler_kwargs(self) -> Optional[Dict[str, Any]]:
"""Arguments for the ``DistributedSampler``.
If this method is not defined, or it returns ``None``, then the ``DistributedSampler`` will not be used.
"""
return {"num_replicas": self.world_size, "rank": self.global_rank}
def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor:
"""Perform a all_gather on all processes."""

View File

@ -49,10 +49,6 @@ class SingleTPUStrategy(SingleDeviceStrategy):
def checkpoint_io(self, io: Optional[CheckpointIO]) -> None:
self._checkpoint_io = io
@property
def is_distributed(self) -> bool:
return False
@classmethod
def register_strategies(cls, strategy_registry: Dict) -> None:
strategy_registry.register("single_tpu", cls, description=f"{cls.__class__.__name__}")

View File

@ -88,11 +88,7 @@ class XLAStrategy(ParallelStrategy):
self._checkpoint_io = io
@property
def distributed_sampler_kwargs(self) -> Dict[str, int]:
return dict(num_replicas=self.world_size, rank=self.global_rank)
@property
def is_distributed(self) -> bool:
def _is_distributed(self) -> bool:
import torch_xla.core.xla_env_vars as xenv
# HOST_WORLD_SIZE is not set outside the xmp.spawn process
@ -145,13 +141,13 @@ class XLAStrategy(ParallelStrategy):
return output
def barrier(self, name: Optional[str] = None, *args: Any, **kwargs: Any) -> None:
if self.is_distributed:
if self._is_distributed:
import torch_xla.core.xla_model as xm
xm.rendezvous(name)
def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
if not self.is_distributed:
if not self._is_distributed:
return obj
buffer = io.BytesIO()
torch.save(obj, buffer)

View File

@ -479,7 +479,7 @@ def test_setup_dataloaders_replace_custom_sampler(strategy):
# explicitly asking to replace when a custom sampler is already configured raises an exception
fabric = EmptyFabric(accelerator="cpu", strategy=strategy, devices=2)
if fabric._connector.is_distributed:
if hasattr(fabric.strategy, "distributed_sampler_kwargs"):
with pytest.raises(TypeError, match="You seem to have configured a sampler in your DataLoader"):
fabric.setup_dataloaders(dataloader, replace_sampler=True)
@ -504,7 +504,7 @@ def test_setup_dataloaders_replace_custom_sampler(strategy):
def test_setup_dataloaders_replace_standard_sampler(shuffle, strategy):
"""Test that Fabric replaces the default samplers with DistributedSampler automatically."""
fabric = EmptyFabric(accelerator="cpu", strategy=strategy, devices=2)
is_distributed = fabric._connector.is_distributed
is_distributed = hasattr(fabric.strategy, "distributed_sampler_kwargs")
fabric_dataloader = fabric.setup_dataloaders(DataLoader(range(3), shuffle=shuffle))
assert not is_distributed or isinstance(fabric_dataloader.sampler, DistributedSampler)