Clean `cuda.empty_cache` usage (#8199)
This commit is contained in:
parent
57dce7244c
commit
74eb6cc7e9
|
@ -42,10 +42,7 @@ class GPUAccelerator(Accelerator):
|
|||
|
||||
def on_train_start(self) -> None:
|
||||
# clear cache before training
|
||||
# use context because of:
|
||||
# https://discuss.pytorch.org/t/out-of-memory-when-i-use-torch-cuda-empty-cache/57898
|
||||
with torch.cuda.device(self.root_device):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@staticmethod
|
||||
def set_nvidia_flags(local_rank: int) -> None:
|
||||
|
|
|
@ -41,7 +41,13 @@ from pytorch_lightning.utilities import (
|
|||
rank_zero_deprecation,
|
||||
rank_zero_warn,
|
||||
)
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_only, ReduceOp, sync_ddp_if_available
|
||||
from pytorch_lightning.utilities.distributed import (
|
||||
distributed_available,
|
||||
rank_zero_info,
|
||||
rank_zero_only,
|
||||
ReduceOp,
|
||||
sync_ddp_if_available,
|
||||
)
|
||||
from pytorch_lightning.utilities.exceptions import DeadlockDetectedException, MisconfigurationException
|
||||
from pytorch_lightning.utilities.seed import reset_seed
|
||||
|
||||
|
@ -347,7 +353,7 @@ class DDPPlugin(ParallelPlugin):
|
|||
self.cluster_environment.teardown()
|
||||
|
||||
def barrier(self, *args, **kwargs) -> None:
|
||||
if not torch.distributed.is_initialized():
|
||||
if not distributed_available():
|
||||
return
|
||||
if _TORCH_GREATER_EQUAL_1_8 and torch.distributed.get_backend() == "nccl":
|
||||
torch.distributed.barrier(device_ids=self.determine_ddp_device_ids())
|
||||
|
|
|
@ -36,7 +36,13 @@ from pytorch_lightning.utilities import (
|
|||
)
|
||||
from pytorch_lightning.utilities.cloud_io import atomic_save
|
||||
from pytorch_lightning.utilities.cloud_io import load as pl_load
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_only, ReduceOp, sync_ddp_if_available
|
||||
from pytorch_lightning.utilities.distributed import (
|
||||
distributed_available,
|
||||
rank_zero_info,
|
||||
rank_zero_only,
|
||||
ReduceOp,
|
||||
sync_ddp_if_available,
|
||||
)
|
||||
from pytorch_lightning.utilities.seed import reset_seed
|
||||
|
||||
if _TORCH_GREATER_EQUAL_1_8:
|
||||
|
@ -312,7 +318,7 @@ class DDPSpawnPlugin(ParallelPlugin):
|
|||
self.lightning_module.load_state_dict(ckpt)
|
||||
|
||||
def barrier(self, *args, **kwargs) -> None:
|
||||
if not torch.distributed.is_initialized():
|
||||
if not distributed_available():
|
||||
return
|
||||
if _TORCH_GREATER_EQUAL_1_8 and torch.distributed.get_backend() == "nccl":
|
||||
torch.distributed.barrier(device_ids=self.determine_ddp_device_ids())
|
||||
|
|
|
@ -15,13 +15,12 @@ from contextlib import ExitStack
|
|||
from typing import Any, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
from torch.optim.lr_scheduler import _LRScheduler, Optimizer
|
||||
|
||||
from pytorch_lightning.core.optimizer import LightningOptimizer
|
||||
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
|
||||
from pytorch_lightning.utilities import _HOROVOD_AVAILABLE
|
||||
from pytorch_lightning.utilities.distributed import group, rank_zero_only, ReduceOp
|
||||
from pytorch_lightning.utilities.distributed import distributed_available, group, rank_zero_only, ReduceOp
|
||||
|
||||
if _HOROVOD_AVAILABLE:
|
||||
import horovod.torch as hvd
|
||||
|
@ -125,7 +124,7 @@ class HorovodPlugin(ParallelPlugin):
|
|||
self.join()
|
||||
|
||||
def barrier(self, *args, **kwargs):
|
||||
if torch.distributed.is_initialized():
|
||||
if distributed_available():
|
||||
self.join()
|
||||
|
||||
def broadcast(self, obj: object, src: int = 0) -> object:
|
||||
|
|
|
@ -137,5 +137,4 @@ class ParallelPlugin(TrainingTypePlugin, ABC):
|
|||
# GPU teardown
|
||||
self.lightning_module.cpu()
|
||||
# clean up memory
|
||||
with torch.cuda.device(self.root_device):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.empty_cache()
|
||||
|
|
|
@ -85,5 +85,4 @@ class SingleDevicePlugin(TrainingTypePlugin):
|
|||
# GPU teardown
|
||||
self.lightning_module.cpu()
|
||||
# clean up memory
|
||||
with torch.cuda.device(self.root_device):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.empty_cache()
|
||||
|
|
|
@ -20,13 +20,7 @@ from typing import Optional, Union
|
|||
import torch
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.utilities import (
|
||||
_OMEGACONF_AVAILABLE,
|
||||
DeviceType,
|
||||
rank_zero_deprecation,
|
||||
rank_zero_info,
|
||||
rank_zero_warn,
|
||||
)
|
||||
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, rank_zero_deprecation, rank_zero_info, rank_zero_warn
|
||||
from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS
|
||||
|
@ -68,8 +62,7 @@ class CheckpointConnector:
|
|||
return
|
||||
|
||||
# clear cache before restore
|
||||
if self.trainer._device_type == DeviceType.GPU:
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Try to read the checkpoint file at `checkpoint_path`. If not exist, do not restore checkpoint.
|
||||
fs = get_filesystem(checkpoint_path)
|
||||
|
@ -87,8 +80,7 @@ class CheckpointConnector:
|
|||
self._loaded_checkpoint = dict()
|
||||
|
||||
# clear cache after restore
|
||||
if self.trainer._device_type == DeviceType.GPU:
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# wait for all to catch up
|
||||
self.trainer.training_type_plugin.barrier("CheckpointConnector.resume_end")
|
||||
|
|
|
@ -76,11 +76,10 @@ def is_out_of_cpu_memory(exception):
|
|||
def garbage_collection_cuda():
|
||||
"""Garbage collection Torch (CUDA) memory."""
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
try:
|
||||
# This is the last thing that should cause an OOM error, but seemingly it can.
|
||||
torch.cuda.empty_cache()
|
||||
except RuntimeError as exception:
|
||||
if not is_oom_error(exception):
|
||||
# Only handle OOM errors
|
||||
raise
|
||||
try:
|
||||
# This is the last thing that should cause an OOM error, but seemingly it can.
|
||||
torch.cuda.empty_cache()
|
||||
except RuntimeError as exception:
|
||||
if not is_oom_error(exception):
|
||||
# Only handle OOM errors
|
||||
raise
|
||||
|
|
Loading…
Reference in New Issue