* added early stop tpu test

* added early stop tpu test

* added early stop tpu test

* added early stop tpu test

* added early stop tpu test

* added early stop tpu test

* added early stop tpu test

* added early stop tpu test

* added early stop tpu test

* added early stop tpu test

* added early stop tpu test

* added early stop tpu test

* added early stop tpu test

* added early stop tpu test

* added early stop tpu test

* added early stop tpu test

* added early stop tpu test
This commit is contained in:
William Falcon 2020-07-03 15:16:45 -04:00 committed by GitHub
parent fc61c200c0
commit e5a979990e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 9 additions and 7 deletions

View File

@ -149,7 +149,10 @@ class EarlyStopping(Callback):
if not isinstance(current, torch.Tensor): if not isinstance(current, torch.Tensor):
current = torch.tensor(current, device=pl_module.device) current = torch.tensor(current, device=pl_module.device)
if self.monitor_op(current - self.min_delta, self.best_score.to(pl_module.device)): if trainer.use_tpu and XLA_AVAILABLE:
current = current.cpu()
if self.monitor_op(current - self.min_delta, self.best_score):
self.best_score = current self.best_score = current
self.wait_count = 0 self.wait_count = 0
else: else:
@ -172,12 +175,11 @@ class EarlyStopping(Callback):
dist.barrier() dist.barrier()
trainer.should_stop = stop == trainer.world_size trainer.should_stop = stop == trainer.world_size
# if trainer.use_tpu: if trainer.use_tpu:
# stop = torch.tensor(int(trainer.should_stop), device=pl_module.device) stop = torch.tensor(int(trainer.should_stop), device=pl_module.device, dtype=torch.int32)
# xm.all_reduce('sum', [stop]) stop = xm.mesh_reduce("stop_signal", stop, torch.cat)
# print(type(stop)) torch_xla.core.xla_model.rendezvous("pl.EarlyStoppingCallback.stop_distributed_training_check")
# torch_xla.core.xla_model.rendezvous("pl.EarlyStoppingCallback.stop_distributed_training_check") trainer.should_stop = int(stop.item()) == trainer.world_size
# trainer.should_stop = stop.item() == trainer.world_size
def on_train_end(self, trainer, pl_module): def on_train_end(self, trainer, pl_module):
if self.stopped_epoch > 0 and self.verbose > 0: if self.stopped_epoch > 0 and self.verbose > 0: