Fix disabled grads after call to predict (#6657)
This commit is contained in:
parent
64d0fa4472
commit
741c452551
|
@ -191,6 +191,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Fixed comparing required versions ([#6434](https://github.com/PyTorchLightning/pytorch-lightning/pull/6434))
|
||||
|
||||
|
||||
- Fixed a bug where gradients were disabled after calling `Trainer.predict` ([#6657](https://github.com/PyTorchLightning/pytorch-lightning/pull/6657))
|
||||
|
||||
|
||||
## [1.2.4] - 2021-03-16
|
||||
|
||||
### Changed
|
||||
|
|
|
@ -800,6 +800,10 @@ class Trainer(
|
|||
|
||||
results = self.predict_loop.on_predict_epoch_end()
|
||||
self.predict_loop.on_predict_end()
|
||||
|
||||
# re-enable grads
|
||||
torch.set_grad_enabled(True)
|
||||
|
||||
return results
|
||||
|
||||
def run_sanity_check(self, ref_model):
|
||||
|
|
|
@ -1450,6 +1450,19 @@ def test_trainer_predict_no_return(tmpdir):
|
|||
predict(tmpdir, None, None, 1, model=CustomBoringModel())
|
||||
|
||||
|
||||
def test_trainer_predict_grad(tmpdir):
|
||||
class CustomBoringModel(BoringModel):
|
||||
|
||||
def predict_step(self, batch, batch_idx, dataloader_idx=None):
|
||||
assert batch.expand_as(batch).grad_fn is None
|
||||
return super().predict_step(batch, batch_idx, dataloader_idx)
|
||||
|
||||
predict(tmpdir, None, None, 1, model=CustomBoringModel())
|
||||
|
||||
x = torch.zeros(1, requires_grad=True)
|
||||
assert x.expand_as(x).grad_fn is not None
|
||||
|
||||
|
||||
@pytest.mark.parametrize('datamodule', [False, True])
|
||||
def test_trainer_predict_cpu(tmpdir, datamodule):
|
||||
predict(tmpdir, None, None, 1, datamodule=datamodule)
|
||||
|
|
Loading…
Reference in New Issue