added sample input for summary
This commit is contained in:
parent
83ccd21bec
commit
7c3786aa52
|
@ -33,7 +33,7 @@ class ModelSummary(object):
|
|||
mods = list(self.model.modules())
|
||||
in_sizes = []
|
||||
out_sizes = []
|
||||
input_ = self.example_input_array
|
||||
input_ = self.model.example_input_array
|
||||
for i in range(1, len(mods)):
|
||||
m = mods[i]
|
||||
if type(input_) is list or type(input_) is tuple:
|
||||
|
@ -121,7 +121,7 @@ class ModelSummary(object):
|
|||
df['Type'] = self.layer_types
|
||||
df['Params'] = self.param_nums
|
||||
|
||||
if self.example_input_array:
|
||||
if self.model.example_input_array:
|
||||
df.columns.extend(['In_sizes', 'Out_sizes'])
|
||||
df['In_sizes'] = self.in_sizes
|
||||
df['Out_sizes'] = self.out_sizes
|
||||
|
@ -134,7 +134,7 @@ class ModelSummary(object):
|
|||
self.get_parameter_sizes()
|
||||
self.get_parameter_nums()
|
||||
|
||||
if self.example_input_array:
|
||||
if self.model.example_input_array:
|
||||
self.get_variable_sizes()
|
||||
self.make_summary()
|
||||
|
||||
|
|
Loading…
Reference in New Issue