diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 4c30b67778..d383a2fb42 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -57,6 +57,7 @@ class EarlyStopping(Callback): self.min_delta = min_delta self.wait = 0 self.stopped_epoch = 0 + self.mode = mode mode_dict = { 'min': torch.lt, @@ -67,9 +68,8 @@ class EarlyStopping(Callback): if mode not in mode_dict: if self.verbose > 0: log.info(f'EarlyStopping mode {mode} is unknown, fallback to auto mode.') - mode = 'auto' + self.mode = 'auto' - self.monitor_op = mode_dict[mode] self.min_delta *= 1 if self.monitor_op == torch.gt else -1 def _validate_condition_metric(self, logs): @@ -94,6 +94,15 @@ class EarlyStopping(Callback): return True + @property + def monitor_op(self): + mode_dict = { + 'min': torch.lt, + 'max': torch.gt, + 'auto': torch.gt if 'acc' in self.monitor else torch.lt + } + return mode_dict[self.mode] + def on_train_start(self, trainer, pl_module): # Allow instances to be re-used self.wait = 0 diff --git a/pytorch_lightning/trainer/distrib_data_parallel.py b/pytorch_lightning/trainer/distrib_data_parallel.py index 659aa7a072..f26901c0ed 100644 --- a/pytorch_lightning/trainer/distrib_data_parallel.py +++ b/pytorch_lightning/trainer/distrib_data_parallel.py @@ -378,6 +378,7 @@ class TrainerDDPMixin(ABC): :param model: :return: """ + import pdb; pdb.set_trace() if self.proc_rank == 0: path = os.path.join(self.default_root_dir, '__temp_weight_ddp_end.ckpt') self.save_checkpoint(path)