From d562172b4cd363b86f9d670120932cd333b03cf5 Mon Sep 17 00:00:00 2001 From: VSJMilewski <6348139+VSJMilewski@users.noreply.github.com> Date: Mon, 9 Dec 2019 13:42:07 +0100 Subject: [PATCH] Allow for multiple example inputs when creating summary (#543) --- pytorch_lightning/core/memory.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/core/memory.py b/pytorch_lightning/core/memory.py index b01451f8cf..1abc349e1e 100644 --- a/pytorch_lightning/core/memory.py +++ b/pytorch_lightning/core/memory.py @@ -50,20 +50,31 @@ class ModelSummary(object): input_ = self.model.example_input_array if self.model.on_gpu: - input_ = input_.cuda(0) + device = next(self.model.parameters()).get_device() + # test if input is a list or a tuple + if isinstance(input_, (list, tuple)): + input_ = [input_i.cuda(device) if torch.is_tensor(input_i) else input_i + for input_i in input_] + else: + input_ = input_.cuda(device) if self.model.trainer.use_amp: - input_ = input_.half() + # test if it is not a list or a tuple + if isinstance(input_, (list, tuple)): + input_ = [input_i.half() if torch.is_tensor(input_i) else input_i + for input_i in input_] + else: + input_ = input_.half() with torch.no_grad(): for _, m in mods: - if type(input_) is list or type(input_) is tuple: # pragma: no cover + if isinstance(input_, (list, tuple)): # pragma: no cover out = m(*input_) else: out = m(input_) - if type(input_) is tuple or type(input_) is list: # pragma: no cover + if isinstance(input_, (list, tuple)): # pragma: no cover in_size = [] for x in input_: if type(x) is list: @@ -75,7 +86,7 @@ class ModelSummary(object): in_sizes.append(in_size) - if type(out) is tuple or type(out) is list: # pragma: no cover + if isinstance(out, (list, tuple)): # pragma: no cover out_size = np.asarray([x.size() for x in out]) else: out_size = np.array(out.size())