87 lines
2.9 KiB
Python
87 lines
2.9 KiB
Python
# Copyright The PyTorch Lightning team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import gc
|
|
|
|
import torch
|
|
|
|
|
|
def recursive_detach(in_dict: dict, to_cpu: bool = False) -> dict:
|
|
"""Detach all tensors in `in_dict`.
|
|
|
|
May operate recursively if some of the values in `in_dict` are dictionaries
|
|
which contain instances of `torch.Tensor`. Other types in `in_dict` are
|
|
not affected by this utility function.
|
|
|
|
Args:
|
|
in_dict: Dictionary with tensors to detach
|
|
to_cpu: Whether to move tensor to cpu
|
|
|
|
Return:
|
|
out_dict: Dictionary with detached tensors
|
|
"""
|
|
out_dict = {}
|
|
for k, v in in_dict.items():
|
|
if isinstance(v, dict):
|
|
v = recursive_detach(v, to_cpu=to_cpu)
|
|
elif callable(getattr(v, 'detach', None)):
|
|
v = v.detach()
|
|
if to_cpu:
|
|
v = v.cpu()
|
|
out_dict[k] = v
|
|
return out_dict
|
|
|
|
|
|
def is_oom_error(exception):
|
|
return is_cuda_out_of_memory(exception) \
|
|
or is_cudnn_snafu(exception) \
|
|
or is_out_of_cpu_memory(exception)
|
|
|
|
|
|
# based on https://github.com/BlackHC/toma/blob/master/toma/torch_cuda_memory.py
|
|
def is_cuda_out_of_memory(exception):
|
|
return isinstance(exception, RuntimeError) \
|
|
and len(exception.args) == 1 \
|
|
and "CUDA" in exception.args[0] \
|
|
and "out of memory" in exception.args[0]
|
|
|
|
|
|
# based on https://github.com/BlackHC/toma/blob/master/toma/torch_cuda_memory.py
|
|
def is_cudnn_snafu(exception):
|
|
# For/because of https://github.com/pytorch/pytorch/issues/4107
|
|
return isinstance(exception, RuntimeError) \
|
|
and len(exception.args) == 1 \
|
|
and "cuDNN error: CUDNN_STATUS_NOT_SUPPORTED." in exception.args[0]
|
|
|
|
|
|
# based on https://github.com/BlackHC/toma/blob/master/toma/cpu_memory.py
|
|
def is_out_of_cpu_memory(exception):
|
|
return isinstance(exception, RuntimeError) \
|
|
and len(exception.args) == 1 \
|
|
and "DefaultCPUAllocator: can't allocate memory" in exception.args[0]
|
|
|
|
|
|
# based on https://github.com/BlackHC/toma/blob/master/toma/torch_cuda_memory.py
|
|
def garbage_collection_cuda():
|
|
"""Garbage collection Torch (CUDA) memory."""
|
|
gc.collect()
|
|
if torch.cuda.is_available():
|
|
try:
|
|
# This is the last thing that should cause an OOM error, but seemingly it can.
|
|
torch.cuda.empty_cache()
|
|
except RuntimeError as exception:
|
|
if not is_oom_error(exception):
|
|
# Only handle OOM errors
|
|
raise
|