diff --git a/pytorch_lightning/root_module/memory.py b/pytorch_lightning/root_module/memory.py index 17fbe99197..0482da6e3d 100644 --- a/pytorch_lightning/root_module/memory.py +++ b/pytorch_lightning/root_module/memory.py @@ -34,6 +34,11 @@ class ModelSummary(object): in_sizes = [] out_sizes = [] input_ = self.model.example_input_array + + if self.model.on_gpu: + input_ = input_.cuda(0) + + for i in range(1, len(mods)): m = mods[i] if type(input_) is list or type(input_) is tuple: