Fix TPU Spawn gather (#6896)

This commit is contained in:
Kaushik B 2021-04-09 18:30:59 +05:30 committed by GitHub
parent 2e53fd3332
commit 55525031c6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 21 additions and 22 deletions

View File

@ -231,6 +231,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `EarlyStopping` logic when `min_epochs` or `min_steps` requirement is not met ([#6705](https://github.com/PyTorchLightning/pytorch-lightning/pull/6705))
- Fixed TPU Spawn all gather ([#6896](https://github.com/PyTorchLightning/pytorch-lightning/pull/6896))
- Fixed `--gpus` default for parser returned by `Trainer.add_argparse_args` ([#6898](https://github.com/PyTorchLightning/pytorch-lightning/pull/6898))

View File

@ -11,9 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
from typing import Any, Callable, TYPE_CHECKING, Union
import torch
from torch.optim import Optimizer
from pytorch_lightning.accelerators.accelerator import Accelerator
@ -57,21 +56,6 @@ class TPUAccelerator(Accelerator):
) -> None:
xm.optimizer_step(optimizer, barrier=False, optimizer_args={'closure': lambda_closure, **kwargs})
def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor:
"""
Function to gather a tensor from several distributed processes
Args:
tensor: tensor of shape (batch, ...)
group: not available with TPUs
sync_grads: not available with TPUs
Return:
A tensor of shape (world_size, batch, ...)
"""
# todo: Add support for backward with all_gather
if isinstance(self.training_type_plugin, TPUSpawnPlugin) and self.training_type_plugin.is_distributed:
return xm.all_gather(tensor).view(-1, *tensor.shape)
return tensor
def clip_gradients(self, optimizer: Optimizer, clip_val: Union[float, int], norm_type: float = 2.0):
model = self.lightning_module

View File

@ -195,14 +195,14 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
return obj
def reduce_boolean_decision(self, decision: bool) -> bool:
decision = torch.tensor(int(decision), device=self.device)
decision = self.reduce(decision, "sum")
decision = torch.tensor(int(decision), device=self.lightning_module.device)
decision = self.reduce(decision, reduce_op="sum")
decision = bool(decision == self.world_size)
return decision
def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None):
if not isinstance(output, torch.Tensor):
output = torch.tensor(output, device=self.device)
output = torch.tensor(output, device=self.lightning_module.device)
_invalid_reduce_op = isinstance(reduce_op, ReduceOp) and reduce_op != ReduceOp.SUM
_invalid_reduce_op_str = isinstance(reduce_op, str) and reduce_op.lower() not in ("sum", "mean", "avg")
@ -267,3 +267,15 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
if _OMEGACONF_AVAILABLE:
checkpoint = apply_to_collection(checkpoint, (DictConfig, ListConfig), OmegaConf.to_container)
self.save({k: v for k, v in checkpoint.items() if k != "callbacks"}, filepath)
def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor:
"""
Function to gather a tensor from several distributed processes
Args:
tensor: tensor of shape (batch, ...)
group: not available with TPUs
sync_grads: not available with TPUs
Return:
A tensor of shape (world_size, batch, ...)
"""
return xm.all_gather(tensor.unsqueeze(0))

View File

@ -229,8 +229,8 @@ def test_tpu_clip_grad_by_value(tmpdir):
progress_bar_refresh_rate=0,
max_epochs=4,
tpu_cores=1,
limit_train_batches=4,
limit_val_batches=4,
limit_train_batches=10,
limit_val_batches=10,
gradient_clip_val=0.5,
gradient_clip_algorithm='value'
)