diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index d9a396f717..544854fa4e 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -149,7 +149,10 @@ class EarlyStopping(Callback): if not isinstance(current, torch.Tensor): current = torch.tensor(current, device=pl_module.device) - if self.monitor_op(current - self.min_delta, self.best_score.to(pl_module.device)): + if trainer.use_tpu and XLA_AVAILABLE: + current = current.cpu() + + if self.monitor_op(current - self.min_delta, self.best_score): self.best_score = current self.wait_count = 0 else: @@ -172,12 +175,11 @@ class EarlyStopping(Callback): dist.barrier() trainer.should_stop = stop == trainer.world_size - # if trainer.use_tpu: - # stop = torch.tensor(int(trainer.should_stop), device=pl_module.device) - # xm.all_reduce('sum', [stop]) - # print(type(stop)) - # torch_xla.core.xla_model.rendezvous("pl.EarlyStoppingCallback.stop_distributed_training_check") - # trainer.should_stop = stop.item() == trainer.world_size + if trainer.use_tpu: + stop = torch.tensor(int(trainer.should_stop), device=pl_module.device, dtype=torch.int32) + stop = xm.mesh_reduce("stop_signal", stop, torch.cat) + torch_xla.core.xla_model.rendezvous("pl.EarlyStoppingCallback.stop_distributed_training_check") + trainer.should_stop = int(stop.item()) == trainer.world_size def on_train_end(self, trainer, pl_module): if self.stopped_epoch > 0 and self.verbose > 0: