contextlib.AbstractContextManager
This commit is contained in:
parent
b9920bdd78
commit
a35f1fe34b
|
@ -14,13 +14,12 @@
|
|||
import inspect
|
||||
import os
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from contextlib import AbstractContextManager, contextmanager, nullcontext
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
ContextManager,
|
||||
Optional,
|
||||
Union,
|
||||
cast,
|
||||
|
@ -484,7 +483,7 @@ class Fabric:
|
|||
)
|
||||
raise ValueError("You have to specify either `clip_val` or `max_norm` to do gradient clipping!")
|
||||
|
||||
def autocast(self) -> ContextManager:
|
||||
def autocast(self) -> AbstractContextManager:
|
||||
"""A context manager to automatically convert operations for the chosen precision.
|
||||
|
||||
Use this only if the `forward` method of your model does not cover all operations you wish to run with the
|
||||
|
@ -634,7 +633,7 @@ class Fabric:
|
|||
if rank == 0:
|
||||
barrier()
|
||||
|
||||
def no_backward_sync(self, module: _FabricModule, enabled: bool = True) -> ContextManager:
|
||||
def no_backward_sync(self, module: _FabricModule, enabled: bool = True) -> AbstractContextManager:
|
||||
r"""Skip gradient synchronization during backward to avoid redundant communication overhead.
|
||||
|
||||
Use this context manager when performing gradient accumulation to speed up training with multiple devices.
|
||||
|
@ -676,7 +675,7 @@ class Fabric:
|
|||
forward_module, _ = _unwrap_compiled(module._forward_module)
|
||||
return self._strategy._backward_sync_control.no_backward_sync(forward_module, enabled)
|
||||
|
||||
def sharded_model(self) -> ContextManager:
|
||||
def sharded_model(self) -> AbstractContextManager:
|
||||
r"""Instantiate a model under this context manager to prepare it for model-parallel sharding.
|
||||
|
||||
.. deprecated:: This context manager is deprecated in favor of :meth:`init_module`, use it instead.
|
||||
|
@ -688,12 +687,12 @@ class Fabric:
|
|||
return self.strategy.module_sharded_context()
|
||||
return nullcontext()
|
||||
|
||||
def init_tensor(self) -> ContextManager:
|
||||
def init_tensor(self) -> AbstractContextManager:
|
||||
"""Tensors that you instantiate under this context manager will be created on the device right away and have
|
||||
the right data type depending on the precision setting in Fabric."""
|
||||
return self._strategy.tensor_init_context()
|
||||
|
||||
def init_module(self, empty_init: Optional[bool] = None) -> ContextManager:
|
||||
def init_module(self, empty_init: Optional[bool] = None) -> AbstractContextManager:
|
||||
"""Instantiate the model and its parameters under this context manager to reduce peak memory usage.
|
||||
|
||||
The parameters get created on the device and with the right data type right away without wasting memory being
|
||||
|
|
|
@ -11,7 +11,8 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, ContextManager, Literal, Optional
|
||||
from contextlib import AbstractContextManager
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
import torch
|
||||
from lightning_utilities.core.apply_func import apply_to_collection
|
||||
|
@ -59,7 +60,7 @@ class MixedPrecision(Precision):
|
|||
self._desired_input_dtype = torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16
|
||||
|
||||
@override
|
||||
def forward_context(self) -> ContextManager:
|
||||
def forward_context(self) -> AbstractContextManager:
|
||||
return torch.autocast(self.device, dtype=self._desired_input_dtype)
|
||||
|
||||
@override
|
||||
|
|
|
@ -17,10 +17,10 @@ import math
|
|||
import os
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
from contextlib import ExitStack
|
||||
from contextlib import AbstractContextManager, ExitStack
|
||||
from functools import partial
|
||||
from types import ModuleType
|
||||
from typing import Any, Callable, ContextManager, Literal, Optional, cast
|
||||
from typing import Any, Callable, Literal, Optional, cast
|
||||
|
||||
import torch
|
||||
from lightning_utilities import apply_to_collection
|
||||
|
@ -123,11 +123,11 @@ class BitsandbytesPrecision(Precision):
|
|||
return module
|
||||
|
||||
@override
|
||||
def tensor_init_context(self) -> ContextManager:
|
||||
def tensor_init_context(self) -> AbstractContextManager:
|
||||
return _DtypeContextManager(self.dtype)
|
||||
|
||||
@override
|
||||
def module_init_context(self) -> ContextManager:
|
||||
def module_init_context(self) -> AbstractContextManager:
|
||||
if self.ignore_modules:
|
||||
# cannot patch the Linear class if the user wants to skip some submodules
|
||||
raise RuntimeError(
|
||||
|
@ -145,7 +145,7 @@ class BitsandbytesPrecision(Precision):
|
|||
return stack
|
||||
|
||||
@override
|
||||
def forward_context(self) -> ContextManager:
|
||||
def forward_context(self) -> AbstractContextManager:
|
||||
return _DtypeContextManager(self.dtype)
|
||||
|
||||
@override
|
||||
|
|
|
@ -11,8 +11,8 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from contextlib import nullcontext
|
||||
from typing import TYPE_CHECKING, Any, ContextManager, Literal
|
||||
from contextlib import AbstractContextManager, nullcontext
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
import torch
|
||||
from lightning_utilities.core.apply_func import apply_to_collection
|
||||
|
@ -68,13 +68,13 @@ class DeepSpeedPrecision(Precision):
|
|||
return module
|
||||
|
||||
@override
|
||||
def tensor_init_context(self) -> ContextManager:
|
||||
def tensor_init_context(self) -> AbstractContextManager:
|
||||
if "true" not in self.precision:
|
||||
return nullcontext()
|
||||
return _DtypeContextManager(self._desired_dtype)
|
||||
|
||||
@override
|
||||
def module_init_context(self) -> ContextManager:
|
||||
def module_init_context(self) -> AbstractContextManager:
|
||||
return self.tensor_init_context()
|
||||
|
||||
@override
|
||||
|
|
|
@ -11,7 +11,8 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, ContextManager, Literal
|
||||
from contextlib import AbstractContextManager
|
||||
from typing import Any, Literal
|
||||
|
||||
import torch
|
||||
from lightning_utilities.core.apply_func import apply_to_collection
|
||||
|
@ -33,15 +34,15 @@ class DoublePrecision(Precision):
|
|||
return module.double()
|
||||
|
||||
@override
|
||||
def tensor_init_context(self) -> ContextManager:
|
||||
def tensor_init_context(self) -> AbstractContextManager:
|
||||
return _DtypeContextManager(torch.double)
|
||||
|
||||
@override
|
||||
def module_init_context(self) -> ContextManager:
|
||||
def module_init_context(self) -> AbstractContextManager:
|
||||
return self.tensor_init_context()
|
||||
|
||||
@override
|
||||
def forward_context(self) -> ContextManager:
|
||||
def forward_context(self) -> AbstractContextManager:
|
||||
return self.tensor_init_context()
|
||||
|
||||
@override
|
||||
|
|
|
@ -11,7 +11,8 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING, Any, ContextManager, Literal, Optional
|
||||
from contextlib import AbstractContextManager
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional
|
||||
|
||||
import torch
|
||||
from lightning_utilities import apply_to_collection
|
||||
|
@ -100,15 +101,15 @@ class FSDPPrecision(Precision):
|
|||
)
|
||||
|
||||
@override
|
||||
def tensor_init_context(self) -> ContextManager:
|
||||
def tensor_init_context(self) -> AbstractContextManager:
|
||||
return _DtypeContextManager(self._desired_input_dtype)
|
||||
|
||||
@override
|
||||
def module_init_context(self) -> ContextManager:
|
||||
def module_init_context(self) -> AbstractContextManager:
|
||||
return _DtypeContextManager(self.mixed_precision_config.param_dtype or torch.float32)
|
||||
|
||||
@override
|
||||
def forward_context(self) -> ContextManager:
|
||||
def forward_context(self) -> AbstractContextManager:
|
||||
if "mixed" in self.precision:
|
||||
return torch.autocast("cuda", dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16))
|
||||
return self.tensor_init_context()
|
||||
|
|
|
@ -11,7 +11,8 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, ContextManager, Literal
|
||||
from contextlib import AbstractContextManager
|
||||
from typing import Any, Literal
|
||||
|
||||
import torch
|
||||
from lightning_utilities.core.apply_func import apply_to_collection
|
||||
|
@ -42,15 +43,15 @@ class HalfPrecision(Precision):
|
|||
return module.to(dtype=self._desired_input_dtype)
|
||||
|
||||
@override
|
||||
def tensor_init_context(self) -> ContextManager:
|
||||
def tensor_init_context(self) -> AbstractContextManager:
|
||||
return _DtypeContextManager(self._desired_input_dtype)
|
||||
|
||||
@override
|
||||
def module_init_context(self) -> ContextManager:
|
||||
def module_init_context(self) -> AbstractContextManager:
|
||||
return self.tensor_init_context()
|
||||
|
||||
@override
|
||||
def forward_context(self) -> ContextManager:
|
||||
def forward_context(self) -> AbstractContextManager:
|
||||
return self.tensor_init_context()
|
||||
|
||||
@override
|
||||
|
|
|
@ -11,8 +11,8 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from contextlib import nullcontext
|
||||
from typing import Any, ContextManager, Literal, Optional, Union
|
||||
from contextlib import AbstractContextManager, nullcontext
|
||||
from typing import Any, Literal, Optional, Union
|
||||
|
||||
from torch import Tensor
|
||||
from torch.nn import Module
|
||||
|
@ -53,11 +53,11 @@ class Precision:
|
|||
"""
|
||||
return module
|
||||
|
||||
def tensor_init_context(self) -> ContextManager:
|
||||
def tensor_init_context(self) -> AbstractContextManager:
|
||||
"""Controls how tensors get created (device, dtype)."""
|
||||
return nullcontext()
|
||||
|
||||
def module_init_context(self) -> ContextManager:
|
||||
def module_init_context(self) -> AbstractContextManager:
|
||||
"""Instantiate module parameters or tensors in the precision type this plugin handles.
|
||||
|
||||
This is optional and depends on the precision limitations during optimization.
|
||||
|
@ -65,7 +65,7 @@ class Precision:
|
|||
"""
|
||||
return nullcontext()
|
||||
|
||||
def forward_context(self) -> ContextManager:
|
||||
def forward_context(self) -> AbstractContextManager:
|
||||
"""A contextmanager for managing model forward/training_step/evaluation_step/predict_step."""
|
||||
return nullcontext()
|
||||
|
||||
|
|
|
@ -13,8 +13,8 @@
|
|||
# limitations under the License.
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from contextlib import ExitStack
|
||||
from typing import TYPE_CHECKING, Any, ContextManager, Literal, Optional, Union
|
||||
from contextlib import AbstractContextManager, ExitStack
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
from lightning_utilities import apply_to_collection
|
||||
|
@ -107,11 +107,11 @@ class TransformerEnginePrecision(Precision):
|
|||
return module
|
||||
|
||||
@override
|
||||
def tensor_init_context(self) -> ContextManager:
|
||||
def tensor_init_context(self) -> AbstractContextManager:
|
||||
return _DtypeContextManager(self.weights_dtype)
|
||||
|
||||
@override
|
||||
def module_init_context(self) -> ContextManager:
|
||||
def module_init_context(self) -> AbstractContextManager:
|
||||
dtype_ctx = self.tensor_init_context()
|
||||
stack = ExitStack()
|
||||
if self.replace_layers:
|
||||
|
@ -126,7 +126,7 @@ class TransformerEnginePrecision(Precision):
|
|||
return stack
|
||||
|
||||
@override
|
||||
def forward_context(self) -> ContextManager:
|
||||
def forward_context(self) -> AbstractContextManager:
|
||||
dtype_ctx = _DtypeContextManager(self.weights_dtype)
|
||||
fallback_autocast_ctx = torch.autocast(device_type="cuda", dtype=self.fallback_compute_dtype)
|
||||
import transformer_engine.pytorch as te
|
||||
|
|
|
@ -11,9 +11,9 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from contextlib import nullcontext
|
||||
from contextlib import AbstractContextManager, nullcontext
|
||||
from datetime import timedelta
|
||||
from typing import Any, ContextManager, Literal, Optional, Union
|
||||
from typing import Any, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
@ -231,7 +231,7 @@ class DDPStrategy(ParallelStrategy):
|
|||
|
||||
class _DDPBackwardSyncControl(_BackwardSyncControl):
|
||||
@override
|
||||
def no_backward_sync(self, module: Module, enabled: bool) -> ContextManager:
|
||||
def no_backward_sync(self, module: Module, enabled: bool) -> AbstractContextManager:
|
||||
"""Blocks gradient synchronization inside the :class:`~torch.nn.parallel.distributed.DistributedDataParallel`
|
||||
wrapper."""
|
||||
if not enabled:
|
||||
|
|
|
@ -17,10 +17,10 @@ import logging
|
|||
import os
|
||||
import platform
|
||||
from collections.abc import Mapping
|
||||
from contextlib import ExitStack
|
||||
from contextlib import AbstractContextManager, ExitStack
|
||||
from itertools import chain
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from lightning_utilities.core.imports import RequirementCache
|
||||
|
@ -353,7 +353,7 @@ class DeepSpeedStrategy(DDPStrategy, _Sharded):
|
|||
raise NotImplementedError(self._err_msg_joint_setup_required())
|
||||
|
||||
@override
|
||||
def module_init_context(self, empty_init: Optional[bool] = None) -> ContextManager:
|
||||
def module_init_context(self, empty_init: Optional[bool] = None) -> AbstractContextManager:
|
||||
if self.zero_stage_3 and empty_init is False:
|
||||
raise NotImplementedError(
|
||||
f"`{empty_init=}` is not a valid choice with `DeepSpeedStrategy` when ZeRO stage 3 is enabled."
|
||||
|
@ -366,7 +366,7 @@ class DeepSpeedStrategy(DDPStrategy, _Sharded):
|
|||
return stack
|
||||
|
||||
@override
|
||||
def module_sharded_context(self) -> ContextManager:
|
||||
def module_sharded_context(self) -> AbstractContextManager:
|
||||
# Current limitation in Fabric: The config needs to be fully determined at the time of calling the context
|
||||
# manager. Later modifications through e.g. `Fabric.setup()` won't have an effect here.
|
||||
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
import shutil
|
||||
import warnings
|
||||
from collections.abc import Generator
|
||||
from contextlib import ExitStack, nullcontext
|
||||
from contextlib import AbstractContextManager, ExitStack, nullcontext
|
||||
from datetime import timedelta
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
|
@ -22,7 +22,6 @@ from typing import (
|
|||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
ContextManager,
|
||||
Literal,
|
||||
Optional,
|
||||
Union,
|
||||
|
@ -335,7 +334,7 @@ class FSDPStrategy(ParallelStrategy, _Sharded):
|
|||
pass
|
||||
|
||||
@override
|
||||
def module_init_context(self, empty_init: Optional[bool] = None) -> ContextManager:
|
||||
def module_init_context(self, empty_init: Optional[bool] = None) -> AbstractContextManager:
|
||||
precision_init_ctx = self.precision.module_init_context()
|
||||
module_sharded_ctx = self.module_sharded_context()
|
||||
stack = ExitStack()
|
||||
|
@ -349,7 +348,7 @@ class FSDPStrategy(ParallelStrategy, _Sharded):
|
|||
return stack
|
||||
|
||||
@override
|
||||
def module_sharded_context(self) -> ContextManager:
|
||||
def module_sharded_context(self) -> AbstractContextManager:
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel
|
||||
from torch.distributed.fsdp.wrap import enable_wrap
|
||||
|
||||
|
@ -740,7 +739,7 @@ def _setup_activation_checkpointing(module: Module, activation_checkpointing_kwa
|
|||
|
||||
class _FSDPBackwardSyncControl(_BackwardSyncControl):
|
||||
@override
|
||||
def no_backward_sync(self, module: Module, enabled: bool) -> ContextManager:
|
||||
def no_backward_sync(self, module: Module, enabled: bool) -> AbstractContextManager:
|
||||
"""Blocks gradient synchronization inside the :class:`~torch.distributed.fsdp.FullyShardedDataParallel`
|
||||
wrapper."""
|
||||
if not enabled:
|
||||
|
|
|
@ -14,10 +14,10 @@
|
|||
import itertools
|
||||
import shutil
|
||||
from collections.abc import Generator
|
||||
from contextlib import ExitStack
|
||||
from contextlib import AbstractContextManager, ExitStack
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Literal, Optional, TypeVar, Union
|
||||
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, TypeVar, Union
|
||||
|
||||
import torch
|
||||
from lightning_utilities.core.rank_zero import rank_zero_only as utils_rank_zero_only
|
||||
|
@ -195,7 +195,7 @@ class ModelParallelStrategy(ParallelStrategy):
|
|||
pass
|
||||
|
||||
@override
|
||||
def module_init_context(self, empty_init: Optional[bool] = None) -> ContextManager:
|
||||
def module_init_context(self, empty_init: Optional[bool] = None) -> AbstractContextManager:
|
||||
precision_init_ctx = self.precision.module_init_context()
|
||||
stack = ExitStack()
|
||||
if empty_init:
|
||||
|
@ -319,12 +319,12 @@ class ModelParallelStrategy(ParallelStrategy):
|
|||
|
||||
class _ParallelBackwardSyncControl(_BackwardSyncControl):
|
||||
@override
|
||||
def no_backward_sync(self, module: Module, enabled: bool) -> ContextManager:
|
||||
def no_backward_sync(self, module: Module, enabled: bool) -> AbstractContextManager:
|
||||
"""Blocks gradient synchronization inside the FSDP2 modules."""
|
||||
return _FSDPNoSync(module=module, enabled=enabled)
|
||||
|
||||
|
||||
class _FSDPNoSync(ContextManager):
|
||||
class _FSDPNoSync(AbstractContextManager):
|
||||
def __init__(self, module: Module, enabled: bool) -> None:
|
||||
self._module = module
|
||||
self._enabled = enabled
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Iterable
|
||||
from contextlib import ExitStack
|
||||
from typing import Any, Callable, ContextManager, Optional, TypeVar, Union
|
||||
from contextlib import AbstractContextManager, ExitStack
|
||||
from typing import Any, Callable, Optional, TypeVar, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
@ -118,7 +118,7 @@ class Strategy(ABC):
|
|||
"""
|
||||
return dataloader
|
||||
|
||||
def tensor_init_context(self) -> ContextManager:
|
||||
def tensor_init_context(self) -> AbstractContextManager:
|
||||
"""Controls how tensors get created (device, dtype)."""
|
||||
precision_init_ctx = self.precision.tensor_init_context()
|
||||
stack = ExitStack()
|
||||
|
@ -126,7 +126,7 @@ class Strategy(ABC):
|
|||
stack.enter_context(precision_init_ctx)
|
||||
return stack
|
||||
|
||||
def module_init_context(self, empty_init: Optional[bool] = None) -> ContextManager:
|
||||
def module_init_context(self, empty_init: Optional[bool] = None) -> AbstractContextManager:
|
||||
"""A context manager wrapping the model instantiation.
|
||||
|
||||
Here, the strategy can control how the parameters of the model get created (device, dtype) and or apply other
|
||||
|
@ -422,7 +422,7 @@ class _BackwardSyncControl(ABC):
|
|||
"""
|
||||
|
||||
@abstractmethod
|
||||
def no_backward_sync(self, module: Module, enabled: bool) -> ContextManager:
|
||||
def no_backward_sync(self, module: Module, enabled: bool) -> AbstractContextManager:
|
||||
"""Blocks the synchronization of gradients during the backward pass.
|
||||
|
||||
This is a context manager. It is only effective if it wraps a call to `.backward()`.
|
||||
|
@ -434,7 +434,7 @@ class _Sharded(ABC):
|
|||
"""Mixin-interface for any :class:`Strategy` that wants to expose functionality for sharding model parameters."""
|
||||
|
||||
@abstractmethod
|
||||
def module_sharded_context(self) -> ContextManager:
|
||||
def module_sharded_context(self) -> AbstractContextManager:
|
||||
"""A context manager that goes over the instantiation of an :class:`torch.nn.Module` and handles sharding of
|
||||
parameters on creation.
|
||||
|
||||
|
|
|
@ -12,10 +12,10 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import io
|
||||
from contextlib import ExitStack, nullcontext
|
||||
from contextlib import AbstractContextManager, ExitStack, nullcontext
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Literal, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
@ -225,7 +225,7 @@ class XLAFSDPStrategy(ParallelStrategy, _Sharded):
|
|||
def module_to_device(self, module: Module) -> None:
|
||||
pass
|
||||
|
||||
def module_init_context(self, empty_init: Optional[bool] = None) -> ContextManager:
|
||||
def module_init_context(self, empty_init: Optional[bool] = None) -> AbstractContextManager:
|
||||
precision_init_ctx = self.precision.module_init_context()
|
||||
module_sharded_ctx = self.module_sharded_context()
|
||||
stack = ExitStack()
|
||||
|
@ -235,7 +235,7 @@ class XLAFSDPStrategy(ParallelStrategy, _Sharded):
|
|||
return stack
|
||||
|
||||
@override
|
||||
def module_sharded_context(self) -> ContextManager:
|
||||
def module_sharded_context(self) -> AbstractContextManager:
|
||||
return nullcontext()
|
||||
|
||||
@override
|
||||
|
@ -668,7 +668,7 @@ def _activation_checkpointing_kwargs(policy: Optional[_POLICY_SET], kwargs: dict
|
|||
|
||||
class _XLAFSDPBackwardSyncControl(_BackwardSyncControl):
|
||||
@override
|
||||
def no_backward_sync(self, module: Module, enabled: bool) -> ContextManager:
|
||||
def no_backward_sync(self, module: Module, enabled: bool) -> AbstractContextManager:
|
||||
"""Blocks gradient synchronization inside the :class:`~torch_xla.distributed.fsdp.XlaFullyShardedDataParallel`
|
||||
wrapper."""
|
||||
if not enabled:
|
||||
|
|
|
@ -13,8 +13,8 @@
|
|||
# limitations under the License.
|
||||
import inspect
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Callable, ContextManager, Optional
|
||||
from contextlib import AbstractContextManager, contextmanager
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
@ -160,7 +160,7 @@ def _no_grad_context(loop_run: Callable) -> Callable:
|
|||
raise TypeError(f"`{type(self).__name__}` needs to be a Loop.")
|
||||
if not hasattr(self, "inference_mode"):
|
||||
raise TypeError(f"`{type(self).__name__}.inference_mode` needs to be defined")
|
||||
context_manager: type[ContextManager]
|
||||
context_manager: type[AbstractContextManager]
|
||||
if _distributed_is_initialized() and dist.get_backend() == "gloo":
|
||||
# gloo backend does not work properly.
|
||||
# https://github.com/Lightning-AI/lightning/pull/12715/files#r854569110
|
||||
|
|
|
@ -11,8 +11,8 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from contextlib import nullcontext
|
||||
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Optional, Union
|
||||
from contextlib import AbstractContextManager, nullcontext
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from lightning_utilities import apply_to_collection
|
||||
|
@ -80,13 +80,13 @@ class DeepSpeedPrecision(Precision):
|
|||
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self._desired_dtype)
|
||||
|
||||
@override
|
||||
def tensor_init_context(self) -> ContextManager:
|
||||
def tensor_init_context(self) -> AbstractContextManager:
|
||||
if "true" not in self.precision:
|
||||
return nullcontext()
|
||||
return _DtypeContextManager(self._desired_dtype)
|
||||
|
||||
@override
|
||||
def module_init_context(self) -> ContextManager:
|
||||
def module_init_context(self) -> AbstractContextManager:
|
||||
return self.tensor_init_context()
|
||||
|
||||
@override
|
||||
|
|
|
@ -12,8 +12,8 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, ContextManager, Literal
|
||||
from contextlib import AbstractContextManager, contextmanager
|
||||
from typing import Any, Literal
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -38,11 +38,11 @@ class DoublePrecision(Precision):
|
|||
return module.double()
|
||||
|
||||
@override
|
||||
def tensor_init_context(self) -> ContextManager:
|
||||
def tensor_init_context(self) -> AbstractContextManager:
|
||||
return _DtypeContextManager(torch.float64)
|
||||
|
||||
@override
|
||||
def module_init_context(self) -> ContextManager:
|
||||
def module_init_context(self) -> AbstractContextManager:
|
||||
return self.tensor_init_context()
|
||||
|
||||
@override
|
||||
|
|
|
@ -11,7 +11,8 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Optional
|
||||
from contextlib import AbstractContextManager
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
from lightning_utilities import apply_to_collection
|
||||
|
@ -109,15 +110,15 @@ class FSDPPrecision(Precision):
|
|||
)
|
||||
|
||||
@override
|
||||
def tensor_init_context(self) -> ContextManager:
|
||||
def tensor_init_context(self) -> AbstractContextManager:
|
||||
return _DtypeContextManager(self._desired_input_dtype)
|
||||
|
||||
@override
|
||||
def module_init_context(self) -> ContextManager:
|
||||
def module_init_context(self) -> AbstractContextManager:
|
||||
return _DtypeContextManager(self.mixed_precision_config.param_dtype or torch.float32)
|
||||
|
||||
@override
|
||||
def forward_context(self) -> ContextManager:
|
||||
def forward_context(self) -> AbstractContextManager:
|
||||
if "mixed" in self.precision:
|
||||
return torch.autocast("cuda", dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16))
|
||||
return _DtypeContextManager(self._desired_input_dtype)
|
||||
|
|
|
@ -12,8 +12,8 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, ContextManager, Literal
|
||||
from contextlib import AbstractContextManager, contextmanager
|
||||
from typing import Any, Literal
|
||||
|
||||
import torch
|
||||
from lightning_utilities import apply_to_collection
|
||||
|
@ -44,11 +44,11 @@ class HalfPrecision(Precision):
|
|||
return module.to(dtype=self._desired_input_dtype)
|
||||
|
||||
@override
|
||||
def tensor_init_context(self) -> ContextManager:
|
||||
def tensor_init_context(self) -> AbstractContextManager:
|
||||
return _DtypeContextManager(self._desired_input_dtype)
|
||||
|
||||
@override
|
||||
def module_init_context(self) -> ContextManager:
|
||||
def module_init_context(self) -> AbstractContextManager:
|
||||
return self.tensor_init_context()
|
||||
|
||||
@override
|
||||
|
|
|
@ -16,9 +16,10 @@
|
|||
import inspect
|
||||
import logging
|
||||
import os
|
||||
from contextlib import AbstractContextManager
|
||||
from functools import lru_cache, partial
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
@ -305,7 +306,7 @@ class PyTorchProfiler(Profiler):
|
|||
self.function_events: Optional[EventList] = None
|
||||
self._lightning_module: Optional[LightningModule] = None # set by ProfilerConnector
|
||||
self._register: Optional[RegisterRecordFunction] = None
|
||||
self._parent_profiler: Optional[ContextManager] = None
|
||||
self._parent_profiler: Optional[AbstractContextManager] = None
|
||||
self._recording_map: dict[str, record_function] = {}
|
||||
self._start_action_name: Optional[str] = None
|
||||
self._schedule: Optional[ScheduleWrapper] = None
|
||||
|
|
|
@ -13,8 +13,9 @@
|
|||
# limitations under the License.
|
||||
import logging
|
||||
import os
|
||||
from contextlib import AbstractContextManager
|
||||
from pathlib import Path
|
||||
from typing import ContextManager, Optional
|
||||
from typing import Optional
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
@ -382,5 +383,5 @@ def test_misconfiguration_error_with_sharded_model(tmp_path, strategy: str):
|
|||
trainer.fit(model)
|
||||
|
||||
|
||||
def _backward_patch(trainer: Trainer) -> ContextManager:
|
||||
def _backward_patch(trainer: Trainer) -> AbstractContextManager:
|
||||
return mock.patch.object(Strategy, "backward", wraps=trainer.strategy.backward)
|
||||
|
|
Loading…
Reference in New Issue