docs: 1/3 enable Sphinx nitpicky [fabric] (#18069)
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
795f2909b5
commit
547e7aa393
|
@ -35,6 +35,7 @@ env:
|
|||
|
||||
jobs:
|
||||
docs-checks:
|
||||
if: github.event.pull_request.draft == false
|
||||
runs-on: ubuntu-20.04
|
||||
strategy:
|
||||
fail-fast: false
|
||||
|
|
|
@ -142,7 +142,7 @@ We can specify a list of layer classes in the **wrapping policy** to inform FSDP
|
|||
# 3. Pass it to the FSDPStrategy object
|
||||
strategy = FSDPStrategy(auto_wrap_policy=policy)
|
||||
|
||||
PyTorch provides several of these functional policies under :mod:`torch.distributed.fsdp.wrap`.
|
||||
PyTorch provides several of these functional policies under ``torch.distributed.fsdp.wrap``.
|
||||
|
||||
|
|
||||
|
||||
|
|
|
@ -149,7 +149,7 @@ plugins
|
|||
|
||||
Plugins allow you to connect arbitrary backends, precision libraries, clusters, etc. For example:
|
||||
To define your own behavior, subclass the relevant class and pass it in. Here's an example linking up your own
|
||||
:class:`~lightning.fabric.plugins.environments.ClusterEnvironment`.
|
||||
:class:`~lightning.fabric.plugins.environments.cluster_environment.ClusterEnvironment`.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
|
|
@ -0,0 +1,11 @@
|
|||
:orphan:
|
||||
|
||||
.. include:: ../links.rst
|
||||
|
||||
##########################
|
||||
lightning.fabric.utilities
|
||||
##########################
|
||||
|
||||
.. autofunction:: lightning.fabric.utilities.seed.seed_everything
|
||||
|
||||
.. autofunction:: lightning.fabric.utilities.seed.pl_worker_init_function
|
|
@ -252,9 +252,29 @@ epub_exclude_files = ["search.html"]
|
|||
# Example configuration for intersphinx: refer to the Python standard library.
|
||||
intersphinx_mapping = {
|
||||
"python": ("https://docs.python.org/3", None),
|
||||
"typing_extensions": ("https://typing-extensions.readthedocs.io/en/stable/", None),
|
||||
"torch": ("https://pytorch.org/docs/stable/", None),
|
||||
"pytorch_lightning": ("https://lightning.ai/docs/pytorch/stable/", None),
|
||||
"tensorboardX": ("https://tensorboardx.readthedocs.io/en/stable/", None),
|
||||
}
|
||||
nitpicky = True
|
||||
|
||||
nitpick_ignore = [
|
||||
("py:class", "typing.Self"),
|
||||
# these are not generated with docs API ref
|
||||
("py:class", "lightning.fabric.utilities.types.Optimizable"),
|
||||
("py:class", "lightning.fabric.utilities.types.Steppable"),
|
||||
# Nitpick does not see protected or private API
|
||||
("py:class", "lightning.fabric.wrappers._FabricModule"),
|
||||
("py:class", "lightning.fabric.wrappers._FabricOptimizer"),
|
||||
("py:class", "lightning.fabric.loggers.csv_logs._ExperimentWriter"),
|
||||
("py:class", "lightning.fabric.strategies.strategy._Sharded"),
|
||||
# Nitpick does not see abstract API
|
||||
("py:meth", "lightning.fabric.plugins.collectives.Collective.init_group"),
|
||||
# These seem to be missing in reference generated API
|
||||
("py:class", "torch.distributed.fsdp.wrap.ModuleWrapPolicy"),
|
||||
("py:class", "torch.distributed.fsdp.sharded_grad_scaler.ShardedGradScaler"),
|
||||
]
|
||||
|
||||
# -- Options for todo extension ----------------------------------------------
|
||||
|
||||
|
|
|
@ -35,7 +35,7 @@ A page opens in your browser where you can follow the instructions to complete t
|
|||
Launch multi-node training in the cloud
|
||||
***************************************
|
||||
|
||||
**Step 1:** Put your code inside a :class:`~lightning_app.core.work.LightningWork`:
|
||||
**Step 1:** Put your code inside a ``lightning.app.core.work.LightningWork``:
|
||||
|
||||
.. code-block:: python
|
||||
:emphasize-lines: 5
|
||||
|
@ -58,7 +58,7 @@ Launch multi-node training in the cloud
|
|||
model, optimizer = fabric.setup(model, optimizer)
|
||||
...
|
||||
|
||||
**Step 2:** Init a :class:`~lightning_app.core.app.LightningApp` with the ``FabricMultiNode`` component.
|
||||
**Step 2:** Init a ``lightning.app.core.app.LightningApp`` with the ``FabricMultiNode`` component.
|
||||
Configure the number of nodes, the number of GPUs per node, and the type of GPU:
|
||||
|
||||
.. code-block:: python
|
||||
|
|
|
@ -160,7 +160,7 @@ We can specify a list of layer classes in the **wrapping policy** to inform FSDP
|
|||
# 3. Pass it to the FSDPStrategy object
|
||||
strategy = FSDPStrategy(auto_wrap_policy=policy)
|
||||
|
||||
PyTorch provides several of these functional policies under :mod:`torch.distributed.fsdp.wrap`.
|
||||
PyTorch provides several of these functional policies under ``torch.distributed.fsdp.wrap``.
|
||||
|
||||
|
|
||||
|
||||
|
|
|
@ -41,7 +41,7 @@ Built-in Checkpoint IO Plugins
|
|||
- CheckpointIO that utilizes :func:`torch.save` and :func:`torch.load` to save and load checkpoints
|
||||
respectively, common for most use cases.
|
||||
* - :class:`~lightning.pytorch.plugins.io.XLACheckpointIO`
|
||||
- CheckpointIO that utilizes :func:`xm.save` to save checkpoints for TPU training strategies.
|
||||
- CheckpointIO that utilizes ``xm.save`` to save checkpoints for TPU training strategies.
|
||||
* - :class:`~lightning.pytorch.plugins.io.AsyncCheckpointIO`
|
||||
- ``AsyncCheckpointIO`` enables saving the checkpoints asynchronously in a thread.
|
||||
|
||||
|
|
|
@ -72,7 +72,7 @@ def _do_nothing(*_: Any) -> None:
|
|||
|
||||
|
||||
class Fabric:
|
||||
"""Fabric accelerates your PyTorch training or inference code with minimal changes required.
|
||||
r"""Fabric accelerates your PyTorch training or inference code with minimal changes required.
|
||||
|
||||
- Automatic placement of models and data onto the device.
|
||||
- Automatic support for mixed and double precision (smaller memory footprint).
|
||||
|
@ -200,7 +200,7 @@ class Fabric:
|
|||
*optimizers: Optimizer,
|
||||
move_to_device: bool = True,
|
||||
) -> Any: # no specific return because the way we want our API to look does not play well with mypy
|
||||
"""Set up a model and its optimizers for accelerated training.
|
||||
r"""Set up a model and its optimizers for accelerated training.
|
||||
|
||||
Args:
|
||||
module: A :class:`torch.nn.Module` to set up
|
||||
|
@ -255,7 +255,7 @@ class Fabric:
|
|||
return module
|
||||
|
||||
def setup_module(self, module: nn.Module, move_to_device: bool = True) -> _FabricModule:
|
||||
"""Set up a model for accelerated training or inference.
|
||||
r"""Set up a model for accelerated training or inference.
|
||||
|
||||
This is the same as calling ``.setup(model)`` with no optimizers. It is useful for inference or for certain
|
||||
strategies like `FSDP` that require setting up the module before the optimizer can be created and set up.
|
||||
|
@ -295,7 +295,7 @@ class Fabric:
|
|||
return module
|
||||
|
||||
def setup_optimizers(self, *optimizers: Optimizer) -> Union[_FabricOptimizer, Tuple[_FabricOptimizer, ...]]:
|
||||
"""Set up one or more optimizers for accelerated training.
|
||||
r"""Set up one or more optimizers for accelerated training.
|
||||
|
||||
Some strategies do not allow setting up model and optimizer independently. For them, you should call
|
||||
``.setup(model, optimizer, ...)`` instead to jointly set them up.
|
||||
|
@ -318,7 +318,7 @@ class Fabric:
|
|||
def setup_dataloaders(
|
||||
self, *dataloaders: DataLoader, use_distributed_sampler: bool = True, move_to_device: bool = True
|
||||
) -> Union[DataLoader, List[DataLoader]]:
|
||||
"""Set up one or multiple dataloaders for accelerated training. If you need different settings for each
|
||||
r"""Set up one or multiple dataloaders for accelerated training. If you need different settings for each
|
||||
dataloader, call this method individually for each one.
|
||||
|
||||
Args:
|
||||
|
@ -347,7 +347,7 @@ class Fabric:
|
|||
def _setup_dataloader(
|
||||
self, dataloader: DataLoader, use_distributed_sampler: bool = True, move_to_device: bool = True
|
||||
) -> DataLoader:
|
||||
"""Set up a single dataloader for accelerated training.
|
||||
r"""Set up a single dataloader for accelerated training.
|
||||
|
||||
Args:
|
||||
dataloader: The dataloader to accelerate.
|
||||
|
@ -378,7 +378,7 @@ class Fabric:
|
|||
return fabric_dataloader
|
||||
|
||||
def backward(self, tensor: Tensor, *args: Any, model: Optional[_FabricModule] = None, **kwargs: Any) -> None:
|
||||
"""Replaces ``loss.backward()`` in your training loop. Handles precision and automatically for you.
|
||||
r"""Replaces ``loss.backward()`` in your training loop. Handles precision and automatically for you.
|
||||
|
||||
Args:
|
||||
tensor: The tensor (loss) to back-propagate gradients from.
|
||||
|
@ -471,8 +471,8 @@ class Fabric:
|
|||
...
|
||||
|
||||
def to_device(self, obj: Union[nn.Module, Tensor, Any]) -> Union[nn.Module, Tensor, Any]:
|
||||
"""Move a :class:`torch.nn.Module` or a collection of tensors to the current device, if it is not already on
|
||||
that device.
|
||||
r"""Move a :class:`torch.nn.Module` or a collection of tensors to the current device, if it is not already
|
||||
on that device.
|
||||
|
||||
Args:
|
||||
obj: An object to move to the device. Can be an instance of :class:`torch.nn.Module`, a tensor, or a
|
||||
|
@ -489,7 +489,7 @@ class Fabric:
|
|||
return move_data_to_device(obj, device=self.device)
|
||||
|
||||
def print(self, *args: Any, **kwargs: Any) -> None:
|
||||
"""Print something only on the first process. If running on multiple machines, it will print from the first
|
||||
r"""Print something only on the first process. If running on multiple machines, it will print from the first
|
||||
process in each machine.
|
||||
|
||||
Arguments passed to this method are forwarded to the Python built-in :func:`print` function.
|
||||
|
@ -510,7 +510,7 @@ class Fabric:
|
|||
self._strategy.barrier(name=name)
|
||||
|
||||
def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
|
||||
"""Send a tensor from one process to all others.
|
||||
r"""Send a tensor from one process to all others.
|
||||
|
||||
This method needs to be called on all processes. Failing to do so will cause your program to stall forever.
|
||||
|
||||
|
@ -537,7 +537,7 @@ class Fabric:
|
|||
Args:
|
||||
data: int, float, tensor of shape (batch, ...), or a (possibly nested) collection thereof.
|
||||
group: the process group to gather results from. Defaults to all processes (world).
|
||||
sync_grads: flag that allows users to synchronize gradients for the all_gather operation
|
||||
sync_grads: flag that allows users to synchronize gradients for the ``all_gather`` operation
|
||||
|
||||
Return:
|
||||
A tensor of shape (world_size, batch, ...), or if the input was a collection
|
||||
|
@ -580,7 +580,7 @@ class Fabric:
|
|||
|
||||
@contextmanager
|
||||
def rank_zero_first(self, local: bool = False) -> Generator:
|
||||
"""The code block under this context manager gets executed first on the main process (rank 0) and only when
|
||||
r"""The code block under this context manager gets executed first on the main process (rank 0) and only when
|
||||
completed, the other processes get to run the code in parallel.
|
||||
|
||||
Args:
|
||||
|
@ -603,7 +603,7 @@ class Fabric:
|
|||
|
||||
@contextmanager
|
||||
def no_backward_sync(self, module: _FabricModule, enabled: bool = True) -> Generator:
|
||||
"""Skip gradient synchronization during backward to avoid redundant communication overhead.
|
||||
r"""Skip gradient synchronization during backward to avoid redundant communication overhead.
|
||||
|
||||
Use this context manager when performing gradient accumulation to speed up training with multiple devices.
|
||||
|
||||
|
@ -617,7 +617,7 @@ class Fabric:
|
|||
...
|
||||
|
||||
For those strategies that don't support it, a warning is emitted. For single-device strategies, it is a no-op.
|
||||
Both the model's `.forward()` and the `fabric.backward()` call need to run under this context.
|
||||
Both the model's ``.forward()`` and the ``fabric.backward()`` call need to run under this context.
|
||||
|
||||
Args:
|
||||
module: The module for which to control the gradient synchronization.
|
||||
|
@ -650,7 +650,7 @@ class Fabric:
|
|||
|
||||
@contextmanager
|
||||
def sharded_model(self) -> Generator:
|
||||
"""Instantiate a model under this context manager to prepare it for model-parallel sharding.
|
||||
r"""Instantiate a model under this context manager to prepare it for model-parallel sharding.
|
||||
|
||||
.. deprecated:: This context manager is deprecated in favor of :meth:`init_module`, use it instead.
|
||||
|
||||
|
@ -691,7 +691,7 @@ class Fabric:
|
|||
Args:
|
||||
empty_init: Whether to initialize the model with empty weights (uninitialized memory).
|
||||
If ``None``, the strategy will decide. Some strategies may not support all options.
|
||||
Set this to ``True`` if you are loading a checkpoint into a large model. Requires `torch >= 1.13`.
|
||||
Set this to ``True`` if you are loading a checkpoint into a large model. Requires ``torch >= 1.13``.
|
||||
|
||||
"""
|
||||
if not _TORCH_GREATER_EQUAL_2_0 and self.device.type != "cpu":
|
||||
|
@ -710,7 +710,7 @@ class Fabric:
|
|||
state: Dict[str, Union[nn.Module, Optimizer, Any]],
|
||||
filter: Optional[Dict[str, Callable[[str, Any], bool]]] = None,
|
||||
) -> None:
|
||||
"""Save checkpoint contents to a file.
|
||||
r"""Save checkpoint contents to a file.
|
||||
|
||||
How and which processes save gets determined by the `strategy`. For example, the `ddp` strategy
|
||||
saves checkpoints only on process 0, while the `fsdp` strategy saves files from every rank.
|
||||
|
@ -807,7 +807,7 @@ class Fabric:
|
|||
the code (programmatically). If you are launching with the Lightning CLI, ``lightning run model ...``, remove
|
||||
``launch()`` from your code.
|
||||
|
||||
``launch()`` is a no-op when called multiple times and no function is passed in.
|
||||
The ``launch()`` is a no-op when called multiple times and no function is passed in.
|
||||
|
||||
"""
|
||||
if _is_using_cli():
|
||||
|
@ -834,7 +834,7 @@ class Fabric:
|
|||
return self._wrap_and_launch(function, self, *args, **kwargs)
|
||||
|
||||
def call(self, hook_name: str, *args: Any, **kwargs: Any) -> None:
|
||||
"""Trigger the callback methods with the given name and arguments.
|
||||
r"""Trigger the callback methods with the given name and arguments.
|
||||
|
||||
Not all objects registered via ``Fabric(callbacks=...)`` must implement a method with the given name. The ones
|
||||
that have a matching method name will get called.
|
||||
|
@ -901,9 +901,9 @@ class Fabric:
|
|||
|
||||
@staticmethod
|
||||
def seed_everything(seed: Optional[int] = None, workers: Optional[bool] = None) -> int:
|
||||
"""Helper function to seed everything without explicitly importing Lightning.
|
||||
r"""Helper function to seed everything without explicitly importing Lightning.
|
||||
|
||||
See :func:`lightning.fabric.utilities.seed.seed_everything` for more details.
|
||||
See :func:`~lightning.fabric.utilities.seed.seed_everything` for more details.
|
||||
|
||||
"""
|
||||
if workers is None:
|
||||
|
|
|
@ -25,7 +25,7 @@ from lightning.fabric.utilities.types import _PATH
|
|||
|
||||
|
||||
class XLACheckpointIO(TorchCheckpointIO):
|
||||
"""CheckpointIO that utilizes :func:`xm.save` to save checkpoints for TPU training strategies.
|
||||
"""CheckpointIO that utilizes ``xm.save`` to save checkpoints for TPU training strategies.
|
||||
|
||||
.. warning:: This is an :ref:`experimental <versioning:Experimental API>` feature.
|
||||
|
||||
|
|
|
@ -47,7 +47,7 @@ class DoublePrecision(Precision):
|
|||
def forward_context(self) -> Generator[None, None, None]:
|
||||
"""A context manager to change the default tensor type.
|
||||
|
||||
See: :meth:`torch.set_default_dtype`
|
||||
See: :func:`torch.set_default_dtype`
|
||||
|
||||
"""
|
||||
default_dtype = torch.get_default_dtype()
|
||||
|
|
|
@ -111,7 +111,7 @@ class FSDPPrecision(Precision):
|
|||
def init_context(self) -> Generator[None, None, None]:
|
||||
"""A context manager to change the default tensor type when initializing module parameters or tensors.
|
||||
|
||||
See: :meth:`torch.set_default_dtype`
|
||||
See: :func:`torch.set_default_dtype`
|
||||
|
||||
"""
|
||||
default_dtype = torch.get_default_dtype()
|
||||
|
|
|
@ -44,7 +44,7 @@ class HalfPrecision(Precision):
|
|||
def init_context(self) -> Generator[None, None, None]:
|
||||
"""A context manager to change the default tensor type when initializing module parameters or tensors.
|
||||
|
||||
See: :meth:`torch.set_default_dtype`
|
||||
See: :func:`torch.set_default_dtype`
|
||||
|
||||
"""
|
||||
default_dtype = torch.get_default_dtype()
|
||||
|
@ -56,7 +56,7 @@ class HalfPrecision(Precision):
|
|||
def forward_context(self) -> Generator[None, None, None]:
|
||||
"""A context manager to change the default tensor type when tensors get created during the module's forward.
|
||||
|
||||
See: :meth:`torch.set_default_dtype`
|
||||
See: :func:`torch.set_default_dtype`
|
||||
|
||||
"""
|
||||
default_dtype = torch.get_default_dtype()
|
||||
|
|
|
@ -56,7 +56,7 @@ class DataParallelStrategy(ParallelStrategy):
|
|||
return None
|
||||
|
||||
def setup_module(self, module: Module) -> DataParallel:
|
||||
"""Wraps the given model into a :class:`~torch.nn.parallel.DataParallel` module."""
|
||||
"""Wraps the given model into a :class:`~torch.nn.DataParallel` module."""
|
||||
return DataParallel(module=module, device_ids=self.parallel_devices)
|
||||
|
||||
def module_to_device(self, module: Module) -> None:
|
||||
|
|
|
@ -16,7 +16,6 @@ from __future__ import annotations
|
|||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.nn import Module
|
||||
|
||||
from lightning.fabric.accelerators import Accelerator
|
||||
|
@ -55,7 +54,7 @@ class SingleDeviceStrategy(Strategy):
|
|||
def module_to_device(self, module: Module) -> None:
|
||||
module.to(self.root_device)
|
||||
|
||||
def all_reduce(self, tensor: Any | Tensor, *args: Any, **kwargs: Any) -> Any | Tensor:
|
||||
def all_reduce(self, tensor: Any | torch.Tensor, *args: Any, **kwargs: Any) -> Any | torch.Tensor:
|
||||
"""Reduces a tensor from several distributed processes to one aggregated tensor. As this plugin only operates
|
||||
with a single device, the reduction is simply the identity.
|
||||
|
||||
|
@ -70,8 +69,8 @@ class SingleDeviceStrategy(Strategy):
|
|||
"""
|
||||
return tensor
|
||||
|
||||
def all_gather(self, tensor: Tensor, group: Any | None = None, sync_grads: bool = False) -> Tensor:
|
||||
"""Perform a all_gather on all processes."""
|
||||
def all_gather(self, tensor: torch.Tensor, group: Any | None = None, sync_grads: bool = False) -> torch.Tensor:
|
||||
"""Perform a ``all_gather`` on all processes."""
|
||||
return tensor
|
||||
|
||||
def barrier(self, *args: Any, **kwargs: Any) -> None:
|
||||
|
|
|
@ -17,15 +17,15 @@ min_seed_value = np.iinfo(np.uint32).min
|
|||
|
||||
|
||||
def seed_everything(seed: Optional[int] = None, workers: bool = False) -> int:
|
||||
"""Function that sets seed for pseudo-random number generators in: pytorch, numpy, python.random In addition, sets
|
||||
the following environment variables:
|
||||
r"""Function that sets seed for pseudo-random number generators in: pytorch, numpy, python.random In addition,
|
||||
sets the following environment variables:
|
||||
|
||||
- `PL_GLOBAL_SEED`: will be passed to spawned subprocesses (e.g. ddp_spawn backend).
|
||||
- `PL_SEED_WORKERS`: (optional) is set to 1 if ``workers=True``.
|
||||
- ``PL_GLOBAL_SEED``: will be passed to spawned subprocesses (e.g. ddp_spawn backend).
|
||||
- ``PL_SEED_WORKERS``: (optional) is set to 1 if ``workers=True``.
|
||||
|
||||
Args:
|
||||
seed: the integer value seed for global random state in Lightning.
|
||||
If `None`, will read seed from `PL_GLOBAL_SEED` env variable
|
||||
If ``None``, will read seed from ``PL_GLOBAL_SEED`` env variable
|
||||
or select it randomly.
|
||||
workers: if set to ``True``, will properly configure all dataloaders passed to the
|
||||
Trainer with a ``worker_init_fn``. If the user already provides such a function
|
||||
|
@ -68,9 +68,9 @@ def _select_seed_randomly(min_seed_value: int = min_seed_value, max_seed_value:
|
|||
|
||||
|
||||
def reset_seed() -> None:
|
||||
"""Reset the seed to the value that :func:`lightning.fabric.utilities.seed.seed_everything` previously set.
|
||||
r"""Reset the seed to the value that :func:`~lightning.fabric.utilities.seed.seed_everything` previously set.
|
||||
|
||||
If :func:`lightning.fabric.utilities.seed.seed_everything` is unused, this function will do nothing.
|
||||
If :func:`~lightning.fabric.utilities.seed.seed_everything` is unused, this function will do nothing.
|
||||
|
||||
"""
|
||||
seed = os.environ.get("PL_GLOBAL_SEED", None)
|
||||
|
@ -81,7 +81,7 @@ def reset_seed() -> None:
|
|||
|
||||
|
||||
def pl_worker_init_function(worker_id: int, rank: Optional[int] = None) -> None: # pragma: no cover
|
||||
"""The worker_init_fn that Lightning automatically adds to your dataloader if you previously set the seed with
|
||||
r"""The worker_init_fn that Lightning automatically adds to your dataloader if you previously set the seed with
|
||||
``seed_everything(seed, workers=True)``.
|
||||
|
||||
See also the PyTorch documentation on
|
||||
|
@ -107,7 +107,7 @@ def pl_worker_init_function(worker_id: int, rank: Optional[int] = None) -> None:
|
|||
|
||||
|
||||
def _collect_rng_states(include_cuda: bool = True) -> Dict[str, Any]:
|
||||
"""Collect the global random state of :mod:`torch`, :mod:`torch.cuda`, :mod:`numpy` and Python."""
|
||||
r"""Collect the global random state of :mod:`torch`, :mod:`torch.cuda`, :mod:`numpy` and Python."""
|
||||
states = {
|
||||
"torch": torch.get_rng_state(),
|
||||
"numpy": np.random.get_state(),
|
||||
|
@ -119,7 +119,7 @@ def _collect_rng_states(include_cuda: bool = True) -> Dict[str, Any]:
|
|||
|
||||
|
||||
def _set_rng_states(rng_state_dict: Dict[str, Any]) -> None:
|
||||
"""Set the global random state of :mod:`torch`, :mod:`torch.cuda`, :mod:`numpy` and Python in the current
|
||||
r"""Set the global random state of :mod:`torch`, :mod:`torch.cuda`, :mod:`numpy` and Python in the current
|
||||
process."""
|
||||
torch.set_rng_state(rng_state_dict["torch"])
|
||||
# torch.cuda rng_state is only included since v1.8.
|
||||
|
|
|
@ -48,6 +48,7 @@ class _FabricOptimizer:
|
|||
optimizer: The optimizer to wrap
|
||||
strategy: Reference to the strategy for handling the optimizer step
|
||||
|
||||
|
||||
"""
|
||||
# `__del__` is skipped in case the optimizer has implemented custom destructor logic which we would
|
||||
# not want to call on destruction of the `_FabricOptimizer
|
||||
|
|
|
@ -38,7 +38,7 @@ class DoublePrecisionPlugin(PrecisionPlugin):
|
|||
def init_context(self) -> Generator[None, None, None]:
|
||||
"""A context manager to change the default tensor type when initializing module parameters or tensors.
|
||||
|
||||
See: :meth:`torch.set_default_dtype`
|
||||
See: :func:`torch.set_default_dtype`
|
||||
|
||||
"""
|
||||
default_dtype = torch.get_default_dtype()
|
||||
|
@ -50,7 +50,7 @@ class DoublePrecisionPlugin(PrecisionPlugin):
|
|||
def forward_context(self) -> Generator[None, None, None]:
|
||||
"""A context manager to change the default tensor type.
|
||||
|
||||
See: :meth:`torch.set_default_dtype`
|
||||
See: :func:`torch.set_default_dtype`
|
||||
|
||||
"""
|
||||
default_dtype = torch.get_default_dtype()
|
||||
|
|
|
@ -120,7 +120,7 @@ class FSDPPrecisionPlugin(PrecisionPlugin):
|
|||
def init_context(self) -> Generator[None, None, None]:
|
||||
"""A context manager to change the default tensor type when initializing module parameters or tensors.
|
||||
|
||||
See: :meth:`torch.set_default_dtype`
|
||||
See: :func:`torch.set_default_dtype`
|
||||
|
||||
"""
|
||||
default_dtype = torch.get_default_dtype()
|
||||
|
|
|
@ -44,7 +44,7 @@ class HalfPrecisionPlugin(PrecisionPlugin):
|
|||
def init_context(self) -> Generator[None, None, None]:
|
||||
"""A context manager to change the default tensor type when initializing module parameters or tensors.
|
||||
|
||||
See: :meth:`torch.set_default_dtype`
|
||||
See: :func:`torch.set_default_dtype`
|
||||
|
||||
"""
|
||||
default_dtype = torch.get_default_dtype()
|
||||
|
|
Loading…
Reference in New Issue