Allow disabling default logger, checkpoint_callback, and early_stop_callback (#360)
* Allow disabling logger, early stopping, and checkpoints * Typo * Get tests passing * Update trainer.py
This commit is contained in:
parent
792ba59b78
commit
19c2b8fc9e
|
@ -58,9 +58,9 @@ def reduce_distributed_output(output, nb_gpus):
|
|||
class Trainer(TrainerIO):
|
||||
|
||||
def __init__(self,
|
||||
logger=None,
|
||||
checkpoint_callback=None,
|
||||
early_stop_callback=None,
|
||||
logger=True,
|
||||
checkpoint_callback=True,
|
||||
early_stop_callback=True,
|
||||
default_save_path=None,
|
||||
gradient_clip_val=0,
|
||||
process_position=0,
|
||||
|
@ -126,7 +126,6 @@ class Trainer(TrainerIO):
|
|||
self.log_gpu_memory = log_gpu_memory
|
||||
self.gradient_clip_val = gradient_clip_val
|
||||
self.check_val_every_n_epoch = check_val_every_n_epoch
|
||||
self.enable_early_stop = early_stop_callback is not None
|
||||
self.track_grad_norm = track_grad_norm
|
||||
self.on_gpu = gpus is not None and torch.cuda.is_available()
|
||||
self.process_position = process_position
|
||||
|
@ -176,24 +175,35 @@ class Trainer(TrainerIO):
|
|||
|
||||
# configure early stop callback
|
||||
# creates a default one if none passed in
|
||||
self.early_stop_callback = early_stop_callback
|
||||
if self.early_stop_callback is None:
|
||||
self.early_stop = EarlyStopping(
|
||||
if early_stop_callback is True:
|
||||
self.early_stop_callback = EarlyStopping(
|
||||
monitor='val_loss',
|
||||
patience=3,
|
||||
verbose=True,
|
||||
mode='min'
|
||||
)
|
||||
self.enable_early_stop = True
|
||||
elif not early_stop_callback:
|
||||
self.early_stop_callback = None
|
||||
self.enable_early_stop = False
|
||||
else:
|
||||
self.early_stop_callback = early_stop_callback
|
||||
self.enable_early_stop = True
|
||||
|
||||
# configure logger
|
||||
self.logger = logger
|
||||
if self.logger is None:
|
||||
if logger is True:
|
||||
# default logger
|
||||
self.logger = TestTubeLogger(
|
||||
save_dir=self.default_save_path,
|
||||
version=self.slurm_job_id,
|
||||
name='lightning_logs'
|
||||
)
|
||||
self.logger.rank = 0
|
||||
self.logger.rank = 0
|
||||
elif logger is False:
|
||||
self.logger = None
|
||||
else:
|
||||
self.logger = logger
|
||||
self.logger.rank = 0
|
||||
|
||||
# configure checkpoint callback
|
||||
self.checkpoint_callback = checkpoint_callback
|
||||
|
@ -257,7 +267,7 @@ class Trainer(TrainerIO):
|
|||
User provided weights_saved_path
|
||||
Otherwise use os.getcwd()
|
||||
"""
|
||||
if self.checkpoint_callback is None:
|
||||
if self.checkpoint_callback is True:
|
||||
# init a default one
|
||||
if isinstance(self.logger, TestTubeLogger):
|
||||
ckpt_path = '{}/{}/version_{:04d}/{}'.format(
|
||||
|
@ -271,12 +281,15 @@ class Trainer(TrainerIO):
|
|||
self.checkpoint_callback = ModelCheckpoint(
|
||||
filepath=ckpt_path
|
||||
)
|
||||
elif self.checkpoint_callback is False:
|
||||
self.checkpoint_callback = None
|
||||
|
||||
# set the path for the callbacks
|
||||
self.checkpoint_callback.save_function = self.save_checkpoint
|
||||
if self.checkpoint_callback:
|
||||
# set the path for the callbacks
|
||||
self.checkpoint_callback.save_function = self.save_checkpoint
|
||||
|
||||
# if checkpoint callback used, then override the weights path
|
||||
self.weights_save_path = self.checkpoint_callback.filepath
|
||||
# if checkpoint callback used, then override the weights path
|
||||
self.weights_save_path = self.checkpoint_callback.filepath
|
||||
|
||||
# if weights_save_path is still none here, set to current working dir
|
||||
if self.weights_save_path is None:
|
||||
|
|
Loading…
Reference in New Issue