improve pickle tests for callbacks (#1717)
* improve pickle tests for callbacks * set mode dict as a class attr
This commit is contained in:
parent
2b03d34931
commit
fc7f5919b5
|
@ -45,11 +45,14 @@ class EarlyStopping(Callback):
|
|||
>>> early_stopping = EarlyStopping('val_loss')
|
||||
>>> trainer = Trainer(early_stop_callback=early_stopping)
|
||||
"""
|
||||
mode_dict = {
|
||||
'min': torch.lt,
|
||||
'max': torch.gt,
|
||||
}
|
||||
|
||||
def __init__(self, monitor: str = 'val_loss', min_delta: float = 0.0, patience: int = 3,
|
||||
verbose: bool = False, mode: str = 'auto', strict: bool = True):
|
||||
super().__init__()
|
||||
|
||||
self.monitor = monitor
|
||||
self.patience = patience
|
||||
self.verbose = verbose
|
||||
|
@ -59,17 +62,19 @@ class EarlyStopping(Callback):
|
|||
self.stopped_epoch = 0
|
||||
self.mode = mode
|
||||
|
||||
mode_dict = {
|
||||
'min': torch.lt,
|
||||
'max': torch.gt,
|
||||
'auto': torch.gt if 'acc' in self.monitor else torch.lt
|
||||
}
|
||||
|
||||
if mode not in mode_dict:
|
||||
if mode not in self.mode_dict:
|
||||
if self.verbose > 0:
|
||||
log.info(f'EarlyStopping mode {mode} is unknown, fallback to auto mode.')
|
||||
self.mode = 'auto'
|
||||
|
||||
if self.mode == 'auto':
|
||||
if self.monitor == 'acc':
|
||||
self.mode = 'max'
|
||||
else:
|
||||
self.mode = 'min'
|
||||
if self.verbose > 0:
|
||||
log.info(f'EarlyStopping mode set to {self.mode} for monitoring {self.monitor}.')
|
||||
|
||||
self.min_delta *= 1 if self.monitor_op == torch.gt else -1
|
||||
|
||||
def _validate_condition_metric(self, logs):
|
||||
|
@ -96,12 +101,7 @@ class EarlyStopping(Callback):
|
|||
|
||||
@property
|
||||
def monitor_op(self):
|
||||
mode_dict = {
|
||||
'min': torch.lt,
|
||||
'max': torch.gt,
|
||||
'auto': torch.gt if 'acc' in self.monitor else torch.lt
|
||||
}
|
||||
return mode_dict[self.mode]
|
||||
return self.mode_dict[self.mode]
|
||||
|
||||
def on_train_start(self, trainer, pl_module):
|
||||
# Allow instances to be re-used
|
||||
|
|
|
@ -229,8 +229,14 @@ def test_pickling(tmpdir):
|
|||
early_stopping = EarlyStopping()
|
||||
ckpt = ModelCheckpoint(tmpdir)
|
||||
|
||||
pickle.dumps(ckpt)
|
||||
pickle.dumps(early_stopping)
|
||||
early_stopping_pickled = pickle.dumps(early_stopping)
|
||||
ckpt_pickled = pickle.dumps(ckpt)
|
||||
|
||||
early_stopping_loaded = pickle.loads(early_stopping_pickled)
|
||||
ckpt_loaded = pickle.loads(ckpt_pickled)
|
||||
|
||||
assert vars(early_stopping) == vars(early_stopping_loaded)
|
||||
assert vars(ckpt) == vars(ckpt_loaded)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('save_top_k', [-1, 0, 1, 2])
|
||||
|
|
Loading…
Reference in New Issue