added summary flag

This commit is contained in:
William Falcon 2019-07-15 21:11:29 -04:00
parent 182c025c88
commit d12f6b7dd8
1 changed files with 3 additions and 1 deletions

View File

@ -52,6 +52,7 @@ class Trainer(TrainerIO):
lr_scheduler_milestones=None,
use_amp=False,
print_nan_grads=False,
print_weights_summary=True,
amp_level='O2',
nb_sanity_val_steps=5):
@ -69,6 +70,7 @@ class Trainer(TrainerIO):
self.cluster = cluster
self.process_position = process_position
self.current_gpu_name = current_gpu_name
self.print_weights_summary = print_weights_summary
self.checkpoint_callback = checkpoint_callback
if self.checkpoint_callback is not None:
@ -445,7 +447,7 @@ class Trainer(TrainerIO):
self.lr_schedulers.append(scheduler)
# print model summary
if self.proc_rank == 0:
if self.proc_rank == 0 and self.print_weights_summary:
ref_model.summarize()
# give model convenience properties