Fix for incorrect usage of detach(), cpu(), to() (#6216)
* Fix for incorrect detach/cpu calls (#6214) * Fix incorrect use of detach(), to(), and cpu(), #6214 * Fix incorrect use of detach() and cpu(), #6214 * update pr * add typing * chlog * more... * revert on module * update on comments * revert changes on model Co-authored-by: tchaton <thomas@grid.ai> Co-authored-by: Jirka Borovec <jirka.borovec@seznam.cz>
This commit is contained in:
parent
925f082572
commit
651c25feb6
|
@ -68,6 +68,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Fixed multiple early stopping callbacks ([#6197](https://github.com/PyTorchLightning/pytorch-lightning/pull/6197))
|
||||
|
||||
|
||||
- Fixed incorrect usage of `detach()`, `cpu()`, `to()` ([#6216](https://github.com/PyTorchLightning/pytorch-lightning/pull/6216))
|
||||
|
||||
|
||||
- Fixed LBFGS optimizer support which didn't converge in automatic optimization ([#6147](https://github.com/PyTorchLightning/pytorch-lightning/pull/6147))
|
||||
|
||||
|
||||
|
|
|
@ -416,20 +416,22 @@ class Result(Dict):
|
|||
|
||||
return result
|
||||
|
||||
def detach(self):
|
||||
def detach(self) -> 'Result':
|
||||
for k, v in self.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
self.__setitem__(k, v.detach())
|
||||
return self
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
def to(self, *args, **kwargs) -> 'Result':
|
||||
"""Move all self attributes to the given device."""
|
||||
for k, v in self.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
self.__setitem__(k, v.to(*args, **kwargs))
|
||||
return self
|
||||
|
||||
def cpu(self):
|
||||
def cpu(self) -> 'Result':
|
||||
"""Move all self attributes to CPU."""
|
||||
self.to(torch.device("cpu"))
|
||||
return self.to(torch.device("cpu"))
|
||||
|
||||
def __repr__(self):
|
||||
self_copy = self.copy()
|
||||
|
|
|
@ -281,11 +281,11 @@ class EpochResultStore:
|
|||
# attach capture batch_size
|
||||
Result.attach_batch_size(self._batch_size, hook_result)
|
||||
|
||||
hook_result.detach()
|
||||
hook_result = hook_result.detach()
|
||||
if self.trainer.move_metrics_to_cpu:
|
||||
hook_result.cpu()
|
||||
hook_result = hook_result.cpu()
|
||||
elif self.trainer._distrib_type == DistributedType.DP:
|
||||
hook_result.to(torch.device("cuda", self.trainer.root_gpu))
|
||||
hook_result = hook_result.to(torch.device("cuda", self.trainer.root_gpu))
|
||||
|
||||
self._internals[fx_name].append(hook_result, info)
|
||||
|
||||
|
|
|
@ -736,9 +736,9 @@ class Trainer(
|
|||
def track_output_for_epoch_end(self, outputs, output):
|
||||
if output is not None:
|
||||
if isinstance(output, Result):
|
||||
output.detach()
|
||||
output = output.detach()
|
||||
if self.move_metrics_to_cpu:
|
||||
output.cpu()
|
||||
output = output.cpu()
|
||||
elif isinstance(output, dict):
|
||||
output = recursive_detach(output, to_cpu=self.move_metrics_to_cpu)
|
||||
elif isinstance(output, torch.Tensor) and output.is_cuda and self.move_metrics_to_cpu:
|
||||
|
|
|
@ -261,7 +261,7 @@ class TrainLoop:
|
|||
is_result_obj = isinstance(training_step_output, Result)
|
||||
|
||||
if is_result_obj:
|
||||
training_step_output.detach()
|
||||
training_step_output = training_step_output.detach()
|
||||
else:
|
||||
training_step_output.batch_loss = training_step_output.batch_loss.detach()
|
||||
|
||||
|
@ -395,9 +395,9 @@ class TrainLoop:
|
|||
|
||||
# track metrics without grads for epoch reduction
|
||||
training_step_output_for_epoch_end = copy(result)
|
||||
training_step_output_for_epoch_end.detach()
|
||||
training_step_output_for_epoch_end = training_step_output_for_epoch_end.detach()
|
||||
if self.trainer.move_metrics_to_cpu:
|
||||
training_step_output_for_epoch_end.cpu()
|
||||
training_step_output_for_epoch_end = training_step_output_for_epoch_end.cpu()
|
||||
|
||||
# what flows back into the system
|
||||
training_step_output = result
|
||||
|
|
|
@ -145,8 +145,7 @@ def test_lightning_parallel_module_python_scalar_conversion(device):
|
|||
output.update({"python scalar": 12.3})
|
||||
return output
|
||||
|
||||
model = TestModel()
|
||||
model.to(device)
|
||||
model = TestModel().to(device)
|
||||
model.trainer = Mock()
|
||||
model.trainer._running_stage = RunningStage.TRAINING
|
||||
batch = torch.rand(2, 32).to(device)
|
||||
|
|
Loading…
Reference in New Issue