contextlib.AbstractContextManager

This commit is contained in:
Jirka B 2024-11-12 22:22:57 +01:00
parent b9920bdd78
commit a35f1fe34b
22 changed files with 96 additions and 91 deletions

View File

@ -14,13 +14,12 @@
import inspect
import os
from collections.abc import Generator, Mapping, Sequence
from contextlib import contextmanager, nullcontext
from contextlib import AbstractContextManager, contextmanager, nullcontext
from functools import partial
from pathlib import Path
from typing import (
Any,
Callable,
ContextManager,
Optional,
Union,
cast,
@ -484,7 +483,7 @@ class Fabric:
)
raise ValueError("You have to specify either `clip_val` or `max_norm` to do gradient clipping!")
def autocast(self) -> ContextManager:
def autocast(self) -> AbstractContextManager:
"""A context manager to automatically convert operations for the chosen precision.
Use this only if the `forward` method of your model does not cover all operations you wish to run with the
@ -634,7 +633,7 @@ class Fabric:
if rank == 0:
barrier()
def no_backward_sync(self, module: _FabricModule, enabled: bool = True) -> ContextManager:
def no_backward_sync(self, module: _FabricModule, enabled: bool = True) -> AbstractContextManager:
r"""Skip gradient synchronization during backward to avoid redundant communication overhead.
Use this context manager when performing gradient accumulation to speed up training with multiple devices.
@ -676,7 +675,7 @@ class Fabric:
forward_module, _ = _unwrap_compiled(module._forward_module)
return self._strategy._backward_sync_control.no_backward_sync(forward_module, enabled)
def sharded_model(self) -> ContextManager:
def sharded_model(self) -> AbstractContextManager:
r"""Instantiate a model under this context manager to prepare it for model-parallel sharding.
.. deprecated:: This context manager is deprecated in favor of :meth:`init_module`, use it instead.
@ -688,12 +687,12 @@ class Fabric:
return self.strategy.module_sharded_context()
return nullcontext()
def init_tensor(self) -> ContextManager:
def init_tensor(self) -> AbstractContextManager:
"""Tensors that you instantiate under this context manager will be created on the device right away and have
the right data type depending on the precision setting in Fabric."""
return self._strategy.tensor_init_context()
def init_module(self, empty_init: Optional[bool] = None) -> ContextManager:
def init_module(self, empty_init: Optional[bool] = None) -> AbstractContextManager:
"""Instantiate the model and its parameters under this context manager to reduce peak memory usage.
The parameters get created on the device and with the right data type right away without wasting memory being

View File

@ -11,7 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, ContextManager, Literal, Optional
from contextlib import AbstractContextManager
from typing import Any, Literal, Optional
import torch
from lightning_utilities.core.apply_func import apply_to_collection
@ -59,7 +60,7 @@ class MixedPrecision(Precision):
self._desired_input_dtype = torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16
@override
def forward_context(self) -> ContextManager:
def forward_context(self) -> AbstractContextManager:
return torch.autocast(self.device, dtype=self._desired_input_dtype)
@override

View File

@ -17,10 +17,10 @@ import math
import os
import warnings
from collections import OrderedDict
from contextlib import ExitStack
from contextlib import AbstractContextManager, ExitStack
from functools import partial
from types import ModuleType
from typing import Any, Callable, ContextManager, Literal, Optional, cast
from typing import Any, Callable, Literal, Optional, cast
import torch
from lightning_utilities import apply_to_collection
@ -123,11 +123,11 @@ class BitsandbytesPrecision(Precision):
return module
@override
def tensor_init_context(self) -> ContextManager:
def tensor_init_context(self) -> AbstractContextManager:
return _DtypeContextManager(self.dtype)
@override
def module_init_context(self) -> ContextManager:
def module_init_context(self) -> AbstractContextManager:
if self.ignore_modules:
# cannot patch the Linear class if the user wants to skip some submodules
raise RuntimeError(
@ -145,7 +145,7 @@ class BitsandbytesPrecision(Precision):
return stack
@override
def forward_context(self) -> ContextManager:
def forward_context(self) -> AbstractContextManager:
return _DtypeContextManager(self.dtype)
@override

View File

@ -11,8 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import nullcontext
from typing import TYPE_CHECKING, Any, ContextManager, Literal
from contextlib import AbstractContextManager, nullcontext
from typing import TYPE_CHECKING, Any, Literal
import torch
from lightning_utilities.core.apply_func import apply_to_collection
@ -68,13 +68,13 @@ class DeepSpeedPrecision(Precision):
return module
@override
def tensor_init_context(self) -> ContextManager:
def tensor_init_context(self) -> AbstractContextManager:
if "true" not in self.precision:
return nullcontext()
return _DtypeContextManager(self._desired_dtype)
@override
def module_init_context(self) -> ContextManager:
def module_init_context(self) -> AbstractContextManager:
return self.tensor_init_context()
@override

View File

@ -11,7 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, ContextManager, Literal
from contextlib import AbstractContextManager
from typing import Any, Literal
import torch
from lightning_utilities.core.apply_func import apply_to_collection
@ -33,15 +34,15 @@ class DoublePrecision(Precision):
return module.double()
@override
def tensor_init_context(self) -> ContextManager:
def tensor_init_context(self) -> AbstractContextManager:
return _DtypeContextManager(torch.double)
@override
def module_init_context(self) -> ContextManager:
def module_init_context(self) -> AbstractContextManager:
return self.tensor_init_context()
@override
def forward_context(self) -> ContextManager:
def forward_context(self) -> AbstractContextManager:
return self.tensor_init_context()
@override

View File

@ -11,7 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Any, ContextManager, Literal, Optional
from contextlib import AbstractContextManager
from typing import TYPE_CHECKING, Any, Literal, Optional
import torch
from lightning_utilities import apply_to_collection
@ -100,15 +101,15 @@ class FSDPPrecision(Precision):
)
@override
def tensor_init_context(self) -> ContextManager:
def tensor_init_context(self) -> AbstractContextManager:
return _DtypeContextManager(self._desired_input_dtype)
@override
def module_init_context(self) -> ContextManager:
def module_init_context(self) -> AbstractContextManager:
return _DtypeContextManager(self.mixed_precision_config.param_dtype or torch.float32)
@override
def forward_context(self) -> ContextManager:
def forward_context(self) -> AbstractContextManager:
if "mixed" in self.precision:
return torch.autocast("cuda", dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16))
return self.tensor_init_context()

View File

@ -11,7 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, ContextManager, Literal
from contextlib import AbstractContextManager
from typing import Any, Literal
import torch
from lightning_utilities.core.apply_func import apply_to_collection
@ -42,15 +43,15 @@ class HalfPrecision(Precision):
return module.to(dtype=self._desired_input_dtype)
@override
def tensor_init_context(self) -> ContextManager:
def tensor_init_context(self) -> AbstractContextManager:
return _DtypeContextManager(self._desired_input_dtype)
@override
def module_init_context(self) -> ContextManager:
def module_init_context(self) -> AbstractContextManager:
return self.tensor_init_context()
@override
def forward_context(self) -> ContextManager:
def forward_context(self) -> AbstractContextManager:
return self.tensor_init_context()
@override

View File

@ -11,8 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import nullcontext
from typing import Any, ContextManager, Literal, Optional, Union
from contextlib import AbstractContextManager, nullcontext
from typing import Any, Literal, Optional, Union
from torch import Tensor
from torch.nn import Module
@ -53,11 +53,11 @@ class Precision:
"""
return module
def tensor_init_context(self) -> ContextManager:
def tensor_init_context(self) -> AbstractContextManager:
"""Controls how tensors get created (device, dtype)."""
return nullcontext()
def module_init_context(self) -> ContextManager:
def module_init_context(self) -> AbstractContextManager:
"""Instantiate module parameters or tensors in the precision type this plugin handles.
This is optional and depends on the precision limitations during optimization.
@ -65,7 +65,7 @@ class Precision:
"""
return nullcontext()
def forward_context(self) -> ContextManager:
def forward_context(self) -> AbstractContextManager:
"""A contextmanager for managing model forward/training_step/evaluation_step/predict_step."""
return nullcontext()

View File

@ -13,8 +13,8 @@
# limitations under the License.
import logging
from collections.abc import Mapping
from contextlib import ExitStack
from typing import TYPE_CHECKING, Any, ContextManager, Literal, Optional, Union
from contextlib import AbstractContextManager, ExitStack
from typing import TYPE_CHECKING, Any, Literal, Optional, Union
import torch
from lightning_utilities import apply_to_collection
@ -107,11 +107,11 @@ class TransformerEnginePrecision(Precision):
return module
@override
def tensor_init_context(self) -> ContextManager:
def tensor_init_context(self) -> AbstractContextManager:
return _DtypeContextManager(self.weights_dtype)
@override
def module_init_context(self) -> ContextManager:
def module_init_context(self) -> AbstractContextManager:
dtype_ctx = self.tensor_init_context()
stack = ExitStack()
if self.replace_layers:
@ -126,7 +126,7 @@ class TransformerEnginePrecision(Precision):
return stack
@override
def forward_context(self) -> ContextManager:
def forward_context(self) -> AbstractContextManager:
dtype_ctx = _DtypeContextManager(self.weights_dtype)
fallback_autocast_ctx = torch.autocast(device_type="cuda", dtype=self.fallback_compute_dtype)
import transformer_engine.pytorch as te

View File

@ -11,9 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import nullcontext
from contextlib import AbstractContextManager, nullcontext
from datetime import timedelta
from typing import Any, ContextManager, Literal, Optional, Union
from typing import Any, Literal, Optional, Union
import torch
import torch.distributed
@ -231,7 +231,7 @@ class DDPStrategy(ParallelStrategy):
class _DDPBackwardSyncControl(_BackwardSyncControl):
@override
def no_backward_sync(self, module: Module, enabled: bool) -> ContextManager:
def no_backward_sync(self, module: Module, enabled: bool) -> AbstractContextManager:
"""Blocks gradient synchronization inside the :class:`~torch.nn.parallel.distributed.DistributedDataParallel`
wrapper."""
if not enabled:

View File

@ -17,10 +17,10 @@ import logging
import os
import platform
from collections.abc import Mapping
from contextlib import ExitStack
from contextlib import AbstractContextManager, ExitStack
from itertools import chain
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Optional, Union
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
import torch
from lightning_utilities.core.imports import RequirementCache
@ -353,7 +353,7 @@ class DeepSpeedStrategy(DDPStrategy, _Sharded):
raise NotImplementedError(self._err_msg_joint_setup_required())
@override
def module_init_context(self, empty_init: Optional[bool] = None) -> ContextManager:
def module_init_context(self, empty_init: Optional[bool] = None) -> AbstractContextManager:
if self.zero_stage_3 and empty_init is False:
raise NotImplementedError(
f"`{empty_init=}` is not a valid choice with `DeepSpeedStrategy` when ZeRO stage 3 is enabled."
@ -366,7 +366,7 @@ class DeepSpeedStrategy(DDPStrategy, _Sharded):
return stack
@override
def module_sharded_context(self) -> ContextManager:
def module_sharded_context(self) -> AbstractContextManager:
# Current limitation in Fabric: The config needs to be fully determined at the time of calling the context
# manager. Later modifications through e.g. `Fabric.setup()` won't have an effect here.

View File

@ -14,7 +14,7 @@
import shutil
import warnings
from collections.abc import Generator
from contextlib import ExitStack, nullcontext
from contextlib import AbstractContextManager, ExitStack, nullcontext
from datetime import timedelta
from functools import partial
from pathlib import Path
@ -22,7 +22,6 @@ from typing import (
TYPE_CHECKING,
Any,
Callable,
ContextManager,
Literal,
Optional,
Union,
@ -335,7 +334,7 @@ class FSDPStrategy(ParallelStrategy, _Sharded):
pass
@override
def module_init_context(self, empty_init: Optional[bool] = None) -> ContextManager:
def module_init_context(self, empty_init: Optional[bool] = None) -> AbstractContextManager:
precision_init_ctx = self.precision.module_init_context()
module_sharded_ctx = self.module_sharded_context()
stack = ExitStack()
@ -349,7 +348,7 @@ class FSDPStrategy(ParallelStrategy, _Sharded):
return stack
@override
def module_sharded_context(self) -> ContextManager:
def module_sharded_context(self) -> AbstractContextManager:
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel
from torch.distributed.fsdp.wrap import enable_wrap
@ -740,7 +739,7 @@ def _setup_activation_checkpointing(module: Module, activation_checkpointing_kwa
class _FSDPBackwardSyncControl(_BackwardSyncControl):
@override
def no_backward_sync(self, module: Module, enabled: bool) -> ContextManager:
def no_backward_sync(self, module: Module, enabled: bool) -> AbstractContextManager:
"""Blocks gradient synchronization inside the :class:`~torch.distributed.fsdp.FullyShardedDataParallel`
wrapper."""
if not enabled:

View File

@ -14,10 +14,10 @@
import itertools
import shutil
from collections.abc import Generator
from contextlib import ExitStack
from contextlib import AbstractContextManager, ExitStack
from datetime import timedelta
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Literal, Optional, TypeVar, Union
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, TypeVar, Union
import torch
from lightning_utilities.core.rank_zero import rank_zero_only as utils_rank_zero_only
@ -195,7 +195,7 @@ class ModelParallelStrategy(ParallelStrategy):
pass
@override
def module_init_context(self, empty_init: Optional[bool] = None) -> ContextManager:
def module_init_context(self, empty_init: Optional[bool] = None) -> AbstractContextManager:
precision_init_ctx = self.precision.module_init_context()
stack = ExitStack()
if empty_init:
@ -319,12 +319,12 @@ class ModelParallelStrategy(ParallelStrategy):
class _ParallelBackwardSyncControl(_BackwardSyncControl):
@override
def no_backward_sync(self, module: Module, enabled: bool) -> ContextManager:
def no_backward_sync(self, module: Module, enabled: bool) -> AbstractContextManager:
"""Blocks gradient synchronization inside the FSDP2 modules."""
return _FSDPNoSync(module=module, enabled=enabled)
class _FSDPNoSync(ContextManager):
class _FSDPNoSync(AbstractContextManager):
def __init__(self, module: Module, enabled: bool) -> None:
self._module = module
self._enabled = enabled

View File

@ -14,8 +14,8 @@
import logging
from abc import ABC, abstractmethod
from collections.abc import Iterable
from contextlib import ExitStack
from typing import Any, Callable, ContextManager, Optional, TypeVar, Union
from contextlib import AbstractContextManager, ExitStack
from typing import Any, Callable, Optional, TypeVar, Union
import torch
from torch import Tensor
@ -118,7 +118,7 @@ class Strategy(ABC):
"""
return dataloader
def tensor_init_context(self) -> ContextManager:
def tensor_init_context(self) -> AbstractContextManager:
"""Controls how tensors get created (device, dtype)."""
precision_init_ctx = self.precision.tensor_init_context()
stack = ExitStack()
@ -126,7 +126,7 @@ class Strategy(ABC):
stack.enter_context(precision_init_ctx)
return stack
def module_init_context(self, empty_init: Optional[bool] = None) -> ContextManager:
def module_init_context(self, empty_init: Optional[bool] = None) -> AbstractContextManager:
"""A context manager wrapping the model instantiation.
Here, the strategy can control how the parameters of the model get created (device, dtype) and or apply other
@ -422,7 +422,7 @@ class _BackwardSyncControl(ABC):
"""
@abstractmethod
def no_backward_sync(self, module: Module, enabled: bool) -> ContextManager:
def no_backward_sync(self, module: Module, enabled: bool) -> AbstractContextManager:
"""Blocks the synchronization of gradients during the backward pass.
This is a context manager. It is only effective if it wraps a call to `.backward()`.
@ -434,7 +434,7 @@ class _Sharded(ABC):
"""Mixin-interface for any :class:`Strategy` that wants to expose functionality for sharding model parameters."""
@abstractmethod
def module_sharded_context(self) -> ContextManager:
def module_sharded_context(self) -> AbstractContextManager:
"""A context manager that goes over the instantiation of an :class:`torch.nn.Module` and handles sharding of
parameters on creation.

View File

@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import io
from contextlib import ExitStack, nullcontext
from contextlib import AbstractContextManager, ExitStack, nullcontext
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Literal, Optional, Union
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union
import torch
from torch import Tensor
@ -225,7 +225,7 @@ class XLAFSDPStrategy(ParallelStrategy, _Sharded):
def module_to_device(self, module: Module) -> None:
pass
def module_init_context(self, empty_init: Optional[bool] = None) -> ContextManager:
def module_init_context(self, empty_init: Optional[bool] = None) -> AbstractContextManager:
precision_init_ctx = self.precision.module_init_context()
module_sharded_ctx = self.module_sharded_context()
stack = ExitStack()
@ -235,7 +235,7 @@ class XLAFSDPStrategy(ParallelStrategy, _Sharded):
return stack
@override
def module_sharded_context(self) -> ContextManager:
def module_sharded_context(self) -> AbstractContextManager:
return nullcontext()
@override
@ -668,7 +668,7 @@ def _activation_checkpointing_kwargs(policy: Optional[_POLICY_SET], kwargs: dict
class _XLAFSDPBackwardSyncControl(_BackwardSyncControl):
@override
def no_backward_sync(self, module: Module, enabled: bool) -> ContextManager:
def no_backward_sync(self, module: Module, enabled: bool) -> AbstractContextManager:
"""Blocks gradient synchronization inside the :class:`~torch_xla.distributed.fsdp.XlaFullyShardedDataParallel`
wrapper."""
if not enabled:

View File

@ -13,8 +13,8 @@
# limitations under the License.
import inspect
from collections.abc import Generator
from contextlib import contextmanager
from typing import Any, Callable, ContextManager, Optional
from contextlib import AbstractContextManager, contextmanager
from typing import Any, Callable, Optional
import torch
import torch.distributed as dist
@ -160,7 +160,7 @@ def _no_grad_context(loop_run: Callable) -> Callable:
raise TypeError(f"`{type(self).__name__}` needs to be a Loop.")
if not hasattr(self, "inference_mode"):
raise TypeError(f"`{type(self).__name__}.inference_mode` needs to be defined")
context_manager: type[ContextManager]
context_manager: type[AbstractContextManager]
if _distributed_is_initialized() and dist.get_backend() == "gloo":
# gloo backend does not work properly.
# https://github.com/Lightning-AI/lightning/pull/12715/files#r854569110

View File

@ -11,8 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import nullcontext
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Optional, Union
from contextlib import AbstractContextManager, nullcontext
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
import torch
from lightning_utilities import apply_to_collection
@ -80,13 +80,13 @@ class DeepSpeedPrecision(Precision):
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self._desired_dtype)
@override
def tensor_init_context(self) -> ContextManager:
def tensor_init_context(self) -> AbstractContextManager:
if "true" not in self.precision:
return nullcontext()
return _DtypeContextManager(self._desired_dtype)
@override
def module_init_context(self) -> ContextManager:
def module_init_context(self) -> AbstractContextManager:
return self.tensor_init_context()
@override

View File

@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Generator
from contextlib import contextmanager
from typing import Any, ContextManager, Literal
from contextlib import AbstractContextManager, contextmanager
from typing import Any, Literal
import torch
import torch.nn as nn
@ -38,11 +38,11 @@ class DoublePrecision(Precision):
return module.double()
@override
def tensor_init_context(self) -> ContextManager:
def tensor_init_context(self) -> AbstractContextManager:
return _DtypeContextManager(torch.float64)
@override
def module_init_context(self) -> ContextManager:
def module_init_context(self) -> AbstractContextManager:
return self.tensor_init_context()
@override

View File

@ -11,7 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Optional
from contextlib import AbstractContextManager
from typing import TYPE_CHECKING, Any, Callable, Optional
import torch
from lightning_utilities import apply_to_collection
@ -109,15 +110,15 @@ class FSDPPrecision(Precision):
)
@override
def tensor_init_context(self) -> ContextManager:
def tensor_init_context(self) -> AbstractContextManager:
return _DtypeContextManager(self._desired_input_dtype)
@override
def module_init_context(self) -> ContextManager:
def module_init_context(self) -> AbstractContextManager:
return _DtypeContextManager(self.mixed_precision_config.param_dtype or torch.float32)
@override
def forward_context(self) -> ContextManager:
def forward_context(self) -> AbstractContextManager:
if "mixed" in self.precision:
return torch.autocast("cuda", dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16))
return _DtypeContextManager(self._desired_input_dtype)

View File

@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Generator
from contextlib import contextmanager
from typing import Any, ContextManager, Literal
from contextlib import AbstractContextManager, contextmanager
from typing import Any, Literal
import torch
from lightning_utilities import apply_to_collection
@ -44,11 +44,11 @@ class HalfPrecision(Precision):
return module.to(dtype=self._desired_input_dtype)
@override
def tensor_init_context(self) -> ContextManager:
def tensor_init_context(self) -> AbstractContextManager:
return _DtypeContextManager(self._desired_input_dtype)
@override
def module_init_context(self) -> ContextManager:
def module_init_context(self) -> AbstractContextManager:
return self.tensor_init_context()
@override

View File

@ -16,9 +16,10 @@
import inspect
import logging
import os
from contextlib import AbstractContextManager
from functools import lru_cache, partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Optional, Union
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
import torch
from torch import Tensor, nn
@ -305,7 +306,7 @@ class PyTorchProfiler(Profiler):
self.function_events: Optional[EventList] = None
self._lightning_module: Optional[LightningModule] = None # set by ProfilerConnector
self._register: Optional[RegisterRecordFunction] = None
self._parent_profiler: Optional[ContextManager] = None
self._parent_profiler: Optional[AbstractContextManager] = None
self._recording_map: dict[str, record_function] = {}
self._start_action_name: Optional[str] = None
self._schedule: Optional[ScheduleWrapper] = None

View File

@ -13,8 +13,9 @@
# limitations under the License.
import logging
import os
from contextlib import AbstractContextManager
from pathlib import Path
from typing import ContextManager, Optional
from typing import Optional
from unittest import mock
import pytest
@ -382,5 +383,5 @@ def test_misconfiguration_error_with_sharded_model(tmp_path, strategy: str):
trainer.fit(model)
def _backward_patch(trainer: Trainer) -> ContextManager:
def _backward_patch(trainer: Trainer) -> AbstractContextManager:
return mock.patch.object(Strategy, "backward", wraps=trainer.strategy.backward)