diff --git a/src/lightning_fabric/CHANGELOG.md b/src/lightning_fabric/CHANGELOG.md index 4abc53d090..ecc541f433 100644 --- a/src/lightning_fabric/CHANGELOG.md +++ b/src/lightning_fabric/CHANGELOG.md @@ -5,6 +5,9 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). +- Renamed `Strategy.reduce` to `Strategy.all_reduce` in all strategies ([#16370](https://github.com/Lightning-AI/lightning/issues/16370)) + + ## [1.9.0] - 2023-01-12 ### Added diff --git a/src/lightning_fabric/strategies/ddp.py b/src/lightning_fabric/strategies/ddp.py index b17545bc0b..91b0b84319 100644 --- a/src/lightning_fabric/strategies/ddp.py +++ b/src/lightning_fabric/strategies/ddp.py @@ -120,7 +120,7 @@ class DDPStrategy(ParallelStrategy): def module_to_device(self, module: Module) -> None: module.to(self.root_device) - def reduce( + def all_reduce( self, tensor: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean" ) -> Tensor: """Reduces a tensor from several distributed processes to one aggregated tensor. diff --git a/src/lightning_fabric/strategies/dp.py b/src/lightning_fabric/strategies/dp.py index 1fcc2b4c67..181ca0bc5b 100644 --- a/src/lightning_fabric/strategies/dp.py +++ b/src/lightning_fabric/strategies/dp.py @@ -65,7 +65,7 @@ class DataParallelStrategy(ParallelStrategy): # DataParallel handles the transfer of batch to the device return batch - def reduce( + def all_reduce( self, collection: TReduce, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean" ) -> TReduce: def mean(t: Tensor) -> Tensor: diff --git a/src/lightning_fabric/strategies/fsdp.py b/src/lightning_fabric/strategies/fsdp.py index 63b7ecfe18..1c776fd366 100644 --- a/src/lightning_fabric/strategies/fsdp.py +++ b/src/lightning_fabric/strategies/fsdp.py @@ -245,7 +245,7 @@ class FSDPStrategy(ParallelStrategy, _Sharded): ): yield - def reduce( + def all_reduce( self, tensor: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean" ) -> Tensor: if isinstance(tensor, Tensor): diff --git a/src/lightning_fabric/strategies/parallel.py b/src/lightning_fabric/strategies/parallel.py index 28a5c6fc23..ffdfa283aa 100644 --- a/src/lightning_fabric/strategies/parallel.py +++ b/src/lightning_fabric/strategies/parallel.py @@ -94,7 +94,7 @@ class ParallelStrategy(Strategy, ABC): bool: The reduced boolean decision. """ decision = torch.tensor(int(decision), device=self.root_device) - decision = self.reduce(decision, reduce_op=ReduceOp.SUM) + decision = self.all_reduce(decision, reduce_op=ReduceOp.SUM) decision = bool(decision == self.world_size) if all else bool(decision) return decision diff --git a/src/lightning_fabric/strategies/single_device.py b/src/lightning_fabric/strategies/single_device.py index c29b49743a..b7b557fe57 100644 --- a/src/lightning_fabric/strategies/single_device.py +++ b/src/lightning_fabric/strategies/single_device.py @@ -53,7 +53,7 @@ class SingleDeviceStrategy(Strategy): def module_to_device(self, module: Module) -> None: module.to(self.root_device) - def reduce(self, tensor: Any | Tensor, *args: Any, **kwargs: Any) -> Any | Tensor: + def all_reduce(self, tensor: Any | Tensor, *args: Any, **kwargs: Any) -> Any | Tensor: """Reduces a tensor from several distributed processes to one aggregated tensor. As this plugin only operates with a single device, the reduction is simply the identity. diff --git a/src/lightning_fabric/strategies/strategy.py b/src/lightning_fabric/strategies/strategy.py index a2bee3c933..07b25bdc92 100644 --- a/src/lightning_fabric/strategies/strategy.py +++ b/src/lightning_fabric/strategies/strategy.py @@ -169,7 +169,17 @@ class Strategy(ABC): return self.precision.optimizer_step(optimizer, **kwargs) @abstractmethod - def reduce( + def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor: + """Perform an all_gather on all processes. + + Args: + tensor: the tensor to all_gather + group: the process group to gather results from + sync_grads: flag that allows users to synchronize gradients for all_gather op + """ + + @abstractmethod + def all_reduce( self, tensor: Union[Tensor, Any], group: Optional[Any] = None, @@ -201,16 +211,6 @@ class Strategy(ABC): src: source rank """ - @abstractmethod - def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor: - """Perform an all_gather on all processes. - - Args: - tensor: the tensor to all_gather - group: the process group to gather results from - sync_grads: flag that allows users to synchronize gradients for all_gather op - """ - def reduce_boolean_decision(self, decision: bool, all: bool = True) -> bool: """Reduce a boolean decision across all processes.""" return decision diff --git a/src/lightning_fabric/strategies/xla.py b/src/lightning_fabric/strategies/xla.py index 08c232e13c..a9ea400ccb 100644 --- a/src/lightning_fabric/strategies/xla.py +++ b/src/lightning_fabric/strategies/xla.py @@ -118,7 +118,25 @@ class XLAStrategy(ParallelStrategy): dataloader.dataset = dataloader._loader.dataset return dataloader - def reduce( + def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor: + """Function to gather a tensor from several distributed processes. + + Args: + tensor: tensor of shape (batch, ...) + group: not available with TPUs + sync_grads: flag that allows users to synchronize gradients for the all_gather operation + Return: + A tensor of shape (world_size, batch, ...) + """ + if isinstance(tensor, Tensor) and tensor.dim() == 0: + tensor = tensor.unsqueeze(0) + + import torch_xla.core.functions as xf + import torch_xla.core.xla_model as xm + + return xf.all_gather(tensor) if sync_grads else xm.all_gather(tensor) + + def all_reduce( self, output: Union[Tensor, Any], group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None ) -> Tensor: if not isinstance(output, Tensor): @@ -160,24 +178,6 @@ class XLAStrategy(ParallelStrategy): obj = torch.load(buffer) return obj - def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor: - """Function to gather a tensor from several distributed processes. - - Args: - tensor: tensor of shape (batch, ...) - group: not available with TPUs - sync_grads: flag that allows users to synchronize gradients for the all_gather operation - Return: - A tensor of shape (world_size, batch, ...) - """ - if isinstance(tensor, Tensor) and tensor.dim() == 0: - tensor = tensor.unsqueeze(0) - - import torch_xla.core.functions as xf - import torch_xla.core.xla_model as xm - - return xf.all_gather(tensor) if sync_grads else xm.all_gather(tensor) - def save_checkpoint( self, checkpoint: Dict[str, Any], filepath: _PATH, storage_options: Optional[Any] = None ) -> None: diff --git a/tests/tests_fabric/strategies/launchers/test_xla.py b/tests/tests_fabric/strategies/launchers/test_xla.py index edd2cfebb9..b99dbdbc3c 100644 --- a/tests/tests_fabric/strategies/launchers/test_xla.py +++ b/tests/tests_fabric/strategies/launchers/test_xla.py @@ -60,14 +60,14 @@ def test_broadcast_on_tpu(): def tpu_reduce_fn(strategy): with pytest.raises(ValueError, match="XLAStrategy only supports"): - strategy.reduce(1, reduce_op="undefined") + strategy.all_reduce(1, reduce_op="undefined") with pytest.raises(ValueError, match="XLAStrategy only supports"): - strategy.reduce(1, reduce_op=ReduceOp.MAX) + strategy.all_reduce(1, reduce_op=ReduceOp.MAX) # it is faster to loop over here than to parameterize the test for reduce_op in ("mean", "AVG", "sum", ReduceOp.SUM): - result = strategy.reduce(1, reduce_op=reduce_op) + result = strategy.all_reduce(1, reduce_op=reduce_op) if isinstance(reduce_op, str) and reduce_op.lower() in ("mean", "avg"): assert result.item() == 1 else: @@ -77,7 +77,7 @@ def tpu_reduce_fn(strategy): @RunIf(tpu=True) @mock.patch.dict(os.environ, os.environ.copy(), clear=True) def test_tpu_reduce(): - """Test tpu spawn reduce operation.""" + """Test tpu spawn all_reduce operation.""" xla_launch(tpu_reduce_fn) diff --git a/tests/tests_fabric/strategies/test_single_device.py b/tests/tests_fabric/strategies/test_single_device.py index d548a92ebd..4e91e9e04b 100644 --- a/tests/tests_fabric/strategies/test_single_device.py +++ b/tests/tests_fabric/strategies/test_single_device.py @@ -42,7 +42,7 @@ def test_single_device_collectives(): strategy = SingleDeviceStrategy() tensor = Mock() assert strategy.all_gather(tensor) == tensor - assert strategy.reduce(tensor) == tensor + assert strategy.all_reduce(tensor) == tensor assert strategy.broadcast(tensor) == tensor diff --git a/tests/tests_fabric/strategies/test_xla.py b/tests/tests_fabric/strategies/test_xla.py index edd2cfebb9..b99dbdbc3c 100644 --- a/tests/tests_fabric/strategies/test_xla.py +++ b/tests/tests_fabric/strategies/test_xla.py @@ -60,14 +60,14 @@ def test_broadcast_on_tpu(): def tpu_reduce_fn(strategy): with pytest.raises(ValueError, match="XLAStrategy only supports"): - strategy.reduce(1, reduce_op="undefined") + strategy.all_reduce(1, reduce_op="undefined") with pytest.raises(ValueError, match="XLAStrategy only supports"): - strategy.reduce(1, reduce_op=ReduceOp.MAX) + strategy.all_reduce(1, reduce_op=ReduceOp.MAX) # it is faster to loop over here than to parameterize the test for reduce_op in ("mean", "AVG", "sum", ReduceOp.SUM): - result = strategy.reduce(1, reduce_op=reduce_op) + result = strategy.all_reduce(1, reduce_op=reduce_op) if isinstance(reduce_op, str) and reduce_op.lower() in ("mean", "avg"): assert result.item() == 1 else: @@ -77,7 +77,7 @@ def tpu_reduce_fn(strategy): @RunIf(tpu=True) @mock.patch.dict(os.environ, os.environ.copy(), clear=True) def test_tpu_reduce(): - """Test tpu spawn reduce operation.""" + """Test tpu spawn all_reduce operation.""" xla_launch(tpu_reduce_fn)