Fix TPU Spawn gather (#6896)
This commit is contained in:
parent
2e53fd3332
commit
55525031c6
|
@ -231,6 +231,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Fixed `EarlyStopping` logic when `min_epochs` or `min_steps` requirement is not met ([#6705](https://github.com/PyTorchLightning/pytorch-lightning/pull/6705))
|
||||
|
||||
|
||||
- Fixed TPU Spawn all gather ([#6896](https://github.com/PyTorchLightning/pytorch-lightning/pull/6896))
|
||||
|
||||
|
||||
- Fixed `--gpus` default for parser returned by `Trainer.add_argparse_args` ([#6898](https://github.com/PyTorchLightning/pytorch-lightning/pull/6898))
|
||||
|
||||
|
||||
|
|
|
@ -11,9 +11,8 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
|
||||
from typing import Any, Callable, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from pytorch_lightning.accelerators.accelerator import Accelerator
|
||||
|
@ -57,21 +56,6 @@ class TPUAccelerator(Accelerator):
|
|||
) -> None:
|
||||
xm.optimizer_step(optimizer, barrier=False, optimizer_args={'closure': lambda_closure, **kwargs})
|
||||
|
||||
def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor:
|
||||
"""
|
||||
Function to gather a tensor from several distributed processes
|
||||
Args:
|
||||
tensor: tensor of shape (batch, ...)
|
||||
group: not available with TPUs
|
||||
sync_grads: not available with TPUs
|
||||
Return:
|
||||
A tensor of shape (world_size, batch, ...)
|
||||
"""
|
||||
# todo: Add support for backward with all_gather
|
||||
if isinstance(self.training_type_plugin, TPUSpawnPlugin) and self.training_type_plugin.is_distributed:
|
||||
return xm.all_gather(tensor).view(-1, *tensor.shape)
|
||||
return tensor
|
||||
|
||||
def clip_gradients(self, optimizer: Optimizer, clip_val: Union[float, int], norm_type: float = 2.0):
|
||||
|
||||
model = self.lightning_module
|
||||
|
|
|
@ -195,14 +195,14 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
|
|||
return obj
|
||||
|
||||
def reduce_boolean_decision(self, decision: bool) -> bool:
|
||||
decision = torch.tensor(int(decision), device=self.device)
|
||||
decision = self.reduce(decision, "sum")
|
||||
decision = torch.tensor(int(decision), device=self.lightning_module.device)
|
||||
decision = self.reduce(decision, reduce_op="sum")
|
||||
decision = bool(decision == self.world_size)
|
||||
return decision
|
||||
|
||||
def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None):
|
||||
if not isinstance(output, torch.Tensor):
|
||||
output = torch.tensor(output, device=self.device)
|
||||
output = torch.tensor(output, device=self.lightning_module.device)
|
||||
|
||||
_invalid_reduce_op = isinstance(reduce_op, ReduceOp) and reduce_op != ReduceOp.SUM
|
||||
_invalid_reduce_op_str = isinstance(reduce_op, str) and reduce_op.lower() not in ("sum", "mean", "avg")
|
||||
|
@ -267,3 +267,15 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
|
|||
if _OMEGACONF_AVAILABLE:
|
||||
checkpoint = apply_to_collection(checkpoint, (DictConfig, ListConfig), OmegaConf.to_container)
|
||||
self.save({k: v for k, v in checkpoint.items() if k != "callbacks"}, filepath)
|
||||
|
||||
def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor:
|
||||
"""
|
||||
Function to gather a tensor from several distributed processes
|
||||
Args:
|
||||
tensor: tensor of shape (batch, ...)
|
||||
group: not available with TPUs
|
||||
sync_grads: not available with TPUs
|
||||
Return:
|
||||
A tensor of shape (world_size, batch, ...)
|
||||
"""
|
||||
return xm.all_gather(tensor.unsqueeze(0))
|
||||
|
|
|
@ -229,8 +229,8 @@ def test_tpu_clip_grad_by_value(tmpdir):
|
|||
progress_bar_refresh_rate=0,
|
||||
max_epochs=4,
|
||||
tpu_cores=1,
|
||||
limit_train_batches=4,
|
||||
limit_val_batches=4,
|
||||
limit_train_batches=10,
|
||||
limit_val_batches=10,
|
||||
gradient_clip_val=0.5,
|
||||
gradient_clip_algorithm='value'
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue