improve pickle tests for callbacks (#1717)

* improve pickle tests for callbacks

* set mode dict as a class attr
This commit is contained in:
Jeremy Jordan 2020-05-05 14:08:54 -04:00 committed by GitHub
parent 2b03d34931
commit fc7f5919b5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 22 additions and 16 deletions

View File

@ -45,11 +45,14 @@ class EarlyStopping(Callback):
>>> early_stopping = EarlyStopping('val_loss')
>>> trainer = Trainer(early_stop_callback=early_stopping)
"""
mode_dict = {
'min': torch.lt,
'max': torch.gt,
}
def __init__(self, monitor: str = 'val_loss', min_delta: float = 0.0, patience: int = 3,
verbose: bool = False, mode: str = 'auto', strict: bool = True):
super().__init__()
self.monitor = monitor
self.patience = patience
self.verbose = verbose
@ -59,17 +62,19 @@ class EarlyStopping(Callback):
self.stopped_epoch = 0
self.mode = mode
mode_dict = {
'min': torch.lt,
'max': torch.gt,
'auto': torch.gt if 'acc' in self.monitor else torch.lt
}
if mode not in mode_dict:
if mode not in self.mode_dict:
if self.verbose > 0:
log.info(f'EarlyStopping mode {mode} is unknown, fallback to auto mode.')
self.mode = 'auto'
if self.mode == 'auto':
if self.monitor == 'acc':
self.mode = 'max'
else:
self.mode = 'min'
if self.verbose > 0:
log.info(f'EarlyStopping mode set to {self.mode} for monitoring {self.monitor}.')
self.min_delta *= 1 if self.monitor_op == torch.gt else -1
def _validate_condition_metric(self, logs):
@ -96,12 +101,7 @@ class EarlyStopping(Callback):
@property
def monitor_op(self):
mode_dict = {
'min': torch.lt,
'max': torch.gt,
'auto': torch.gt if 'acc' in self.monitor else torch.lt
}
return mode_dict[self.mode]
return self.mode_dict[self.mode]
def on_train_start(self, trainer, pl_module):
# Allow instances to be re-used

View File

@ -229,8 +229,14 @@ def test_pickling(tmpdir):
early_stopping = EarlyStopping()
ckpt = ModelCheckpoint(tmpdir)
pickle.dumps(ckpt)
pickle.dumps(early_stopping)
early_stopping_pickled = pickle.dumps(early_stopping)
ckpt_pickled = pickle.dumps(ckpt)
early_stopping_loaded = pickle.loads(early_stopping_pickled)
ckpt_loaded = pickle.loads(ckpt_pickled)
assert vars(early_stopping) == vars(early_stopping_loaded)
assert vars(ckpt) == vars(ckpt_loaded)
@pytest.mark.parametrize('save_top_k', [-1, 0, 1, 2])