diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index 434e50609b..984dcb3836 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -27,6 +27,7 @@ def reduce_distributed_output(output, nb_gpus): output[k] = reduced return output + class Trainer(TrainerIO): def __init__(self, @@ -427,7 +428,6 @@ class Trainer(TrainerIO): loss.backward() if self.check_grad_nans: - model = self.model.module if self.data_parallel else self.model for param in model.parameters(): print(param.grad.float().sum())