Add toma comments to auto_scale_batch_size (#1994)

* Add source comments

* Update training_tricks.rst
This commit is contained in:
Andreas Kirsch 2020-05-29 06:57:50 +01:00 committed by GitHub
parent cd3fed03a2
commit 3af3f37d43
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 6 additions and 1 deletions

View File

@ -39,7 +39,7 @@ Auto scaling of batch size
-------------------------- --------------------------
Auto scaling of batch size may be enabled to find the largest batch size that fits into Auto scaling of batch size may be enabled to find the largest batch size that fits into
memory. Larger batch size often yields better estimates of gradients, but may also result in memory. Larger batch size often yields better estimates of gradients, but may also result in
longer training time. longer training time. Inspired by https://github.com/BlackHC/toma.
.. seealso:: :class:`~pytorch_lightning.trainer.trainer.Trainer` .. seealso:: :class:`~pytorch_lightning.trainer.trainer.Trainer`

View File

@ -32,24 +32,29 @@ def is_oom_error(exception):
or is_out_of_cpu_memory(exception) or is_out_of_cpu_memory(exception)
# based on https://github.com/BlackHC/toma/blob/master/toma/torch_cuda_memory.py
def is_cuda_out_of_memory(exception): def is_cuda_out_of_memory(exception):
return isinstance(exception, RuntimeError) \ return isinstance(exception, RuntimeError) \
and len(exception.args) == 1 \ and len(exception.args) == 1 \
and "CUDA out of memory." in exception.args[0] and "CUDA out of memory." in exception.args[0]
# based on https://github.com/BlackHC/toma/blob/master/toma/torch_cuda_memory.py
def is_cudnn_snafu(exception): def is_cudnn_snafu(exception):
# For/because of https://github.com/pytorch/pytorch/issues/4107
return isinstance(exception, RuntimeError) \ return isinstance(exception, RuntimeError) \
and len(exception.args) == 1 \ and len(exception.args) == 1 \
and "cuDNN error: CUDNN_STATUS_NOT_SUPPORTED." in exception.args[0] and "cuDNN error: CUDNN_STATUS_NOT_SUPPORTED." in exception.args[0]
# based on https://github.com/BlackHC/toma/blob/master/toma/cpu_memory.py
def is_out_of_cpu_memory(exception): def is_out_of_cpu_memory(exception):
return isinstance(exception, RuntimeError) \ return isinstance(exception, RuntimeError) \
and len(exception.args) == 1 \ and len(exception.args) == 1 \
and "DefaultCPUAllocator: can't allocate memory" in exception.args[0] and "DefaultCPUAllocator: can't allocate memory" in exception.args[0]
# based on https://github.com/BlackHC/toma/blob/master/toma/torch_cuda_memory.py
def garbage_collection_cuda(): def garbage_collection_cuda():
"""Garbage collection Torch (CUDA) memory.""" """Garbage collection Torch (CUDA) memory."""
gc.collect() gc.collect()