Bugfix/cuda oom detection and handling (#6934)
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
parent
895bea1ad3
commit
5bd3cd5f71
|
@ -53,7 +53,8 @@ def is_oom_error(exception):
|
||||||
def is_cuda_out_of_memory(exception):
|
def is_cuda_out_of_memory(exception):
|
||||||
return isinstance(exception, RuntimeError) \
|
return isinstance(exception, RuntimeError) \
|
||||||
and len(exception.args) == 1 \
|
and len(exception.args) == 1 \
|
||||||
and "CUDA out of memory." in exception.args[0]
|
and "CUDA" in exception.args[0] \
|
||||||
|
and "out of memory" in exception.args[0]
|
||||||
|
|
||||||
|
|
||||||
# based on https://github.com/BlackHC/toma/blob/master/toma/torch_cuda_memory.py
|
# based on https://github.com/BlackHC/toma/blob/master/toma/torch_cuda_memory.py
|
||||||
|
@ -76,4 +77,10 @@ def garbage_collection_cuda():
|
||||||
"""Garbage collection Torch (CUDA) memory."""
|
"""Garbage collection Torch (CUDA) memory."""
|
||||||
gc.collect()
|
gc.collect()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
|
try:
|
||||||
|
# This is the last thing that should cause an OOM error, but seemingly it can.
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
except RuntimeError as exception:
|
||||||
|
if not is_oom_error(exception):
|
||||||
|
# Only handle OOM errors
|
||||||
|
raise
|
||||||
|
|
Loading…
Reference in New Issue