Clean `cuda.empty_cache` usage (#8199)

This commit is contained in:
Carlos Mocholí 2021-06-30 13:04:24 +02:00 committed by GitHub
parent 57dce7244c
commit 74eb6cc7e9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 31 additions and 34 deletions

View File

@ -42,10 +42,7 @@ class GPUAccelerator(Accelerator):
def on_train_start(self) -> None:
# clear cache before training
# use context because of:
# https://discuss.pytorch.org/t/out-of-memory-when-i-use-torch-cuda-empty-cache/57898
with torch.cuda.device(self.root_device):
torch.cuda.empty_cache()
torch.cuda.empty_cache()
@staticmethod
def set_nvidia_flags(local_rank: int) -> None:

View File

@ -41,7 +41,13 @@ from pytorch_lightning.utilities import (
rank_zero_deprecation,
rank_zero_warn,
)
from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_only, ReduceOp, sync_ddp_if_available
from pytorch_lightning.utilities.distributed import (
distributed_available,
rank_zero_info,
rank_zero_only,
ReduceOp,
sync_ddp_if_available,
)
from pytorch_lightning.utilities.exceptions import DeadlockDetectedException, MisconfigurationException
from pytorch_lightning.utilities.seed import reset_seed
@ -347,7 +353,7 @@ class DDPPlugin(ParallelPlugin):
self.cluster_environment.teardown()
def barrier(self, *args, **kwargs) -> None:
if not torch.distributed.is_initialized():
if not distributed_available():
return
if _TORCH_GREATER_EQUAL_1_8 and torch.distributed.get_backend() == "nccl":
torch.distributed.barrier(device_ids=self.determine_ddp_device_ids())

View File

@ -36,7 +36,13 @@ from pytorch_lightning.utilities import (
)
from pytorch_lightning.utilities.cloud_io import atomic_save
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_only, ReduceOp, sync_ddp_if_available
from pytorch_lightning.utilities.distributed import (
distributed_available,
rank_zero_info,
rank_zero_only,
ReduceOp,
sync_ddp_if_available,
)
from pytorch_lightning.utilities.seed import reset_seed
if _TORCH_GREATER_EQUAL_1_8:
@ -312,7 +318,7 @@ class DDPSpawnPlugin(ParallelPlugin):
self.lightning_module.load_state_dict(ckpt)
def barrier(self, *args, **kwargs) -> None:
if not torch.distributed.is_initialized():
if not distributed_available():
return
if _TORCH_GREATER_EQUAL_1_8 and torch.distributed.get_backend() == "nccl":
torch.distributed.barrier(device_ids=self.determine_ddp_device_ids())

View File

@ -15,13 +15,12 @@ from contextlib import ExitStack
from typing import Any, List, Optional, Union
import torch
import torch.distributed
from torch.optim.lr_scheduler import _LRScheduler, Optimizer
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
from pytorch_lightning.utilities import _HOROVOD_AVAILABLE
from pytorch_lightning.utilities.distributed import group, rank_zero_only, ReduceOp
from pytorch_lightning.utilities.distributed import distributed_available, group, rank_zero_only, ReduceOp
if _HOROVOD_AVAILABLE:
import horovod.torch as hvd
@ -125,7 +124,7 @@ class HorovodPlugin(ParallelPlugin):
self.join()
def barrier(self, *args, **kwargs):
if torch.distributed.is_initialized():
if distributed_available():
self.join()
def broadcast(self, obj: object, src: int = 0) -> object:

View File

@ -137,5 +137,4 @@ class ParallelPlugin(TrainingTypePlugin, ABC):
# GPU teardown
self.lightning_module.cpu()
# clean up memory
with torch.cuda.device(self.root_device):
torch.cuda.empty_cache()
torch.cuda.empty_cache()

View File

@ -85,5 +85,4 @@ class SingleDevicePlugin(TrainingTypePlugin):
# GPU teardown
self.lightning_module.cpu()
# clean up memory
with torch.cuda.device(self.root_device):
torch.cuda.empty_cache()
torch.cuda.empty_cache()

View File

@ -20,13 +20,7 @@ from typing import Optional, Union
import torch
import pytorch_lightning as pl
from pytorch_lightning.utilities import (
_OMEGACONF_AVAILABLE,
DeviceType,
rank_zero_deprecation,
rank_zero_info,
rank_zero_warn,
)
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, rank_zero_deprecation, rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS
@ -68,8 +62,7 @@ class CheckpointConnector:
return
# clear cache before restore
if self.trainer._device_type == DeviceType.GPU:
torch.cuda.empty_cache()
torch.cuda.empty_cache()
# Try to read the checkpoint file at `checkpoint_path`. If not exist, do not restore checkpoint.
fs = get_filesystem(checkpoint_path)
@ -87,8 +80,7 @@ class CheckpointConnector:
self._loaded_checkpoint = dict()
# clear cache after restore
if self.trainer._device_type == DeviceType.GPU:
torch.cuda.empty_cache()
torch.cuda.empty_cache()
# wait for all to catch up
self.trainer.training_type_plugin.barrier("CheckpointConnector.resume_end")

View File

@ -76,11 +76,10 @@ def is_out_of_cpu_memory(exception):
def garbage_collection_cuda():
"""Garbage collection Torch (CUDA) memory."""
gc.collect()
if torch.cuda.is_available():
try:
# This is the last thing that should cause an OOM error, but seemingly it can.
torch.cuda.empty_cache()
except RuntimeError as exception:
if not is_oom_error(exception):
# Only handle OOM errors
raise
try:
# This is the last thing that should cause an OOM error, but seemingly it can.
torch.cuda.empty_cache()
except RuntimeError as exception:
if not is_oom_error(exception):
# Only handle OOM errors
raise