From b8cc62ee5260f016472a6d719313baaf81876207 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Wed, 24 Jul 2019 16:24:58 -0400 Subject: [PATCH] added sample input for summary --- pytorch_lightning/root_module/memory.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/pytorch_lightning/root_module/memory.py b/pytorch_lightning/root_module/memory.py index 17fbe99197..0482da6e3d 100644 --- a/pytorch_lightning/root_module/memory.py +++ b/pytorch_lightning/root_module/memory.py @@ -34,6 +34,11 @@ class ModelSummary(object): in_sizes = [] out_sizes = [] input_ = self.model.example_input_array + + if self.model.on_gpu: + input_ = input_.cuda(0) + + for i in range(1, len(mods)): m = mods[i] if type(input_) is list or type(input_) is tuple: