diff --git a/pytorch_lightning/root_module/memory.py b/pytorch_lightning/root_module/memory.py index ed8854ba30..a2feb689bd 100644 --- a/pytorch_lightning/root_module/memory.py +++ b/pytorch_lightning/root_module/memory.py @@ -41,7 +41,7 @@ class ModelSummary(object): if self.model.trainer.use_amp: input_ = input_.half() - with torch.no_grad: + with torch.no_grad(): for i in range(1, len(mods)): m = mods[i]