diff --git a/src/lightning/fabric/fabric.py b/src/lightning/fabric/fabric.py index 5d32e9d96f..058e5e7c40 100644 --- a/src/lightning/fabric/fabric.py +++ b/src/lightning/fabric/fabric.py @@ -14,13 +14,12 @@ import inspect import os from collections.abc import Generator, Mapping, Sequence -from contextlib import contextmanager, nullcontext +from contextlib import AbstractContextManager, contextmanager, nullcontext from functools import partial from pathlib import Path from typing import ( Any, Callable, - ContextManager, Optional, Union, cast, @@ -484,7 +483,7 @@ class Fabric: ) raise ValueError("You have to specify either `clip_val` or `max_norm` to do gradient clipping!") - def autocast(self) -> ContextManager: + def autocast(self) -> AbstractContextManager: """A context manager to automatically convert operations for the chosen precision. Use this only if the `forward` method of your model does not cover all operations you wish to run with the @@ -634,7 +633,7 @@ class Fabric: if rank == 0: barrier() - def no_backward_sync(self, module: _FabricModule, enabled: bool = True) -> ContextManager: + def no_backward_sync(self, module: _FabricModule, enabled: bool = True) -> AbstractContextManager: r"""Skip gradient synchronization during backward to avoid redundant communication overhead. Use this context manager when performing gradient accumulation to speed up training with multiple devices. @@ -676,7 +675,7 @@ class Fabric: forward_module, _ = _unwrap_compiled(module._forward_module) return self._strategy._backward_sync_control.no_backward_sync(forward_module, enabled) - def sharded_model(self) -> ContextManager: + def sharded_model(self) -> AbstractContextManager: r"""Instantiate a model under this context manager to prepare it for model-parallel sharding. .. deprecated:: This context manager is deprecated in favor of :meth:`init_module`, use it instead. @@ -688,12 +687,12 @@ class Fabric: return self.strategy.module_sharded_context() return nullcontext() - def init_tensor(self) -> ContextManager: + def init_tensor(self) -> AbstractContextManager: """Tensors that you instantiate under this context manager will be created on the device right away and have the right data type depending on the precision setting in Fabric.""" return self._strategy.tensor_init_context() - def init_module(self, empty_init: Optional[bool] = None) -> ContextManager: + def init_module(self, empty_init: Optional[bool] = None) -> AbstractContextManager: """Instantiate the model and its parameters under this context manager to reduce peak memory usage. The parameters get created on the device and with the right data type right away without wasting memory being diff --git a/src/lightning/fabric/plugins/precision/amp.py b/src/lightning/fabric/plugins/precision/amp.py index c9205ea2b4..d5fc1f0c1c 100644 --- a/src/lightning/fabric/plugins/precision/amp.py +++ b/src/lightning/fabric/plugins/precision/amp.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, ContextManager, Literal, Optional +from contextlib import AbstractContextManager +from typing import Any, Literal, Optional import torch from lightning_utilities.core.apply_func import apply_to_collection @@ -59,7 +60,7 @@ class MixedPrecision(Precision): self._desired_input_dtype = torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16 @override - def forward_context(self) -> ContextManager: + def forward_context(self) -> AbstractContextManager: return torch.autocast(self.device, dtype=self._desired_input_dtype) @override diff --git a/src/lightning/fabric/plugins/precision/bitsandbytes.py b/src/lightning/fabric/plugins/precision/bitsandbytes.py index f0e80cbd77..ecb1d8a442 100644 --- a/src/lightning/fabric/plugins/precision/bitsandbytes.py +++ b/src/lightning/fabric/plugins/precision/bitsandbytes.py @@ -17,10 +17,10 @@ import math import os import warnings from collections import OrderedDict -from contextlib import ExitStack +from contextlib import AbstractContextManager, ExitStack from functools import partial from types import ModuleType -from typing import Any, Callable, ContextManager, Literal, Optional, cast +from typing import Any, Callable, Literal, Optional, cast import torch from lightning_utilities import apply_to_collection @@ -123,11 +123,11 @@ class BitsandbytesPrecision(Precision): return module @override - def tensor_init_context(self) -> ContextManager: + def tensor_init_context(self) -> AbstractContextManager: return _DtypeContextManager(self.dtype) @override - def module_init_context(self) -> ContextManager: + def module_init_context(self) -> AbstractContextManager: if self.ignore_modules: # cannot patch the Linear class if the user wants to skip some submodules raise RuntimeError( @@ -145,7 +145,7 @@ class BitsandbytesPrecision(Precision): return stack @override - def forward_context(self) -> ContextManager: + def forward_context(self) -> AbstractContextManager: return _DtypeContextManager(self.dtype) @override diff --git a/src/lightning/fabric/plugins/precision/deepspeed.py b/src/lightning/fabric/plugins/precision/deepspeed.py index 2fcaa38258..526095008f 100644 --- a/src/lightning/fabric/plugins/precision/deepspeed.py +++ b/src/lightning/fabric/plugins/precision/deepspeed.py @@ -11,8 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from contextlib import nullcontext -from typing import TYPE_CHECKING, Any, ContextManager, Literal +from contextlib import AbstractContextManager, nullcontext +from typing import TYPE_CHECKING, Any, Literal import torch from lightning_utilities.core.apply_func import apply_to_collection @@ -68,13 +68,13 @@ class DeepSpeedPrecision(Precision): return module @override - def tensor_init_context(self) -> ContextManager: + def tensor_init_context(self) -> AbstractContextManager: if "true" not in self.precision: return nullcontext() return _DtypeContextManager(self._desired_dtype) @override - def module_init_context(self) -> ContextManager: + def module_init_context(self) -> AbstractContextManager: return self.tensor_init_context() @override diff --git a/src/lightning/fabric/plugins/precision/double.py b/src/lightning/fabric/plugins/precision/double.py index 0a857499f3..9aa0365a55 100644 --- a/src/lightning/fabric/plugins/precision/double.py +++ b/src/lightning/fabric/plugins/precision/double.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, ContextManager, Literal +from contextlib import AbstractContextManager +from typing import Any, Literal import torch from lightning_utilities.core.apply_func import apply_to_collection @@ -33,15 +34,15 @@ class DoublePrecision(Precision): return module.double() @override - def tensor_init_context(self) -> ContextManager: + def tensor_init_context(self) -> AbstractContextManager: return _DtypeContextManager(torch.double) @override - def module_init_context(self) -> ContextManager: + def module_init_context(self) -> AbstractContextManager: return self.tensor_init_context() @override - def forward_context(self) -> ContextManager: + def forward_context(self) -> AbstractContextManager: return self.tensor_init_context() @override diff --git a/src/lightning/fabric/plugins/precision/fsdp.py b/src/lightning/fabric/plugins/precision/fsdp.py index 7e3170c2c7..0b78ad72a4 100644 --- a/src/lightning/fabric/plugins/precision/fsdp.py +++ b/src/lightning/fabric/plugins/precision/fsdp.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Any, ContextManager, Literal, Optional +from contextlib import AbstractContextManager +from typing import TYPE_CHECKING, Any, Literal, Optional import torch from lightning_utilities import apply_to_collection @@ -100,15 +101,15 @@ class FSDPPrecision(Precision): ) @override - def tensor_init_context(self) -> ContextManager: + def tensor_init_context(self) -> AbstractContextManager: return _DtypeContextManager(self._desired_input_dtype) @override - def module_init_context(self) -> ContextManager: + def module_init_context(self) -> AbstractContextManager: return _DtypeContextManager(self.mixed_precision_config.param_dtype or torch.float32) @override - def forward_context(self) -> ContextManager: + def forward_context(self) -> AbstractContextManager: if "mixed" in self.precision: return torch.autocast("cuda", dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16)) return self.tensor_init_context() diff --git a/src/lightning/fabric/plugins/precision/half.py b/src/lightning/fabric/plugins/precision/half.py index 32ca7da815..fcb28ad332 100644 --- a/src/lightning/fabric/plugins/precision/half.py +++ b/src/lightning/fabric/plugins/precision/half.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, ContextManager, Literal +from contextlib import AbstractContextManager +from typing import Any, Literal import torch from lightning_utilities.core.apply_func import apply_to_collection @@ -42,15 +43,15 @@ class HalfPrecision(Precision): return module.to(dtype=self._desired_input_dtype) @override - def tensor_init_context(self) -> ContextManager: + def tensor_init_context(self) -> AbstractContextManager: return _DtypeContextManager(self._desired_input_dtype) @override - def module_init_context(self) -> ContextManager: + def module_init_context(self) -> AbstractContextManager: return self.tensor_init_context() @override - def forward_context(self) -> ContextManager: + def forward_context(self) -> AbstractContextManager: return self.tensor_init_context() @override diff --git a/src/lightning/fabric/plugins/precision/precision.py b/src/lightning/fabric/plugins/precision/precision.py index d32ccbabb4..1dfab2a7bc 100644 --- a/src/lightning/fabric/plugins/precision/precision.py +++ b/src/lightning/fabric/plugins/precision/precision.py @@ -11,8 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from contextlib import nullcontext -from typing import Any, ContextManager, Literal, Optional, Union +from contextlib import AbstractContextManager, nullcontext +from typing import Any, Literal, Optional, Union from torch import Tensor from torch.nn import Module @@ -53,11 +53,11 @@ class Precision: """ return module - def tensor_init_context(self) -> ContextManager: + def tensor_init_context(self) -> AbstractContextManager: """Controls how tensors get created (device, dtype).""" return nullcontext() - def module_init_context(self) -> ContextManager: + def module_init_context(self) -> AbstractContextManager: """Instantiate module parameters or tensors in the precision type this plugin handles. This is optional and depends on the precision limitations during optimization. @@ -65,7 +65,7 @@ class Precision: """ return nullcontext() - def forward_context(self) -> ContextManager: + def forward_context(self) -> AbstractContextManager: """A contextmanager for managing model forward/training_step/evaluation_step/predict_step.""" return nullcontext() diff --git a/src/lightning/fabric/plugins/precision/transformer_engine.py b/src/lightning/fabric/plugins/precision/transformer_engine.py index ddc30c2e1b..c3ef84a453 100644 --- a/src/lightning/fabric/plugins/precision/transformer_engine.py +++ b/src/lightning/fabric/plugins/precision/transformer_engine.py @@ -13,8 +13,8 @@ # limitations under the License. import logging from collections.abc import Mapping -from contextlib import ExitStack -from typing import TYPE_CHECKING, Any, ContextManager, Literal, Optional, Union +from contextlib import AbstractContextManager, ExitStack +from typing import TYPE_CHECKING, Any, Literal, Optional, Union import torch from lightning_utilities import apply_to_collection @@ -107,11 +107,11 @@ class TransformerEnginePrecision(Precision): return module @override - def tensor_init_context(self) -> ContextManager: + def tensor_init_context(self) -> AbstractContextManager: return _DtypeContextManager(self.weights_dtype) @override - def module_init_context(self) -> ContextManager: + def module_init_context(self) -> AbstractContextManager: dtype_ctx = self.tensor_init_context() stack = ExitStack() if self.replace_layers: @@ -126,7 +126,7 @@ class TransformerEnginePrecision(Precision): return stack @override - def forward_context(self) -> ContextManager: + def forward_context(self) -> AbstractContextManager: dtype_ctx = _DtypeContextManager(self.weights_dtype) fallback_autocast_ctx = torch.autocast(device_type="cuda", dtype=self.fallback_compute_dtype) import transformer_engine.pytorch as te diff --git a/src/lightning/fabric/strategies/ddp.py b/src/lightning/fabric/strategies/ddp.py index 2fd325c9c3..ce47e4e403 100644 --- a/src/lightning/fabric/strategies/ddp.py +++ b/src/lightning/fabric/strategies/ddp.py @@ -11,9 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from contextlib import nullcontext +from contextlib import AbstractContextManager, nullcontext from datetime import timedelta -from typing import Any, ContextManager, Literal, Optional, Union +from typing import Any, Literal, Optional, Union import torch import torch.distributed @@ -231,7 +231,7 @@ class DDPStrategy(ParallelStrategy): class _DDPBackwardSyncControl(_BackwardSyncControl): @override - def no_backward_sync(self, module: Module, enabled: bool) -> ContextManager: + def no_backward_sync(self, module: Module, enabled: bool) -> AbstractContextManager: """Blocks gradient synchronization inside the :class:`~torch.nn.parallel.distributed.DistributedDataParallel` wrapper.""" if not enabled: diff --git a/src/lightning/fabric/strategies/deepspeed.py b/src/lightning/fabric/strategies/deepspeed.py index 1cd5690966..03d90cd5df 100644 --- a/src/lightning/fabric/strategies/deepspeed.py +++ b/src/lightning/fabric/strategies/deepspeed.py @@ -17,10 +17,10 @@ import logging import os import platform from collections.abc import Mapping -from contextlib import ExitStack +from contextlib import AbstractContextManager, ExitStack from itertools import chain from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, ContextManager, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Optional, Union import torch from lightning_utilities.core.imports import RequirementCache @@ -353,7 +353,7 @@ class DeepSpeedStrategy(DDPStrategy, _Sharded): raise NotImplementedError(self._err_msg_joint_setup_required()) @override - def module_init_context(self, empty_init: Optional[bool] = None) -> ContextManager: + def module_init_context(self, empty_init: Optional[bool] = None) -> AbstractContextManager: if self.zero_stage_3 and empty_init is False: raise NotImplementedError( f"`{empty_init=}` is not a valid choice with `DeepSpeedStrategy` when ZeRO stage 3 is enabled." @@ -366,7 +366,7 @@ class DeepSpeedStrategy(DDPStrategy, _Sharded): return stack @override - def module_sharded_context(self) -> ContextManager: + def module_sharded_context(self) -> AbstractContextManager: # Current limitation in Fabric: The config needs to be fully determined at the time of calling the context # manager. Later modifications through e.g. `Fabric.setup()` won't have an effect here. diff --git a/src/lightning/fabric/strategies/fsdp.py b/src/lightning/fabric/strategies/fsdp.py index 7e598191ce..9dd5b2c62d 100644 --- a/src/lightning/fabric/strategies/fsdp.py +++ b/src/lightning/fabric/strategies/fsdp.py @@ -14,7 +14,7 @@ import shutil import warnings from collections.abc import Generator -from contextlib import ExitStack, nullcontext +from contextlib import AbstractContextManager, ExitStack, nullcontext from datetime import timedelta from functools import partial from pathlib import Path @@ -22,7 +22,6 @@ from typing import ( TYPE_CHECKING, Any, Callable, - ContextManager, Literal, Optional, Union, @@ -335,7 +334,7 @@ class FSDPStrategy(ParallelStrategy, _Sharded): pass @override - def module_init_context(self, empty_init: Optional[bool] = None) -> ContextManager: + def module_init_context(self, empty_init: Optional[bool] = None) -> AbstractContextManager: precision_init_ctx = self.precision.module_init_context() module_sharded_ctx = self.module_sharded_context() stack = ExitStack() @@ -349,7 +348,7 @@ class FSDPStrategy(ParallelStrategy, _Sharded): return stack @override - def module_sharded_context(self) -> ContextManager: + def module_sharded_context(self) -> AbstractContextManager: from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel from torch.distributed.fsdp.wrap import enable_wrap @@ -740,7 +739,7 @@ def _setup_activation_checkpointing(module: Module, activation_checkpointing_kwa class _FSDPBackwardSyncControl(_BackwardSyncControl): @override - def no_backward_sync(self, module: Module, enabled: bool) -> ContextManager: + def no_backward_sync(self, module: Module, enabled: bool) -> AbstractContextManager: """Blocks gradient synchronization inside the :class:`~torch.distributed.fsdp.FullyShardedDataParallel` wrapper.""" if not enabled: diff --git a/src/lightning/fabric/strategies/model_parallel.py b/src/lightning/fabric/strategies/model_parallel.py index 435dc3dad1..ad1fc19074 100644 --- a/src/lightning/fabric/strategies/model_parallel.py +++ b/src/lightning/fabric/strategies/model_parallel.py @@ -14,10 +14,10 @@ import itertools import shutil from collections.abc import Generator -from contextlib import ExitStack +from contextlib import AbstractContextManager, ExitStack from datetime import timedelta from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, ContextManager, Literal, Optional, TypeVar, Union +from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, TypeVar, Union import torch from lightning_utilities.core.rank_zero import rank_zero_only as utils_rank_zero_only @@ -195,7 +195,7 @@ class ModelParallelStrategy(ParallelStrategy): pass @override - def module_init_context(self, empty_init: Optional[bool] = None) -> ContextManager: + def module_init_context(self, empty_init: Optional[bool] = None) -> AbstractContextManager: precision_init_ctx = self.precision.module_init_context() stack = ExitStack() if empty_init: @@ -319,12 +319,12 @@ class ModelParallelStrategy(ParallelStrategy): class _ParallelBackwardSyncControl(_BackwardSyncControl): @override - def no_backward_sync(self, module: Module, enabled: bool) -> ContextManager: + def no_backward_sync(self, module: Module, enabled: bool) -> AbstractContextManager: """Blocks gradient synchronization inside the FSDP2 modules.""" return _FSDPNoSync(module=module, enabled=enabled) -class _FSDPNoSync(ContextManager): +class _FSDPNoSync(AbstractContextManager): def __init__(self, module: Module, enabled: bool) -> None: self._module = module self._enabled = enabled diff --git a/src/lightning/fabric/strategies/strategy.py b/src/lightning/fabric/strategies/strategy.py index c83160a3a2..4daad9b954 100644 --- a/src/lightning/fabric/strategies/strategy.py +++ b/src/lightning/fabric/strategies/strategy.py @@ -14,8 +14,8 @@ import logging from abc import ABC, abstractmethod from collections.abc import Iterable -from contextlib import ExitStack -from typing import Any, Callable, ContextManager, Optional, TypeVar, Union +from contextlib import AbstractContextManager, ExitStack +from typing import Any, Callable, Optional, TypeVar, Union import torch from torch import Tensor @@ -118,7 +118,7 @@ class Strategy(ABC): """ return dataloader - def tensor_init_context(self) -> ContextManager: + def tensor_init_context(self) -> AbstractContextManager: """Controls how tensors get created (device, dtype).""" precision_init_ctx = self.precision.tensor_init_context() stack = ExitStack() @@ -126,7 +126,7 @@ class Strategy(ABC): stack.enter_context(precision_init_ctx) return stack - def module_init_context(self, empty_init: Optional[bool] = None) -> ContextManager: + def module_init_context(self, empty_init: Optional[bool] = None) -> AbstractContextManager: """A context manager wrapping the model instantiation. Here, the strategy can control how the parameters of the model get created (device, dtype) and or apply other @@ -422,7 +422,7 @@ class _BackwardSyncControl(ABC): """ @abstractmethod - def no_backward_sync(self, module: Module, enabled: bool) -> ContextManager: + def no_backward_sync(self, module: Module, enabled: bool) -> AbstractContextManager: """Blocks the synchronization of gradients during the backward pass. This is a context manager. It is only effective if it wraps a call to `.backward()`. @@ -434,7 +434,7 @@ class _Sharded(ABC): """Mixin-interface for any :class:`Strategy` that wants to expose functionality for sharding model parameters.""" @abstractmethod - def module_sharded_context(self) -> ContextManager: + def module_sharded_context(self) -> AbstractContextManager: """A context manager that goes over the instantiation of an :class:`torch.nn.Module` and handles sharding of parameters on creation. diff --git a/src/lightning/fabric/strategies/xla_fsdp.py b/src/lightning/fabric/strategies/xla_fsdp.py index 9b2ffe3505..935ef72713 100644 --- a/src/lightning/fabric/strategies/xla_fsdp.py +++ b/src/lightning/fabric/strategies/xla_fsdp.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import io -from contextlib import ExitStack, nullcontext +from contextlib import AbstractContextManager, ExitStack, nullcontext from functools import partial from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, ContextManager, Literal, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union import torch from torch import Tensor @@ -225,7 +225,7 @@ class XLAFSDPStrategy(ParallelStrategy, _Sharded): def module_to_device(self, module: Module) -> None: pass - def module_init_context(self, empty_init: Optional[bool] = None) -> ContextManager: + def module_init_context(self, empty_init: Optional[bool] = None) -> AbstractContextManager: precision_init_ctx = self.precision.module_init_context() module_sharded_ctx = self.module_sharded_context() stack = ExitStack() @@ -235,7 +235,7 @@ class XLAFSDPStrategy(ParallelStrategy, _Sharded): return stack @override - def module_sharded_context(self) -> ContextManager: + def module_sharded_context(self) -> AbstractContextManager: return nullcontext() @override @@ -668,7 +668,7 @@ def _activation_checkpointing_kwargs(policy: Optional[_POLICY_SET], kwargs: dict class _XLAFSDPBackwardSyncControl(_BackwardSyncControl): @override - def no_backward_sync(self, module: Module, enabled: bool) -> ContextManager: + def no_backward_sync(self, module: Module, enabled: bool) -> AbstractContextManager: """Blocks gradient synchronization inside the :class:`~torch_xla.distributed.fsdp.XlaFullyShardedDataParallel` wrapper.""" if not enabled: diff --git a/src/lightning/pytorch/loops/utilities.py b/src/lightning/pytorch/loops/utilities.py index 43d7e333ef..2aaf877c89 100644 --- a/src/lightning/pytorch/loops/utilities.py +++ b/src/lightning/pytorch/loops/utilities.py @@ -13,8 +13,8 @@ # limitations under the License. import inspect from collections.abc import Generator -from contextlib import contextmanager -from typing import Any, Callable, ContextManager, Optional +from contextlib import AbstractContextManager, contextmanager +from typing import Any, Callable, Optional import torch import torch.distributed as dist @@ -160,7 +160,7 @@ def _no_grad_context(loop_run: Callable) -> Callable: raise TypeError(f"`{type(self).__name__}` needs to be a Loop.") if not hasattr(self, "inference_mode"): raise TypeError(f"`{type(self).__name__}.inference_mode` needs to be defined") - context_manager: type[ContextManager] + context_manager: type[AbstractContextManager] if _distributed_is_initialized() and dist.get_backend() == "gloo": # gloo backend does not work properly. # https://github.com/Lightning-AI/lightning/pull/12715/files#r854569110 diff --git a/src/lightning/pytorch/plugins/precision/deepspeed.py b/src/lightning/pytorch/plugins/precision/deepspeed.py index e1e90281cf..9225e3bb9e 100644 --- a/src/lightning/pytorch/plugins/precision/deepspeed.py +++ b/src/lightning/pytorch/plugins/precision/deepspeed.py @@ -11,8 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from contextlib import nullcontext -from typing import TYPE_CHECKING, Any, Callable, ContextManager, Optional, Union +from contextlib import AbstractContextManager, nullcontext +from typing import TYPE_CHECKING, Any, Callable, Optional, Union import torch from lightning_utilities import apply_to_collection @@ -80,13 +80,13 @@ class DeepSpeedPrecision(Precision): return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self._desired_dtype) @override - def tensor_init_context(self) -> ContextManager: + def tensor_init_context(self) -> AbstractContextManager: if "true" not in self.precision: return nullcontext() return _DtypeContextManager(self._desired_dtype) @override - def module_init_context(self) -> ContextManager: + def module_init_context(self) -> AbstractContextManager: return self.tensor_init_context() @override diff --git a/src/lightning/pytorch/plugins/precision/double.py b/src/lightning/pytorch/plugins/precision/double.py index 5d0af8b992..efa1aa008a 100644 --- a/src/lightning/pytorch/plugins/precision/double.py +++ b/src/lightning/pytorch/plugins/precision/double.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Generator -from contextlib import contextmanager -from typing import Any, ContextManager, Literal +from contextlib import AbstractContextManager, contextmanager +from typing import Any, Literal import torch import torch.nn as nn @@ -38,11 +38,11 @@ class DoublePrecision(Precision): return module.double() @override - def tensor_init_context(self) -> ContextManager: + def tensor_init_context(self) -> AbstractContextManager: return _DtypeContextManager(torch.float64) @override - def module_init_context(self) -> ContextManager: + def module_init_context(self) -> AbstractContextManager: return self.tensor_init_context() @override diff --git a/src/lightning/pytorch/plugins/precision/fsdp.py b/src/lightning/pytorch/plugins/precision/fsdp.py index 64cee32359..7029497c17 100644 --- a/src/lightning/pytorch/plugins/precision/fsdp.py +++ b/src/lightning/pytorch/plugins/precision/fsdp.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Any, Callable, ContextManager, Optional +from contextlib import AbstractContextManager +from typing import TYPE_CHECKING, Any, Callable, Optional import torch from lightning_utilities import apply_to_collection @@ -109,15 +110,15 @@ class FSDPPrecision(Precision): ) @override - def tensor_init_context(self) -> ContextManager: + def tensor_init_context(self) -> AbstractContextManager: return _DtypeContextManager(self._desired_input_dtype) @override - def module_init_context(self) -> ContextManager: + def module_init_context(self) -> AbstractContextManager: return _DtypeContextManager(self.mixed_precision_config.param_dtype or torch.float32) @override - def forward_context(self) -> ContextManager: + def forward_context(self) -> AbstractContextManager: if "mixed" in self.precision: return torch.autocast("cuda", dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16)) return _DtypeContextManager(self._desired_input_dtype) diff --git a/src/lightning/pytorch/plugins/precision/half.py b/src/lightning/pytorch/plugins/precision/half.py index 2ad30de2b8..fe9deb44c3 100644 --- a/src/lightning/pytorch/plugins/precision/half.py +++ b/src/lightning/pytorch/plugins/precision/half.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Generator -from contextlib import contextmanager -from typing import Any, ContextManager, Literal +from contextlib import AbstractContextManager, contextmanager +from typing import Any, Literal import torch from lightning_utilities import apply_to_collection @@ -44,11 +44,11 @@ class HalfPrecision(Precision): return module.to(dtype=self._desired_input_dtype) @override - def tensor_init_context(self) -> ContextManager: + def tensor_init_context(self) -> AbstractContextManager: return _DtypeContextManager(self._desired_input_dtype) @override - def module_init_context(self) -> ContextManager: + def module_init_context(self) -> AbstractContextManager: return self.tensor_init_context() @override diff --git a/src/lightning/pytorch/profilers/pytorch.py b/src/lightning/pytorch/profilers/pytorch.py index 6ba7466ee0..e264d5154f 100644 --- a/src/lightning/pytorch/profilers/pytorch.py +++ b/src/lightning/pytorch/profilers/pytorch.py @@ -16,9 +16,10 @@ import inspect import logging import os +from contextlib import AbstractContextManager from functools import lru_cache, partial from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, ContextManager, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Optional, Union import torch from torch import Tensor, nn @@ -305,7 +306,7 @@ class PyTorchProfiler(Profiler): self.function_events: Optional[EventList] = None self._lightning_module: Optional[LightningModule] = None # set by ProfilerConnector self._register: Optional[RegisterRecordFunction] = None - self._parent_profiler: Optional[ContextManager] = None + self._parent_profiler: Optional[AbstractContextManager] = None self._recording_map: dict[str, record_function] = {} self._start_action_name: Optional[str] = None self._schedule: Optional[ScheduleWrapper] = None diff --git a/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py b/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py index d57ac76f04..8d3a1800e1 100644 --- a/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py +++ b/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py @@ -13,8 +13,9 @@ # limitations under the License. import logging import os +from contextlib import AbstractContextManager from pathlib import Path -from typing import ContextManager, Optional +from typing import Optional from unittest import mock import pytest @@ -382,5 +383,5 @@ def test_misconfiguration_error_with_sharded_model(tmp_path, strategy: str): trainer.fit(model) -def _backward_patch(trainer: Trainer) -> ContextManager: +def _backward_patch(trainer: Trainer) -> AbstractContextManager: return mock.patch.object(Strategy, "backward", wraps=trainer.strategy.backward)