diff --git a/pytorch_lightning/root_module/memory.py b/pytorch_lightning/root_module/memory.py index 4a166d63a5..909f97f119 100644 --- a/pytorch_lightning/root_module/memory.py +++ b/pytorch_lightning/root_module/memory.py @@ -33,7 +33,7 @@ class ModelSummary(object): mods = list(self.model.modules()) in_sizes = [] out_sizes = [] - input_ = self.example_input_array + input_ = self.model.example_input_array for i in range(1, len(mods)): m = mods[i] if type(input_) is list or type(input_) is tuple: @@ -121,7 +121,7 @@ class ModelSummary(object): df['Type'] = self.layer_types df['Params'] = self.param_nums - if self.example_input_array: + if self.model.example_input_array: df.columns.extend(['In_sizes', 'Out_sizes']) df['In_sizes'] = self.in_sizes df['Out_sizes'] = self.out_sizes @@ -134,7 +134,7 @@ class ModelSummary(object): self.get_parameter_sizes() self.get_parameter_nums() - if self.example_input_array: + if self.model.example_input_array: self.get_variable_sizes() self.make_summary()