diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 78c698d5f2..20c0673d93 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -58,9 +58,9 @@ def reduce_distributed_output(output, nb_gpus): class Trainer(TrainerIO): def __init__(self, - logger=None, - checkpoint_callback=None, - early_stop_callback=None, + logger=True, + checkpoint_callback=True, + early_stop_callback=True, default_save_path=None, gradient_clip_val=0, process_position=0, @@ -126,7 +126,6 @@ class Trainer(TrainerIO): self.log_gpu_memory = log_gpu_memory self.gradient_clip_val = gradient_clip_val self.check_val_every_n_epoch = check_val_every_n_epoch - self.enable_early_stop = early_stop_callback is not None self.track_grad_norm = track_grad_norm self.on_gpu = gpus is not None and torch.cuda.is_available() self.process_position = process_position @@ -176,24 +175,35 @@ class Trainer(TrainerIO): # configure early stop callback # creates a default one if none passed in - self.early_stop_callback = early_stop_callback - if self.early_stop_callback is None: - self.early_stop = EarlyStopping( + if early_stop_callback is True: + self.early_stop_callback = EarlyStopping( monitor='val_loss', patience=3, verbose=True, mode='min' ) + self.enable_early_stop = True + elif not early_stop_callback: + self.early_stop_callback = None + self.enable_early_stop = False + else: + self.early_stop_callback = early_stop_callback + self.enable_early_stop = True # configure logger - self.logger = logger - if self.logger is None: + if logger is True: + # default logger self.logger = TestTubeLogger( save_dir=self.default_save_path, version=self.slurm_job_id, name='lightning_logs' ) - self.logger.rank = 0 + self.logger.rank = 0 + elif logger is False: + self.logger = None + else: + self.logger = logger + self.logger.rank = 0 # configure checkpoint callback self.checkpoint_callback = checkpoint_callback @@ -257,7 +267,7 @@ class Trainer(TrainerIO): User provided weights_saved_path Otherwise use os.getcwd() """ - if self.checkpoint_callback is None: + if self.checkpoint_callback is True: # init a default one if isinstance(self.logger, TestTubeLogger): ckpt_path = '{}/{}/version_{:04d}/{}'.format( @@ -271,12 +281,15 @@ class Trainer(TrainerIO): self.checkpoint_callback = ModelCheckpoint( filepath=ckpt_path ) + elif self.checkpoint_callback is False: + self.checkpoint_callback = None - # set the path for the callbacks - self.checkpoint_callback.save_function = self.save_checkpoint + if self.checkpoint_callback: + # set the path for the callbacks + self.checkpoint_callback.save_function = self.save_checkpoint - # if checkpoint callback used, then override the weights path - self.weights_save_path = self.checkpoint_callback.filepath + # if checkpoint callback used, then override the weights path + self.weights_save_path = self.checkpoint_callback.filepath # if weights_save_path is still none here, set to current working dir if self.weights_save_path is None: