Rename `Strategy.reduce` to `Strategy.all_reduce` in Lite (#16370)
This commit is contained in:
parent
596494b719
commit
f1e0fda879
|
@ -5,6 +5,9 @@ All notable changes to this project will be documented in this file.
|
|||
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
||||
|
||||
|
||||
- Renamed `Strategy.reduce` to `Strategy.all_reduce` in all strategies ([#16370](https://github.com/Lightning-AI/lightning/issues/16370))
|
||||
|
||||
|
||||
## [1.9.0] - 2023-01-12
|
||||
|
||||
### Added
|
||||
|
|
|
@ -120,7 +120,7 @@ class DDPStrategy(ParallelStrategy):
|
|||
def module_to_device(self, module: Module) -> None:
|
||||
module.to(self.root_device)
|
||||
|
||||
def reduce(
|
||||
def all_reduce(
|
||||
self, tensor: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean"
|
||||
) -> Tensor:
|
||||
"""Reduces a tensor from several distributed processes to one aggregated tensor.
|
||||
|
|
|
@ -65,7 +65,7 @@ class DataParallelStrategy(ParallelStrategy):
|
|||
# DataParallel handles the transfer of batch to the device
|
||||
return batch
|
||||
|
||||
def reduce(
|
||||
def all_reduce(
|
||||
self, collection: TReduce, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean"
|
||||
) -> TReduce:
|
||||
def mean(t: Tensor) -> Tensor:
|
||||
|
|
|
@ -245,7 +245,7 @@ class FSDPStrategy(ParallelStrategy, _Sharded):
|
|||
):
|
||||
yield
|
||||
|
||||
def reduce(
|
||||
def all_reduce(
|
||||
self, tensor: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean"
|
||||
) -> Tensor:
|
||||
if isinstance(tensor, Tensor):
|
||||
|
|
|
@ -94,7 +94,7 @@ class ParallelStrategy(Strategy, ABC):
|
|||
bool: The reduced boolean decision.
|
||||
"""
|
||||
decision = torch.tensor(int(decision), device=self.root_device)
|
||||
decision = self.reduce(decision, reduce_op=ReduceOp.SUM)
|
||||
decision = self.all_reduce(decision, reduce_op=ReduceOp.SUM)
|
||||
decision = bool(decision == self.world_size) if all else bool(decision)
|
||||
return decision
|
||||
|
||||
|
|
|
@ -53,7 +53,7 @@ class SingleDeviceStrategy(Strategy):
|
|||
def module_to_device(self, module: Module) -> None:
|
||||
module.to(self.root_device)
|
||||
|
||||
def reduce(self, tensor: Any | Tensor, *args: Any, **kwargs: Any) -> Any | Tensor:
|
||||
def all_reduce(self, tensor: Any | Tensor, *args: Any, **kwargs: Any) -> Any | Tensor:
|
||||
"""Reduces a tensor from several distributed processes to one aggregated tensor. As this plugin only
|
||||
operates with a single device, the reduction is simply the identity.
|
||||
|
||||
|
|
|
@ -169,7 +169,17 @@ class Strategy(ABC):
|
|||
return self.precision.optimizer_step(optimizer, **kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def reduce(
|
||||
def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor:
|
||||
"""Perform an all_gather on all processes.
|
||||
|
||||
Args:
|
||||
tensor: the tensor to all_gather
|
||||
group: the process group to gather results from
|
||||
sync_grads: flag that allows users to synchronize gradients for all_gather op
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def all_reduce(
|
||||
self,
|
||||
tensor: Union[Tensor, Any],
|
||||
group: Optional[Any] = None,
|
||||
|
@ -201,16 +211,6 @@ class Strategy(ABC):
|
|||
src: source rank
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor:
|
||||
"""Perform an all_gather on all processes.
|
||||
|
||||
Args:
|
||||
tensor: the tensor to all_gather
|
||||
group: the process group to gather results from
|
||||
sync_grads: flag that allows users to synchronize gradients for all_gather op
|
||||
"""
|
||||
|
||||
def reduce_boolean_decision(self, decision: bool, all: bool = True) -> bool:
|
||||
"""Reduce a boolean decision across all processes."""
|
||||
return decision
|
||||
|
|
|
@ -118,7 +118,25 @@ class XLAStrategy(ParallelStrategy):
|
|||
dataloader.dataset = dataloader._loader.dataset
|
||||
return dataloader
|
||||
|
||||
def reduce(
|
||||
def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor:
|
||||
"""Function to gather a tensor from several distributed processes.
|
||||
|
||||
Args:
|
||||
tensor: tensor of shape (batch, ...)
|
||||
group: not available with TPUs
|
||||
sync_grads: flag that allows users to synchronize gradients for the all_gather operation
|
||||
Return:
|
||||
A tensor of shape (world_size, batch, ...)
|
||||
"""
|
||||
if isinstance(tensor, Tensor) and tensor.dim() == 0:
|
||||
tensor = tensor.unsqueeze(0)
|
||||
|
||||
import torch_xla.core.functions as xf
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
return xf.all_gather(tensor) if sync_grads else xm.all_gather(tensor)
|
||||
|
||||
def all_reduce(
|
||||
self, output: Union[Tensor, Any], group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None
|
||||
) -> Tensor:
|
||||
if not isinstance(output, Tensor):
|
||||
|
@ -160,24 +178,6 @@ class XLAStrategy(ParallelStrategy):
|
|||
obj = torch.load(buffer)
|
||||
return obj
|
||||
|
||||
def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor:
|
||||
"""Function to gather a tensor from several distributed processes.
|
||||
|
||||
Args:
|
||||
tensor: tensor of shape (batch, ...)
|
||||
group: not available with TPUs
|
||||
sync_grads: flag that allows users to synchronize gradients for the all_gather operation
|
||||
Return:
|
||||
A tensor of shape (world_size, batch, ...)
|
||||
"""
|
||||
if isinstance(tensor, Tensor) and tensor.dim() == 0:
|
||||
tensor = tensor.unsqueeze(0)
|
||||
|
||||
import torch_xla.core.functions as xf
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
return xf.all_gather(tensor) if sync_grads else xm.all_gather(tensor)
|
||||
|
||||
def save_checkpoint(
|
||||
self, checkpoint: Dict[str, Any], filepath: _PATH, storage_options: Optional[Any] = None
|
||||
) -> None:
|
||||
|
|
|
@ -60,14 +60,14 @@ def test_broadcast_on_tpu():
|
|||
|
||||
def tpu_reduce_fn(strategy):
|
||||
with pytest.raises(ValueError, match="XLAStrategy only supports"):
|
||||
strategy.reduce(1, reduce_op="undefined")
|
||||
strategy.all_reduce(1, reduce_op="undefined")
|
||||
|
||||
with pytest.raises(ValueError, match="XLAStrategy only supports"):
|
||||
strategy.reduce(1, reduce_op=ReduceOp.MAX)
|
||||
strategy.all_reduce(1, reduce_op=ReduceOp.MAX)
|
||||
|
||||
# it is faster to loop over here than to parameterize the test
|
||||
for reduce_op in ("mean", "AVG", "sum", ReduceOp.SUM):
|
||||
result = strategy.reduce(1, reduce_op=reduce_op)
|
||||
result = strategy.all_reduce(1, reduce_op=reduce_op)
|
||||
if isinstance(reduce_op, str) and reduce_op.lower() in ("mean", "avg"):
|
||||
assert result.item() == 1
|
||||
else:
|
||||
|
@ -77,7 +77,7 @@ def tpu_reduce_fn(strategy):
|
|||
@RunIf(tpu=True)
|
||||
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
|
||||
def test_tpu_reduce():
|
||||
"""Test tpu spawn reduce operation."""
|
||||
"""Test tpu spawn all_reduce operation."""
|
||||
xla_launch(tpu_reduce_fn)
|
||||
|
||||
|
||||
|
|
|
@ -42,7 +42,7 @@ def test_single_device_collectives():
|
|||
strategy = SingleDeviceStrategy()
|
||||
tensor = Mock()
|
||||
assert strategy.all_gather(tensor) == tensor
|
||||
assert strategy.reduce(tensor) == tensor
|
||||
assert strategy.all_reduce(tensor) == tensor
|
||||
assert strategy.broadcast(tensor) == tensor
|
||||
|
||||
|
||||
|
|
|
@ -60,14 +60,14 @@ def test_broadcast_on_tpu():
|
|||
|
||||
def tpu_reduce_fn(strategy):
|
||||
with pytest.raises(ValueError, match="XLAStrategy only supports"):
|
||||
strategy.reduce(1, reduce_op="undefined")
|
||||
strategy.all_reduce(1, reduce_op="undefined")
|
||||
|
||||
with pytest.raises(ValueError, match="XLAStrategy only supports"):
|
||||
strategy.reduce(1, reduce_op=ReduceOp.MAX)
|
||||
strategy.all_reduce(1, reduce_op=ReduceOp.MAX)
|
||||
|
||||
# it is faster to loop over here than to parameterize the test
|
||||
for reduce_op in ("mean", "AVG", "sum", ReduceOp.SUM):
|
||||
result = strategy.reduce(1, reduce_op=reduce_op)
|
||||
result = strategy.all_reduce(1, reduce_op=reduce_op)
|
||||
if isinstance(reduce_op, str) and reduce_op.lower() in ("mean", "avg"):
|
||||
assert result.item() == 1
|
||||
else:
|
||||
|
@ -77,7 +77,7 @@ def tpu_reduce_fn(strategy):
|
|||
@RunIf(tpu=True)
|
||||
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
|
||||
def test_tpu_reduce():
|
||||
"""Test tpu spawn reduce operation."""
|
||||
"""Test tpu spawn all_reduce operation."""
|
||||
xla_launch(tpu_reduce_fn)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue