diff --git a/pytorch_lightning/examples/new_project_templates/lightning_module_template.py b/pytorch_lightning/examples/new_project_templates/lightning_module_template.py index e4ca0c8fad..fd68ef2e79 100644 --- a/pytorch_lightning/examples/new_project_templates/lightning_module_template.py +++ b/pytorch_lightning/examples/new_project_templates/lightning_module_template.py @@ -28,7 +28,8 @@ class LightningTemplateModel(LightningModule): self.batch_size = hparams.batch_size - self.example_input_array = torch.rand(5, 3 * 28 * 28) + # if you specify an example input, the summary will show input/output for each layer + self.example_input_array = torch.rand(5, 28 * 28) # build model self.__build_model() diff --git a/pytorch_lightning/root_module/memory.py b/pytorch_lightning/root_module/memory.py index 0482da6e3d..ed8854ba30 100644 --- a/pytorch_lightning/root_module/memory.py +++ b/pytorch_lightning/root_module/memory.py @@ -38,33 +38,37 @@ class ModelSummary(object): if self.model.on_gpu: input_ = input_.cuda(0) + if self.model.trainer.use_amp: + input_ = input_.half() - for i in range(1, len(mods)): - m = mods[i] - if type(input_) is list or type(input_) is tuple: - out = m(*input_) - else: - out = m(input_) + with torch.no_grad: - if type(input_) is tuple or type(input_) is list: - in_size = [] - for x in input_: - if type(x) is list: - in_size.append(len(x)) - else: - in_size.append(x.size()) - else: - in_size = np.array(input_.size()) + for i in range(1, len(mods)): + m = mods[i] + if type(input_) is list or type(input_) is tuple: + out = m(*input_) + else: + out = m(input_) - in_sizes.append(in_size) + if type(input_) is tuple or type(input_) is list: + in_size = [] + for x in input_: + if type(x) is list: + in_size.append(len(x)) + else: + in_size.append(x.size()) + else: + in_size = np.array(input_.size()) - if type(out) is tuple or type(out) is list: - out_size = np.asarray([x.size() for x in out]) - else: - out_size = np.array(out.size()) + in_sizes.append(in_size) - out_sizes.append(out_size) - input_ = out + if type(out) is tuple or type(out) is list: + out_size = np.asarray([x.size() for x in out]) + else: + out_size = np.array(out.size()) + + out_sizes.append(out_size) + input_ = out self.in_sizes = in_sizes self.out_sizes = out_sizes