diff --git a/pytorch_lightning/utilities/memory.py b/pytorch_lightning/utilities/memory.py index d67739c3b3..6c01390a8c 100644 --- a/pytorch_lightning/utilities/memory.py +++ b/pytorch_lightning/utilities/memory.py @@ -53,7 +53,8 @@ def is_oom_error(exception): def is_cuda_out_of_memory(exception): return isinstance(exception, RuntimeError) \ and len(exception.args) == 1 \ - and "CUDA out of memory." in exception.args[0] + and "CUDA" in exception.args[0] \ + and "out of memory" in exception.args[0] # based on https://github.com/BlackHC/toma/blob/master/toma/torch_cuda_memory.py @@ -76,4 +77,10 @@ def garbage_collection_cuda(): """Garbage collection Torch (CUDA) memory.""" gc.collect() if torch.cuda.is_available(): - torch.cuda.empty_cache() + try: + # This is the last thing that should cause an OOM error, but seemingly it can. + torch.cuda.empty_cache() + except RuntimeError as exception: + if not is_oom_error(exception): + # Only handle OOM errors + raise