diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index a07bc534d9..d9506cafd2 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -52,6 +52,7 @@ class Trainer(TrainerIO): lr_scheduler_milestones=None, use_amp=False, print_nan_grads=False, + print_weights_summary=True, amp_level='O2', nb_sanity_val_steps=5): @@ -69,6 +70,7 @@ class Trainer(TrainerIO): self.cluster = cluster self.process_position = process_position self.current_gpu_name = current_gpu_name + self.print_weights_summary = print_weights_summary self.checkpoint_callback = checkpoint_callback if self.checkpoint_callback is not None: @@ -445,7 +447,7 @@ class Trainer(TrainerIO): self.lr_schedulers.append(scheduler) # print model summary - if self.proc_rank == 0: + if self.proc_rank == 0 and self.print_weights_summary: ref_model.summarize() # give model convenience properties