Bugfix/cuda oom detection and handling (#6934)
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
parent
895bea1ad3
commit
5bd3cd5f71
|
@ -53,7 +53,8 @@ def is_oom_error(exception):
|
|||
def is_cuda_out_of_memory(exception):
|
||||
return isinstance(exception, RuntimeError) \
|
||||
and len(exception.args) == 1 \
|
||||
and "CUDA out of memory." in exception.args[0]
|
||||
and "CUDA" in exception.args[0] \
|
||||
and "out of memory" in exception.args[0]
|
||||
|
||||
|
||||
# based on https://github.com/BlackHC/toma/blob/master/toma/torch_cuda_memory.py
|
||||
|
@ -76,4 +77,10 @@ def garbage_collection_cuda():
|
|||
"""Garbage collection Torch (CUDA) memory."""
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
try:
|
||||
# This is the last thing that should cause an OOM error, but seemingly it can.
|
||||
torch.cuda.empty_cache()
|
||||
except RuntimeError as exception:
|
||||
if not is_oom_error(exception):
|
||||
# Only handle OOM errors
|
||||
raise
|
||||
|
|
Loading…
Reference in New Issue