[bugfix] Prevent a DDP failure using copy (#9239)

This commit is contained in:
thomas chaton 2021-08-31 22:02:33 +01:00 committed by GitHub
parent 3e71046f49
commit ff7305f74d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 14 additions and 5 deletions

View File

@ -273,9 +273,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed not setting a default value for `max_epochs` if `max_time` was specified on the `Trainer` constructor ([#9072](https://github.com/PyTorchLightning/pytorch-lightning/pull/9072))
- Fixed the CometLogger, no longer modifies the metrics in place. Instead creates a copy of metrics before performing any operations ([#9150](https://github.com/PyTorchLightning/pytorch-lightning/pull/9150))
- Fixed `DDP` "CUDA error: initialization error" due to a `copy` instead of `deepcopy` on `ResultCollection` ([#9239](https://github.com/PyTorchLightning/pytorch-lightning/pull/9239))
## [1.4.3] - 2021-08-17
- Fixed plateau scheduler stepping on incomplete epoch ([#8861](https://github.com/PyTorchLightning/pytorch-lightning/pull/8861))

View File

@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import copy
from copy import deepcopy
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Tuple
@ -142,12 +142,12 @@ class TrainingBatchLoop(Loop):
result = self._run_optimization(batch_idx, split_batch, opt_idx, optimizer)
if result:
self.batch_outputs[opt_idx].append(copy(result.result_collection))
self.batch_outputs[opt_idx].append(deepcopy(result.result_collection))
else:
# in manual optimization, there is no looping over optimizers
result = self._run_optimization(batch_idx, split_batch)
if result:
self.batch_outputs[0].append(copy(result.result_collection))
self.batch_outputs[0].append(deepcopy(result.result_collection))
def teardown(self) -> None:
# release memory

View File

@ -17,6 +17,7 @@ from functools import partial, wraps
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union
import torch
from torch.functional import Tensor
from torchmetrics import Metric
from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin
@ -435,8 +436,12 @@ class ResultCollection(dict):
) -> None:
"""See :meth:`~pytorch_lightning.core.lightning.LightningModule.log`"""
# no metrics should be logged with graphs
if not enable_graph and isinstance(value, torch.Tensor):
value = value.detach()
if not enable_graph:
def detach_fn(tensor: Tensor) -> Tensor:
return tensor.detach()
value = apply_to_collection(value, Tensor, detach_fn)
# move metrics to cpu on TPU.
if isinstance(value, torch.Tensor) and value.device.type == "xla":