Rename `Strategy.reduce` to `Strategy.all_reduce` in Lite (#16370)

This commit is contained in:
Adrian Wälchli 2023-01-16 14:17:45 +01:00 committed by GitHub
parent 596494b719
commit f1e0fda879
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 47 additions and 44 deletions

View File

@ -5,6 +5,9 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Renamed `Strategy.reduce` to `Strategy.all_reduce` in all strategies ([#16370](https://github.com/Lightning-AI/lightning/issues/16370))
## [1.9.0] - 2023-01-12
### Added

View File

@ -120,7 +120,7 @@ class DDPStrategy(ParallelStrategy):
def module_to_device(self, module: Module) -> None:
module.to(self.root_device)
def reduce(
def all_reduce(
self, tensor: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean"
) -> Tensor:
"""Reduces a tensor from several distributed processes to one aggregated tensor.

View File

@ -65,7 +65,7 @@ class DataParallelStrategy(ParallelStrategy):
# DataParallel handles the transfer of batch to the device
return batch
def reduce(
def all_reduce(
self, collection: TReduce, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean"
) -> TReduce:
def mean(t: Tensor) -> Tensor:

View File

@ -245,7 +245,7 @@ class FSDPStrategy(ParallelStrategy, _Sharded):
):
yield
def reduce(
def all_reduce(
self, tensor: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean"
) -> Tensor:
if isinstance(tensor, Tensor):

View File

@ -94,7 +94,7 @@ class ParallelStrategy(Strategy, ABC):
bool: The reduced boolean decision.
"""
decision = torch.tensor(int(decision), device=self.root_device)
decision = self.reduce(decision, reduce_op=ReduceOp.SUM)
decision = self.all_reduce(decision, reduce_op=ReduceOp.SUM)
decision = bool(decision == self.world_size) if all else bool(decision)
return decision

View File

@ -53,7 +53,7 @@ class SingleDeviceStrategy(Strategy):
def module_to_device(self, module: Module) -> None:
module.to(self.root_device)
def reduce(self, tensor: Any | Tensor, *args: Any, **kwargs: Any) -> Any | Tensor:
def all_reduce(self, tensor: Any | Tensor, *args: Any, **kwargs: Any) -> Any | Tensor:
"""Reduces a tensor from several distributed processes to one aggregated tensor. As this plugin only
operates with a single device, the reduction is simply the identity.

View File

@ -169,7 +169,17 @@ class Strategy(ABC):
return self.precision.optimizer_step(optimizer, **kwargs)
@abstractmethod
def reduce(
def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor:
"""Perform an all_gather on all processes.
Args:
tensor: the tensor to all_gather
group: the process group to gather results from
sync_grads: flag that allows users to synchronize gradients for all_gather op
"""
@abstractmethod
def all_reduce(
self,
tensor: Union[Tensor, Any],
group: Optional[Any] = None,
@ -201,16 +211,6 @@ class Strategy(ABC):
src: source rank
"""
@abstractmethod
def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor:
"""Perform an all_gather on all processes.
Args:
tensor: the tensor to all_gather
group: the process group to gather results from
sync_grads: flag that allows users to synchronize gradients for all_gather op
"""
def reduce_boolean_decision(self, decision: bool, all: bool = True) -> bool:
"""Reduce a boolean decision across all processes."""
return decision

View File

@ -118,7 +118,25 @@ class XLAStrategy(ParallelStrategy):
dataloader.dataset = dataloader._loader.dataset
return dataloader
def reduce(
def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor:
"""Function to gather a tensor from several distributed processes.
Args:
tensor: tensor of shape (batch, ...)
group: not available with TPUs
sync_grads: flag that allows users to synchronize gradients for the all_gather operation
Return:
A tensor of shape (world_size, batch, ...)
"""
if isinstance(tensor, Tensor) and tensor.dim() == 0:
tensor = tensor.unsqueeze(0)
import torch_xla.core.functions as xf
import torch_xla.core.xla_model as xm
return xf.all_gather(tensor) if sync_grads else xm.all_gather(tensor)
def all_reduce(
self, output: Union[Tensor, Any], group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None
) -> Tensor:
if not isinstance(output, Tensor):
@ -160,24 +178,6 @@ class XLAStrategy(ParallelStrategy):
obj = torch.load(buffer)
return obj
def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor:
"""Function to gather a tensor from several distributed processes.
Args:
tensor: tensor of shape (batch, ...)
group: not available with TPUs
sync_grads: flag that allows users to synchronize gradients for the all_gather operation
Return:
A tensor of shape (world_size, batch, ...)
"""
if isinstance(tensor, Tensor) and tensor.dim() == 0:
tensor = tensor.unsqueeze(0)
import torch_xla.core.functions as xf
import torch_xla.core.xla_model as xm
return xf.all_gather(tensor) if sync_grads else xm.all_gather(tensor)
def save_checkpoint(
self, checkpoint: Dict[str, Any], filepath: _PATH, storage_options: Optional[Any] = None
) -> None:

View File

@ -60,14 +60,14 @@ def test_broadcast_on_tpu():
def tpu_reduce_fn(strategy):
with pytest.raises(ValueError, match="XLAStrategy only supports"):
strategy.reduce(1, reduce_op="undefined")
strategy.all_reduce(1, reduce_op="undefined")
with pytest.raises(ValueError, match="XLAStrategy only supports"):
strategy.reduce(1, reduce_op=ReduceOp.MAX)
strategy.all_reduce(1, reduce_op=ReduceOp.MAX)
# it is faster to loop over here than to parameterize the test
for reduce_op in ("mean", "AVG", "sum", ReduceOp.SUM):
result = strategy.reduce(1, reduce_op=reduce_op)
result = strategy.all_reduce(1, reduce_op=reduce_op)
if isinstance(reduce_op, str) and reduce_op.lower() in ("mean", "avg"):
assert result.item() == 1
else:
@ -77,7 +77,7 @@ def tpu_reduce_fn(strategy):
@RunIf(tpu=True)
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
def test_tpu_reduce():
"""Test tpu spawn reduce operation."""
"""Test tpu spawn all_reduce operation."""
xla_launch(tpu_reduce_fn)

View File

@ -42,7 +42,7 @@ def test_single_device_collectives():
strategy = SingleDeviceStrategy()
tensor = Mock()
assert strategy.all_gather(tensor) == tensor
assert strategy.reduce(tensor) == tensor
assert strategy.all_reduce(tensor) == tensor
assert strategy.broadcast(tensor) == tensor

View File

@ -60,14 +60,14 @@ def test_broadcast_on_tpu():
def tpu_reduce_fn(strategy):
with pytest.raises(ValueError, match="XLAStrategy only supports"):
strategy.reduce(1, reduce_op="undefined")
strategy.all_reduce(1, reduce_op="undefined")
with pytest.raises(ValueError, match="XLAStrategy only supports"):
strategy.reduce(1, reduce_op=ReduceOp.MAX)
strategy.all_reduce(1, reduce_op=ReduceOp.MAX)
# it is faster to loop over here than to parameterize the test
for reduce_op in ("mean", "AVG", "sum", ReduceOp.SUM):
result = strategy.reduce(1, reduce_op=reduce_op)
result = strategy.all_reduce(1, reduce_op=reduce_op)
if isinstance(reduce_op, str) and reduce_op.lower() in ("mean", "avg"):
assert result.item() == 1
else:
@ -77,7 +77,7 @@ def tpu_reduce_fn(strategy):
@RunIf(tpu=True)
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
def test_tpu_reduce():
"""Test tpu spawn reduce operation."""
"""Test tpu spawn all_reduce operation."""
xla_launch(tpu_reduce_fn)