From 74eb6cc7e90b4b06c0136504101d6f5c343e93dc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Wed, 30 Jun 2021 13:04:24 +0200 Subject: [PATCH] Clean `cuda.empty_cache` usage (#8199) --- pytorch_lightning/accelerators/gpu.py | 5 +---- pytorch_lightning/plugins/training_type/ddp.py | 10 ++++++++-- .../plugins/training_type/ddp_spawn.py | 10 ++++++++-- .../plugins/training_type/horovod.py | 5 ++--- .../plugins/training_type/parallel.py | 3 +-- .../plugins/training_type/single_device.py | 3 +-- .../trainer/connectors/checkpoint_connector.py | 14 +++----------- pytorch_lightning/utilities/memory.py | 15 +++++++-------- 8 files changed, 31 insertions(+), 34 deletions(-) diff --git a/pytorch_lightning/accelerators/gpu.py b/pytorch_lightning/accelerators/gpu.py index 7543a2b794..1c5ff56d80 100644 --- a/pytorch_lightning/accelerators/gpu.py +++ b/pytorch_lightning/accelerators/gpu.py @@ -42,10 +42,7 @@ class GPUAccelerator(Accelerator): def on_train_start(self) -> None: # clear cache before training - # use context because of: - # https://discuss.pytorch.org/t/out-of-memory-when-i-use-torch-cuda-empty-cache/57898 - with torch.cuda.device(self.root_device): - torch.cuda.empty_cache() + torch.cuda.empty_cache() @staticmethod def set_nvidia_flags(local_rank: int) -> None: diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 64c402cc5c..aba2bf242b 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -41,7 +41,13 @@ from pytorch_lightning.utilities import ( rank_zero_deprecation, rank_zero_warn, ) -from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_only, ReduceOp, sync_ddp_if_available +from pytorch_lightning.utilities.distributed import ( + distributed_available, + rank_zero_info, + rank_zero_only, + ReduceOp, + sync_ddp_if_available, +) from pytorch_lightning.utilities.exceptions import DeadlockDetectedException, MisconfigurationException from pytorch_lightning.utilities.seed import reset_seed @@ -347,7 +353,7 @@ class DDPPlugin(ParallelPlugin): self.cluster_environment.teardown() def barrier(self, *args, **kwargs) -> None: - if not torch.distributed.is_initialized(): + if not distributed_available(): return if _TORCH_GREATER_EQUAL_1_8 and torch.distributed.get_backend() == "nccl": torch.distributed.barrier(device_ids=self.determine_ddp_device_ids()) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 6c7b24d6fe..e5084adb1a 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -36,7 +36,13 @@ from pytorch_lightning.utilities import ( ) from pytorch_lightning.utilities.cloud_io import atomic_save from pytorch_lightning.utilities.cloud_io import load as pl_load -from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_only, ReduceOp, sync_ddp_if_available +from pytorch_lightning.utilities.distributed import ( + distributed_available, + rank_zero_info, + rank_zero_only, + ReduceOp, + sync_ddp_if_available, +) from pytorch_lightning.utilities.seed import reset_seed if _TORCH_GREATER_EQUAL_1_8: @@ -312,7 +318,7 @@ class DDPSpawnPlugin(ParallelPlugin): self.lightning_module.load_state_dict(ckpt) def barrier(self, *args, **kwargs) -> None: - if not torch.distributed.is_initialized(): + if not distributed_available(): return if _TORCH_GREATER_EQUAL_1_8 and torch.distributed.get_backend() == "nccl": torch.distributed.barrier(device_ids=self.determine_ddp_device_ids()) diff --git a/pytorch_lightning/plugins/training_type/horovod.py b/pytorch_lightning/plugins/training_type/horovod.py index cbd9e80dab..a402f4b19a 100644 --- a/pytorch_lightning/plugins/training_type/horovod.py +++ b/pytorch_lightning/plugins/training_type/horovod.py @@ -15,13 +15,12 @@ from contextlib import ExitStack from typing import Any, List, Optional, Union import torch -import torch.distributed from torch.optim.lr_scheduler import _LRScheduler, Optimizer from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin from pytorch_lightning.utilities import _HOROVOD_AVAILABLE -from pytorch_lightning.utilities.distributed import group, rank_zero_only, ReduceOp +from pytorch_lightning.utilities.distributed import distributed_available, group, rank_zero_only, ReduceOp if _HOROVOD_AVAILABLE: import horovod.torch as hvd @@ -125,7 +124,7 @@ class HorovodPlugin(ParallelPlugin): self.join() def barrier(self, *args, **kwargs): - if torch.distributed.is_initialized(): + if distributed_available(): self.join() def broadcast(self, obj: object, src: int = 0) -> object: diff --git a/pytorch_lightning/plugins/training_type/parallel.py b/pytorch_lightning/plugins/training_type/parallel.py index a2a35a6e9e..e1c9a7149d 100644 --- a/pytorch_lightning/plugins/training_type/parallel.py +++ b/pytorch_lightning/plugins/training_type/parallel.py @@ -137,5 +137,4 @@ class ParallelPlugin(TrainingTypePlugin, ABC): # GPU teardown self.lightning_module.cpu() # clean up memory - with torch.cuda.device(self.root_device): - torch.cuda.empty_cache() + torch.cuda.empty_cache() diff --git a/pytorch_lightning/plugins/training_type/single_device.py b/pytorch_lightning/plugins/training_type/single_device.py index 1816f5838c..d4a328902e 100644 --- a/pytorch_lightning/plugins/training_type/single_device.py +++ b/pytorch_lightning/plugins/training_type/single_device.py @@ -85,5 +85,4 @@ class SingleDevicePlugin(TrainingTypePlugin): # GPU teardown self.lightning_module.cpu() # clean up memory - with torch.cuda.device(self.root_device): - torch.cuda.empty_cache() + torch.cuda.empty_cache() diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index f1620c10bb..24d21ace4a 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -20,13 +20,7 @@ from typing import Optional, Union import torch import pytorch_lightning as pl -from pytorch_lightning.utilities import ( - _OMEGACONF_AVAILABLE, - DeviceType, - rank_zero_deprecation, - rank_zero_info, - rank_zero_warn, -) +from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, rank_zero_deprecation, rank_zero_info, rank_zero_warn from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS @@ -68,8 +62,7 @@ class CheckpointConnector: return # clear cache before restore - if self.trainer._device_type == DeviceType.GPU: - torch.cuda.empty_cache() + torch.cuda.empty_cache() # Try to read the checkpoint file at `checkpoint_path`. If not exist, do not restore checkpoint. fs = get_filesystem(checkpoint_path) @@ -87,8 +80,7 @@ class CheckpointConnector: self._loaded_checkpoint = dict() # clear cache after restore - if self.trainer._device_type == DeviceType.GPU: - torch.cuda.empty_cache() + torch.cuda.empty_cache() # wait for all to catch up self.trainer.training_type_plugin.barrier("CheckpointConnector.resume_end") diff --git a/pytorch_lightning/utilities/memory.py b/pytorch_lightning/utilities/memory.py index 6c01390a8c..0ae88e8995 100644 --- a/pytorch_lightning/utilities/memory.py +++ b/pytorch_lightning/utilities/memory.py @@ -76,11 +76,10 @@ def is_out_of_cpu_memory(exception): def garbage_collection_cuda(): """Garbage collection Torch (CUDA) memory.""" gc.collect() - if torch.cuda.is_available(): - try: - # This is the last thing that should cause an OOM error, but seemingly it can. - torch.cuda.empty_cache() - except RuntimeError as exception: - if not is_oom_error(exception): - # Only handle OOM errors - raise + try: + # This is the last thing that should cause an OOM error, but seemingly it can. + torch.cuda.empty_cache() + except RuntimeError as exception: + if not is_oom_error(exception): + # Only handle OOM errors + raise