Hang (#2488)
* added early stop tpu test * added early stop tpu test * added early stop tpu test * added early stop tpu test * added early stop tpu test * added early stop tpu test * added early stop tpu test * added early stop tpu test * added early stop tpu test * added early stop tpu test * added early stop tpu test * added early stop tpu test * added early stop tpu test * added early stop tpu test * added early stop tpu test * added early stop tpu test * added early stop tpu test
This commit is contained in:
parent
fc61c200c0
commit
e5a979990e
|
@ -149,7 +149,10 @@ class EarlyStopping(Callback):
|
|||
if not isinstance(current, torch.Tensor):
|
||||
current = torch.tensor(current, device=pl_module.device)
|
||||
|
||||
if self.monitor_op(current - self.min_delta, self.best_score.to(pl_module.device)):
|
||||
if trainer.use_tpu and XLA_AVAILABLE:
|
||||
current = current.cpu()
|
||||
|
||||
if self.monitor_op(current - self.min_delta, self.best_score):
|
||||
self.best_score = current
|
||||
self.wait_count = 0
|
||||
else:
|
||||
|
@ -172,12 +175,11 @@ class EarlyStopping(Callback):
|
|||
dist.barrier()
|
||||
trainer.should_stop = stop == trainer.world_size
|
||||
|
||||
# if trainer.use_tpu:
|
||||
# stop = torch.tensor(int(trainer.should_stop), device=pl_module.device)
|
||||
# xm.all_reduce('sum', [stop])
|
||||
# print(type(stop))
|
||||
# torch_xla.core.xla_model.rendezvous("pl.EarlyStoppingCallback.stop_distributed_training_check")
|
||||
# trainer.should_stop = stop.item() == trainer.world_size
|
||||
if trainer.use_tpu:
|
||||
stop = torch.tensor(int(trainer.should_stop), device=pl_module.device, dtype=torch.int32)
|
||||
stop = xm.mesh_reduce("stop_signal", stop, torch.cat)
|
||||
torch_xla.core.xla_model.rendezvous("pl.EarlyStoppingCallback.stop_distributed_training_check")
|
||||
trainer.should_stop = int(stop.item()) == trainer.world_size
|
||||
|
||||
def on_train_end(self, trainer, pl_module):
|
||||
if self.stopped_epoch > 0 and self.verbose > 0:
|
||||
|
|
Loading…
Reference in New Issue