Fabric: Remove `_Connector.is_distributed` (#16327)
This commit is contained in:
parent
4bc2080c71
commit
3c3bff5e6e
|
@ -43,8 +43,6 @@ from lightning_fabric.plugins.precision.double import DoublePrecision
|
|||
from lightning_fabric.plugins.precision.fsdp import FSDPPrecision
|
||||
from lightning_fabric.plugins.precision.precision import _PRECISION_INPUT, _PRECISION_INPUT_INT, _PRECISION_INPUT_STR
|
||||
from lightning_fabric.strategies import (
|
||||
DDPShardedStrategy,
|
||||
DDPStrategy,
|
||||
DeepSpeedStrategy,
|
||||
SingleDeviceStrategy,
|
||||
SingleTPUStrategy,
|
||||
|
@ -547,21 +545,3 @@ class _Connector:
|
|||
if env_value is None:
|
||||
return current
|
||||
return env_value
|
||||
|
||||
@property
|
||||
def is_distributed(self) -> bool:
|
||||
# TODO: deprecate this property
|
||||
# Used for custom plugins.
|
||||
# Custom plugins should implement is_distributed property.
|
||||
if hasattr(self.strategy, "is_distributed") and not isinstance(self.accelerator, TPUAccelerator):
|
||||
return self.strategy.is_distributed
|
||||
distributed_strategy = (
|
||||
DDPStrategy,
|
||||
DDPShardedStrategy,
|
||||
DeepSpeedStrategy,
|
||||
XLAStrategy,
|
||||
)
|
||||
is_distributed = isinstance(self.strategy, distributed_strategy)
|
||||
if isinstance(self.accelerator, TPUAccelerator):
|
||||
is_distributed |= self.strategy.is_distributed
|
||||
return is_distributed
|
||||
|
|
|
@ -69,7 +69,7 @@ class Fabric:
|
|||
accelerator: The hardware to run on. Possible choices are:
|
||||
``"cpu"``, ``"cuda"``, ``"mps"``, ``"gpu"``, ``"tpu"``, ``"auto"``.
|
||||
strategy: Strategy for how to run across multiple devices. Possible choices are:
|
||||
``"dp"``, ``"ddp"``, ``"ddp_spawn"``, ``"deepspeed"``, ``"ddp_sharded"``.
|
||||
``"dp"``, ``"ddp"``, ``"ddp_spawn"``, ``"deepspeed"``, ``"fsdp"``, ``"ddp_sharded"``.
|
||||
devices: Number of devices to train on (``int``), which GPUs to train on (``list`` or ``str``), or ``"auto"``.
|
||||
The value applies per node.
|
||||
num_nodes: Number of GPU nodes for distributed training.
|
||||
|
@ -673,7 +673,7 @@ class Fabric:
|
|||
|
||||
def _requires_distributed_sampler(self, dataloader: DataLoader) -> bool:
|
||||
return (
|
||||
self._connector.is_distributed
|
||||
getattr(self.strategy, "distributed_sampler_kwargs", None) is not None
|
||||
and not isinstance(dataloader.sampler, DistributedSampler)
|
||||
and not has_iterable_dataset(dataloader)
|
||||
)
|
||||
|
|
|
@ -81,10 +81,6 @@ class DDPStrategy(ParallelStrategy):
|
|||
assert self.parallel_devices is not None
|
||||
return self.parallel_devices[self.local_rank]
|
||||
|
||||
@property
|
||||
def is_distributed(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def num_nodes(self) -> int:
|
||||
return self._num_nodes
|
||||
|
|
|
@ -50,6 +50,10 @@ class DataParallelStrategy(ParallelStrategy):
|
|||
assert self.parallel_devices is not None
|
||||
return self.parallel_devices[0]
|
||||
|
||||
@property
|
||||
def distributed_sampler_kwargs(self) -> None:
|
||||
return None
|
||||
|
||||
def setup_module(self, module: Module) -> DataParallel:
|
||||
"""Wraps the given model into a :class:`~torch.nn.parallel.DataParallel` module."""
|
||||
return DataParallel(module=module, device_ids=self.parallel_devices)
|
||||
|
|
|
@ -133,10 +133,6 @@ class FSDPStrategy(ParallelStrategy, _Sharded):
|
|||
assert self.parallel_devices is not None
|
||||
return self.parallel_devices[self.local_rank]
|
||||
|
||||
@property
|
||||
def is_distributed(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def num_nodes(self) -> int:
|
||||
return self._num_nodes
|
||||
|
@ -150,7 +146,7 @@ class FSDPStrategy(ParallelStrategy, _Sharded):
|
|||
return len(self.parallel_devices) if self.parallel_devices is not None else 0
|
||||
|
||||
@property
|
||||
def distributed_sampler_kwargs(self) -> Dict:
|
||||
def distributed_sampler_kwargs(self) -> Dict[str, Any]:
|
||||
return dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank)
|
||||
|
||||
@property
|
||||
|
|
|
@ -11,7 +11,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from abc import ABC, abstractmethod
|
||||
from abc import ABC
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
|
@ -41,11 +41,6 @@ class ParallelStrategy(Strategy, ABC):
|
|||
self.parallel_devices = parallel_devices
|
||||
self.cluster_environment: Optional[ClusterEnvironment] = cluster_environment
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def root_device(self) -> torch.device:
|
||||
"""Return the root device."""
|
||||
|
||||
@property
|
||||
def global_rank(self) -> int:
|
||||
return self.cluster_environment.global_rank() if self.cluster_environment is not None else 0
|
||||
|
@ -75,11 +70,12 @@ class ParallelStrategy(Strategy, ABC):
|
|||
self._parallel_devices = parallel_devices
|
||||
|
||||
@property
|
||||
def distributed_sampler_kwargs(self) -> Dict[str, Any]:
|
||||
return dict(
|
||||
num_replicas=len(self.parallel_devices) if self.parallel_devices is not None else 0,
|
||||
rank=self.global_rank,
|
||||
)
|
||||
def distributed_sampler_kwargs(self) -> Optional[Dict[str, Any]]:
|
||||
"""Arguments for the ``DistributedSampler``.
|
||||
|
||||
If this method is not defined, or it returns ``None``, then the ``DistributedSampler`` will not be used.
|
||||
"""
|
||||
return {"num_replicas": self.world_size, "rank": self.global_rank}
|
||||
|
||||
def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor:
|
||||
"""Perform a all_gather on all processes."""
|
||||
|
|
|
@ -49,10 +49,6 @@ class SingleTPUStrategy(SingleDeviceStrategy):
|
|||
def checkpoint_io(self, io: Optional[CheckpointIO]) -> None:
|
||||
self._checkpoint_io = io
|
||||
|
||||
@property
|
||||
def is_distributed(self) -> bool:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def register_strategies(cls, strategy_registry: Dict) -> None:
|
||||
strategy_registry.register("single_tpu", cls, description=f"{cls.__class__.__name__}")
|
||||
|
|
|
@ -88,11 +88,7 @@ class XLAStrategy(ParallelStrategy):
|
|||
self._checkpoint_io = io
|
||||
|
||||
@property
|
||||
def distributed_sampler_kwargs(self) -> Dict[str, int]:
|
||||
return dict(num_replicas=self.world_size, rank=self.global_rank)
|
||||
|
||||
@property
|
||||
def is_distributed(self) -> bool:
|
||||
def _is_distributed(self) -> bool:
|
||||
import torch_xla.core.xla_env_vars as xenv
|
||||
|
||||
# HOST_WORLD_SIZE is not set outside the xmp.spawn process
|
||||
|
@ -145,13 +141,13 @@ class XLAStrategy(ParallelStrategy):
|
|||
return output
|
||||
|
||||
def barrier(self, name: Optional[str] = None, *args: Any, **kwargs: Any) -> None:
|
||||
if self.is_distributed:
|
||||
if self._is_distributed:
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
xm.rendezvous(name)
|
||||
|
||||
def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
|
||||
if not self.is_distributed:
|
||||
if not self._is_distributed:
|
||||
return obj
|
||||
buffer = io.BytesIO()
|
||||
torch.save(obj, buffer)
|
||||
|
|
|
@ -479,7 +479,7 @@ def test_setup_dataloaders_replace_custom_sampler(strategy):
|
|||
|
||||
# explicitly asking to replace when a custom sampler is already configured raises an exception
|
||||
fabric = EmptyFabric(accelerator="cpu", strategy=strategy, devices=2)
|
||||
if fabric._connector.is_distributed:
|
||||
if hasattr(fabric.strategy, "distributed_sampler_kwargs"):
|
||||
with pytest.raises(TypeError, match="You seem to have configured a sampler in your DataLoader"):
|
||||
fabric.setup_dataloaders(dataloader, replace_sampler=True)
|
||||
|
||||
|
@ -504,7 +504,7 @@ def test_setup_dataloaders_replace_custom_sampler(strategy):
|
|||
def test_setup_dataloaders_replace_standard_sampler(shuffle, strategy):
|
||||
"""Test that Fabric replaces the default samplers with DistributedSampler automatically."""
|
||||
fabric = EmptyFabric(accelerator="cpu", strategy=strategy, devices=2)
|
||||
is_distributed = fabric._connector.is_distributed
|
||||
is_distributed = hasattr(fabric.strategy, "distributed_sampler_kwargs")
|
||||
fabric_dataloader = fabric.setup_dataloaders(DataLoader(range(3), shuffle=shuffle))
|
||||
assert not is_distributed or isinstance(fabric_dataloader.sampler, DistributedSampler)
|
||||
|
||||
|
|
Loading…
Reference in New Issue