From 651c25feb66ba0a4c715ca671744a20f0a1355b0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20V=C3=B6lgyes?= Date: Mon, 1 Mar 2021 16:15:52 +0100 Subject: [PATCH] Fix for incorrect usage of detach(), cpu(), to() (#6216) * Fix for incorrect detach/cpu calls (#6214) * Fix incorrect use of detach(), to(), and cpu(), #6214 * Fix incorrect use of detach() and cpu(), #6214 * update pr * add typing * chlog * more... * revert on module * update on comments * revert changes on model Co-authored-by: tchaton Co-authored-by: Jirka Borovec --- CHANGELOG.md | 3 +++ pytorch_lightning/core/step_result.py | 10 ++++++---- .../connectors/logger_connector/epoch_result_store.py | 6 +++--- pytorch_lightning/trainer/trainer.py | 4 ++-- pytorch_lightning/trainer/training_loop.py | 6 +++--- tests/overrides/test_data_parallel.py | 3 +-- 6 files changed, 18 insertions(+), 14 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 06a91bf973..8f31000b0c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -68,6 +68,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed multiple early stopping callbacks ([#6197](https://github.com/PyTorchLightning/pytorch-lightning/pull/6197)) +- Fixed incorrect usage of `detach()`, `cpu()`, `to()` ([#6216](https://github.com/PyTorchLightning/pytorch-lightning/pull/6216)) + + - Fixed LBFGS optimizer support which didn't converge in automatic optimization ([#6147](https://github.com/PyTorchLightning/pytorch-lightning/pull/6147)) diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index 974974b032..f8d7a2ffe3 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -416,20 +416,22 @@ class Result(Dict): return result - def detach(self): + def detach(self) -> 'Result': for k, v in self.items(): if isinstance(v, torch.Tensor): self.__setitem__(k, v.detach()) + return self - def to(self, *args, **kwargs): + def to(self, *args, **kwargs) -> 'Result': """Move all self attributes to the given device.""" for k, v in self.items(): if isinstance(v, torch.Tensor): self.__setitem__(k, v.to(*args, **kwargs)) + return self - def cpu(self): + def cpu(self) -> 'Result': """Move all self attributes to CPU.""" - self.to(torch.device("cpu")) + return self.to(torch.device("cpu")) def __repr__(self): self_copy = self.copy() diff --git a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py index c435204107..a547144c8a 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py @@ -281,11 +281,11 @@ class EpochResultStore: # attach capture batch_size Result.attach_batch_size(self._batch_size, hook_result) - hook_result.detach() + hook_result = hook_result.detach() if self.trainer.move_metrics_to_cpu: - hook_result.cpu() + hook_result = hook_result.cpu() elif self.trainer._distrib_type == DistributedType.DP: - hook_result.to(torch.device("cuda", self.trainer.root_gpu)) + hook_result = hook_result.to(torch.device("cuda", self.trainer.root_gpu)) self._internals[fx_name].append(hook_result, info) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 530001e0be..68453811da 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -736,9 +736,9 @@ class Trainer( def track_output_for_epoch_end(self, outputs, output): if output is not None: if isinstance(output, Result): - output.detach() + output = output.detach() if self.move_metrics_to_cpu: - output.cpu() + output = output.cpu() elif isinstance(output, dict): output = recursive_detach(output, to_cpu=self.move_metrics_to_cpu) elif isinstance(output, torch.Tensor) and output.is_cuda and self.move_metrics_to_cpu: diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 3b9cd6544a..97814bb912 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -261,7 +261,7 @@ class TrainLoop: is_result_obj = isinstance(training_step_output, Result) if is_result_obj: - training_step_output.detach() + training_step_output = training_step_output.detach() else: training_step_output.batch_loss = training_step_output.batch_loss.detach() @@ -395,9 +395,9 @@ class TrainLoop: # track metrics without grads for epoch reduction training_step_output_for_epoch_end = copy(result) - training_step_output_for_epoch_end.detach() + training_step_output_for_epoch_end = training_step_output_for_epoch_end.detach() if self.trainer.move_metrics_to_cpu: - training_step_output_for_epoch_end.cpu() + training_step_output_for_epoch_end = training_step_output_for_epoch_end.cpu() # what flows back into the system training_step_output = result diff --git a/tests/overrides/test_data_parallel.py b/tests/overrides/test_data_parallel.py index 90bb6fac88..128bb22411 100644 --- a/tests/overrides/test_data_parallel.py +++ b/tests/overrides/test_data_parallel.py @@ -145,8 +145,7 @@ def test_lightning_parallel_module_python_scalar_conversion(device): output.update({"python scalar": 12.3}) return output - model = TestModel() - model.to(device) + model = TestModel().to(device) model.trainer = Mock() model.trainer._running_stage = RunningStage.TRAINING batch = torch.rand(2, 32).to(device)