From 741c452551780e110938c8635db496682784be07 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Tue, 23 Mar 2021 22:07:48 +0000 Subject: [PATCH] Fix disabled grads after call to predict (#6657) --- CHANGELOG.md | 3 +++ pytorch_lightning/trainer/trainer.py | 4 ++++ tests/trainer/test_trainer.py | 13 +++++++++++++ 3 files changed, 20 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1b3359ace5..6a1e85d4ad 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -191,6 +191,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed comparing required versions ([#6434](https://github.com/PyTorchLightning/pytorch-lightning/pull/6434)) +- Fixed a bug where gradients were disabled after calling `Trainer.predict` ([#6657](https://github.com/PyTorchLightning/pytorch-lightning/pull/6657)) + + ## [1.2.4] - 2021-03-16 ### Changed diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index bb5d691996..dbc493aa76 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -800,6 +800,10 @@ class Trainer( results = self.predict_loop.on_predict_epoch_end() self.predict_loop.on_predict_end() + + # re-enable grads + torch.set_grad_enabled(True) + return results def run_sanity_check(self, ref_model): diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index d461d9d152..490f205a7b 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1450,6 +1450,19 @@ def test_trainer_predict_no_return(tmpdir): predict(tmpdir, None, None, 1, model=CustomBoringModel()) +def test_trainer_predict_grad(tmpdir): + class CustomBoringModel(BoringModel): + + def predict_step(self, batch, batch_idx, dataloader_idx=None): + assert batch.expand_as(batch).grad_fn is None + return super().predict_step(batch, batch_idx, dataloader_idx) + + predict(tmpdir, None, None, 1, model=CustomBoringModel()) + + x = torch.zeros(1, requires_grad=True) + assert x.expand_as(x).grad_fn is not None + + @pytest.mark.parametrize('datamodule', [False, True]) def test_trainer_predict_cpu(tmpdir, datamodule): predict(tmpdir, None, None, 1, datamodule=datamodule)