diff --git a/pytorch_lightning/root_module/memory.py b/pytorch_lightning/root_module/memory.py index 909f97f119..17fbe99197 100644 --- a/pytorch_lightning/root_module/memory.py +++ b/pytorch_lightning/root_module/memory.py @@ -121,7 +121,7 @@ class ModelSummary(object): df['Type'] = self.layer_types df['Params'] = self.param_nums - if self.model.example_input_array: + if self.model.example_input_array is not None: df.columns.extend(['In_sizes', 'Out_sizes']) df['In_sizes'] = self.in_sizes df['Out_sizes'] = self.out_sizes @@ -134,7 +134,7 @@ class ModelSummary(object): self.get_parameter_sizes() self.get_parameter_nums() - if self.model.example_input_array: + if self.model.example_input_array is not None: self.get_variable_sizes() self.make_summary()