added summary flag
This commit is contained in:
parent
182c025c88
commit
d12f6b7dd8
|
@ -52,6 +52,7 @@ class Trainer(TrainerIO):
|
|||
lr_scheduler_milestones=None,
|
||||
use_amp=False,
|
||||
print_nan_grads=False,
|
||||
print_weights_summary=True,
|
||||
amp_level='O2',
|
||||
nb_sanity_val_steps=5):
|
||||
|
||||
|
@ -69,6 +70,7 @@ class Trainer(TrainerIO):
|
|||
self.cluster = cluster
|
||||
self.process_position = process_position
|
||||
self.current_gpu_name = current_gpu_name
|
||||
self.print_weights_summary = print_weights_summary
|
||||
self.checkpoint_callback = checkpoint_callback
|
||||
|
||||
if self.checkpoint_callback is not None:
|
||||
|
@ -445,7 +447,7 @@ class Trainer(TrainerIO):
|
|||
self.lr_schedulers.append(scheduler)
|
||||
|
||||
# print model summary
|
||||
if self.proc_rank == 0:
|
||||
if self.proc_rank == 0 and self.print_weights_summary:
|
||||
ref_model.summarize()
|
||||
|
||||
# give model convenience properties
|
||||
|
|
Loading…
Reference in New Issue