* added early stop tpu test

* added early stop tpu test

* added early stop tpu test

* added early stop tpu test

* added early stop tpu test

* added early stop tpu test

* added early stop tpu test

* added early stop tpu test

* added early stop tpu test

* added early stop tpu test

* added early stop tpu test

* added early stop tpu test

* added early stop tpu test

* added early stop tpu test

* added early stop tpu test

* added early stop tpu test

* added early stop tpu test
This commit is contained in:
William Falcon 2020-07-03 15:16:45 -04:00 committed by GitHub
parent fc61c200c0
commit e5a979990e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 9 additions and 7 deletions

View File

@ -149,7 +149,10 @@ class EarlyStopping(Callback):
if not isinstance(current, torch.Tensor):
current = torch.tensor(current, device=pl_module.device)
if self.monitor_op(current - self.min_delta, self.best_score.to(pl_module.device)):
if trainer.use_tpu and XLA_AVAILABLE:
current = current.cpu()
if self.monitor_op(current - self.min_delta, self.best_score):
self.best_score = current
self.wait_count = 0
else:
@ -172,12 +175,11 @@ class EarlyStopping(Callback):
dist.barrier()
trainer.should_stop = stop == trainer.world_size
# if trainer.use_tpu:
# stop = torch.tensor(int(trainer.should_stop), device=pl_module.device)
# xm.all_reduce('sum', [stop])
# print(type(stop))
# torch_xla.core.xla_model.rendezvous("pl.EarlyStoppingCallback.stop_distributed_training_check")
# trainer.should_stop = stop.item() == trainer.world_size
if trainer.use_tpu:
stop = torch.tensor(int(trainer.should_stop), device=pl_module.device, dtype=torch.int32)
stop = xm.mesh_reduce("stop_signal", stop, torch.cat)
torch_xla.core.xla_model.rendezvous("pl.EarlyStoppingCallback.stop_distributed_training_check")
trainer.should_stop = int(stop.item()) == trainer.world_size
def on_train_end(self, trainer, pl_module):
if self.stopped_epoch > 0 and self.verbose > 0: