Standalone Lite: Remaining Utilities (#14492)
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> Co-authored-by: Laverne Henderson <laverne.henderson@coupa.com> Co-authored-by: Felonious-Spellfire <felonious.spellfire@gmail.com>
This commit is contained in:
parent
31dc6c6714
commit
d2459df2ff
|
@ -88,6 +88,13 @@ jobs:
|
|||
pip list
|
||||
shell: bash
|
||||
|
||||
- name: Testing Warnings
|
||||
# the stacklevel can only be set on >=3.7
|
||||
if: matrix.python-version != '3.7'
|
||||
working-directory: tests/tests_lite
|
||||
# needs to run outside of `pytest`
|
||||
run: python utilities/test_warnings.py
|
||||
|
||||
- name: Testing Lite
|
||||
working-directory: tests/tests_lite
|
||||
# NOTE: do not include coverage report here, see: https://github.com/nedbat/coveragepy/issues/1003
|
||||
|
|
|
@ -231,7 +231,7 @@ Use the :func:`~pytorch_lightning.loggers.logger.rank_zero_experiment` and :func
|
|||
.. testcode::
|
||||
|
||||
from pytorch_lightning.loggers.logger import Logger, rank_zero_experiment
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_only
|
||||
from pytorch_lightning.utilities import rank_zero_only
|
||||
|
||||
|
||||
class MyLogger(Logger):
|
||||
|
|
|
@ -26,8 +26,8 @@ from torch.utils.data import DataLoader, random_split
|
|||
from pytorch_lightning import callbacks, cli_lightning_logo, LightningDataModule, LightningModule, Trainer
|
||||
from pytorch_lightning.cli import LightningCLI
|
||||
from pytorch_lightning.demos.mnist_datamodule import MNIST
|
||||
from pytorch_lightning.utilities import rank_zero_only
|
||||
from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_only
|
||||
|
||||
if _TORCHVISION_AVAILABLE:
|
||||
import torchvision
|
||||
|
|
|
@ -57,7 +57,7 @@ from torchvision.datasets.utils import download_and_extract_archive
|
|||
from pytorch_lightning import cli_lightning_logo, LightningDataModule, LightningModule
|
||||
from pytorch_lightning.callbacks.finetuning import BaseFinetuning
|
||||
from pytorch_lightning.cli import LightningCLI
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_info
|
||||
from pytorch_lightning.utilities import rank_zero_info
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
DATA_URL = "https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip"
|
||||
|
|
|
@ -27,6 +27,7 @@ exclude = '(_notebooks/.*)'
|
|||
[tool.mypy]
|
||||
files = [
|
||||
"src/pytorch_lightning",
|
||||
"src/lightning_lite",
|
||||
# TODO: Check typing in app source
|
||||
# "src/lightning_app",
|
||||
]
|
||||
|
@ -57,5 +58,6 @@ module = [
|
|||
"pytorch_lightning.trainer.trainer",
|
||||
"pytorch_lightning.tuner.batch_size_scaling",
|
||||
"pytorch_lightning.utilities.data",
|
||||
"lightning_lite.utilities.data",
|
||||
]
|
||||
ignore_errors = "True"
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package
|
||||
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment
|
||||
|
||||
numpy>=1.17.2, <1.23.1
|
||||
torch>=1.9.*, <1.13.0
|
||||
fsspec[http]>=2021.05.0, !=2021.06.0, <2022.6.0
|
||||
packaging>=17.0, <=21.3
|
||||
|
|
|
@ -1,4 +1,21 @@
|
|||
"""Root package info."""
|
||||
import logging
|
||||
|
||||
from lightning_lite.__about__ import * # noqa: F401, F403
|
||||
from lightning_lite.__version__ import version as __version__ # noqa: F401
|
||||
|
||||
_root_logger = logging.getLogger()
|
||||
_logger = logging.getLogger(__name__)
|
||||
_logger.setLevel(logging.INFO)
|
||||
|
||||
if not _root_logger.hasHandlers():
|
||||
_logger.addHandler(logging.StreamHandler())
|
||||
_logger.propagate = False
|
||||
|
||||
from lightning_lite.lite import LightningLite # noqa: E402
|
||||
from lightning_lite.utilities.seed import seed_everything # noqa: E402
|
||||
|
||||
__all__ = ["LightningLite", "seed_everything"]
|
||||
|
||||
# for compatibility with namespace packages
|
||||
__import__("pkg_resources").declare_namespace(__name__)
|
||||
|
|
|
@ -0,0 +1,40 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""General utilities."""
|
||||
|
||||
from lightning_lite.utilities.apply_func import move_data_to_device # noqa: F401
|
||||
from lightning_lite.utilities.distributed import AllGatherGrad # noqa: F401
|
||||
from lightning_lite.utilities.enums import _AcceleratorType, _StrategyType, AMPType, LightningEnum # noqa: F401
|
||||
|
||||
# TODO(lite): Avoid importing protected attributes in `__init__.py` files
|
||||
from lightning_lite.utilities.imports import ( # noqa: F401
|
||||
_HIVEMIND_AVAILABLE,
|
||||
_HOROVOD_AVAILABLE,
|
||||
_HPU_AVAILABLE,
|
||||
_IPU_AVAILABLE,
|
||||
_IS_INTERACTIVE,
|
||||
_IS_WINDOWS,
|
||||
_POPTORCH_AVAILABLE,
|
||||
_TORCH_GREATER_EQUAL_1_10,
|
||||
_TORCH_GREATER_EQUAL_1_11,
|
||||
_TORCH_GREATER_EQUAL_1_12,
|
||||
_TPU_AVAILABLE,
|
||||
_XLA_AVAILABLE,
|
||||
)
|
||||
from lightning_lite.utilities.rank_zero import ( # noqa: F401
|
||||
rank_zero_deprecation,
|
||||
rank_zero_info,
|
||||
rank_zero_only,
|
||||
rank_zero_warn,
|
||||
)
|
|
@ -22,7 +22,7 @@ import torch
|
|||
from fsspec.core import url_to_fs
|
||||
from fsspec.implementations.local import AbstractFileSystem
|
||||
|
||||
from pytorch_lightning.utilities.types import _MAP_LOCATION_TYPE, _PATH
|
||||
from lightning_lite.utilities.types import _MAP_LOCATION_TYPE, _PATH
|
||||
|
||||
|
||||
def load(
|
||||
|
|
|
@ -0,0 +1,411 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import functools
|
||||
import inspect
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from contextlib import contextmanager
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Dict, Generator, Iterable, Optional, Tuple, Type, Union
|
||||
|
||||
from lightning_utilities.core.inheritance import get_all_subclasses
|
||||
from torch.utils.data import BatchSampler, DataLoader, IterableDataset, Sampler
|
||||
|
||||
from lightning_lite.utilities.enums import LightningEnum
|
||||
from lightning_lite.utilities.exceptions import MisconfigurationException
|
||||
from lightning_lite.utilities.rank_zero import rank_zero_warn
|
||||
from lightning_lite.utilities.seed import pl_worker_init_function
|
||||
|
||||
|
||||
class _WrapAttrTag(LightningEnum):
|
||||
SET = "set"
|
||||
DEL = "del"
|
||||
|
||||
def __call__(self, *args):
|
||||
if self == self.SET:
|
||||
fn = setattr
|
||||
else:
|
||||
fn = delattr
|
||||
return fn(*args)
|
||||
|
||||
|
||||
def has_iterable_dataset(dataloader: DataLoader) -> bool:
|
||||
return hasattr(dataloader, "dataset") and isinstance(dataloader.dataset, IterableDataset)
|
||||
|
||||
|
||||
def has_len(dataloader: Union[DataLoader, Iterable]) -> bool:
|
||||
"""Checks if a given Dataloader has ``__len__`` method implemented i.e. if it is a finite dataloader or
|
||||
infinite dataloader."""
|
||||
try:
|
||||
# try getting the length
|
||||
if len(dataloader) == 0:
|
||||
rank_zero_warn(
|
||||
f"`{dataloader.__class__.__name__}` returned 0 length. Please make sure this was your intention."
|
||||
)
|
||||
has_len = True
|
||||
except (TypeError, NotImplementedError):
|
||||
has_len = False
|
||||
|
||||
if has_len and has_iterable_dataset(dataloader):
|
||||
rank_zero_warn(
|
||||
"Your `IterableDataset` has `__len__` defined."
|
||||
" In combination with multi-process data loading (when num_workers > 1),"
|
||||
" `__len__` could be inaccurate if each worker is not configured independently"
|
||||
" to avoid having duplicate data."
|
||||
)
|
||||
return has_len
|
||||
|
||||
|
||||
def _update_dataloader(dataloader: DataLoader, sampler: Union[Sampler, Iterable]) -> DataLoader:
|
||||
dl_args, dl_kwargs = _get_dataloader_init_args_and_kwargs(dataloader, sampler)
|
||||
dataloader = _reinstantiate_wrapped_cls(dataloader, *dl_args, **dl_kwargs)
|
||||
return dataloader
|
||||
|
||||
|
||||
def _get_dataloader_init_args_and_kwargs(
|
||||
dataloader: DataLoader,
|
||||
sampler: Optional[Sampler],
|
||||
disallow_batch_sampler: bool = False,
|
||||
) -> Tuple[Tuple[Any], Dict[str, Any]]:
|
||||
if not isinstance(dataloader, DataLoader):
|
||||
raise ValueError(f"The dataloader {dataloader} needs to subclass `torch.utils.data.DataLoader`")
|
||||
|
||||
was_wrapped = hasattr(dataloader, "__pl_saved_args")
|
||||
if was_wrapped:
|
||||
dl_args = dataloader.__pl_saved_args
|
||||
dl_kwargs = dataloader.__pl_saved_kwargs
|
||||
arg_names = dataloader.__pl_saved_arg_names
|
||||
original_dataset = dataloader.__dataset # we have this saved from _wrap_init
|
||||
else:
|
||||
# get the dataloader instance attributes
|
||||
attrs = {k: v for k, v in vars(dataloader).items() if not k.startswith("_")}
|
||||
# We cannot be 100% sure the class sets dataset argument. Let's set it to None to be safe
|
||||
# and hope we can get it from the instance attributes
|
||||
original_dataset = None
|
||||
# not part of `vars`
|
||||
attrs["multiprocessing_context"] = dataloader.multiprocessing_context
|
||||
arg_names = ()
|
||||
|
||||
# get the dataloader instance `__init__` parameters
|
||||
params = dict(inspect.signature(dataloader.__init__).parameters)
|
||||
has_variadic_kwargs = any(p.kind is p.VAR_KEYWORD for p in params.values())
|
||||
if has_variadic_kwargs:
|
||||
# if the signature takes **kwargs, assume they will be passed down with `super().__init__(**kwargs)`
|
||||
|
||||
if was_wrapped:
|
||||
# if the dataloader was wrapped in a hook, only take arguments with default values
|
||||
# and assume user passes their kwargs correctly
|
||||
params.update(
|
||||
{k: v for k, v in inspect.signature(DataLoader.__init__).parameters.items() if v.default is not v.empty}
|
||||
)
|
||||
else:
|
||||
params.update(inspect.signature(DataLoader.__init__).parameters)
|
||||
params.pop("self", None)
|
||||
|
||||
if not was_wrapped:
|
||||
# keep only the params whose default is different to the current attr value
|
||||
non_defaults = {name for name, p in params.items() if name in attrs and p.default != attrs[name]}
|
||||
|
||||
# add `dataset` as it might have been replaced with `*args`
|
||||
non_defaults.add("dataset")
|
||||
# kwargs to re-construct the dataloader
|
||||
dl_kwargs = {k: v for k, v in attrs.items() if k in non_defaults}
|
||||
dl_args = ()
|
||||
|
||||
dataset = dl_kwargs.get("dataset", original_dataset)
|
||||
if isinstance(dataset, IterableDataset):
|
||||
dl_kwargs["batch_sampler"] = None
|
||||
dl_kwargs["sampler"] = None
|
||||
else:
|
||||
dl_kwargs.update(_dataloader_init_kwargs_resolve_sampler(dataloader, sampler, disallow_batch_sampler))
|
||||
|
||||
required_args = {
|
||||
p.name
|
||||
for p in params.values()
|
||||
if p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD)
|
||||
and p.default is p.empty
|
||||
and p.name not in dl_kwargs
|
||||
and p.name not in arg_names
|
||||
}
|
||||
# the dataloader has required args which we could not extract from the existing attributes
|
||||
if required_args:
|
||||
required_args = sorted(required_args)
|
||||
dataloader_cls_name = dataloader.__class__.__name__
|
||||
missing_args_message = ", ".join(f"`self.{arg_name}`" for arg_name in required_args)
|
||||
raise MisconfigurationException(
|
||||
f"Trying to inject custom `Sampler` into the `{dataloader_cls_name}` instance. "
|
||||
"This would fail as some of the `__init__` arguments are not available as instance attributes. "
|
||||
f"The missing attributes are {required_args}. If you instantiate your `{dataloader_cls_name}` inside a "
|
||||
"`*_dataloader` hook of your module, we will do this for you."
|
||||
f" Otherwise, define {missing_args_message} inside your `__init__`."
|
||||
)
|
||||
|
||||
if not has_variadic_kwargs:
|
||||
# the dataloader signature does not allow keyword arguments that need to be passed
|
||||
missing_kwargs = (set(dl_kwargs) | set(arg_names)) - params.keys()
|
||||
if missing_kwargs:
|
||||
missing_kwargs = sorted(missing_kwargs)
|
||||
dataloader_cls_name = dataloader.__class__.__name__
|
||||
raise TypeError(
|
||||
f"Trying to inject parameters into the `{dataloader_cls_name}` instance. "
|
||||
"This would fail as it doesn't expose all its attributes in the `__init__` signature. "
|
||||
f"The missing arguments are {missing_kwargs}. HINT: If you wrote the `{dataloader_cls_name}` class, "
|
||||
"add the `__init__` arguments or allow passing `**kwargs`"
|
||||
)
|
||||
|
||||
return dl_args, dl_kwargs
|
||||
|
||||
|
||||
def _dataloader_init_kwargs_resolve_sampler(
|
||||
dataloader: DataLoader,
|
||||
sampler: Optional[Sampler],
|
||||
disallow_batch_sampler: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
"""This function is used to handle the sampler, batch_sampler arguments associated within a DataLoader for its
|
||||
re-instantiation.
|
||||
|
||||
If there are multiple devices in IPU mode, it is necessary to disallow BatchSampler that isn't instantiated
|
||||
automatically, since `poptorch.DataLoader` will try to increase the batch_size
|
||||
"""
|
||||
batch_sampler = getattr(dataloader, "batch_sampler")
|
||||
|
||||
if batch_sampler is not None:
|
||||
if disallow_batch_sampler:
|
||||
# Check that we don't have a PyTorch default batch sampler that was instantiated in DataLoader __init__
|
||||
if not (
|
||||
type(batch_sampler) is BatchSampler
|
||||
and batch_sampler.sampler == sampler
|
||||
and dataloader.batch_size == batch_sampler.batch_size
|
||||
):
|
||||
raise MisconfigurationException(
|
||||
"It is not possible to have a batch sampler in your dataloader, "
|
||||
"when running on multiple IPU devices."
|
||||
)
|
||||
elif type(batch_sampler) is not BatchSampler:
|
||||
batch_sampler_cls = type(batch_sampler)
|
||||
if hasattr(batch_sampler, "__pl_saved_args"):
|
||||
args = batch_sampler.__pl_saved_args
|
||||
kwargs = batch_sampler.__pl_saved_kwargs
|
||||
default_kwargs = batch_sampler.__pl_saved_default_kwargs
|
||||
arg_names = batch_sampler.__pl_saved_arg_names
|
||||
|
||||
success, args, kwargs = _replace_value_in_saved_args(
|
||||
"sampler", sampler, args, kwargs, default_kwargs, arg_names
|
||||
)
|
||||
if not success:
|
||||
raise TypeError(
|
||||
"Trying to inject a modified sampler into the batch sampler; however, it seems the class "
|
||||
f"`{batch_sampler_cls.__qualname__}` does not have an argument called `sampler.` To mitigate "
|
||||
"this, expose an argument `sampler` in the `__init__` method of your custom class."
|
||||
)
|
||||
|
||||
batch_sampler = _reinstantiate_wrapped_cls(batch_sampler, *args, **kwargs)
|
||||
else:
|
||||
try:
|
||||
batch_sampler = batch_sampler_cls(
|
||||
sampler,
|
||||
batch_size=batch_sampler.batch_size,
|
||||
drop_last=batch_sampler.drop_last,
|
||||
)
|
||||
except TypeError as e:
|
||||
import re
|
||||
|
||||
match = re.match(r".*__init__\(\) (got multiple values)|(missing \d required)", str(e))
|
||||
if not match:
|
||||
# an unexpected `TypeError`, continue failure
|
||||
raise
|
||||
|
||||
# There could either be too few or too many arguments. Customizing the message based on this doesn't
|
||||
# make much sense since our MisconfigurationException is going to be raised from the original one.
|
||||
raise TypeError(
|
||||
"We tried to re-instantiate your custom batch sampler and failed. "
|
||||
"To mitigate this, either follow the API of `BatchSampler` or instantiate "
|
||||
"your custom batch sampler inside `*_dataloader` hooks of your module."
|
||||
) from e
|
||||
|
||||
return {
|
||||
"sampler": None,
|
||||
"shuffle": False,
|
||||
"batch_sampler": batch_sampler,
|
||||
"batch_size": 1,
|
||||
"drop_last": False,
|
||||
}
|
||||
|
||||
return {"sampler": sampler, "shuffle": False, "batch_sampler": None}
|
||||
|
||||
|
||||
def _auto_add_worker_init_fn(dataloader: DataLoader, rank: int) -> None:
|
||||
if int(os.environ.get("PL_SEED_WORKERS", 0)) and dataloader.worker_init_fn is None:
|
||||
dataloader.worker_init_fn = partial(pl_worker_init_function, rank=rank)
|
||||
|
||||
|
||||
def _reinstantiate_wrapped_cls(orig_object: Any, *args: Any, explicit_cls: Optional[Type] = None, **kwargs: Any) -> Any:
|
||||
constructor = type(orig_object) if explicit_cls is None else explicit_cls
|
||||
|
||||
try:
|
||||
result = constructor(*args, **kwargs)
|
||||
except TypeError as e:
|
||||
# improve exception message due to an incorrect implementation of the `DataLoader` where multiple subclass
|
||||
# `__init__` arguments map to one `DataLoader.__init__` argument
|
||||
import re
|
||||
|
||||
match = re.match(r".*__init__\(\) got multiple values .* '(\w+)'", str(e))
|
||||
if not match:
|
||||
# an unexpected `TypeError`, continue failure
|
||||
raise
|
||||
argument = match.groups()[0]
|
||||
message = (
|
||||
f"The {constructor.__name__} implementation has an error where more than one `__init__` argument"
|
||||
f" can be passed to its parent's `{argument}=...` `__init__` argument. This is likely caused by allowing"
|
||||
f" passing both a custom argument that will map to the `{argument}` argument as well as `**kwargs`."
|
||||
f" `kwargs` should be filtered to make sure they don't contain the `{argument}` key."
|
||||
" This argument was automatically passed to your object by PyTorch Lightning."
|
||||
)
|
||||
raise MisconfigurationException(message) from e
|
||||
|
||||
attrs_record = getattr(orig_object, "__pl_attrs_record", list())
|
||||
for args, fn in attrs_record:
|
||||
fn(result, *args)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _wrap_init_method(init: Callable, store_explicit_arg: Optional[str] = None) -> Callable:
|
||||
"""Wraps the ``__init__`` method of classes (currently :class:`~torch.utils.data.DataLoader` and
|
||||
:class:`~torch.utils.data.BatchSampler`) in order to enable re-instantiation of custom subclasses."""
|
||||
|
||||
@functools.wraps(init)
|
||||
def wrapper(obj: Any, *args: Any, **kwargs: Any) -> None:
|
||||
# We need to inspect `init`, as inspecting `obj.__init__`
|
||||
# can lead to inspecting the wrong function with multiple inheritance
|
||||
old_inside_init = getattr(obj, "__pl_inside_init", False)
|
||||
object.__setattr__(obj, "__pl_inside_init", True)
|
||||
params = inspect.signature(init).parameters
|
||||
|
||||
parameters_defaults = OrderedDict(
|
||||
(param.name, param.default)
|
||||
for param in params.values()
|
||||
if param.name != "self" and param.kind not in (param.VAR_POSITIONAL, param.VAR_KEYWORD)
|
||||
)
|
||||
|
||||
param_names = tuple(parameters_defaults)[: len(args)]
|
||||
|
||||
default_kwargs = {
|
||||
name: value
|
||||
for name, value in parameters_defaults.items()
|
||||
if name not in kwargs and name not in param_names and value != inspect.Parameter.empty
|
||||
}
|
||||
|
||||
if not hasattr(obj, "__pl_saved_args"):
|
||||
object.__setattr__(obj, "__pl_saved_args", args)
|
||||
object.__setattr__(obj, "__pl_saved_kwargs", kwargs)
|
||||
object.__setattr__(obj, "__pl_saved_arg_names", param_names)
|
||||
object.__setattr__(obj, "__pl_saved_default_kwargs", default_kwargs)
|
||||
|
||||
# We want to use the latest possible value for explicit argument (i.e. ideally what gets passed to base class)
|
||||
# so that we can be sure, that it will not get changed anymore.
|
||||
# That is why we are setting this in every `__init__`
|
||||
if store_explicit_arg is not None:
|
||||
if store_explicit_arg in param_names:
|
||||
object.__setattr__(obj, f"__{store_explicit_arg}", args[param_names.index(store_explicit_arg)])
|
||||
elif store_explicit_arg in kwargs:
|
||||
object.__setattr__(obj, f"__{store_explicit_arg}", kwargs[store_explicit_arg])
|
||||
|
||||
init(obj, *args, **kwargs)
|
||||
object.__setattr__(obj, "__pl_inside_init", old_inside_init)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def _wrap_attr_method(method: Callable, tag: _WrapAttrTag) -> Callable:
|
||||
"""Wraps the ``__setattr__`` or ``__delattr__`` method of classes (currently :class:`~torch.utils.data.DataLoader` and
|
||||
:class:`~torch.utils.data.BatchSampler`) in order to enable re-instantiation of custom subclasses."""
|
||||
|
||||
@functools.wraps(method)
|
||||
def wrapper(obj: Any, *args: Any):
|
||||
# First, let's find out if we're the first in inheritance chain calling the patched method.
|
||||
name, *_ = args
|
||||
prev_call_name, prev_call_method = getattr(obj, "__pl_current_call", (None, "method"))
|
||||
first_call = not (prev_call_name == name and prev_call_method == tag)
|
||||
|
||||
# Then mark the current called method
|
||||
object.__setattr__(obj, "__pl_current_call", (name, tag))
|
||||
|
||||
# call original method
|
||||
method(obj, *args)
|
||||
if first_call and not getattr(obj, "__pl_inside_init", True):
|
||||
# and save the value it was called with to the internal list,
|
||||
# if we're outside of __init__ and the original call did not fail and we're the first call
|
||||
attrs_record = getattr(obj, "__pl_attrs_record", list())
|
||||
attrs_record.append((args, tag))
|
||||
object.__setattr__(obj, "__pl_attrs_record", attrs_record)
|
||||
object.__setattr__(obj, "__pl_current_call", (prev_call_name, prev_call_method))
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _replace_dunder_methods(base_cls: Type, store_explicit_arg: Optional[str] = None) -> Generator[None, None, None]:
|
||||
"""This context manager is used to add support for re-instantiation of custom (subclasses) of `base_cls`.
|
||||
|
||||
It patches the ``__init__``, ``__setattr__`` and ``__delattr__`` methods.
|
||||
"""
|
||||
classes = get_all_subclasses(base_cls) | {base_cls}
|
||||
for cls in classes:
|
||||
# Check that __init__ belongs to the class
|
||||
# https://stackoverflow.com/a/5253424
|
||||
if "__init__" in cls.__dict__:
|
||||
cls.__old__init__ = cls.__init__
|
||||
cls.__init__ = _wrap_init_method(cls.__init__, store_explicit_arg)
|
||||
|
||||
# we want at least one setattr/delattr in the chain to be patched and it can happen, that none of the subclasses
|
||||
# implement `__setattr__`/`__delattr__`. Therefore, we are always patching the `base_cls`
|
||||
for patch_fn_name, tag in (("__setattr__", _WrapAttrTag.SET), ("__delattr__", _WrapAttrTag.DEL)):
|
||||
if patch_fn_name in cls.__dict__ or cls is base_cls:
|
||||
saved_name = f"__old{patch_fn_name}"
|
||||
setattr(cls, saved_name, getattr(cls, patch_fn_name))
|
||||
setattr(cls, patch_fn_name, _wrap_attr_method(getattr(cls, patch_fn_name), tag))
|
||||
yield
|
||||
for cls in classes:
|
||||
for patched_name in ("__setattr__", "__delattr__", "__init__"):
|
||||
# Check that __old__{init,setattr,delattr} belongs to the class
|
||||
# https://stackoverflow.com/a/5253424
|
||||
if f"__old{patched_name}" in cls.__dict__:
|
||||
setattr(cls, patched_name, getattr(cls, f"__old{patched_name}"))
|
||||
delattr(cls, f"__old{patched_name}")
|
||||
|
||||
|
||||
def _replace_value_in_saved_args(
|
||||
replace_key: str,
|
||||
replace_value: Any,
|
||||
args: Tuple[Any, ...],
|
||||
kwargs: Dict[str, Any],
|
||||
default_kwargs: Dict[str, Any],
|
||||
arg_names: Tuple[str, ...],
|
||||
) -> Tuple[bool, Tuple[Any, ...], Dict[str, Any]]:
|
||||
"""Tries to replace an argument value in a saved list of args and kwargs.
|
||||
|
||||
Returns a tuple indicating success of the operation and modified saved args and kwargs
|
||||
"""
|
||||
|
||||
if replace_key in arg_names:
|
||||
replace_index = arg_names.index(replace_key)
|
||||
args = args[:replace_index] + (replace_value,) + args[replace_index + 1 :]
|
||||
return True, args, kwargs
|
||||
elif replace_key in kwargs or replace_key in default_kwargs:
|
||||
kwargs[replace_key] = replace_value
|
||||
return True, args, kwargs
|
||||
|
||||
return False, args, kwargs
|
|
@ -0,0 +1,316 @@
|
|||
import multiprocessing
|
||||
import os
|
||||
from typing import Any, List, MutableSequence, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
# TODO(lite): Fix the imports
|
||||
# from lightning_lite.plugins.environments import TorchElasticEnvironment
|
||||
# from lightning_lite.strategies.launchers.multiprocessing import _is_forking_disabled
|
||||
from lightning_lite.utilities.exceptions import MisconfigurationException
|
||||
from lightning_lite.utilities.types import _DEVICE
|
||||
|
||||
|
||||
def determine_root_gpu_device(gpus: List[_DEVICE]) -> Optional[_DEVICE]:
|
||||
"""
|
||||
Args:
|
||||
gpus: Non-empty list of ints representing which GPUs to use
|
||||
|
||||
Returns:
|
||||
Designated root GPU device id
|
||||
|
||||
Raises:
|
||||
TypeError:
|
||||
If ``gpus`` is not a list
|
||||
AssertionError:
|
||||
If GPU list is empty
|
||||
"""
|
||||
if gpus is None:
|
||||
return None
|
||||
|
||||
if not isinstance(gpus, list):
|
||||
raise TypeError("GPUs should be a list")
|
||||
|
||||
assert len(gpus) > 0, "GPUs should be a non-empty list"
|
||||
|
||||
# set root gpu
|
||||
root_gpu = gpus[0]
|
||||
|
||||
return root_gpu
|
||||
|
||||
|
||||
def parse_gpu_ids(
|
||||
gpus: Optional[Union[int, str, List[int]]],
|
||||
include_cuda: bool = False,
|
||||
include_mps: bool = False,
|
||||
) -> Optional[List[int]]:
|
||||
"""
|
||||
Parses the GPU IDs given in the format as accepted by the
|
||||
:class:`~pytorch_lightning.trainer.Trainer`.
|
||||
|
||||
Args:
|
||||
gpus: An int -1 or string '-1' indicate that all available GPUs should be used.
|
||||
A list of unique ints or a string containing a list of comma separated unique integers
|
||||
indicates specific GPUs to use.
|
||||
An int of 0 means that no GPUs should be used.
|
||||
Any int N > 0 indicates that GPUs [0..N) should be used.
|
||||
include_cuda: A boolean value indicating whether to include CUDA devices for GPU parsing.
|
||||
include_mps: A boolean value indicating whether to include MPS devices for GPU parsing.
|
||||
|
||||
Returns:
|
||||
A list of GPUs to be used or ``None`` if no GPUs were requested
|
||||
|
||||
Raises:
|
||||
MisconfigurationException:
|
||||
If no GPUs are available but the value of gpus variable indicates request for GPUs
|
||||
|
||||
.. note::
|
||||
``include_cuda`` and ``include_mps`` default to ``False`` so that you only
|
||||
have to specify which device type to use and all other devices are not disabled.
|
||||
"""
|
||||
# Check that gpus param is None, Int, String or Sequence of Ints
|
||||
_check_data_type(gpus)
|
||||
|
||||
# Handle the case when no GPUs are requested
|
||||
if gpus is None or (isinstance(gpus, int) and gpus == 0) or str(gpus).strip() in ("0", "[]"):
|
||||
return None
|
||||
|
||||
# We know the user requested GPUs therefore if some of the
|
||||
# requested GPUs are not available an exception is thrown.
|
||||
gpus = _normalize_parse_gpu_string_input(gpus)
|
||||
gpus = _normalize_parse_gpu_input_to_list(gpus, include_cuda=include_cuda, include_mps=include_mps)
|
||||
if not gpus:
|
||||
raise MisconfigurationException("GPUs requested but none are available.")
|
||||
|
||||
if (
|
||||
True # TorchElasticEnvironment.detect() # TODO(lite): Revert this once environments have moved
|
||||
and len(gpus) != 1
|
||||
and len(_get_all_available_gpus(include_cuda=include_cuda, include_mps=include_mps)) == 1
|
||||
):
|
||||
# Omit sanity check on torchelastic because by default it shows one visible GPU per process
|
||||
return gpus
|
||||
|
||||
# Check that GPUs are unique. Duplicate GPUs are not supported by the backend.
|
||||
_check_unique(gpus)
|
||||
|
||||
return _sanitize_gpu_ids(gpus, include_cuda=include_cuda, include_mps=include_mps)
|
||||
|
||||
|
||||
def parse_tpu_cores(tpu_cores: Optional[Union[int, str, List[int]]]) -> Optional[Union[int, List[int]]]:
|
||||
"""
|
||||
Parses the tpu_cores given in the format as accepted by the
|
||||
:class:`~pytorch_lightning.trainer.Trainer`.
|
||||
|
||||
Args:
|
||||
tpu_cores: An int of 1 or string '1' indicates that 1 core with multi-processing should be used
|
||||
An int 8 or string '8' indicates that all 8 cores with multi-processing should be used
|
||||
A list of ints or a strings containing a list of comma separated integers
|
||||
indicates the specific TPU core to use.
|
||||
|
||||
Returns:
|
||||
A list of tpu_cores to be used or ``None`` if no TPU cores were requested
|
||||
|
||||
Raises:
|
||||
MisconfigurationException:
|
||||
If TPU cores aren't 1, 8 or [<1-8>]
|
||||
"""
|
||||
_check_data_type(tpu_cores)
|
||||
|
||||
if isinstance(tpu_cores, str):
|
||||
tpu_cores = _parse_tpu_cores_str(tpu_cores.strip())
|
||||
|
||||
if not _tpu_cores_valid(tpu_cores):
|
||||
raise MisconfigurationException("`tpu_cores` can only be 1, 8 or [<1-8>]")
|
||||
|
||||
return tpu_cores
|
||||
|
||||
|
||||
def parse_cpu_cores(cpu_cores: Union[int, str, List[int]]) -> int:
|
||||
"""Parses the cpu_cores given in the format as accepted by the ``devices`` argument in the
|
||||
:class:`~pytorch_lightning.trainer.Trainer`.
|
||||
|
||||
Args:
|
||||
cpu_cores: An int > 0.
|
||||
|
||||
Returns:
|
||||
An int representing the number of processes
|
||||
|
||||
Raises:
|
||||
MisconfigurationException:
|
||||
If cpu_cores is not an int > 0
|
||||
"""
|
||||
if isinstance(cpu_cores, str) and cpu_cores.strip().isdigit():
|
||||
cpu_cores = int(cpu_cores)
|
||||
|
||||
if not isinstance(cpu_cores, int) or cpu_cores <= 0:
|
||||
raise MisconfigurationException("`devices` selected with `CPUAccelerator` should be an int > 0.")
|
||||
|
||||
return cpu_cores
|
||||
|
||||
|
||||
def _normalize_parse_gpu_string_input(s: Union[int, str, List[int]]) -> Union[int, List[int]]:
|
||||
if not isinstance(s, str):
|
||||
return s
|
||||
if s == "-1":
|
||||
return -1
|
||||
if "," in s:
|
||||
return [int(x.strip()) for x in s.split(",") if len(x) > 0]
|
||||
return int(s.strip())
|
||||
|
||||
|
||||
def _sanitize_gpu_ids(gpus: List[int], include_cuda: bool = False, include_mps: bool = False) -> List[int]:
|
||||
"""Checks that each of the GPUs in the list is actually available. Raises a MisconfigurationException if any of
|
||||
the GPUs is not available.
|
||||
|
||||
Args:
|
||||
gpus: List of ints corresponding to GPU indices
|
||||
|
||||
Returns:
|
||||
Unmodified gpus variable
|
||||
|
||||
Raises:
|
||||
MisconfigurationException:
|
||||
If machine has fewer available GPUs than requested.
|
||||
"""
|
||||
if sum((include_cuda, include_mps)) == 0:
|
||||
raise ValueError("At least one gpu type should be specified!")
|
||||
all_available_gpus = _get_all_available_gpus(include_cuda=include_cuda, include_mps=include_mps)
|
||||
for gpu in gpus:
|
||||
if gpu not in all_available_gpus:
|
||||
raise MisconfigurationException(
|
||||
f"You requested gpu: {gpus}\n But your machine only has: {all_available_gpus}"
|
||||
)
|
||||
return gpus
|
||||
|
||||
|
||||
def _normalize_parse_gpu_input_to_list(
|
||||
gpus: Union[int, List[int], Tuple[int, ...]], include_cuda: bool, include_mps: bool
|
||||
) -> Optional[List[int]]:
|
||||
assert gpus is not None
|
||||
if isinstance(gpus, (MutableSequence, tuple)):
|
||||
return list(gpus)
|
||||
|
||||
# must be an int
|
||||
if not gpus: # gpus==0
|
||||
return None
|
||||
if gpus == -1:
|
||||
return _get_all_available_gpus(include_cuda=include_cuda, include_mps=include_mps)
|
||||
|
||||
return list(range(gpus))
|
||||
|
||||
|
||||
def _get_all_available_gpus(include_cuda: bool = False, include_mps: bool = False) -> List[int]:
|
||||
"""
|
||||
Returns:
|
||||
A list of all available GPUs
|
||||
"""
|
||||
cuda_gpus = _get_all_available_cuda_gpus() if include_cuda else []
|
||||
mps_gpus = _get_all_available_mps_gpus() if include_mps else []
|
||||
return cuda_gpus + mps_gpus
|
||||
|
||||
|
||||
def _get_all_available_mps_gpus() -> List[int]:
|
||||
"""
|
||||
Returns:
|
||||
A list of all available MPS GPUs
|
||||
"""
|
||||
# lazy import to avoid circular dependencies
|
||||
# from lightning_lite.accelerators.mps import _MPS_AVAILABLE
|
||||
_MPS_AVAILABLE = False # TODO(lite): revert this once MPS utils have moved
|
||||
return [0] if _MPS_AVAILABLE else []
|
||||
|
||||
|
||||
def _get_all_available_cuda_gpus() -> List[int]:
|
||||
"""
|
||||
Returns:
|
||||
A list of all available CUDA GPUs
|
||||
"""
|
||||
return list(range(num_cuda_devices()))
|
||||
|
||||
|
||||
def _check_unique(device_ids: List[int]) -> None:
|
||||
"""Checks that the device_ids are unique.
|
||||
|
||||
Args:
|
||||
device_ids: List of ints corresponding to GPUs indices
|
||||
|
||||
Raises:
|
||||
MisconfigurationException:
|
||||
If ``device_ids`` of GPUs aren't unique
|
||||
"""
|
||||
if len(device_ids) != len(set(device_ids)):
|
||||
raise MisconfigurationException("Device ID's (GPU) must be unique.")
|
||||
|
||||
|
||||
def _check_data_type(device_ids: Any) -> None:
|
||||
"""Checks that the device_ids argument is one of the following: None, int, string, or sequence of integers.
|
||||
|
||||
Args:
|
||||
device_ids: gpus/tpu_cores parameter as passed to the Trainer
|
||||
|
||||
Raises:
|
||||
MisconfigurationException:
|
||||
If ``device_ids`` of GPU/TPUs aren't ``int``, ``str``, sequence of ``int`` or ``None``
|
||||
"""
|
||||
msg = "Device IDs (GPU/TPU) must be an int, a string, a sequence of ints or None, but you passed"
|
||||
|
||||
if device_ids is None:
|
||||
return
|
||||
elif isinstance(device_ids, (MutableSequence, tuple)):
|
||||
for id_ in device_ids:
|
||||
if type(id_) is not int:
|
||||
raise MisconfigurationException(f"{msg} a sequence of {type(id_).__name__}.")
|
||||
elif type(device_ids) not in (int, str):
|
||||
raise MisconfigurationException(f"{msg} {type(device_ids).__name__}.")
|
||||
|
||||
|
||||
def _tpu_cores_valid(tpu_cores: Any) -> bool:
|
||||
# allow 1 or 8 cores
|
||||
if tpu_cores in (1, 8, None):
|
||||
return True
|
||||
|
||||
# allow picking 1 of 8 indexes
|
||||
if isinstance(tpu_cores, (list, tuple, set)):
|
||||
has_1_tpu_idx = len(tpu_cores) == 1
|
||||
is_valid_tpu_idx = 1 <= list(tpu_cores)[0] <= 8
|
||||
|
||||
is_valid_tpu_core_choice = has_1_tpu_idx and is_valid_tpu_idx
|
||||
return is_valid_tpu_core_choice
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _parse_tpu_cores_str(tpu_cores: str) -> Union[int, List[int]]:
|
||||
if tpu_cores in ("1", "8"):
|
||||
return int(tpu_cores)
|
||||
return [int(x.strip()) for x in tpu_cores.split(",") if len(x) > 0]
|
||||
|
||||
|
||||
def num_cuda_devices() -> int:
|
||||
"""Returns the number of GPUs available.
|
||||
|
||||
Unlike :func:`torch.cuda.device_count`, this function does its best not to create a CUDA context for fork support,
|
||||
if the platform allows it.
|
||||
"""
|
||||
if "fork" not in torch.multiprocessing.get_all_start_methods() or _is_forking_disabled():
|
||||
return torch.cuda.device_count()
|
||||
with multiprocessing.get_context("fork").Pool(1) as pool:
|
||||
return pool.apply(torch.cuda.device_count)
|
||||
|
||||
|
||||
def is_cuda_available() -> bool:
|
||||
"""Returns a bool indicating if CUDA is currently available.
|
||||
|
||||
Unlike :func:`torch.cuda.is_available`, this function does its best not to create a CUDA context for fork support,
|
||||
if the platform allows it.
|
||||
"""
|
||||
if "fork" not in torch.multiprocessing.get_all_start_methods() or _is_forking_disabled():
|
||||
return torch.cuda.is_available()
|
||||
with multiprocessing.get_context("fork").Pool(1) as pool:
|
||||
return pool.apply(torch.cuda.is_available)
|
||||
|
||||
|
||||
# TODO(lite): move this back to launchers/multiprocessing.py once launchers have moved
|
||||
def _is_forking_disabled() -> bool:
|
||||
"""Returns whether forking is disabled through the environment variable ``PL_DISABLE_FORK``."""
|
||||
return bool(int(os.environ.get("PL_DISABLE_FORK", "0")))
|
|
@ -0,0 +1,264 @@
|
|||
import logging
|
||||
import os
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.nn import functional as F
|
||||
|
||||
from lightning_lite.utilities.imports import _HPU_AVAILABLE, _TPU_AVAILABLE
|
||||
from lightning_lite.utilities.rank_zero import rank_zero_deprecation
|
||||
from lightning_lite.utilities.rank_zero import rank_zero_info as new_rank_zero_info
|
||||
|
||||
if _TPU_AVAILABLE:
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
|
||||
if torch.distributed.is_available():
|
||||
from torch.distributed import group, ReduceOp
|
||||
else:
|
||||
|
||||
class ReduceOp: # type: ignore # (see https://github.com/python/mypy/issues/1153)
|
||||
SUM = None
|
||||
|
||||
class group: # type: ignore
|
||||
WORLD = None
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def gather_all_tensors(result: Tensor, group: Optional[Any] = None) -> List[Tensor]:
|
||||
"""Function to gather all tensors from several DDP processes onto a list that is broadcasted to all processes.
|
||||
|
||||
Works on tensors that have the same number of dimensions, but where each dimension may differ. In this case
|
||||
tensors are padded, gathered and then trimmed to secure equal workload for all processes.
|
||||
|
||||
Args:
|
||||
result: The value to sync
|
||||
group: The process group to gather results from. Defaults to all processes (world)
|
||||
|
||||
Return:
|
||||
gathered_result: List with size equal to the process group where
|
||||
gathered_result[i] corresponds to result tensor from process i
|
||||
"""
|
||||
if group is None:
|
||||
group = torch.distributed.group.WORLD
|
||||
|
||||
# Convert tensors to contiguous format
|
||||
result = result.contiguous()
|
||||
|
||||
world_size = torch.distributed.get_world_size(group)
|
||||
torch.distributed.barrier(group=group)
|
||||
|
||||
# If the tensor is scalar, things are easy
|
||||
if result.ndim == 0:
|
||||
return _simple_gather_all_tensors(result, group, world_size)
|
||||
|
||||
# 1. Gather sizes of all tensors
|
||||
local_size = torch.tensor(result.shape, device=result.device)
|
||||
local_sizes = [torch.zeros_like(local_size) for _ in range(world_size)]
|
||||
torch.distributed.all_gather(local_sizes, local_size, group=group)
|
||||
max_size = torch.stack(local_sizes).max(dim=0).values
|
||||
all_sizes_equal = all(all(ls == max_size) for ls in local_sizes)
|
||||
|
||||
# 2. If shapes are all the same, then do a simple gather:
|
||||
if all_sizes_equal:
|
||||
return _simple_gather_all_tensors(result, group, world_size)
|
||||
|
||||
# 3. If not, we need to pad each local tensor to maximum size, gather and then truncate
|
||||
pad_dims = []
|
||||
pad_by = (max_size - local_size).detach().cpu()
|
||||
for val in reversed(pad_by):
|
||||
pad_dims.append(0)
|
||||
pad_dims.append(val.item())
|
||||
result_padded = F.pad(result, pad_dims)
|
||||
gathered_result = [torch.zeros_like(result_padded) for _ in range(world_size)]
|
||||
torch.distributed.all_gather(gathered_result, result_padded, group)
|
||||
for idx, item_size in enumerate(local_sizes):
|
||||
slice_param = [slice(dim_size) for dim_size in item_size]
|
||||
gathered_result[idx] = gathered_result[idx][slice_param]
|
||||
return gathered_result
|
||||
|
||||
|
||||
def _simple_gather_all_tensors(result: Tensor, group: Any, world_size: int) -> List[Tensor]:
|
||||
gathered_result = [torch.zeros_like(result) for _ in range(world_size)]
|
||||
torch.distributed.all_gather(gathered_result, result, group)
|
||||
return gathered_result
|
||||
|
||||
|
||||
def distributed_available() -> bool:
|
||||
return torch.distributed.is_available() and torch.distributed.is_initialized() or tpu_distributed()
|
||||
|
||||
|
||||
def sync_ddp_if_available(
|
||||
result: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None
|
||||
) -> Tensor:
|
||||
"""Function to reduce a tensor across worker processes during distributed training.
|
||||
|
||||
Args:
|
||||
result: The value to sync and reduce (typically tensor or number)
|
||||
group: The process group to gather results from. Defaults to all processes (world)
|
||||
reduce_op: The reduction operation. Defaults to sum.
|
||||
Can also be a string of 'avg', 'mean' to calculate the mean during reduction.
|
||||
|
||||
Return:
|
||||
reduced value
|
||||
"""
|
||||
if distributed_available():
|
||||
return sync_ddp(result, group=group, reduce_op=reduce_op)
|
||||
return result
|
||||
|
||||
|
||||
def sync_ddp(result: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None) -> Tensor:
|
||||
"""Function to reduce the tensors from several DDP processes to one main process.
|
||||
|
||||
Args:
|
||||
result: The value to sync and reduce (typically tensor or number)
|
||||
group: The process group to gather results from. Defaults to all processes (world)
|
||||
reduce_op: The reduction operation. Defaults to sum.
|
||||
Can also be a string of 'avg', 'mean' to calculate the mean during reduction.
|
||||
|
||||
Return:
|
||||
reduced value
|
||||
"""
|
||||
divide_by_world_size = False
|
||||
|
||||
if group is None:
|
||||
group = torch.distributed.group.WORLD
|
||||
|
||||
op: Optional[ReduceOp]
|
||||
if isinstance(reduce_op, str):
|
||||
if reduce_op.lower() in ("avg", "mean"):
|
||||
op = ReduceOp.SUM
|
||||
divide_by_world_size = True
|
||||
else:
|
||||
op = getattr(ReduceOp, reduce_op.upper())
|
||||
else:
|
||||
op = reduce_op
|
||||
|
||||
# WA for HPU. HPU doesn't support Long types, forcefully set it to float
|
||||
if _HPU_AVAILABLE:
|
||||
is_hpu_backend = os.environ.get("HCCL_DISTRIBUTED_BACKEND") == "1"
|
||||
if is_hpu_backend:
|
||||
if (result.type() == "torch.LongTensor") or (result.type() == "torch.hpu.LongTensor"):
|
||||
new_rank_zero_info("Long tensor unsupported on HPU, casting to float")
|
||||
result = result.float()
|
||||
|
||||
# Sync all processes before reduction
|
||||
torch.distributed.barrier(group=group)
|
||||
torch.distributed.all_reduce(result, op=op, group=group, async_op=False)
|
||||
|
||||
if divide_by_world_size:
|
||||
result = result / torch.distributed.get_world_size(group)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class AllGatherGrad(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward( # type: ignore[override]
|
||||
ctx: Any,
|
||||
tensor: Tensor,
|
||||
group: Optional["torch.distributed.ProcessGroup"] = group.WORLD,
|
||||
) -> Tensor:
|
||||
ctx.group = group
|
||||
|
||||
gathered_tensor = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())]
|
||||
|
||||
torch.distributed.all_gather(gathered_tensor, tensor, group=group)
|
||||
gathered_tensor = torch.stack(gathered_tensor, dim=0)
|
||||
|
||||
return gathered_tensor
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Any, *grad_output: Tensor) -> Tuple[Tensor, None]:
|
||||
grad_output = torch.cat(grad_output)
|
||||
|
||||
torch.distributed.all_reduce(grad_output, op=torch.distributed.ReduceOp.SUM, async_op=False, group=ctx.group)
|
||||
|
||||
return grad_output[torch.distributed.get_rank()], None
|
||||
|
||||
|
||||
def all_gather_ddp_if_available(
|
||||
tensor: Tensor, group: Optional["torch.distributed.ProcessGroup"] = None, sync_grads: bool = False
|
||||
) -> Tensor:
|
||||
"""Function to gather a tensor from several distributed processes.
|
||||
|
||||
Args:
|
||||
tensor: Tensor of shape (batch, ...)
|
||||
group: The process group to gather results from. Defaults to all processes (world)
|
||||
sync_grads: Flag that allows users to synchronize gradients for all_gather op
|
||||
|
||||
Return:
|
||||
A tensor of shape (world_size, batch, ...)
|
||||
"""
|
||||
group = group if group is not None else torch.distributed.group.WORLD
|
||||
if distributed_available():
|
||||
if sync_grads:
|
||||
return AllGatherGrad.apply(tensor, group)
|
||||
with torch.no_grad():
|
||||
return AllGatherGrad.apply(tensor, group)
|
||||
return tensor
|
||||
|
||||
|
||||
def init_dist_connection(
|
||||
# TODO(lite): Fix this type error
|
||||
cluster_environment: "ClusterEnvironment", # type: ignore[name-defined] # noqa: F821
|
||||
torch_distributed_backend: str,
|
||||
global_rank: Optional[int] = None,
|
||||
world_size: Optional[int] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Utility function to initialize distributed connection by setting env variables and initializing the
|
||||
distributed process group.
|
||||
|
||||
Args:
|
||||
cluster_environment: ``ClusterEnvironment`` instance
|
||||
torch_distributed_backend: Backend to use (includes `nccl` and `gloo`)
|
||||
global_rank: Rank of the current process
|
||||
world_size: Number of processes in the group
|
||||
kwargs: Kwargs for ``init_process_group``
|
||||
|
||||
Raises:
|
||||
RuntimeError:
|
||||
If ``torch.distributed`` is not available
|
||||
"""
|
||||
if not torch.distributed.is_available():
|
||||
raise RuntimeError("torch.distributed is not available. Cannot initialize distributed process group")
|
||||
if torch.distributed.is_initialized():
|
||||
log.debug("torch.distributed is already initialized. Exiting early")
|
||||
return
|
||||
global_rank = global_rank if global_rank is not None else cluster_environment.global_rank()
|
||||
world_size = world_size if world_size is not None else cluster_environment.world_size()
|
||||
os.environ["MASTER_ADDR"] = cluster_environment.main_address
|
||||
os.environ["MASTER_PORT"] = str(cluster_environment.main_port)
|
||||
log.info(f"Initializing distributed: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}")
|
||||
torch.distributed.init_process_group(torch_distributed_backend, rank=global_rank, world_size=world_size, **kwargs)
|
||||
|
||||
# On rank=0 let everyone know training is starting
|
||||
new_rank_zero_info(
|
||||
f"{'-' * 100}\n"
|
||||
f"distributed_backend={torch_distributed_backend}\n"
|
||||
f"All distributed processes registered. Starting with {world_size} processes\n"
|
||||
f"{'-' * 100}\n"
|
||||
)
|
||||
|
||||
|
||||
def tpu_distributed() -> bool:
|
||||
return _TPU_AVAILABLE and xm.xrt_world_size() > 1
|
||||
|
||||
|
||||
def get_default_process_group_backend_for_device(device: torch.device) -> str:
|
||||
return "nccl" if device.type == "cuda" else "gloo"
|
||||
|
||||
|
||||
def _get_process_group_backend_from_env() -> Optional[str]:
|
||||
torch_backend = os.getenv("PL_TORCH_DISTRIBUTED_BACKEND")
|
||||
if torch_backend is not None:
|
||||
rank_zero_deprecation(
|
||||
"Environment variable `PL_TORCH_DISTRIBUTED_BACKEND`"
|
||||
" was deprecated in v1.6 and will be removed in v1.8."
|
||||
" Specify `process_group_backend` directly on the strategy constructor."
|
||||
)
|
||||
return torch_backend
|
|
@ -0,0 +1,95 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Enumerated utilities."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from lightning_utilities.core.enums import StrEnum
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from enum import Enum
|
||||
|
||||
# re-defined because `mypy` infers `StrEnum` as `Any`
|
||||
class LightningEnum(StrEnum, Enum):
|
||||
...
|
||||
|
||||
else:
|
||||
LightningEnum = StrEnum
|
||||
|
||||
|
||||
class AMPType(LightningEnum):
|
||||
"""Type of Automatic Mixed Precission used for training."""
|
||||
|
||||
APEX = "apex"
|
||||
NATIVE = "native"
|
||||
|
||||
|
||||
class PrecisionType(LightningEnum):
|
||||
"""Type of precision used."""
|
||||
|
||||
HALF = "16"
|
||||
FLOAT = "32"
|
||||
FULL = "64"
|
||||
BFLOAT = "bf16"
|
||||
MIXED = "mixed"
|
||||
|
||||
@staticmethod
|
||||
def supported_type(precision: str | int) -> bool:
|
||||
return any(x == precision for x in PrecisionType)
|
||||
|
||||
@staticmethod
|
||||
def supported_types() -> list[str]:
|
||||
return [x.value for x in PrecisionType]
|
||||
|
||||
|
||||
class _StrategyType(LightningEnum):
|
||||
"""Define type of training strategy."""
|
||||
|
||||
DP = "dp"
|
||||
DDP = "ddp"
|
||||
DDP_SPAWN = "ddp_spawn"
|
||||
DDP_FORK = "ddp_fork"
|
||||
TPU_SPAWN = "tpu_spawn"
|
||||
DEEPSPEED = "deepspeed"
|
||||
HOROVOD = "horovod"
|
||||
DDP_SHARDED = "ddp_sharded"
|
||||
DDP_SHARDED_SPAWN = "ddp_sharded_spawn"
|
||||
DDP_FULLY_SHARDED = "ddp_fully_sharded"
|
||||
BAGUA = "bagua"
|
||||
HPU_PARALLEL = "hpu_parallel"
|
||||
|
||||
@staticmethod
|
||||
def interactive_compatible_types() -> list[_StrategyType]:
|
||||
"""Returns a list containing interactive compatible _StrategyTypes."""
|
||||
return [
|
||||
_StrategyType.DP,
|
||||
_StrategyType.TPU_SPAWN,
|
||||
_StrategyType.DDP_FORK,
|
||||
]
|
||||
|
||||
def is_interactive_compatible(self) -> bool:
|
||||
"""Returns whether self is interactive compatible."""
|
||||
return self in _StrategyType.interactive_compatible_types()
|
||||
|
||||
|
||||
class _AcceleratorType(LightningEnum):
|
||||
"""Define Accelerator type by its nature."""
|
||||
|
||||
CPU = "CPU"
|
||||
CUDA = "CUDA"
|
||||
IPU = "IPU"
|
||||
TPU = "TPU"
|
||||
HPU = "HPU"
|
||||
MPS = "MPS"
|
|
@ -0,0 +1,17 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
class MisconfigurationException(Exception):
|
||||
"""Exception used to inform users of misuse with Lightning."""
|
|
@ -0,0 +1,61 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""General utilities."""
|
||||
import operator
|
||||
import platform
|
||||
import sys
|
||||
|
||||
from lightning_utilities.core.imports import compare_version, module_available, package_available
|
||||
|
||||
_IS_WINDOWS = platform.system() == "Windows"
|
||||
_IS_INTERACTIVE = hasattr(sys, "ps1") # https://stackoverflow.com/a/64523765
|
||||
_PYTHON_GREATER_EQUAL_3_8_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 8)
|
||||
_PYTHON_GREATER_EQUAL_3_10_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 10)
|
||||
_TORCH_GREATER_EQUAL_1_9_1 = compare_version("torch", operator.ge, "1.9.1")
|
||||
_TORCH_GREATER_EQUAL_1_10 = compare_version("torch", operator.ge, "1.10.0")
|
||||
_TORCH_LESSER_EQUAL_1_10_2 = compare_version("torch", operator.le, "1.10.2")
|
||||
_TORCH_GREATER_EQUAL_1_11 = compare_version("torch", operator.ge, "1.11.0")
|
||||
_TORCH_GREATER_EQUAL_1_12 = compare_version("torch", operator.ge, "1.12.0")
|
||||
_TORCH_GREATER_EQUAL_1_13 = compare_version("torch", operator.ge, "1.13.0", use_base_version=True)
|
||||
|
||||
_APEX_AVAILABLE = module_available("apex.amp")
|
||||
_HABANA_FRAMEWORK_AVAILABLE = package_available("habana_frameworks")
|
||||
_HIVEMIND_AVAILABLE = package_available("hivemind")
|
||||
_HOROVOD_AVAILABLE = module_available("horovod.torch")
|
||||
_OMEGACONF_AVAILABLE = package_available("omegaconf")
|
||||
_POPTORCH_AVAILABLE = package_available("poptorch")
|
||||
_PSUTIL_AVAILABLE = package_available("psutil")
|
||||
_XLA_AVAILABLE: bool = package_available("torch_xla")
|
||||
|
||||
# TODO(lite): import this from the fairscale files once they move to lite package
|
||||
_FAIRSCALE_AVAILABLE = not _IS_WINDOWS and module_available("fairscale.nn")
|
||||
|
||||
|
||||
from lightning_lite.utilities.xla_device import XLADeviceUtils # noqa: E402
|
||||
|
||||
_TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists()
|
||||
|
||||
if _POPTORCH_AVAILABLE:
|
||||
import poptorch
|
||||
|
||||
_IPU_AVAILABLE = poptorch.ipuHardwareIsAvailable()
|
||||
else:
|
||||
_IPU_AVAILABLE = False
|
||||
|
||||
if _HABANA_FRAMEWORK_AVAILABLE:
|
||||
from habana_frameworks.torch.utils.library_loader import is_habana_avaialble
|
||||
|
||||
_HPU_AVAILABLE = is_habana_avaialble()
|
||||
else:
|
||||
_HPU_AVAILABLE = False
|
|
@ -0,0 +1,34 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Iterable
|
||||
|
||||
from lightning_utilities.core.apply_func import apply_to_collection
|
||||
from torch import Tensor
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from lightning_lite.utilities.apply_func import move_data_to_device
|
||||
from lightning_lite.utilities.types import _DEVICE
|
||||
|
||||
|
||||
def optimizers_to_device(optimizers: Iterable[Optimizer], device: _DEVICE) -> None:
|
||||
"""Moves optimizer states for a sequence of optimizers to the device."""
|
||||
for opt in optimizers:
|
||||
optimizer_to_device(opt, device)
|
||||
|
||||
|
||||
def optimizer_to_device(optimizer: Optimizer, device: _DEVICE) -> None:
|
||||
"""Moves the state of a single optimizer to the device."""
|
||||
for p, v in optimizer.state.items():
|
||||
optimizer.state[p] = apply_to_collection(v, Tensor, move_data_to_device, device)
|
|
@ -0,0 +1,60 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Utilities that can be used for calling functions on a particular rank."""
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import lightning_utilities.core.rank_zero as rank_zero_module
|
||||
|
||||
# note: we want to keep these indirections so the `rank_zero_only.rank` is set on import
|
||||
from lightning_utilities.core.rank_zero import ( # noqa: F401
|
||||
rank_zero_debug,
|
||||
rank_zero_deprecation,
|
||||
rank_zero_info,
|
||||
rank_zero_only,
|
||||
rank_zero_warn,
|
||||
)
|
||||
|
||||
import lightning_lite
|
||||
|
||||
rank_zero_module.log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_rank(
|
||||
strategy: Optional["lightning_lite.strategies.Strategy"] = None, # type: ignore[name-defined]
|
||||
) -> Optional[int]:
|
||||
if strategy is not None:
|
||||
return strategy.global_rank
|
||||
# SLURM_PROCID can be set even if SLURM is not managing the multiprocessing,
|
||||
# therefore LOCAL_RANK needs to be checked first
|
||||
rank_keys = ("RANK", "LOCAL_RANK", "SLURM_PROCID", "JSM_NAMESPACE_RANK")
|
||||
for key in rank_keys:
|
||||
rank = os.environ.get(key)
|
||||
if rank is not None:
|
||||
return int(rank)
|
||||
# None to differentiate whether an environment variable was set at all
|
||||
return None
|
||||
|
||||
|
||||
# add the attribute to the function but don't overwrite in case Trainer has already set it
|
||||
rank_zero_only.rank = getattr(rank_zero_only, "rank", _get_rank() or 0)
|
||||
|
||||
|
||||
class LightningDeprecationWarning(DeprecationWarning):
|
||||
"""Deprecation warnings raised by Lightning."""
|
||||
|
||||
|
||||
rank_zero_module.rank_zero_deprecation_category = LightningDeprecationWarning
|
|
@ -0,0 +1,127 @@
|
|||
import logging
|
||||
import os
|
||||
import random
|
||||
from random import getstate as python_get_rng_state
|
||||
from random import setstate as python_set_rng_state
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from lightning_utilities.core.rank_zero import rank_prefixed_message
|
||||
|
||||
from lightning_lite.utilities.rank_zero import _get_rank, rank_zero_only, rank_zero_warn
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
max_seed_value = np.iinfo(np.uint32).max
|
||||
min_seed_value = np.iinfo(np.uint32).min
|
||||
|
||||
|
||||
def seed_everything(seed: Optional[int] = None, workers: bool = False) -> int:
|
||||
"""Function that sets seed for pseudo-random number generators in: pytorch, numpy, python.random In addition,
|
||||
sets the following environment variables:
|
||||
|
||||
- `PL_GLOBAL_SEED`: will be passed to spawned subprocesses (e.g. ddp_spawn backend).
|
||||
- `PL_SEED_WORKERS`: (optional) is set to 1 if ``workers=True``.
|
||||
|
||||
Args:
|
||||
seed: the integer value seed for global random state in Lightning.
|
||||
If `None`, will read seed from `PL_GLOBAL_SEED` env variable
|
||||
or select it randomly.
|
||||
workers: if set to ``True``, will properly configure all dataloaders passed to the
|
||||
Trainer with a ``worker_init_fn``. If the user already provides such a function
|
||||
for their dataloaders, setting this argument will have no influence. See also:
|
||||
:func:`~lightning_lite.utilities.seed.pl_worker_init_function`.
|
||||
"""
|
||||
if seed is None:
|
||||
env_seed = os.environ.get("PL_GLOBAL_SEED")
|
||||
if env_seed is None:
|
||||
seed = _select_seed_randomly(min_seed_value, max_seed_value)
|
||||
rank_zero_warn(f"No seed found, seed set to {seed}")
|
||||
else:
|
||||
try:
|
||||
seed = int(env_seed)
|
||||
except ValueError:
|
||||
seed = _select_seed_randomly(min_seed_value, max_seed_value)
|
||||
rank_zero_warn(f"Invalid seed found: {repr(env_seed)}, seed set to {seed}")
|
||||
elif not isinstance(seed, int):
|
||||
seed = int(seed)
|
||||
|
||||
if not (min_seed_value <= seed <= max_seed_value):
|
||||
rank_zero_warn(f"{seed} is not in bounds, numpy accepts from {min_seed_value} to {max_seed_value}")
|
||||
seed = _select_seed_randomly(min_seed_value, max_seed_value)
|
||||
|
||||
log.info(rank_prefixed_message(f"Global seed set to {seed}", _get_rank()))
|
||||
os.environ["PL_GLOBAL_SEED"] = str(seed)
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
os.environ["PL_SEED_WORKERS"] = f"{int(workers)}"
|
||||
|
||||
return seed
|
||||
|
||||
|
||||
def _select_seed_randomly(min_seed_value: int = min_seed_value, max_seed_value: int = max_seed_value) -> int:
|
||||
return random.randint(min_seed_value, max_seed_value)
|
||||
|
||||
|
||||
def reset_seed() -> None:
|
||||
"""Reset the seed to the value that :func:`lightning_lite.utilities.seed.seed_everything` previously set.
|
||||
|
||||
If :func:`lightning_lite.utilities.seed.seed_everything` is unused, this function will do nothing.
|
||||
"""
|
||||
seed = os.environ.get("PL_GLOBAL_SEED", None)
|
||||
if seed is None:
|
||||
return
|
||||
workers = os.environ.get("PL_SEED_WORKERS", "0")
|
||||
seed_everything(int(seed), workers=bool(int(workers)))
|
||||
|
||||
|
||||
def pl_worker_init_function(worker_id: int, rank: Optional[int] = None) -> None: # pragma: no cover
|
||||
"""The worker_init_fn that Lightning automatically adds to your dataloader if you previously set the seed with
|
||||
``seed_everything(seed, workers=True)``.
|
||||
|
||||
See also the PyTorch documentation on
|
||||
`randomness in DataLoaders <https://pytorch.org/docs/stable/notes/randomness.html#dataloader>`_.
|
||||
"""
|
||||
# implementation notes: https://github.com/pytorch/pytorch/issues/5059#issuecomment-817392562
|
||||
global_rank = rank if rank is not None else rank_zero_only.rank
|
||||
process_seed = torch.initial_seed()
|
||||
# back out the base seed so we can use all the bits
|
||||
base_seed = process_seed - worker_id
|
||||
log.debug(
|
||||
f"Initializing random number generators of process {global_rank} worker {worker_id} with base seed {base_seed}"
|
||||
)
|
||||
ss = np.random.SeedSequence([base_seed, worker_id, global_rank])
|
||||
# use 128 bits (4 x 32-bit words)
|
||||
np.random.seed(ss.generate_state(4))
|
||||
# Spawn distinct SeedSequences for the PyTorch PRNG and the stdlib random module
|
||||
torch_ss, stdlib_ss = ss.spawn(2)
|
||||
torch.manual_seed(torch_ss.generate_state(1, dtype=np.uint64)[0])
|
||||
# use 128 bits expressed as an integer
|
||||
stdlib_seed = (stdlib_ss.generate_state(2, dtype=np.uint64).astype(object) * [1 << 64, 1]).sum()
|
||||
random.seed(stdlib_seed)
|
||||
|
||||
|
||||
def _collect_rng_states() -> Dict[str, Any]:
|
||||
"""Collect the global random state of :mod:`torch`, :mod:`torch.cuda`, :mod:`numpy` and Python."""
|
||||
return {
|
||||
"torch": torch.get_rng_state(),
|
||||
"torch.cuda": torch.cuda.get_rng_state_all(),
|
||||
"numpy": np.random.get_state(),
|
||||
"python": python_get_rng_state(),
|
||||
}
|
||||
|
||||
|
||||
def _set_rng_states(rng_state_dict: Dict[str, Any]) -> None:
|
||||
"""Set the global random state of :mod:`torch`, :mod:`torch.cuda`, :mod:`numpy` and Python in the current
|
||||
process."""
|
||||
torch.set_rng_state(rng_state_dict["torch"])
|
||||
# torch.cuda rng_state is only included since v1.8.
|
||||
if "torch.cuda" in rng_state_dict:
|
||||
torch.cuda.set_rng_state_all(rng_state_dict["torch.cuda"])
|
||||
np.random.set_state(rng_state_dict["numpy"])
|
||||
version, state, gauss = rng_state_dict["python"]
|
||||
python_set_rng_state((version, tuple(state), gauss))
|
|
@ -11,8 +11,69 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Union
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Iterator, List, Optional, TypeVar, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.optim import Optimizer
|
||||
from typing_extensions import Protocol, runtime_checkable
|
||||
|
||||
_PATH = Union[str, Path]
|
||||
_DEVICE = Union[torch.device, str, int]
|
||||
_MAP_LOCATION_TYPE = Optional[Union[_DEVICE, Callable[[_DEVICE], _DEVICE], Dict[_DEVICE, _DEVICE]]]
|
||||
_PARAMETERS = Iterator[torch.nn.Parameter]
|
||||
|
||||
|
||||
_DictKey = TypeVar("_DictKey")
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class _Stateful(Protocol[_DictKey]):
|
||||
"""This class is used to detect if an object is stateful using `isinstance(obj, _Stateful)`."""
|
||||
|
||||
def state_dict(self) -> Dict[_DictKey, Any]:
|
||||
...
|
||||
|
||||
def load_state_dict(self, state_dict: Dict[_DictKey, Any]) -> None:
|
||||
...
|
||||
|
||||
|
||||
# Inferred from `torch.optim.lr_scheduler.pyi`
|
||||
# Missing attributes were added to improve typing
|
||||
@runtime_checkable
|
||||
class _LRScheduler(_Stateful[str], Protocol):
|
||||
optimizer: Optimizer
|
||||
base_lrs: List[float]
|
||||
|
||||
def __init__(self, optimizer: Optimizer, *args: Any, **kwargs: Any) -> None:
|
||||
...
|
||||
|
||||
def step(self, epoch: Optional[int] = None) -> None:
|
||||
...
|
||||
|
||||
|
||||
# Inferred from `torch.optim.lr_scheduler.pyi`
|
||||
# Missing attributes were added to improve typing
|
||||
@runtime_checkable
|
||||
class ReduceLROnPlateau(_Stateful[str], Protocol):
|
||||
in_cooldown: bool
|
||||
optimizer: Optimizer
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
optimizer: Optimizer,
|
||||
mode: str = ...,
|
||||
factor: float = ...,
|
||||
patience: int = ...,
|
||||
verbose: bool = ...,
|
||||
threshold: float = ...,
|
||||
threshold_mode: str = ...,
|
||||
cooldown: int = ...,
|
||||
min_lr: float = ...,
|
||||
eps: float = ...,
|
||||
) -> None:
|
||||
...
|
||||
|
||||
def step(self, metrics: Union[float, int, Tensor], epoch: Optional[int] = None) -> None:
|
||||
...
|
||||
|
|
|
@ -0,0 +1,24 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Warning-related utilities."""
|
||||
import warnings
|
||||
|
||||
from lightning_lite.utilities.rank_zero import LightningDeprecationWarning
|
||||
|
||||
# enable our warnings
|
||||
warnings.simplefilter("default", category=LightningDeprecationWarning)
|
||||
|
||||
|
||||
class PossibleUserWarning(UserWarning):
|
||||
"""Warnings that could be false positives."""
|
|
@ -18,7 +18,7 @@ import traceback
|
|||
from multiprocessing import Process, Queue
|
||||
from typing import Any, Callable, Union
|
||||
|
||||
from pytorch_lightning.utilities.imports import _XLA_AVAILABLE
|
||||
from lightning_lite.utilities.imports import _XLA_AVAILABLE
|
||||
|
||||
if _XLA_AVAILABLE:
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
|
|
@ -31,10 +31,10 @@ if not _root_logger.hasHandlers():
|
|||
_logger.addHandler(logging.StreamHandler())
|
||||
_logger.propagate = False
|
||||
|
||||
from lightning_lite.utilities.seed import seed_everything # noqa: E402
|
||||
from pytorch_lightning.callbacks import Callback # noqa: E402
|
||||
from pytorch_lightning.core import LightningDataModule, LightningModule # noqa: E402
|
||||
from pytorch_lightning.trainer import Trainer # noqa: E402
|
||||
from pytorch_lightning.utilities.seed import seed_everything # noqa: E402
|
||||
|
||||
__all__ = ["Trainer", "LightningDataModule", "LightningModule", "Callback", "seed_everything"]
|
||||
|
||||
|
|
|
@ -15,11 +15,11 @@ from typing import Any, Dict, List, Union
|
|||
|
||||
import torch
|
||||
|
||||
from lightning_lite.utilities.device_parser import parse_cpu_cores
|
||||
from lightning_lite.utilities.types import _DEVICE
|
||||
from pytorch_lightning.accelerators.accelerator import Accelerator
|
||||
from pytorch_lightning.utilities.device_parser import parse_cpu_cores
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _PSUTIL_AVAILABLE
|
||||
from pytorch_lightning.utilities.types import _DEVICE
|
||||
|
||||
|
||||
class CPUAccelerator(Accelerator):
|
||||
|
|
|
@ -20,10 +20,10 @@ from typing import Any, Dict, List, Optional, Union
|
|||
import torch
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from lightning_lite.utilities import device_parser
|
||||
from lightning_lite.utilities.types import _DEVICE
|
||||
from pytorch_lightning.accelerators.accelerator import Accelerator
|
||||
from pytorch_lightning.utilities import device_parser
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.types import _DEVICE
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
|
|
|
@ -16,11 +16,11 @@ from typing import Any, Dict, List, Optional, Union
|
|||
|
||||
import torch
|
||||
|
||||
from lightning_lite.utilities import device_parser
|
||||
from lightning_lite.utilities.types import _DEVICE
|
||||
from pytorch_lightning.accelerators.accelerator import Accelerator
|
||||
from pytorch_lightning.utilities import device_parser
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _PSUTIL_AVAILABLE, _TORCH_GREATER_EQUAL_1_12
|
||||
from pytorch_lightning.utilities.types import _DEVICE
|
||||
|
||||
# For using the `MPSAccelerator`, user's machine should have `torch>=1.12`, Metal programming framework and
|
||||
# the ARM-based Apple Silicon processors.
|
||||
|
|
|
@ -15,9 +15,9 @@ import importlib
|
|||
from inspect import getmembers, isclass
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
from lightning_lite.utilities.registry import _is_register_method_overridden
|
||||
from pytorch_lightning.accelerators.accelerator import Accelerator
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.registry import _is_register_method_overridden
|
||||
|
||||
|
||||
class _AcceleratorRegistry(dict):
|
||||
|
|
|
@ -15,8 +15,8 @@ from typing import Any, Dict, List, Optional, Union
|
|||
|
||||
import torch
|
||||
|
||||
from lightning_lite.utilities import device_parser
|
||||
from pytorch_lightning.accelerators.accelerator import Accelerator
|
||||
from pytorch_lightning.utilities import device_parser
|
||||
from pytorch_lightning.utilities.imports import _TPU_AVAILABLE, _XLA_AVAILABLE
|
||||
|
||||
if _XLA_AVAILABLE:
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
from typing import Any
|
||||
|
||||
from pytorch_lightning.callbacks.callback import Callback as NewCallback
|
||||
from pytorch_lightning.utilities import rank_zero_deprecation
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation
|
||||
|
||||
|
||||
class Callback(NewCallback):
|
||||
|
|
|
@ -27,9 +27,10 @@ from lightning_utilities.core.rank_zero import rank_prefixed_message
|
|||
from torch import Tensor
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from lightning_lite.utilities.rank_zero import _get_rank
|
||||
from pytorch_lightning.callbacks.callback import Callback
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.rank_zero import _get_rank, rank_zero_warn
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
|
|
@ -21,8 +21,8 @@ import os
|
|||
from typing import Any
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from lightning_lite.utilities.types import _PATH
|
||||
from pytorch_lightning.callbacks import Checkpoint
|
||||
from pytorch_lightning.utilities.types import _PATH
|
||||
|
||||
|
||||
class _FaultToleranceCheckpoint(Checkpoint):
|
||||
|
|
|
@ -36,10 +36,11 @@ from torch import Tensor
|
|||
|
||||
import pytorch_lightning as pl
|
||||
from lightning_lite.utilities.cloud_io import get_filesystem
|
||||
from lightning_lite.utilities.types import _PATH
|
||||
from pytorch_lightning.callbacks import Checkpoint
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info, rank_zero_warn
|
||||
from pytorch_lightning.utilities.types import _PATH, STEP_OUTPUT
|
||||
from pytorch_lightning.utilities.types import STEP_OUTPUT
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
warning_cache = WarningCache()
|
||||
|
|
|
@ -23,12 +23,13 @@ from torch import nn, Tensor
|
|||
from torch.optim.swa_utils import SWALR
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from lightning_lite.utilities.types import _LRScheduler
|
||||
from pytorch_lightning.callbacks.callback import Callback
|
||||
from pytorch_lightning.strategies import DDPFullyShardedStrategy, DeepSpeedStrategy
|
||||
from pytorch_lightning.strategies.fully_sharded_native import DDPFullyShardedNativeStrategy
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_warn
|
||||
from pytorch_lightning.utilities.types import _LRScheduler, LRSchedulerConfig
|
||||
from pytorch_lightning.utilities.types import LRSchedulerConfig
|
||||
|
||||
_AVG_FN = Callable[[Tensor, Tensor, Tensor], Tensor]
|
||||
|
||||
|
|
|
@ -287,7 +287,7 @@ class LightningCLI:
|
|||
this argument will not be configurable from a configuration file and will always be present for
|
||||
this particular CLI. Alternatively, configurable callbacks can be added as explained in
|
||||
:ref:`the CLI docs <lightning-cli>`.
|
||||
seed_everything_default: Value for the :func:`~pytorch_lightning.utilities.seed.seed_everything`
|
||||
seed_everything_default: Value for the :func:`~lightning_lite.utilities.seed.seed_everything`
|
||||
seed argument. Set to True to automatically choose a valid seed.
|
||||
Setting it to False will not call seed_everything.
|
||||
description: Description of the tool shown when running ``--help``.
|
||||
|
|
|
@ -19,6 +19,7 @@ from typing import Any, Dict, IO, List, Mapping, Optional, Sequence, Tuple, Unio
|
|||
from torch.utils.data import DataLoader, Dataset, IterableDataset
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from lightning_lite.utilities.types import _PATH
|
||||
from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks
|
||||
from pytorch_lightning.core.mixins import HyperparametersMixin
|
||||
from pytorch_lightning.core.saving import _load_from_checkpoint
|
||||
|
@ -28,7 +29,7 @@ from pytorch_lightning.utilities.argparse import (
|
|||
get_init_arguments_and_types,
|
||||
parse_argparser,
|
||||
)
|
||||
from pytorch_lightning.utilities.types import _ADD_ARGPARSE_RETURN, _PATH, EVAL_DATALOADERS, TRAIN_DATALOADERS
|
||||
from pytorch_lightning.utilities.types import _ADD_ARGPARSE_RETURN, EVAL_DATALOADERS, TRAIN_DATALOADERS
|
||||
|
||||
|
||||
class LightningDataModule(CheckpointHooks, DataHooks, HyperparametersMixin):
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
from typing import Any
|
||||
|
||||
from pytorch_lightning.core.module import LightningModule as NewLightningModule
|
||||
from pytorch_lightning.utilities import rank_zero_deprecation
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation
|
||||
|
||||
|
||||
class LightningModule(NewLightningModule):
|
||||
|
|
|
@ -37,6 +37,7 @@ import pytorch_lightning as pl
|
|||
from lightning_lite.utilities.apply_func import convert_to_tensors
|
||||
from lightning_lite.utilities.cloud_io import get_filesystem
|
||||
from lightning_lite.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
|
||||
from lightning_lite.utilities.distributed import distributed_available, sync_ddp
|
||||
from pytorch_lightning.callbacks.callback import Callback
|
||||
from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks, ModelHooks
|
||||
from pytorch_lightning.core.mixins import HyperparametersMixin
|
||||
|
@ -45,7 +46,6 @@ from pytorch_lightning.core.saving import ModelIO
|
|||
from pytorch_lightning.loggers import Logger, LoggerCollection
|
||||
from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import _FxValidator
|
||||
from pytorch_lightning.utilities import _IS_WINDOWS, _TORCH_GREATER_EQUAL_1_10, GradClipAlgorithmType
|
||||
from pytorch_lightning.utilities.distributed import distributed_available, sync_ddp
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_11, _TORCH_GREATER_EQUAL_1_13
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_debug, rank_zero_deprecation, rank_zero_warn
|
||||
|
|
|
@ -21,10 +21,11 @@ from torch import optim
|
|||
from torch.optim import Optimizer
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from lightning_lite.utilities.types import _Stateful, ReduceLROnPlateau
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
|
||||
from pytorch_lightning.utilities.types import _Stateful, LRSchedulerConfig, LRSchedulerTypeTuple, ReduceLROnPlateau
|
||||
from pytorch_lightning.utilities.types import LRSchedulerConfig, LRSchedulerTypeTuple
|
||||
|
||||
|
||||
def do_nothing_closure() -> None:
|
||||
|
|
|
@ -29,11 +29,11 @@ from lightning_utilities.core.apply_func import apply_to_collection
|
|||
import pytorch_lightning as pl
|
||||
from lightning_lite.utilities.cloud_io import get_filesystem
|
||||
from lightning_lite.utilities.cloud_io import load as pl_load
|
||||
from lightning_lite.utilities.types import _MAP_LOCATION_TYPE, _PATH
|
||||
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, AttributeDict
|
||||
from pytorch_lightning.utilities.migration import pl_legacy_patch
|
||||
from pytorch_lightning.utilities.parsing import parse_class_init_keys
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
|
||||
from pytorch_lightning.utilities.types import _MAP_LOCATION_TYPE, _PATH
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
PRIMITIVE_TYPES = (bool, int, float, str)
|
||||
|
|
|
@ -25,7 +25,15 @@ from torch import Tensor
|
|||
from torch.optim import Optimizer
|
||||
from torch.utils.data import BatchSampler, DataLoader, DistributedSampler
|
||||
|
||||
from lightning_lite.utilities import _AcceleratorType, _StrategyType, move_data_to_device
|
||||
from lightning_lite.utilities.apply_func import convert_to_tensors
|
||||
from lightning_lite.utilities.data import (
|
||||
_auto_add_worker_init_fn,
|
||||
_replace_dunder_methods,
|
||||
_update_dataloader,
|
||||
has_iterable_dataset,
|
||||
)
|
||||
from lightning_lite.utilities.seed import seed_everything
|
||||
from pytorch_lightning.accelerators.accelerator import Accelerator
|
||||
from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer
|
||||
from pytorch_lightning.overrides.distributed import DistributedSamplerWrapper
|
||||
|
@ -33,15 +41,7 @@ from pytorch_lightning.plugins import PLUGIN_INPUT
|
|||
from pytorch_lightning.strategies import DeepSpeedStrategy, Strategy, TPUSpawnStrategy
|
||||
from pytorch_lightning.strategies.strategy import TBroadcast
|
||||
from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector
|
||||
from pytorch_lightning.utilities import _AcceleratorType, _StrategyType, move_data_to_device
|
||||
from pytorch_lightning.utilities.data import (
|
||||
_auto_add_worker_init_fn,
|
||||
_replace_dunder_methods,
|
||||
_update_dataloader,
|
||||
has_iterable_dataset,
|
||||
)
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.seed import seed_everything
|
||||
|
||||
|
||||
class LightningLite(ABC):
|
||||
|
|
|
@ -23,16 +23,16 @@ from torch.optim import Optimizer
|
|||
from torch.utils.data import DataLoader
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from lightning_lite.utilities.warnings import PossibleUserWarning
|
||||
from pytorch_lightning.callbacks.timer import Timer
|
||||
from pytorch_lightning.loops import Loop
|
||||
from pytorch_lightning.strategies import ParallelStrategy, Strategy
|
||||
from pytorch_lightning.trainer.progress import BaseProgress
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.memory import recursive_detach
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
|
||||
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
|
||||
from pytorch_lightning.utilities.types import STEP_OUTPUT
|
||||
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
||||
|
||||
|
||||
def check_finite_loss(loss: Optional[Tensor]) -> None:
|
||||
|
|
|
@ -20,7 +20,7 @@ from torch.nn.parallel import DistributedDataParallel
|
|||
|
||||
import pytorch_lightning as pl
|
||||
from lightning_lite.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
|
||||
from pytorch_lightning.utilities import rank_zero_deprecation
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation
|
||||
|
||||
|
||||
class _LightningPrecisionModuleWrapperBase(_DeviceDtypeModuleMixin, torch.nn.Module):
|
||||
|
|
|
@ -22,8 +22,8 @@ from pytorch_lightning.overrides.base import (
|
|||
_LightningPrecisionModuleWrapperBase,
|
||||
unwrap_lightning_module,
|
||||
)
|
||||
from pytorch_lightning.utilities import rank_zero_deprecation
|
||||
from pytorch_lightning.utilities.imports import _IS_WINDOWS
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation
|
||||
|
||||
_FAIRSCALE_AVAILABLE = not _IS_WINDOWS and module_available("fairscale.nn")
|
||||
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pytorch_lightning.utilities.types import _PATH
|
||||
from lightning_lite.utilities.types import _PATH
|
||||
|
||||
|
||||
class CheckpointIO(ABC):
|
||||
|
|
|
@ -19,8 +19,8 @@ import torch
|
|||
|
||||
from lightning_lite.utilities.apply_func import move_data_to_device
|
||||
from lightning_lite.utilities.cloud_io import atomic_save, get_filesystem
|
||||
from lightning_lite.utilities.types import _PATH
|
||||
from pytorch_lightning.plugins.io.torch_plugin import TorchCheckpointIO
|
||||
from pytorch_lightning.utilities.types import _PATH
|
||||
|
||||
|
||||
class HPUCheckpointIO(TorchCheckpointIO):
|
||||
|
|
|
@ -18,9 +18,9 @@ from typing import Any, Callable, Dict, Optional
|
|||
import pytorch_lightning as pl
|
||||
from lightning_lite.utilities.cloud_io import atomic_save, get_filesystem
|
||||
from lightning_lite.utilities.cloud_io import load as pl_load
|
||||
from lightning_lite.utilities.types import _PATH
|
||||
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
|
||||
from pytorch_lightning.utilities.types import _PATH
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
|
|
@ -17,9 +17,9 @@ from typing import Any, Dict, Optional
|
|||
from lightning_utilities.core.apply_func import apply_to_collection
|
||||
|
||||
from lightning_lite.utilities.cloud_io import get_filesystem
|
||||
from lightning_lite.utilities.types import _PATH
|
||||
from pytorch_lightning.plugins.io.torch_plugin import TorchCheckpointIO
|
||||
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, _TPU_AVAILABLE
|
||||
from pytorch_lightning.utilities.types import _PATH
|
||||
|
||||
if _TPU_AVAILABLE:
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
|
|
@ -18,10 +18,10 @@ from torch.nn import Module
|
|||
from torch.optim import LBFGS, Optimizer
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from lightning_lite.utilities.types import _PARAMETERS
|
||||
from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin
|
||||
from pytorch_lightning.utilities import _APEX_AVAILABLE, AMPType
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.types import _PARAMETERS
|
||||
|
||||
if _APEX_AVAILABLE:
|
||||
from apex import amp
|
||||
|
|
|
@ -20,9 +20,9 @@ from torch.nn import Module
|
|||
from torch.optim import LBFGS, Optimizer
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from lightning_lite.utilities.enums import AMPType, PrecisionType
|
||||
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
|
||||
from pytorch_lightning.utilities import GradClipAlgorithmType
|
||||
from pytorch_lightning.utilities.enums import AMPType, PrecisionType
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _APEX_AVAILABLE
|
||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||
|
|
|
@ -15,8 +15,8 @@ from typing import Any, Optional, Union
|
|||
|
||||
import torch
|
||||
|
||||
from lightning_lite.utilities.enums import PrecisionType
|
||||
from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin
|
||||
from pytorch_lightning.utilities.enums import PrecisionType
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_12
|
||||
|
||||
|
|
|
@ -13,8 +13,8 @@
|
|||
# limitations under the License.
|
||||
from typing import Optional, Union
|
||||
|
||||
from lightning_lite.utilities.enums import PrecisionType
|
||||
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
|
||||
from pytorch_lightning.utilities.enums import PrecisionType
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _HPU_AVAILABLE
|
||||
|
||||
|
|
|
@ -18,9 +18,9 @@ from torch.nn import Module
|
|||
from torch.optim import LBFGS, Optimizer
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from lightning_lite.utilities.enums import PrecisionType
|
||||
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
|
||||
from pytorch_lightning.utilities import GradClipAlgorithmType
|
||||
from pytorch_lightning.utilities.enums import PrecisionType
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||
|
||||
|
|
|
@ -21,9 +21,9 @@ from torch.nn import Module
|
|||
from torch.optim import Optimizer
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from lightning_lite.utilities.types import _PARAMETERS
|
||||
from pytorch_lightning.core.hooks import CheckpointHooks
|
||||
from pytorch_lightning.utilities import grad_norm, GradClipAlgorithmType
|
||||
from pytorch_lightning.utilities.types import _PARAMETERS
|
||||
|
||||
|
||||
class PrecisionPlugin(CheckpointHooks):
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from pytorch_lightning.profilers.advanced import AdvancedProfiler as NewAdvancedProfiler
|
||||
from pytorch_lightning.utilities import rank_zero_deprecation
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation
|
||||
|
||||
|
||||
class AdvancedProfiler(NewAdvancedProfiler):
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from pytorch_lightning.profilers.profiler import Profiler as NewProfiler
|
||||
from pytorch_lightning.utilities import rank_zero_deprecation
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation
|
||||
|
||||
|
||||
class Profiler(NewProfiler):
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
from pytorch_lightning.profilers.pytorch import PyTorchProfiler as NewPyTorchProfiler
|
||||
from pytorch_lightning.profilers.pytorch import RegisterRecordFunction as NewRegisterRecordFuncion
|
||||
from pytorch_lightning.profilers.pytorch import ScheduleWrapper as NewScheduleWrapper
|
||||
from pytorch_lightning.utilities import rank_zero_deprecation
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation
|
||||
|
||||
|
||||
class RegisterRecordFunction(NewRegisterRecordFuncion):
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from pytorch_lightning.profilers.simple import SimpleProfiler as NewSimpleProfiler
|
||||
from pytorch_lightning.utilities import rank_zero_deprecation
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation
|
||||
|
||||
|
||||
class SimpleProfiler(NewSimpleProfiler):
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from pytorch_lightning.profilers.xla import XLAProfiler as NewXLAProfiler
|
||||
from pytorch_lightning.utilities import rank_zero_deprecation
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation
|
||||
|
||||
|
||||
class XLAProfiler(NewXLAProfiler):
|
||||
|
|
|
@ -24,8 +24,8 @@ from lightning_utilities.core.rank_zero import WarningCache
|
|||
from torch import nn, Tensor
|
||||
from torch.autograd.profiler import record_function
|
||||
|
||||
from lightning_lite.utilities.device_parser import is_cuda_available
|
||||
from pytorch_lightning.profilers.profiler import Profiler
|
||||
from pytorch_lightning.utilities.device_parser import is_cuda_available
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _KINETO_AVAILABLE
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
|
||||
|
|
|
@ -8,6 +8,9 @@ from torch import Tensor
|
|||
from torch.nn import Module
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from lightning_lite.utilities.distributed import ReduceOp
|
||||
from lightning_lite.utilities.optimizer import optimizers_to_device
|
||||
from lightning_lite.utilities.seed import reset_seed
|
||||
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase
|
||||
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
|
||||
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
|
||||
|
@ -15,10 +18,7 @@ from pytorch_lightning.plugins.precision import PrecisionPlugin
|
|||
from pytorch_lightning.strategies.ddp import DDPStrategy
|
||||
from pytorch_lightning.strategies.strategy import TBroadcast
|
||||
from pytorch_lightning.trainer.states import TrainerFn
|
||||
from pytorch_lightning.utilities.distributed import ReduceOp
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.optimizer import optimizers_to_device
|
||||
from pytorch_lightning.utilities.seed import reset_seed
|
||||
|
||||
_BAGUA_AVAILABLE = package_available("bagua")
|
||||
|
||||
|
|
|
@ -29,6 +29,15 @@ from torch.nn.parallel.distributed import DistributedDataParallel
|
|||
from torch.optim.optimizer import Optimizer
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from lightning_lite.utilities.distributed import (
|
||||
_get_process_group_backend_from_env,
|
||||
distributed_available,
|
||||
get_default_process_group_backend_for_device,
|
||||
)
|
||||
from lightning_lite.utilities.distributed import group as _group
|
||||
from lightning_lite.utilities.distributed import init_dist_connection, ReduceOp, sync_ddp_if_available
|
||||
from lightning_lite.utilities.optimizer import optimizers_to_device
|
||||
from lightning_lite.utilities.seed import reset_seed
|
||||
from pytorch_lightning.core.optimizer import LightningOptimizer
|
||||
from pytorch_lightning.overrides import LightningDistributedModule
|
||||
from pytorch_lightning.overrides.base import _LightningPrecisionModuleWrapperBase
|
||||
|
@ -41,23 +50,10 @@ from pytorch_lightning.strategies.launchers.subprocess_script import _Subprocess
|
|||
from pytorch_lightning.strategies.parallel import ParallelStrategy
|
||||
from pytorch_lightning.strategies.strategy import TBroadcast
|
||||
from pytorch_lightning.trainer.states import TrainerFn
|
||||
from pytorch_lightning.utilities.distributed import (
|
||||
_get_process_group_backend_from_env,
|
||||
distributed_available,
|
||||
get_default_process_group_backend_for_device,
|
||||
)
|
||||
from pytorch_lightning.utilities.distributed import group as _group
|
||||
from pytorch_lightning.utilities.distributed import (
|
||||
init_dist_connection,
|
||||
ReduceOp,
|
||||
register_ddp_comm_hook,
|
||||
sync_ddp_if_available,
|
||||
)
|
||||
from pytorch_lightning.utilities.distributed import register_ddp_comm_hook
|
||||
from pytorch_lightning.utilities.exceptions import DeadlockDetectedException
|
||||
from pytorch_lightning.utilities.imports import _IS_WINDOWS, _TORCH_GREATER_EQUAL_1_10, _TORCH_GREATER_EQUAL_1_11
|
||||
from pytorch_lightning.utilities.optimizer import optimizers_to_device
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_only, rank_zero_warn
|
||||
from pytorch_lightning.utilities.seed import reset_seed
|
||||
from pytorch_lightning.utilities.types import PredictStep, STEP_OUTPUT, TestStep, ValidationStep
|
||||
|
||||
if _FAIRSCALE_AVAILABLE:
|
||||
|
|
|
@ -24,6 +24,14 @@ from torch.nn.parallel.distributed import DistributedDataParallel
|
|||
from typing_extensions import Literal
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from lightning_lite.utilities.distributed import (
|
||||
_get_process_group_backend_from_env,
|
||||
distributed_available,
|
||||
get_default_process_group_backend_for_device,
|
||||
)
|
||||
from lightning_lite.utilities.distributed import group as _group
|
||||
from lightning_lite.utilities.distributed import init_dist_connection, ReduceOp, sync_ddp_if_available
|
||||
from lightning_lite.utilities.optimizer import optimizers_to_device
|
||||
from pytorch_lightning.overrides import LightningDistributedModule
|
||||
from pytorch_lightning.overrides.base import _LightningPrecisionModuleWrapperBase
|
||||
from pytorch_lightning.overrides.distributed import prepare_for_backward
|
||||
|
@ -34,20 +42,8 @@ from pytorch_lightning.strategies.launchers.multiprocessing import _MultiProcess
|
|||
from pytorch_lightning.strategies.parallel import ParallelStrategy
|
||||
from pytorch_lightning.strategies.strategy import TBroadcast
|
||||
from pytorch_lightning.trainer.states import TrainerFn
|
||||
from pytorch_lightning.utilities.distributed import (
|
||||
_get_process_group_backend_from_env,
|
||||
distributed_available,
|
||||
get_default_process_group_backend_for_device,
|
||||
)
|
||||
from pytorch_lightning.utilities.distributed import group as _group
|
||||
from pytorch_lightning.utilities.distributed import (
|
||||
init_dist_connection,
|
||||
ReduceOp,
|
||||
register_ddp_comm_hook,
|
||||
sync_ddp_if_available,
|
||||
)
|
||||
from pytorch_lightning.utilities.distributed import register_ddp_comm_hook
|
||||
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_11
|
||||
from pytorch_lightning.utilities.optimizer import optimizers_to_device
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_only
|
||||
from pytorch_lightning.utilities.types import PredictStep, STEP_OUTPUT, TestStep, ValidationStep
|
||||
|
||||
|
|
|
@ -30,6 +30,15 @@ from torch.nn import Module
|
|||
from torch.optim import Optimizer
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from lightning_lite.utilities.distributed import (
|
||||
_get_process_group_backend_from_env,
|
||||
get_default_process_group_backend_for_device,
|
||||
log,
|
||||
)
|
||||
from lightning_lite.utilities.enums import AMPType, PrecisionType
|
||||
from lightning_lite.utilities.optimizer import optimizers_to_device
|
||||
from lightning_lite.utilities.seed import reset_seed
|
||||
from lightning_lite.utilities.types import _LRScheduler, _PATH, ReduceLROnPlateau
|
||||
from pytorch_lightning.accelerators.cuda import CUDAAccelerator
|
||||
from pytorch_lightning.core.optimizer import _init_optimizers_and_lr_schedulers
|
||||
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase
|
||||
|
@ -39,18 +48,10 @@ from pytorch_lightning.strategies.ddp import DDPStrategy
|
|||
from pytorch_lightning.strategies.utils import _fp_to_half
|
||||
from pytorch_lightning.trainer.states import TrainerFn
|
||||
from pytorch_lightning.utilities import GradClipAlgorithmType
|
||||
from pytorch_lightning.utilities.distributed import (
|
||||
_get_process_group_backend_from_env,
|
||||
get_default_process_group_backend_for_device,
|
||||
log,
|
||||
)
|
||||
from pytorch_lightning.utilities.enums import AMPType, PrecisionType
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||
from pytorch_lightning.utilities.optimizer import optimizers_to_device
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info, rank_zero_warn
|
||||
from pytorch_lightning.utilities.seed import reset_seed
|
||||
from pytorch_lightning.utilities.types import _LRScheduler, _PATH, LRSchedulerConfig, ReduceLROnPlateau, STEP_OUTPUT
|
||||
from pytorch_lightning.utilities.types import LRSchedulerConfig, STEP_OUTPUT
|
||||
|
||||
warning_cache = WarningCache()
|
||||
|
||||
|
|
|
@ -19,13 +19,13 @@ from torch import Tensor
|
|||
from torch.nn import DataParallel, Module
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from lightning_lite.utilities.distributed import ReduceOp
|
||||
from pytorch_lightning.overrides.base import _LightningPrecisionModuleWrapperBase
|
||||
from pytorch_lightning.overrides.data_parallel import LightningParallelModule
|
||||
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
|
||||
from pytorch_lightning.plugins.precision import PrecisionPlugin
|
||||
from pytorch_lightning.strategies.parallel import ParallelStrategy
|
||||
from pytorch_lightning.strategies.strategy import TBroadcast, TReduce
|
||||
from pytorch_lightning.utilities.distributed import ReduceOp
|
||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||
from pytorch_lightning.utilities.types import STEP_OUTPUT
|
||||
|
||||
|
|
|
@ -18,6 +18,8 @@ from typing import Any, Dict, Generator, List, Optional
|
|||
import torch
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from lightning_lite.utilities.enums import PrecisionType
|
||||
from lightning_lite.utilities.optimizer import optimizers_to_device
|
||||
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
|
||||
from pytorch_lightning.overrides.fairscale import _FAIRSCALE_AVAILABLE
|
||||
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
|
||||
|
@ -25,10 +27,8 @@ from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
|
|||
from pytorch_lightning.plugins.precision import PrecisionPlugin
|
||||
from pytorch_lightning.strategies.ddp import DDPStrategy
|
||||
from pytorch_lightning.trainer.states import TrainerFn
|
||||
from pytorch_lightning.utilities.enums import PrecisionType
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||
from pytorch_lightning.utilities.optimizer import optimizers_to_device
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_info
|
||||
from pytorch_lightning.utilities.types import STEP_OUTPUT
|
||||
|
||||
|
|
|
@ -19,6 +19,14 @@ import torch
|
|||
from torch import Tensor
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from lightning_lite.utilities.distributed import (
|
||||
_get_process_group_backend_from_env,
|
||||
get_default_process_group_backend_for_device,
|
||||
)
|
||||
from lightning_lite.utilities.distributed import group as _group
|
||||
from lightning_lite.utilities.distributed import init_dist_connection, ReduceOp, sync_ddp_if_available
|
||||
from lightning_lite.utilities.optimizer import optimizers_to_device
|
||||
from lightning_lite.utilities.seed import reset_seed
|
||||
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
|
||||
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
|
||||
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
|
||||
|
@ -28,19 +36,10 @@ from pytorch_lightning.strategies.launchers.subprocess_script import _Subprocess
|
|||
from pytorch_lightning.strategies.parallel import ParallelStrategy
|
||||
from pytorch_lightning.strategies.strategy import TBroadcast
|
||||
from pytorch_lightning.trainer.states import TrainerFn
|
||||
from pytorch_lightning.utilities import rank_zero_only
|
||||
from pytorch_lightning.utilities.distributed import (
|
||||
_get_process_group_backend_from_env,
|
||||
get_default_process_group_backend_for_device,
|
||||
)
|
||||
from pytorch_lightning.utilities.distributed import group as _group
|
||||
from pytorch_lightning.utilities.distributed import init_dist_connection, ReduceOp, sync_ddp_if_available
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_12
|
||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||
from pytorch_lightning.utilities.optimizer import optimizers_to_device
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_info
|
||||
from pytorch_lightning.utilities.seed import reset_seed
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_only
|
||||
from pytorch_lightning.utilities.types import ProcessGroup, STEP_OUTPUT
|
||||
|
||||
_distributed_available = torch.distributed.is_available()
|
||||
|
|
|
@ -8,14 +8,14 @@ import torch
|
|||
from torch import Tensor
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from lightning_lite.utilities.enums import PrecisionType
|
||||
from lightning_lite.utilities.types import _LRScheduler, ReduceLROnPlateau
|
||||
from pytorch_lightning.strategies.strategy import Strategy, TBroadcast
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
from pytorch_lightning.utilities.data import extract_batch_size
|
||||
from pytorch_lightning.utilities.enums import PrecisionType
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _HIVEMIND_AVAILABLE
|
||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||
from pytorch_lightning.utilities.types import _LRScheduler, ReduceLROnPlateau
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
|
||||
|
||||
if _HIVEMIND_AVAILABLE:
|
||||
import hivemind
|
||||
|
|
|
@ -20,14 +20,14 @@ from torch import Tensor
|
|||
from torch.optim import Optimizer
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from lightning_lite.utilities.distributed import distributed_available
|
||||
from lightning_lite.utilities.distributed import group as dist_group
|
||||
from lightning_lite.utilities.distributed import ReduceOp
|
||||
from pytorch_lightning.core.optimizer import LightningOptimizer
|
||||
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
|
||||
from pytorch_lightning.plugins.precision import PrecisionPlugin
|
||||
from pytorch_lightning.strategies.parallel import ParallelStrategy
|
||||
from pytorch_lightning.strategies.strategy import TBroadcast
|
||||
from pytorch_lightning.utilities.distributed import distributed_available
|
||||
from pytorch_lightning.utilities.distributed import group as dist_group
|
||||
from pytorch_lightning.utilities.distributed import ReduceOp
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _HOROVOD_AVAILABLE
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_only
|
||||
|
|
|
@ -18,6 +18,7 @@ from typing import Any, Callable, Dict, List, Optional
|
|||
import torch.distributed
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from lightning_lite.utilities.distributed import group as _group
|
||||
from pytorch_lightning.overrides import LightningDistributedModule
|
||||
from pytorch_lightning.overrides.torch_distributed import broadcast_object_list
|
||||
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
|
||||
|
@ -26,7 +27,6 @@ from pytorch_lightning.plugins.io.hpu_plugin import HPUCheckpointIO
|
|||
from pytorch_lightning.plugins.io.wrapper import _WrappingCheckpointIO
|
||||
from pytorch_lightning.plugins.precision import PrecisionPlugin
|
||||
from pytorch_lightning.strategies.ddp import DDPStrategy
|
||||
from pytorch_lightning.utilities.distributed import group as _group
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _HPU_AVAILABLE, _TORCH_LESSER_EQUAL_1_10_2
|
||||
from pytorch_lightning.utilities.types import STEP_OUTPUT
|
||||
|
|
|
@ -22,6 +22,7 @@ from torch.utils.data import DataLoader, Sampler
|
|||
|
||||
import pytorch_lightning as pl
|
||||
from lightning_lite.utilities.cloud_io import get_filesystem
|
||||
from lightning_lite.utilities.enums import PrecisionType
|
||||
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase
|
||||
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
|
||||
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
|
||||
|
@ -32,7 +33,6 @@ from pytorch_lightning.strategies.utils import _fp_to_half
|
|||
from pytorch_lightning.trainer.states import RunningStage, TrainerFn
|
||||
from pytorch_lightning.utilities import _IPU_AVAILABLE, _POPTORCH_AVAILABLE, rank_zero_warn
|
||||
from pytorch_lightning.utilities.data import _get_dataloader_init_args_and_kwargs, _reinstantiate_wrapped_cls
|
||||
from pytorch_lightning.utilities.enums import PrecisionType
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation
|
||||
|
|
|
@ -27,13 +27,13 @@ from typing_extensions import Literal
|
|||
|
||||
import pytorch_lightning as pl
|
||||
from lightning_lite.utilities.apply_func import move_data_to_device
|
||||
from lightning_lite.utilities.seed import _collect_rng_states, _set_rng_states
|
||||
from lightning_lite.utilities.types import _PATH
|
||||
from pytorch_lightning.strategies.launchers.base import _Launcher
|
||||
from pytorch_lightning.strategies.strategy import Strategy
|
||||
from pytorch_lightning.trainer.states import TrainerFn, TrainerState
|
||||
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_11
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_debug
|
||||
from pytorch_lightning.utilities.seed import _collect_rng_states, _set_rng_states
|
||||
from pytorch_lightning.utilities.types import _PATH
|
||||
|
||||
|
||||
class _MultiProcessingLauncher(_Launcher):
|
||||
|
|
|
@ -19,17 +19,17 @@ import torch
|
|||
from torch import Tensor
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.plugins import LayerSync
|
||||
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
|
||||
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
|
||||
from pytorch_lightning.plugins.precision import PrecisionPlugin
|
||||
from pytorch_lightning.strategies.strategy import Strategy
|
||||
from pytorch_lightning.utilities.distributed import (
|
||||
from lightning_lite.utilities.distributed import (
|
||||
_get_process_group_backend_from_env,
|
||||
all_gather_ddp_if_available,
|
||||
get_default_process_group_backend_for_device,
|
||||
ReduceOp,
|
||||
)
|
||||
from pytorch_lightning.plugins import LayerSync
|
||||
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
|
||||
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
|
||||
from pytorch_lightning.plugins.precision import PrecisionPlugin
|
||||
from pytorch_lightning.strategies.strategy import Strategy
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation
|
||||
|
||||
|
||||
|
|
|
@ -19,14 +19,14 @@ from torch.nn import Module
|
|||
from torch.optim import Optimizer
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from lightning_lite.utilities.enums import PrecisionType
|
||||
from lightning_lite.utilities.optimizer import optimizers_to_device
|
||||
from pytorch_lightning.core.optimizer import LightningOptimizer
|
||||
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase
|
||||
from pytorch_lightning.overrides.fairscale import _FAIRSCALE_AVAILABLE
|
||||
from pytorch_lightning.strategies.ddp import DDPStrategy
|
||||
from pytorch_lightning.trainer.states import TrainerFn
|
||||
from pytorch_lightning.utilities.enums import PrecisionType
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.optimizer import optimizers_to_device
|
||||
|
||||
if _FAIRSCALE_AVAILABLE:
|
||||
from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel
|
||||
|
|
|
@ -19,12 +19,12 @@ from torch.nn import Module
|
|||
from torch.optim import Optimizer
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from lightning_lite.utilities.optimizer import optimizers_to_device
|
||||
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase
|
||||
from pytorch_lightning.overrides.fairscale import _FAIRSCALE_AVAILABLE
|
||||
from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy
|
||||
from pytorch_lightning.trainer.states import TrainerFn
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.optimizer import optimizers_to_device
|
||||
|
||||
if _FAIRSCALE_AVAILABLE:
|
||||
from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel
|
||||
|
|
|
@ -19,10 +19,10 @@ import torch
|
|||
from torch import Tensor
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from lightning_lite.utilities.types import _DEVICE
|
||||
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
|
||||
from pytorch_lightning.plugins.precision import PrecisionPlugin
|
||||
from pytorch_lightning.strategies.strategy import Strategy, TBroadcast
|
||||
from pytorch_lightning.utilities.types import _DEVICE
|
||||
|
||||
|
||||
class SingleDeviceStrategy(Strategy):
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
from typing import Dict, Optional
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from lightning_lite.utilities.types import _DEVICE
|
||||
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
|
||||
from pytorch_lightning.plugins.io.hpu_plugin import HPUCheckpointIO
|
||||
from pytorch_lightning.plugins.io.wrapper import _WrappingCheckpointIO
|
||||
|
@ -22,7 +23,7 @@ from pytorch_lightning.plugins.precision import PrecisionPlugin
|
|||
from pytorch_lightning.strategies.single_device import SingleDeviceStrategy
|
||||
from pytorch_lightning.utilities import _HPU_AVAILABLE
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.types import _DEVICE, STEP_OUTPUT
|
||||
from pytorch_lightning.utilities.types import STEP_OUTPUT
|
||||
|
||||
if _HPU_AVAILABLE:
|
||||
import habana_frameworks.torch.core as htcore
|
||||
|
|
|
@ -24,6 +24,9 @@ from torch.utils.data import DataLoader
|
|||
|
||||
import pytorch_lightning as pl
|
||||
from lightning_lite.utilities.apply_func import move_data_to_device
|
||||
from lightning_lite.utilities.distributed import ReduceOp
|
||||
from lightning_lite.utilities.optimizer import optimizer_to_device, optimizers_to_device
|
||||
from lightning_lite.utilities.types import _PATH
|
||||
from pytorch_lightning.core.optimizer import _init_optimizers_and_lr_schedulers, LightningOptimizer
|
||||
from pytorch_lightning.plugins import TorchCheckpointIO
|
||||
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
|
||||
|
@ -31,10 +34,7 @@ from pytorch_lightning.plugins.io.wrapper import _WrappingCheckpointIO
|
|||
from pytorch_lightning.plugins.precision import PrecisionPlugin
|
||||
from pytorch_lightning.strategies.launchers.base import _Launcher
|
||||
from pytorch_lightning.trainer.states import TrainerFn
|
||||
from pytorch_lightning.utilities.distributed import ReduceOp
|
||||
from pytorch_lightning.utilities.optimizer import optimizer_to_device, optimizers_to_device
|
||||
from pytorch_lightning.utilities.types import (
|
||||
_PATH,
|
||||
LRSchedulerConfig,
|
||||
PredictStep,
|
||||
STEP_OUTPUT,
|
||||
|
|
|
@ -15,9 +15,9 @@ import importlib
|
|||
from inspect import getmembers, isclass
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
from lightning_lite.utilities.registry import _is_register_method_overridden
|
||||
from pytorch_lightning.strategies.strategy import Strategy
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.registry import _is_register_method_overridden
|
||||
|
||||
|
||||
class _StrategyRegistry(dict):
|
||||
|
|
|
@ -22,6 +22,10 @@ from torch.nn import Module
|
|||
from torch.utils.data import DataLoader
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from lightning_lite.utilities.data import has_len
|
||||
from lightning_lite.utilities.distributed import ReduceOp
|
||||
from lightning_lite.utilities.optimizer import optimizers_to_device
|
||||
from lightning_lite.utilities.types import _PATH
|
||||
from pytorch_lightning.overrides import LightningDistributedModule
|
||||
from pytorch_lightning.plugins.environments import XLAEnvironment
|
||||
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
|
||||
|
@ -34,12 +38,9 @@ from pytorch_lightning.strategies.strategy import TBroadcast
|
|||
from pytorch_lightning.trainer.connectors.data_connector import DataConnector
|
||||
from pytorch_lightning.trainer.states import TrainerFn
|
||||
from pytorch_lightning.utilities import _TPU_AVAILABLE, find_shared_parameters, set_shared_parameters
|
||||
from pytorch_lightning.utilities.data import has_len
|
||||
from pytorch_lightning.utilities.distributed import ReduceOp
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.optimizer import optimizers_to_device
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_only
|
||||
from pytorch_lightning.utilities.types import _PATH, EVAL_DATALOADERS, STEP_OUTPUT, TRAIN_DATALOADERS
|
||||
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, STEP_OUTPUT, TRAIN_DATALOADERS
|
||||
|
||||
if _TPU_AVAILABLE:
|
||||
import torch_xla.core.xla_env_vars as xenv
|
||||
|
|
|
@ -15,7 +15,7 @@ import os
|
|||
|
||||
import torch
|
||||
|
||||
from pytorch_lightning.utilities.enums import PrecisionType
|
||||
from lightning_lite.utilities.enums import PrecisionType
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation
|
||||
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
""""""
|
||||
|
||||
from lightning_lite.utilities.seed import seed_everything
|
||||
from pytorch_lightning.trainer.trainer import Trainer
|
||||
from pytorch_lightning.utilities.seed import seed_everything
|
||||
|
||||
__all__ = ["Trainer", "seed_everything"]
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import pytorch_lightning as pl
|
||||
from lightning_lite.utilities.warnings import PossibleUserWarning
|
||||
from pytorch_lightning.accelerators.ipu import IPUAccelerator
|
||||
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
|
||||
from pytorch_lightning.strategies import DataParallelStrategy
|
||||
|
@ -20,7 +21,6 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
|||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn
|
||||
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
|
||||
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
||||
|
||||
|
||||
def verify_loop_configurations(trainer: "pl.Trainer") -> None:
|
||||
|
|
|
@ -20,6 +20,7 @@ from typing import Dict, List, Optional, Union
|
|||
import torch
|
||||
from typing_extensions import Literal
|
||||
|
||||
from lightning_lite.utilities import _StrategyType, AMPType, device_parser, LightningEnum
|
||||
from pytorch_lightning.accelerators.accelerator import Accelerator
|
||||
from pytorch_lightning.accelerators.cpu import CPUAccelerator
|
||||
from pytorch_lightning.accelerators.cuda import CUDAAccelerator
|
||||
|
@ -75,15 +76,6 @@ from pytorch_lightning.strategies import (
|
|||
from pytorch_lightning.strategies.ddp_spawn import _DDP_FORK_ALIASES
|
||||
from pytorch_lightning.strategies.launchers.multiprocessing import _is_forking_disabled
|
||||
from pytorch_lightning.tuner.auto_gpu_select import pick_multiple_gpus
|
||||
from pytorch_lightning.utilities import (
|
||||
_StrategyType,
|
||||
AMPType,
|
||||
device_parser,
|
||||
LightningEnum,
|
||||
rank_zero_deprecation,
|
||||
rank_zero_info,
|
||||
rank_zero_warn,
|
||||
)
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import (
|
||||
_HOROVOD_AVAILABLE,
|
||||
|
@ -93,6 +85,7 @@ from pytorch_lightning.utilities.imports import (
|
|||
_TORCH_GREATER_EQUAL_1_11,
|
||||
_TPU_AVAILABLE,
|
||||
)
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info, rank_zero_warn
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
|
|
@ -24,6 +24,7 @@ from torchmetrics import Metric
|
|||
|
||||
import pytorch_lightning as pl
|
||||
from lightning_lite.utilities.cloud_io import get_filesystem
|
||||
from lightning_lite.utilities.types import _PATH
|
||||
from pytorch_lightning.plugins.precision import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin
|
||||
from pytorch_lightning.trainer.states import TrainerFn
|
||||
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE
|
||||
|
@ -31,7 +32,6 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
|||
from pytorch_lightning.utilities.imports import _fault_tolerant_training
|
||||
from pytorch_lightning.utilities.migration import pl_legacy_patch
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info
|
||||
from pytorch_lightning.utilities.types import _PATH
|
||||
from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS
|
||||
|
||||
if _OMEGACONF_AVAILABLE:
|
||||
|
|
|
@ -23,20 +23,14 @@ from torch.utils.data import BatchSampler, DataLoader, Sampler, SequentialSample
|
|||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from lightning_lite.utilities.data import _auto_add_worker_init_fn, _replace_dunder_methods, has_iterable_dataset
|
||||
from pytorch_lightning.accelerators.ipu import IPUAccelerator
|
||||
from pytorch_lightning.overrides.distributed import DistributedSamplerWrapper, UnrepeatedDistributedSamplerWrapper
|
||||
from pytorch_lightning.strategies import DDPSpawnStrategy
|
||||
from pytorch_lightning.trainer.states import RunningStage, TrainerFn
|
||||
from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator
|
||||
from pytorch_lightning.utilities.auto_restart import _validate_fault_tolerant_automatic
|
||||
from pytorch_lightning.utilities.data import (
|
||||
_auto_add_worker_init_fn,
|
||||
_is_dataloader_shuffled,
|
||||
_replace_dunder_methods,
|
||||
_update_dataloader,
|
||||
has_iterable_dataset,
|
||||
has_len_all_ranks,
|
||||
)
|
||||
from pytorch_lightning.utilities.data import _is_dataloader_shuffled, _update_dataloader, has_len_all_ranks
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _fault_tolerant_training
|
||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||
|
|
|
@ -24,8 +24,8 @@ from typing_extensions import TypedDict
|
|||
|
||||
from lightning_lite.utilities.apply_func import move_data_to_device
|
||||
from lightning_lite.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
|
||||
from lightning_lite.utilities.distributed import distributed_available
|
||||
from pytorch_lightning.utilities.data import extract_batch_size
|
||||
from pytorch_lightning.utilities.distributed import distributed_available
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _fault_tolerant_training
|
||||
from pytorch_lightning.utilities.memory import recursive_detach
|
||||
|
|
|
@ -18,7 +18,7 @@ from torch.utils.data import DataLoader
|
|||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.trainer.states import RunningStage
|
||||
from pytorch_lightning.utilities import rank_zero_deprecation
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation
|
||||
|
||||
|
||||
class TrainerDataLoadingMixin(ABC):
|
||||
|
|
|
@ -19,7 +19,7 @@ from torch.optim import Optimizer
|
|||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.core.optimizer import _init_optimizers_and_lr_schedulers, LightningOptimizer
|
||||
from pytorch_lightning.utilities import rank_zero_deprecation
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation
|
||||
|
||||
|
||||
class TrainerOptimizersMixin(ABC):
|
||||
|
|
|
@ -21,13 +21,13 @@ from torch.utils.data import Dataset
|
|||
from torch.utils.data.dataloader import _BaseDataLoaderIter, _MultiProcessingDataLoaderIter, DataLoader
|
||||
from torch.utils.data.dataset import IterableDataset
|
||||
|
||||
from lightning_lite.utilities.distributed import distributed_available
|
||||
from pytorch_lightning.utilities.auto_restart import (
|
||||
_reload_dataloader_state_dict,
|
||||
MergedIteratorState,
|
||||
patch_dataloader_iterator,
|
||||
)
|
||||
from pytorch_lightning.utilities.data import get_len
|
||||
from pytorch_lightning.utilities.distributed import distributed_available
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _fault_tolerant_training
|
||||
|
||||
|
|
|
@ -39,6 +39,10 @@ from torch.utils.data import DataLoader
|
|||
|
||||
import pytorch_lightning as pl
|
||||
from lightning_lite.utilities.cloud_io import get_filesystem
|
||||
from lightning_lite.utilities.data import _auto_add_worker_init_fn
|
||||
from lightning_lite.utilities.distributed import distributed_available
|
||||
from lightning_lite.utilities.types import _PATH
|
||||
from lightning_lite.utilities.warnings import PossibleUserWarning
|
||||
from pytorch_lightning.accelerators import (
|
||||
Accelerator,
|
||||
CUDAAccelerator,
|
||||
|
@ -102,8 +106,7 @@ from pytorch_lightning.utilities.argparse import (
|
|||
parse_env_variables,
|
||||
)
|
||||
from pytorch_lightning.utilities.auto_restart import _add_capture_metadata_collate
|
||||
from pytorch_lightning.utilities.data import _auto_add_worker_init_fn, has_len_all_ranks
|
||||
from pytorch_lightning.utilities.distributed import distributed_available
|
||||
from pytorch_lightning.utilities.data import has_len_all_ranks
|
||||
from pytorch_lightning.utilities.exceptions import ExitGracefullyException, MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _fault_tolerant_training
|
||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||
|
@ -111,13 +114,11 @@ from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_ze
|
|||
from pytorch_lightning.utilities.seed import isolate_rng
|
||||
from pytorch_lightning.utilities.types import (
|
||||
_EVALUATE_OUTPUT,
|
||||
_PATH,
|
||||
_PREDICT_OUTPUT,
|
||||
EVAL_DATALOADERS,
|
||||
LRSchedulerConfig,
|
||||
TRAIN_DATALOADERS,
|
||||
)
|
||||
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
# warnings to ignore in trainer
|
||||
|
|
|
@ -15,7 +15,7 @@ from typing import List
|
|||
|
||||
import torch
|
||||
|
||||
from pytorch_lightning.utilities import device_parser
|
||||
from lightning_lite.utilities import device_parser
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
|
||||
|
|
|
@ -15,15 +15,9 @@
|
|||
|
||||
import numpy
|
||||
|
||||
from lightning_lite.utilities import AllGatherGrad, AMPType, LightningEnum # noqa: F401
|
||||
from lightning_lite.utilities.apply_func import move_data_to_device # noqa: F401
|
||||
from pytorch_lightning.utilities.distributed import AllGatherGrad # noqa: F401
|
||||
from pytorch_lightning.utilities.enums import ( # noqa: F401
|
||||
_AcceleratorType,
|
||||
_StrategyType,
|
||||
AMPType,
|
||||
GradClipAlgorithmType,
|
||||
LightningEnum,
|
||||
)
|
||||
from pytorch_lightning.utilities.enums import GradClipAlgorithmType # noqa: F401
|
||||
from pytorch_lightning.utilities.grads import grad_norm # noqa: F401
|
||||
from pytorch_lightning.utilities.imports import ( # noqa: F401
|
||||
_APEX_AVAILABLE,
|
||||
|
|
|
@ -29,11 +29,11 @@ from torch.utils.data.dataloader import (
|
|||
from typing_extensions import TypedDict
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from lightning_lite.utilities.types import _Stateful
|
||||
from pytorch_lightning.utilities.distributed import _collect_states_on_rank_zero
|
||||
from pytorch_lightning.utilities.enums import _FaultTolerantMode, AutoRestartBatchKeys
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.seed import _collect_rng_states, _set_rng_states
|
||||
from pytorch_lightning.utilities.types import _Stateful
|
||||
|
||||
|
||||
class _IteratorStateDict(TypedDict):
|
||||
|
|
|
@ -18,7 +18,7 @@ from typing import Any
|
|||
from lightning_lite.utilities.cloud_io import atomic_save as new_atomic_save
|
||||
from lightning_lite.utilities.cloud_io import get_filesystem as new_get_filesystem
|
||||
from lightning_lite.utilities.cloud_io import load as new_load
|
||||
from pytorch_lightning.utilities import rank_zero_deprecation # TODO(lite): change to lightning_lite.utilities
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation
|
||||
|
||||
|
||||
def atomic_save(*args: Any, **kwargs: Any) -> Any:
|
||||
|
|
|
@ -11,18 +11,12 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import functools
|
||||
import inspect
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import fields
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Dict, Generator, Iterable, Mapping, Optional, Tuple, Type, Union
|
||||
from typing import Any, Dict, Generator, Iterable, Mapping, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from lightning_utilities.core.apply_func import is_dataclass_instance
|
||||
from lightning_utilities.core.inheritance import get_all_subclasses
|
||||
from lightning_utilities.core.rank_zero import WarningCache
|
||||
from torch import Tensor
|
||||
from torch.utils.data import (
|
||||
|
@ -36,13 +30,16 @@ from torch.utils.data import (
|
|||
)
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from lightning_lite.utilities import LightningEnum
|
||||
from lightning_lite.utilities.data import _reinstantiate_wrapped_cls, _replace_value_in_saved_args
|
||||
from lightning_lite.utilities.data import has_iterable_dataset as new_has_iterable_dataset
|
||||
from lightning_lite.utilities.data import has_len as new_has_len
|
||||
from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper
|
||||
from pytorch_lightning.trainer.states import RunningStage
|
||||
from pytorch_lightning.utilities.auto_restart import CaptureIterableDataset, CaptureMapDataset, FastForwardSampler
|
||||
from pytorch_lightning.utilities.enums import _FaultTolerantMode, LightningEnum
|
||||
from pytorch_lightning.utilities.enums import _FaultTolerantMode
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
|
||||
from pytorch_lightning.utilities.seed import pl_worker_init_function
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn
|
||||
|
||||
BType = Union[Tensor, str, Mapping[Any, "BType"], Iterable["BType"]]
|
||||
|
||||
|
@ -110,33 +107,6 @@ def extract_batch_size(batch: BType) -> int:
|
|||
return batch_size
|
||||
|
||||
|
||||
def has_iterable_dataset(dataloader: DataLoader) -> bool:
|
||||
return hasattr(dataloader, "dataset") and isinstance(dataloader.dataset, IterableDataset)
|
||||
|
||||
|
||||
def has_len(dataloader: Union[DataLoader, Iterable]) -> bool:
|
||||
"""Checks if a given Dataloader has ``__len__`` method implemented i.e. if it is a finite dataloader or
|
||||
infinite dataloader."""
|
||||
try:
|
||||
# try getting the length
|
||||
if len(dataloader) == 0:
|
||||
rank_zero_warn(
|
||||
f"`{dataloader.__class__.__name__}` returned 0 length. Please make sure this was your intention."
|
||||
)
|
||||
has_len = True
|
||||
except (TypeError, NotImplementedError):
|
||||
has_len = False
|
||||
|
||||
if has_len and has_iterable_dataset(dataloader):
|
||||
rank_zero_warn(
|
||||
"Your `IterableDataset` has `__len__` defined."
|
||||
" In combination with multi-process data loading (when num_workers > 1),"
|
||||
" `__len__` could be inaccurate if each worker is not configured independently"
|
||||
" to avoid having duplicate data."
|
||||
)
|
||||
return has_len
|
||||
|
||||
|
||||
def has_len_all_ranks(
|
||||
dataloader: DataLoader,
|
||||
strategy: "pl.Strategy",
|
||||
|
@ -171,7 +141,7 @@ def has_len_all_ranks(
|
|||
except (TypeError, NotImplementedError):
|
||||
has_len = False
|
||||
|
||||
if has_len and has_iterable_dataset(dataloader):
|
||||
if has_len and new_has_iterable_dataset(dataloader):
|
||||
rank_zero_warn(
|
||||
"Your `IterableDataset` has `__len__` defined."
|
||||
" In combination with multi-process data loading (when num_workers > 1),"
|
||||
|
@ -187,7 +157,7 @@ def get_len(dataloader: DataLoader) -> Union[int, float]:
|
|||
If ``__len__`` method is not implemented, return float('inf').
|
||||
"""
|
||||
|
||||
if has_len(dataloader):
|
||||
if new_has_len(dataloader):
|
||||
return len(dataloader)
|
||||
|
||||
return float("inf")
|
||||
|
@ -409,171 +379,6 @@ def _dataloader_init_kwargs_resolve_sampler(
|
|||
return {"sampler": sampler, "shuffle": False, "batch_sampler": None}
|
||||
|
||||
|
||||
def _replace_value_in_saved_args(
|
||||
replace_key: str,
|
||||
replace_value: Any,
|
||||
args: Tuple[Any, ...],
|
||||
kwargs: Dict[str, Any],
|
||||
default_kwargs: Dict[str, Any],
|
||||
arg_names: Tuple[str, ...],
|
||||
) -> Tuple[bool, Tuple[Any, ...], Dict[str, Any]]:
|
||||
"""Tries to replace an argument value in a saved list of args and kwargs.
|
||||
|
||||
Returns a tuple indicating success of the operation and modified saved args and kwargs
|
||||
"""
|
||||
|
||||
if replace_key in arg_names:
|
||||
replace_index = arg_names.index(replace_key)
|
||||
args = args[:replace_index] + (replace_value,) + args[replace_index + 1 :]
|
||||
return True, args, kwargs
|
||||
elif replace_key in kwargs or replace_key in default_kwargs:
|
||||
kwargs[replace_key] = replace_value
|
||||
return True, args, kwargs
|
||||
|
||||
return False, args, kwargs
|
||||
|
||||
|
||||
def _auto_add_worker_init_fn(dataloader: DataLoader, rank: int) -> None:
|
||||
if int(os.environ.get("PL_SEED_WORKERS", 0)) and dataloader.worker_init_fn is None:
|
||||
dataloader.worker_init_fn = partial(pl_worker_init_function, rank=rank)
|
||||
|
||||
|
||||
def _reinstantiate_wrapped_cls(orig_object: Any, *args: Any, explicit_cls: Optional[Type] = None, **kwargs: Any) -> Any:
|
||||
constructor = type(orig_object) if explicit_cls is None else explicit_cls
|
||||
|
||||
try:
|
||||
result = constructor(*args, **kwargs)
|
||||
except TypeError as e:
|
||||
# improve exception message due to an incorrect implementation of the `DataLoader` where multiple subclass
|
||||
# `__init__` arguments map to one `DataLoader.__init__` argument
|
||||
import re
|
||||
|
||||
match = re.match(r".*__init__\(\) got multiple values .* '(\w+)'", str(e))
|
||||
if not match:
|
||||
# an unexpected `TypeError`, continue failure
|
||||
raise
|
||||
argument = match.groups()[0]
|
||||
message = (
|
||||
f"The {constructor.__name__} implementation has an error where more than one `__init__` argument"
|
||||
f" can be passed to its parent's `{argument}=...` `__init__` argument. This is likely caused by allowing"
|
||||
f" passing both a custom argument that will map to the `{argument}` argument as well as `**kwargs`."
|
||||
f" `kwargs` should be filtered to make sure they don't contain the `{argument}` key."
|
||||
" This argument was automatically passed to your object by PyTorch Lightning."
|
||||
)
|
||||
raise MisconfigurationException(message) from e
|
||||
|
||||
attrs_record = getattr(orig_object, "__pl_attrs_record", list())
|
||||
for args, fn in attrs_record:
|
||||
fn(result, *args)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _wrap_init_method(init: Callable, store_explicit_arg: Optional[str] = None) -> Callable:
|
||||
"""Wraps the ``__init__`` method of classes (currently :class:`~torch.utils.data.DataLoader` and
|
||||
:class:`~torch.utils.data.BatchSampler`) in order to enable re-instantiation of custom subclasses."""
|
||||
|
||||
@functools.wraps(init)
|
||||
def wrapper(obj: Any, *args: Any, **kwargs: Any) -> None:
|
||||
# We need to inspect `init`, as inspecting `obj.__init__`
|
||||
# can lead to inspecting the wrong function with multiple inheritance
|
||||
old_inside_init = getattr(obj, "__pl_inside_init", False)
|
||||
object.__setattr__(obj, "__pl_inside_init", True)
|
||||
params = inspect.signature(init).parameters
|
||||
|
||||
parameters_defaults = OrderedDict(
|
||||
(param.name, param.default)
|
||||
for param in params.values()
|
||||
if param.name != "self" and param.kind not in (param.VAR_POSITIONAL, param.VAR_KEYWORD)
|
||||
)
|
||||
|
||||
param_names = tuple(parameters_defaults)[: len(args)]
|
||||
|
||||
default_kwargs = {
|
||||
name: value
|
||||
for name, value in parameters_defaults.items()
|
||||
if name not in kwargs and name not in param_names and value != inspect.Parameter.empty
|
||||
}
|
||||
|
||||
if not hasattr(obj, "__pl_saved_args"):
|
||||
object.__setattr__(obj, "__pl_saved_args", args)
|
||||
object.__setattr__(obj, "__pl_saved_kwargs", kwargs)
|
||||
object.__setattr__(obj, "__pl_saved_arg_names", param_names)
|
||||
object.__setattr__(obj, "__pl_saved_default_kwargs", default_kwargs)
|
||||
|
||||
# We want to use the latest possible value for explicit argument (i.e. ideally what gets passed to base class)
|
||||
# so that we can be sure, that it will not get changed anymore.
|
||||
# That is why we are setting this in every `__init__`
|
||||
if store_explicit_arg is not None:
|
||||
if store_explicit_arg in param_names:
|
||||
object.__setattr__(obj, f"__{store_explicit_arg}", args[param_names.index(store_explicit_arg)])
|
||||
elif store_explicit_arg in kwargs:
|
||||
object.__setattr__(obj, f"__{store_explicit_arg}", kwargs[store_explicit_arg])
|
||||
|
||||
init(obj, *args, **kwargs)
|
||||
object.__setattr__(obj, "__pl_inside_init", old_inside_init)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def _wrap_attr_method(method: Callable, tag: _WrapAttrTag) -> Callable:
|
||||
"""Wraps the ``__setattr__`` or ``__delattr__`` method of classes (currently :class:`~torch.utils.data.DataLoader` and
|
||||
:class:`~torch.utils.data.BatchSampler`) in order to enable re-instantiation of custom subclasses."""
|
||||
|
||||
@functools.wraps(method)
|
||||
def wrapper(obj: Any, *args: Any):
|
||||
# First, let's find out if we're the first in inheritance chain calling the patched method.
|
||||
name, *_ = args
|
||||
prev_call_name, prev_call_method = getattr(obj, "__pl_current_call", (None, "method"))
|
||||
first_call = not (prev_call_name == name and prev_call_method == tag)
|
||||
|
||||
# Then mark the current called method
|
||||
object.__setattr__(obj, "__pl_current_call", (name, tag))
|
||||
|
||||
# call original method
|
||||
method(obj, *args)
|
||||
if first_call and not getattr(obj, "__pl_inside_init", True):
|
||||
# and save the value it was called with to the internal list,
|
||||
# if we're outside of __init__ and the original call did not fail and we're the first call
|
||||
attrs_record = getattr(obj, "__pl_attrs_record", list())
|
||||
attrs_record.append((args, tag))
|
||||
object.__setattr__(obj, "__pl_attrs_record", attrs_record)
|
||||
object.__setattr__(obj, "__pl_current_call", (prev_call_name, prev_call_method))
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _replace_dunder_methods(base_cls: Type, store_explicit_arg: Optional[str] = None) -> Generator[None, None, None]:
|
||||
"""This context manager is used to add support for re-instantiation of custom (subclasses) of `base_cls`.
|
||||
|
||||
It patches the ``__init__``, ``__setattr__`` and ``__delattr__`` methods.
|
||||
"""
|
||||
classes = get_all_subclasses(base_cls) | {base_cls}
|
||||
for cls in classes:
|
||||
# Check that __init__ belongs to the class
|
||||
# https://stackoverflow.com/a/5253424
|
||||
if "__init__" in cls.__dict__:
|
||||
cls.__old__init__ = cls.__init__
|
||||
cls.__init__ = _wrap_init_method(cls.__init__, store_explicit_arg)
|
||||
|
||||
# we want at least one setattr/delattr in the chain to be patched and it can happen, that none of the subclasses
|
||||
# implement `__setattr__`/`__delattr__`. Therefore, we are always patching the `base_cls`
|
||||
for patch_fn_name, tag in (("__setattr__", _WrapAttrTag.SET), ("__delattr__", _WrapAttrTag.DEL)):
|
||||
if patch_fn_name in cls.__dict__ or cls is base_cls:
|
||||
saved_name = f"__old{patch_fn_name}"
|
||||
setattr(cls, saved_name, getattr(cls, patch_fn_name))
|
||||
setattr(cls, patch_fn_name, _wrap_attr_method(getattr(cls, patch_fn_name), tag))
|
||||
yield
|
||||
for cls in classes:
|
||||
for patched_name in ("__setattr__", "__delattr__", "__init__"):
|
||||
# Check that __old__{init,setattr,delattr} belongs to the class
|
||||
# https://stackoverflow.com/a/5253424
|
||||
if f"__old{patched_name}" in cls.__dict__:
|
||||
setattr(cls, patched_name, getattr(cls, f"__old{patched_name}"))
|
||||
delattr(cls, f"__old{patched_name}")
|
||||
|
||||
|
||||
def _wrap_with_capture_dataset(dataset: Dataset) -> Dataset:
|
||||
if isinstance(dataset, IterableDataset):
|
||||
# wrap the `IterableDataset` into a `CaptureIterableDataset` to record sampler states.
|
||||
|
@ -627,3 +432,19 @@ def _is_dataloader_shuffled(dataloader: object) -> bool:
|
|||
if isinstance(sampler, SequentialSampler):
|
||||
return False
|
||||
return isinstance(sampler, RandomSampler)
|
||||
|
||||
|
||||
def has_iterable_dataset(*args: Any, **kwargs: Any) -> Any:
|
||||
rank_zero_deprecation(
|
||||
"`pytorch_lightning.utilities.data.has_iterable_dataset` has been deprecated in v1.8.0 and will be"
|
||||
" removed in v1.10.0. Please use `lightning_lite.utilities.data.has_iterable_dataset` instead."
|
||||
)
|
||||
return new_has_iterable_dataset(*args, **kwargs)
|
||||
|
||||
|
||||
def has_len(*args: Any, **kwargs: Any) -> Any:
|
||||
rank_zero_deprecation(
|
||||
"`pytorch_lightning.utilities.data.has_len` has been deprecated in v1.8.0 and will be"
|
||||
" removed in v1.10.0. Please use `lightning_lite.utilities.data.has_len` instead."
|
||||
)
|
||||
return new_has_len(*args, **kwargs)
|
||||
|
|
|
@ -19,8 +19,8 @@ import os
|
|||
|
||||
import torch
|
||||
|
||||
from lightning_lite.utilities.types import _PATH
|
||||
from pytorch_lightning.strategies.deepspeed import _DEEPSPEED_AVAILABLE
|
||||
from pytorch_lightning.utilities.types import _PATH
|
||||
|
||||
if _DEEPSPEED_AVAILABLE:
|
||||
from deepspeed.utils.zero_to_fp32 import (
|
||||
|
|
|
@ -11,291 +11,16 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import multiprocessing
|
||||
from typing import Any, List, MutableSequence, Optional, Tuple, Union
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.cuda
|
||||
|
||||
from pytorch_lightning.plugins.environments import TorchElasticEnvironment
|
||||
from pytorch_lightning.strategies.launchers.multiprocessing import _is_forking_disabled
|
||||
from lightning_lite.utilities.device_parser import determine_root_gpu_device as new_determine_root_gpu_device
|
||||
from lightning_lite.utilities.device_parser import is_cuda_available as new_is_cuda_available
|
||||
from lightning_lite.utilities.device_parser import num_cuda_devices as new_num_cuda_devices
|
||||
from lightning_lite.utilities.device_parser import parse_cpu_cores as new_parse_cpu_cores
|
||||
from lightning_lite.utilities.device_parser import parse_gpu_ids as new_parse_gpu_ids
|
||||
from lightning_lite.utilities.device_parser import parse_tpu_cores as new_parse_tpu_cores
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.types import _DEVICE
|
||||
|
||||
|
||||
def determine_root_gpu_device(gpus: List[_DEVICE]) -> Optional[_DEVICE]:
|
||||
"""
|
||||
Args:
|
||||
gpus: non-empty list of ints representing which gpus to use
|
||||
|
||||
Returns:
|
||||
designated root GPU device id
|
||||
|
||||
Raises:
|
||||
TypeError:
|
||||
If ``gpus`` is not a list
|
||||
AssertionError:
|
||||
If GPU list is empty
|
||||
"""
|
||||
if gpus is None:
|
||||
return None
|
||||
|
||||
if not isinstance(gpus, list):
|
||||
raise TypeError("gpus should be a list")
|
||||
|
||||
assert len(gpus) > 0, "gpus should be a non empty list"
|
||||
|
||||
# set root gpu
|
||||
root_gpu = gpus[0]
|
||||
|
||||
return root_gpu
|
||||
|
||||
|
||||
def parse_gpu_ids(
|
||||
gpus: Optional[Union[int, str, List[int]]],
|
||||
include_cuda: bool = False,
|
||||
include_mps: bool = False,
|
||||
) -> Optional[List[int]]:
|
||||
"""
|
||||
Parses the GPU ids given in the format as accepted by the
|
||||
:class:`~pytorch_lightning.trainer.Trainer`.
|
||||
|
||||
Args:
|
||||
gpus: An int -1 or string '-1' indicate that all available GPUs should be used.
|
||||
A list of unique ints or a string containing list of comma separated unique integers
|
||||
indicates specific GPUs to use.
|
||||
An int 0 means that no GPUs should be used.
|
||||
Any int N > 0 indicates that GPUs [0..N) should be used.
|
||||
include_cuda: A boolean indicating whether to include cuda devices for gpu parsing.
|
||||
include_mps: A boolean indicating whether to include mps devices for gpu parsing.
|
||||
|
||||
Returns:
|
||||
a list of gpus to be used or ``None`` if no GPUs were requested
|
||||
|
||||
Raises:
|
||||
MisconfigurationException:
|
||||
If no GPUs are available but the value of gpus variable indicates request for GPUs
|
||||
|
||||
.. note::
|
||||
``include_cuda`` and ``include_mps`` default to ``False`` so that you only
|
||||
have to specify which device type to use and not disabling all the others.
|
||||
"""
|
||||
# Check that gpus param is None, Int, String or Sequence of Ints
|
||||
_check_data_type(gpus)
|
||||
|
||||
# Handle the case when no gpus are requested
|
||||
if gpus is None or (isinstance(gpus, int) and gpus == 0) or str(gpus).strip() in ("0", "[]"):
|
||||
return None
|
||||
|
||||
# We know user requested GPUs therefore if some of the
|
||||
# requested GPUs are not available an exception is thrown.
|
||||
gpus = _normalize_parse_gpu_string_input(gpus)
|
||||
gpus = _normalize_parse_gpu_input_to_list(gpus, include_cuda=include_cuda, include_mps=include_mps)
|
||||
if not gpus:
|
||||
raise MisconfigurationException("GPUs requested but none are available.")
|
||||
|
||||
if (
|
||||
TorchElasticEnvironment.detect()
|
||||
and len(gpus) != 1
|
||||
and len(_get_all_available_gpus(include_cuda=include_cuda, include_mps=include_mps)) == 1
|
||||
):
|
||||
# omit sanity check on torchelastic as by default shows one visible GPU per process
|
||||
return gpus
|
||||
|
||||
# Check that gpus are unique. Duplicate gpus are not supported by the backend.
|
||||
_check_unique(gpus)
|
||||
|
||||
return _sanitize_gpu_ids(gpus, include_cuda=include_cuda, include_mps=include_mps)
|
||||
|
||||
|
||||
def parse_tpu_cores(tpu_cores: Optional[Union[int, str, List[int]]]) -> Optional[Union[int, List[int]]]:
|
||||
"""
|
||||
Parses the tpu_cores given in the format as accepted by the
|
||||
:class:`~pytorch_lightning.trainer.Trainer`.
|
||||
|
||||
Args:
|
||||
tpu_cores: An int 1 or string '1' indicate that 1 core with multi-processing should be used
|
||||
An int 8 or string '8' indicate that all 8 cores with multi-processing should be used
|
||||
A list of int or a string containing list of comma separated integer
|
||||
indicates specific TPU core to use.
|
||||
|
||||
Returns:
|
||||
a list of tpu_cores to be used or ``None`` if no TPU cores were requested
|
||||
|
||||
Raises:
|
||||
MisconfigurationException:
|
||||
If TPU cores aren't 1, 8 or [<1-8>]
|
||||
"""
|
||||
_check_data_type(tpu_cores)
|
||||
|
||||
if isinstance(tpu_cores, str):
|
||||
tpu_cores = _parse_tpu_cores_str(tpu_cores.strip())
|
||||
|
||||
if not _tpu_cores_valid(tpu_cores):
|
||||
raise MisconfigurationException("`tpu_cores` can only be 1, 8 or [<1-8>]")
|
||||
|
||||
return tpu_cores
|
||||
|
||||
|
||||
def parse_cpu_cores(cpu_cores: Union[int, str, List[int]]) -> int:
|
||||
"""Parses the cpu_cores given in the format as accepted by the ``devices`` argument in the
|
||||
:class:`~pytorch_lightning.trainer.Trainer`.
|
||||
|
||||
Args:
|
||||
cpu_cores: An int > 0.
|
||||
|
||||
Returns:
|
||||
an int representing the number of processes
|
||||
|
||||
Raises:
|
||||
MisconfigurationException:
|
||||
If cpu_cores is not an int > 0
|
||||
"""
|
||||
if isinstance(cpu_cores, str) and cpu_cores.strip().isdigit():
|
||||
cpu_cores = int(cpu_cores)
|
||||
|
||||
if not isinstance(cpu_cores, int) or cpu_cores <= 0:
|
||||
raise MisconfigurationException("`devices` selected with `CPUAccelerator` should be an int > 0.")
|
||||
|
||||
return cpu_cores
|
||||
|
||||
|
||||
def _normalize_parse_gpu_string_input(s: Union[int, str, List[int]]) -> Union[int, List[int]]:
|
||||
if not isinstance(s, str):
|
||||
return s
|
||||
if s == "-1":
|
||||
return -1
|
||||
if "," in s:
|
||||
return [int(x.strip()) for x in s.split(",") if len(x) > 0]
|
||||
return int(s.strip())
|
||||
|
||||
|
||||
def _sanitize_gpu_ids(gpus: List[int], include_cuda: bool = False, include_mps: bool = False) -> List[int]:
|
||||
"""Checks that each of the GPUs in the list is actually available. Raises a MisconfigurationException if any of
|
||||
the GPUs is not available.
|
||||
|
||||
Args:
|
||||
gpus: list of ints corresponding to GPU indices
|
||||
|
||||
Returns:
|
||||
unmodified gpus variable
|
||||
|
||||
Raises:
|
||||
MisconfigurationException:
|
||||
If machine has fewer available GPUs than requested.
|
||||
"""
|
||||
if sum((include_cuda, include_mps)) == 0:
|
||||
raise ValueError("At least one gpu type should be specified!")
|
||||
all_available_gpus = _get_all_available_gpus(include_cuda=include_cuda, include_mps=include_mps)
|
||||
for gpu in gpus:
|
||||
if gpu not in all_available_gpus:
|
||||
raise MisconfigurationException(
|
||||
f"You requested gpu: {gpus}\n But your machine only has: {all_available_gpus}"
|
||||
)
|
||||
return gpus
|
||||
|
||||
|
||||
def _normalize_parse_gpu_input_to_list(
|
||||
gpus: Union[int, List[int], Tuple[int, ...]], include_cuda: bool, include_mps: bool
|
||||
) -> Optional[List[int]]:
|
||||
assert gpus is not None
|
||||
if isinstance(gpus, (MutableSequence, tuple)):
|
||||
return list(gpus)
|
||||
|
||||
# must be an int
|
||||
if not gpus: # gpus==0
|
||||
return None
|
||||
if gpus == -1:
|
||||
return _get_all_available_gpus(include_cuda=include_cuda, include_mps=include_mps)
|
||||
|
||||
return list(range(gpus))
|
||||
|
||||
|
||||
def _get_all_available_gpus(include_cuda: bool = False, include_mps: bool = False) -> List[int]:
|
||||
"""
|
||||
Returns:
|
||||
a list of all available gpus
|
||||
"""
|
||||
cuda_gpus = _get_all_available_cuda_gpus() if include_cuda else []
|
||||
mps_gpus = _get_all_available_mps_gpus() if include_mps else []
|
||||
return cuda_gpus + mps_gpus
|
||||
|
||||
|
||||
def _get_all_available_mps_gpus() -> List[int]:
|
||||
"""
|
||||
Returns:
|
||||
a list of all available MPS gpus
|
||||
"""
|
||||
# lazy import to avoid circular dependencies
|
||||
from pytorch_lightning.accelerators.mps import _MPS_AVAILABLE
|
||||
|
||||
return [0] if _MPS_AVAILABLE else []
|
||||
|
||||
|
||||
def _get_all_available_cuda_gpus() -> List[int]:
|
||||
"""
|
||||
Returns:
|
||||
a list of all available CUDA gpus
|
||||
"""
|
||||
return list(range(num_cuda_devices()))
|
||||
|
||||
|
||||
def _check_unique(device_ids: List[int]) -> None:
|
||||
"""Checks that the device_ids are unique.
|
||||
|
||||
Args:
|
||||
device_ids: list of ints corresponding to gpus indices
|
||||
|
||||
Raises:
|
||||
MisconfigurationException:
|
||||
If ``device_ids`` of GPUs aren't unique
|
||||
"""
|
||||
if len(device_ids) != len(set(device_ids)):
|
||||
raise MisconfigurationException("Device ID's (GPU) must be unique.")
|
||||
|
||||
|
||||
def _check_data_type(device_ids: Any) -> None:
|
||||
"""Checks that the device_ids argument is one of None, int, string, or sequence of integers.
|
||||
|
||||
Args:
|
||||
device_ids: gpus/tpu_cores parameter as passed to the Trainer
|
||||
|
||||
Raises:
|
||||
MisconfigurationException:
|
||||
If ``device_ids`` of GPU/TPUs aren't ``int``, ``str``, sequence of ``int`` or ``None``
|
||||
"""
|
||||
msg = "Device IDs (GPU/TPU) must be an int, a string, a sequence of ints or None, but you passed"
|
||||
|
||||
if device_ids is None:
|
||||
return
|
||||
elif isinstance(device_ids, (MutableSequence, tuple)):
|
||||
for id_ in device_ids:
|
||||
if type(id_) is not int:
|
||||
raise MisconfigurationException(f"{msg} a sequence of {type(id_).__name__}.")
|
||||
elif type(device_ids) not in (int, str):
|
||||
raise MisconfigurationException(f"{msg} {type(device_ids).__name__}.")
|
||||
|
||||
|
||||
def _tpu_cores_valid(tpu_cores: Any) -> bool:
|
||||
# allow 1 or 8 cores
|
||||
if tpu_cores in (1, 8, None):
|
||||
return True
|
||||
|
||||
# allow picking 1 of 8 indexes
|
||||
if isinstance(tpu_cores, (list, tuple, set)):
|
||||
has_1_tpu_idx = len(tpu_cores) == 1
|
||||
is_valid_tpu_idx = 1 <= list(tpu_cores)[0] <= 8
|
||||
|
||||
is_valid_tpu_core_choice = has_1_tpu_idx and is_valid_tpu_idx
|
||||
return is_valid_tpu_core_choice
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _parse_tpu_cores_str(tpu_cores: str) -> Union[int, List[int]]:
|
||||
if tpu_cores in ("1", "8"):
|
||||
return int(tpu_cores)
|
||||
return [int(x.strip()) for x in tpu_cores.split(",") if len(x) > 0]
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation
|
||||
|
||||
|
||||
def parse_hpus(devices: Optional[Union[int, str, List[int]]]) -> Optional[int]:
|
||||
|
@ -319,25 +44,49 @@ def parse_hpus(devices: Optional[Union[int, str, List[int]]]) -> Optional[int]:
|
|||
return int(devices) if isinstance(devices, str) else devices
|
||||
|
||||
|
||||
def num_cuda_devices() -> int:
|
||||
"""Returns the number of GPUs available.
|
||||
|
||||
Unlike :func:`torch.cuda.device_count`, this function will do its best not to create a CUDA context for fork
|
||||
support, if the platform allows it.
|
||||
"""
|
||||
if "fork" not in torch.multiprocessing.get_all_start_methods() or _is_forking_disabled():
|
||||
return torch.cuda.device_count()
|
||||
with multiprocessing.get_context("fork").Pool(1) as pool:
|
||||
return pool.apply(torch.cuda.device_count)
|
||||
def determine_root_gpu_device(*args: Any, **kwargs: Any) -> Any:
|
||||
rank_zero_deprecation(
|
||||
"`pytorch_lightning.utilities.device_parser.determine_root_gpu_device` has been deprecated in v1.8.0 and will"
|
||||
" be removed in v1.10.0. Please use `lightning_lite.utilities.device_parser.determine_root_gpu_device` instead."
|
||||
)
|
||||
return new_determine_root_gpu_device(*args, **kwargs)
|
||||
|
||||
|
||||
def is_cuda_available() -> bool:
|
||||
"""Returns a bool indicating if CUDA is currently available.
|
||||
rank_zero_deprecation(
|
||||
"`pytorch_lightning.utilities.device_parser.is_cuda_available` has been deprecated in v1.8.0 and will"
|
||||
" be removed in v1.10.0. Please use `lightning_lite.utilities.device_parser.is_cuda_available` instead."
|
||||
)
|
||||
return new_is_cuda_available()
|
||||
|
||||
Unlike :func:`torch.cuda.is_available`, this function will do its best not to create a CUDA context for fork
|
||||
support, if the platform allows it.
|
||||
"""
|
||||
if "fork" not in torch.multiprocessing.get_all_start_methods() or _is_forking_disabled():
|
||||
return torch.cuda.is_available()
|
||||
with multiprocessing.get_context("fork").Pool(1) as pool:
|
||||
return pool.apply(torch.cuda.is_available)
|
||||
|
||||
def num_cuda_devices() -> int:
|
||||
rank_zero_deprecation(
|
||||
"`pytorch_lightning.utilities.device_parser.num_cuda_devices` has been deprecated in v1.8.0 and will"
|
||||
" be removed in v1.10.0. Please use `lightning_lite.utilities.device_parser.num_cuda_devices` instead."
|
||||
)
|
||||
return new_num_cuda_devices()
|
||||
|
||||
|
||||
def parse_cpu_cores(*args: Any, **kwargs: Any) -> Any:
|
||||
rank_zero_deprecation(
|
||||
"`pytorch_lightning.utilities.device_parser.parse_cpu_cores` has been deprecated in v1.8.0 and will"
|
||||
" be removed in v1.10.0. Please use `lightning_lite.utilities.device_parser.parse_cpu_cores` instead."
|
||||
)
|
||||
return new_parse_cpu_cores(*args, **kwargs)
|
||||
|
||||
|
||||
def parse_gpu_ids(*args: Any, **kwargs: Any) -> Any:
|
||||
rank_zero_deprecation(
|
||||
"`pytorch_lightning.utilities.device_parser.parse_gpu_ids` has been deprecated in v1.8.0 and will"
|
||||
" be removed in v1.10.0. Please use `lightning_lite.utilities.device_parser.parse_gpu_ids` instead."
|
||||
)
|
||||
return new_parse_gpu_ids(*args, **kwargs)
|
||||
|
||||
|
||||
def parse_tpu_cores(*args: Any, **kwargs: Any) -> Any:
|
||||
rank_zero_deprecation(
|
||||
"`pytorch_lightning.utilities.device_parser.parse_tpu_cores` has been deprecated in v1.8.0 and will"
|
||||
" be removed in v1.10.0. Please use `lightning_lite.utilities.device_parser.parse_tpu_cores` instead."
|
||||
)
|
||||
return new_parse_tpu_cores(*args, **kwargs)
|
||||
|
|
|
@ -12,211 +12,23 @@
|
|||
# limitations under the License.
|
||||
"""Utilities that can be used with distributed training."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
from torch.nn.parallel.distributed import DistributedDataParallel
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.utilities.imports import _HPU_AVAILABLE, _TPU_AVAILABLE
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_only # noqa: F401
|
||||
from lightning_lite.utilities.distributed import all_gather_ddp_if_available as new_all_gather_ddp_if_available
|
||||
from lightning_lite.utilities.distributed import distributed_available as new_distributed_available
|
||||
from lightning_lite.utilities.distributed import gather_all_tensors as new_gather_all_tensors
|
||||
from lightning_lite.utilities.distributed import (
|
||||
get_default_process_group_backend_for_device as new_get_default_process_group_backend_for_device,
|
||||
)
|
||||
from lightning_lite.utilities.distributed import init_dist_connection as new_init_dist_connection
|
||||
from lightning_lite.utilities.distributed import sync_ddp as new_sync_ddp
|
||||
from lightning_lite.utilities.distributed import sync_ddp_if_available as new_sync_ddp_if_available
|
||||
from lightning_lite.utilities.distributed import tpu_distributed as new_tpu_distributed
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_debug, rank_zero_deprecation, rank_zero_info
|
||||
|
||||
if _TPU_AVAILABLE:
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
if torch.distributed.is_available():
|
||||
from torch.distributed import group, ReduceOp
|
||||
|
||||
else:
|
||||
|
||||
class ReduceOp: # type: ignore # (see https://github.com/python/mypy/issues/1153)
|
||||
SUM = None
|
||||
|
||||
class group: # type: ignore
|
||||
WORLD = None
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def gather_all_tensors(result: Tensor, group: Optional[Any] = None) -> List[Tensor]:
|
||||
"""Function to gather all tensors from several ddp processes onto a list that is broadcasted to all processes.
|
||||
|
||||
Works on tensors that have the same number of dimensions, but where each dimension may differ. In this case
|
||||
tensors are padded, gathered and then trimmed to secure equal workload for all processes.
|
||||
|
||||
Args:
|
||||
result: the value to sync
|
||||
group: the process group to gather results from. Defaults to all processes (world)
|
||||
|
||||
Return:
|
||||
gathered_result: list with size equal to the process group where
|
||||
gathered_result[i] corresponds to result tensor from process i
|
||||
"""
|
||||
if group is None:
|
||||
group = torch.distributed.group.WORLD
|
||||
|
||||
# convert tensors to contiguous format
|
||||
result = result.contiguous()
|
||||
|
||||
world_size = torch.distributed.get_world_size(group)
|
||||
torch.distributed.barrier(group=group)
|
||||
|
||||
# if the tensor is scalar, things are easy
|
||||
if result.ndim == 0:
|
||||
return _simple_gather_all_tensors(result, group, world_size)
|
||||
|
||||
# 1. Gather sizes of all tensors
|
||||
local_size = torch.tensor(result.shape, device=result.device)
|
||||
local_sizes = [torch.zeros_like(local_size) for _ in range(world_size)]
|
||||
torch.distributed.all_gather(local_sizes, local_size, group=group)
|
||||
max_size = torch.stack(local_sizes).max(dim=0).values
|
||||
all_sizes_equal = all(all(ls == max_size) for ls in local_sizes)
|
||||
|
||||
# 2. If shapes are all the same, then do a simple gather:
|
||||
if all_sizes_equal:
|
||||
return _simple_gather_all_tensors(result, group, world_size)
|
||||
|
||||
# 3. If not, we need to pad each local tensor to maximum size, gather and then truncate
|
||||
pad_dims = []
|
||||
pad_by = (max_size - local_size).detach().cpu()
|
||||
for val in reversed(pad_by):
|
||||
pad_dims.append(0)
|
||||
pad_dims.append(val.item())
|
||||
result_padded = F.pad(result, pad_dims)
|
||||
gathered_result = [torch.zeros_like(result_padded) for _ in range(world_size)]
|
||||
torch.distributed.all_gather(gathered_result, result_padded, group)
|
||||
for idx, item_size in enumerate(local_sizes):
|
||||
slice_param = [slice(dim_size) for dim_size in item_size]
|
||||
gathered_result[idx] = gathered_result[idx][slice_param]
|
||||
return gathered_result
|
||||
|
||||
|
||||
def _simple_gather_all_tensors(result: Tensor, group: Any, world_size: int) -> List[Tensor]:
|
||||
gathered_result = [torch.zeros_like(result) for _ in range(world_size)]
|
||||
torch.distributed.all_gather(gathered_result, result, group)
|
||||
return gathered_result
|
||||
|
||||
|
||||
def distributed_available() -> bool:
|
||||
return torch.distributed.is_available() and torch.distributed.is_initialized() or tpu_distributed()
|
||||
|
||||
|
||||
def sync_ddp_if_available(
|
||||
result: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None
|
||||
) -> Tensor:
|
||||
"""Function to reduce a tensor across worker processes during distributed training.
|
||||
|
||||
Args:
|
||||
result: the value to sync and reduce (typically tensor or number)
|
||||
group: the process group to gather results from. Defaults to all processes (world)
|
||||
reduce_op: the reduction operation. Defaults to sum.
|
||||
Can also be a string of 'avg', 'mean' to calculate the mean during reduction.
|
||||
|
||||
Return:
|
||||
reduced value
|
||||
"""
|
||||
if distributed_available():
|
||||
return sync_ddp(result, group=group, reduce_op=reduce_op)
|
||||
return result
|
||||
|
||||
|
||||
def sync_ddp(result: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None) -> Tensor:
|
||||
"""Function to reduce the tensors from several ddp processes to one main process.
|
||||
|
||||
Args:
|
||||
result: the value to sync and reduce (typically tensor or number)
|
||||
group: the process group to gather results from. Defaults to all processes (world)
|
||||
reduce_op: the reduction operation. Defaults to sum.
|
||||
Can also be a string of 'avg', 'mean' to calculate the mean during reduction.
|
||||
|
||||
Return:
|
||||
reduced value
|
||||
"""
|
||||
divide_by_world_size = False
|
||||
|
||||
if group is None:
|
||||
group = torch.distributed.group.WORLD
|
||||
|
||||
op: Optional[ReduceOp]
|
||||
if isinstance(reduce_op, str):
|
||||
if reduce_op.lower() in ("avg", "mean"):
|
||||
op = ReduceOp.SUM
|
||||
divide_by_world_size = True
|
||||
else:
|
||||
op = getattr(ReduceOp, reduce_op.upper())
|
||||
else:
|
||||
op = reduce_op
|
||||
|
||||
# WA for HPU. HPU doesn't support Long types, forcefully set it to float
|
||||
if _HPU_AVAILABLE:
|
||||
is_hpu_backend = os.environ.get("HCCL_DISTRIBUTED_BACKEND") == "1"
|
||||
if is_hpu_backend:
|
||||
if (result.type() == "torch.LongTensor") or (result.type() == "torch.hpu.LongTensor"):
|
||||
rank_zero_info("Long tensor unsupported on HPU, casting to float")
|
||||
result = result.float()
|
||||
|
||||
# sync all processes before reduction
|
||||
torch.distributed.barrier(group=group)
|
||||
torch.distributed.all_reduce(result, op=op, group=group, async_op=False)
|
||||
|
||||
if divide_by_world_size:
|
||||
result = result / torch.distributed.get_world_size(group)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class AllGatherGrad(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward( # type: ignore[override]
|
||||
ctx: Any,
|
||||
tensor: Tensor,
|
||||
group: Optional["torch.distributed.ProcessGroup"] = group.WORLD,
|
||||
) -> Tensor:
|
||||
ctx.group = group
|
||||
|
||||
gathered_tensor = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())]
|
||||
|
||||
torch.distributed.all_gather(gathered_tensor, tensor, group=group)
|
||||
gathered_tensor = torch.stack(gathered_tensor, dim=0)
|
||||
|
||||
return gathered_tensor
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Any, *grad_output: Tensor) -> Tuple[Tensor, None]:
|
||||
grad_output = torch.cat(grad_output)
|
||||
|
||||
torch.distributed.all_reduce(grad_output, op=torch.distributed.ReduceOp.SUM, async_op=False, group=ctx.group)
|
||||
|
||||
return grad_output[torch.distributed.get_rank()], None
|
||||
|
||||
|
||||
def all_gather_ddp_if_available(
|
||||
tensor: Tensor, group: Optional["torch.distributed.ProcessGroup"] = None, sync_grads: bool = False
|
||||
) -> Tensor:
|
||||
"""Function to gather a tensor from several distributed processes.
|
||||
|
||||
Args:
|
||||
tensor: tensor of shape (batch, ...)
|
||||
group: the process group to gather results from. Defaults to all processes (world)
|
||||
sync_grads: flag that allows users to synchronize gradients for all_gather op
|
||||
|
||||
Return:
|
||||
A tensor of shape (world_size, batch, ...)
|
||||
"""
|
||||
group = group if group is not None else torch.distributed.group.WORLD
|
||||
if distributed_available():
|
||||
if sync_grads:
|
||||
return AllGatherGrad.apply(tensor, group)
|
||||
with torch.no_grad():
|
||||
return AllGatherGrad.apply(tensor, group)
|
||||
return tensor
|
||||
|
||||
|
||||
def register_ddp_comm_hook(
|
||||
model: DistributedDataParallel,
|
||||
|
@ -319,67 +131,6 @@ def register_ddp_comm_hook(
|
|||
model.register_comm_hook(state=ddp_comm_state, hook=ddp_comm_hook) # type: ignore[operator]
|
||||
|
||||
|
||||
def tpu_distributed() -> bool:
|
||||
return _TPU_AVAILABLE and xm.xrt_world_size() > 1
|
||||
|
||||
|
||||
def get_default_process_group_backend_for_device(device: torch.device) -> str:
|
||||
return "nccl" if device.type == "cuda" else "gloo"
|
||||
|
||||
|
||||
def _get_process_group_backend_from_env() -> Optional[str]:
|
||||
torch_backend = os.getenv("PL_TORCH_DISTRIBUTED_BACKEND")
|
||||
if torch_backend is not None:
|
||||
rank_zero_deprecation(
|
||||
"Environment variable `PL_TORCH_DISTRIBUTED_BACKEND`"
|
||||
" was deprecated in v1.6 and will be removed in v1.8."
|
||||
" Specify `process_group_backend` directly on the strategy constructor."
|
||||
)
|
||||
return torch_backend
|
||||
|
||||
|
||||
def init_dist_connection(
|
||||
cluster_environment: "pl.plugins.environments.ClusterEnvironment",
|
||||
torch_distributed_backend: str,
|
||||
global_rank: Optional[int] = None,
|
||||
world_size: Optional[int] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Utility function to initialize distributed connection by setting env variables and initializing the
|
||||
distributed process group.
|
||||
|
||||
Args:
|
||||
cluster_environment: ``ClusterEnvironment`` instance
|
||||
torch_distributed_backend: backend to use (includes `nccl` and `gloo`)
|
||||
global_rank: rank of the current process
|
||||
world_size: number of processes in the group
|
||||
kwargs: kwargs for ``init_process_group``
|
||||
|
||||
Raises:
|
||||
RuntimeError:
|
||||
If ``torch.distributed`` is not available
|
||||
"""
|
||||
if not torch.distributed.is_available():
|
||||
raise RuntimeError("torch.distributed is not available. Cannot initialize distributed process group")
|
||||
if torch.distributed.is_initialized():
|
||||
log.debug("torch.distributed is already initialized. Exiting early")
|
||||
return
|
||||
global_rank = global_rank if global_rank is not None else cluster_environment.global_rank()
|
||||
world_size = world_size if world_size is not None else cluster_environment.world_size()
|
||||
os.environ["MASTER_ADDR"] = cluster_environment.main_address
|
||||
os.environ["MASTER_PORT"] = str(cluster_environment.main_port)
|
||||
log.info(f"Initializing distributed: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}")
|
||||
torch.distributed.init_process_group(torch_distributed_backend, rank=global_rank, world_size=world_size, **kwargs)
|
||||
|
||||
# on rank=0 let everyone know training is starting
|
||||
rank_zero_info(
|
||||
f"{'-' * 100}\n"
|
||||
f"distributed_backend={torch_distributed_backend}\n"
|
||||
f"All distributed processes registered. Starting with {world_size} processes\n"
|
||||
f"{'-' * 100}\n"
|
||||
)
|
||||
|
||||
|
||||
def _broadcast_object_list(obj: Any, rank: int) -> Any:
|
||||
objects = [obj if torch.distributed.get_rank() == rank else None]
|
||||
torch.distributed.broadcast_object_list(objects, src=rank)
|
||||
|
@ -397,6 +148,71 @@ def _collect_states_on_rank_zero(state: Dict[str, Any]) -> Dict[int, Any]:
|
|||
states: On global rank 0, a dictionary where the primary keys are
|
||||
the process rank and the values their associated states. Otherwise, returns None.
|
||||
"""
|
||||
if not distributed_available():
|
||||
if not new_distributed_available():
|
||||
return {0: state}
|
||||
return {rank: _broadcast_object_list(state, rank) for rank in range(torch.distributed.get_world_size())}
|
||||
|
||||
|
||||
def all_gather_ddp_if_available(*args: Any, **kwargs: Any) -> Any:
|
||||
rank_zero_deprecation(
|
||||
"`pytorch_lightning.utilities.distributed.all_gather_ddp_if_available` has been deprecated in v1.8.0 and will"
|
||||
" be removed in v1.10.0. Please use `lightning_lite.utilities.distributed.all_gather_ddp_if_available` instead."
|
||||
)
|
||||
return new_all_gather_ddp_if_available(*args, **kwargs)
|
||||
|
||||
|
||||
def distributed_available() -> Any:
|
||||
rank_zero_deprecation(
|
||||
"`pytorch_lightning.utilities.distributed.distributed_available` has been deprecated in v1.8.0 and will"
|
||||
" be removed in v1.10.0. Please use `lightning_lite.utilities.distributed.distributed_available` instead."
|
||||
)
|
||||
return new_distributed_available()
|
||||
|
||||
|
||||
def gather_all_tensors(*args: Any, **kwargs: Any) -> Any:
|
||||
rank_zero_deprecation(
|
||||
"`pytorch_lightning.utilities.distributed.gather_all_tensors` has been deprecated in v1.8.0 and will"
|
||||
" be removed in v1.10.0. Please use `lightning_lite.utilities.distributed.gather_all_tensors` instead."
|
||||
)
|
||||
return new_gather_all_tensors(*args, **kwargs)
|
||||
|
||||
|
||||
def get_default_process_group_backend_for_device(*args: Any, **kwargs: Any) -> Any:
|
||||
rank_zero_deprecation(
|
||||
"`pytorch_lightning.utilities.distributed.get_default_process_group_backend_for_device` has been deprecated"
|
||||
" in v1.8.0 and will be removed in v1.10.0. Please use"
|
||||
" `lightning_lite.utilities.distributed.get_default_process_group_backend_for_device` instead."
|
||||
)
|
||||
return new_get_default_process_group_backend_for_device(*args, **kwargs)
|
||||
|
||||
|
||||
def init_dist_connection(*args: Any, **kwargs: Any) -> Any:
|
||||
rank_zero_deprecation(
|
||||
"`pytorch_lightning.utilities.distributed.init_dist_connection` has been deprecated in v1.8.0 and will"
|
||||
" be removed in v1.10.0. Please use `lightning_lite.utilities.distributed.init_dist_connection` instead."
|
||||
)
|
||||
return new_init_dist_connection(*args, **kwargs)
|
||||
|
||||
|
||||
def sync_ddp(*args: Any, **kwargs: Any) -> Any:
|
||||
rank_zero_deprecation(
|
||||
"`pytorch_lightning.utilities.distributed.sync_ddp` has been deprecated in v1.8.0 and will"
|
||||
" be removed in v1.10.0. Please use `lightning_lite.utilities.distributed.sync_ddp` instead."
|
||||
)
|
||||
return new_sync_ddp(*args, **kwargs)
|
||||
|
||||
|
||||
def sync_ddp_if_available(*args: Any, **kwargs: Any) -> Any:
|
||||
rank_zero_deprecation(
|
||||
"`pytorch_lightning.utilities.distributed.sync_ddp_if_available` has been deprecated in v1.8.0 and will"
|
||||
" be removed in v1.10.0. Please use `lightning_lite.utilities.distributed.sync_ddp_if_available` instead."
|
||||
)
|
||||
return new_sync_ddp_if_available(*args, **kwargs)
|
||||
|
||||
|
||||
def tpu_distributed() -> bool:
|
||||
rank_zero_deprecation(
|
||||
"`pytorch_lightning.utilities.distributed.tpu_distributed` has been deprecated in v1.8.0 and will"
|
||||
" be removed in v1.10.0. Please use `lightning_lite.utilities.distributed.tpu_distributed` instead."
|
||||
)
|
||||
return new_tpu_distributed()
|
||||
|
|
|
@ -15,47 +15,10 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from lightning_utilities.core.enums import StrEnum
|
||||
|
||||
from lightning_lite.utilities.enums import AMPType, LightningEnum, PrecisionType # noqa: F401
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from enum import Enum
|
||||
|
||||
# re-defined because `mypy` infers `StrEnum` as `Any`
|
||||
class LightningEnum(StrEnum, Enum):
|
||||
...
|
||||
|
||||
else:
|
||||
LightningEnum = StrEnum
|
||||
|
||||
|
||||
class AMPType(LightningEnum):
|
||||
"""Type of Automatic Mixed Precission used for training."""
|
||||
|
||||
APEX = "apex"
|
||||
NATIVE = "native"
|
||||
|
||||
|
||||
class PrecisionType(LightningEnum):
|
||||
"""Type of precision used."""
|
||||
|
||||
HALF = "16"
|
||||
FLOAT = "32"
|
||||
FULL = "64"
|
||||
BFLOAT = "bf16"
|
||||
MIXED = "mixed"
|
||||
|
||||
@staticmethod
|
||||
def supported_type(precision: str | int) -> bool:
|
||||
return any(x == precision for x in PrecisionType)
|
||||
|
||||
@staticmethod
|
||||
def supported_types() -> list[str]:
|
||||
return [x.value for x in PrecisionType]
|
||||
|
||||
|
||||
class GradClipAlgorithmType(LightningEnum):
|
||||
"""Define gradient_clip_algorithm types - training-tricks.
|
||||
|
@ -85,47 +48,6 @@ class AutoRestartBatchKeys(LightningEnum):
|
|||
PL_RESTART_META = "__pl_restart_meta"
|
||||
|
||||
|
||||
class _StrategyType(LightningEnum):
|
||||
"""Define type of training strategy."""
|
||||
|
||||
DP = "dp"
|
||||
DDP = "ddp"
|
||||
DDP_SPAWN = "ddp_spawn"
|
||||
DDP_FORK = "ddp_fork"
|
||||
TPU_SPAWN = "tpu_spawn"
|
||||
DEEPSPEED = "deepspeed"
|
||||
HOROVOD = "horovod"
|
||||
DDP_SHARDED = "ddp_sharded"
|
||||
DDP_SHARDED_SPAWN = "ddp_sharded_spawn"
|
||||
DDP_FULLY_SHARDED = "ddp_fully_sharded"
|
||||
BAGUA = "bagua"
|
||||
HPU_PARALLEL = "hpu_parallel"
|
||||
|
||||
@staticmethod
|
||||
def interactive_compatible_types() -> list[_StrategyType]:
|
||||
"""Returns a list containing interactive compatible _StrategyTypes."""
|
||||
return [
|
||||
_StrategyType.DP,
|
||||
_StrategyType.TPU_SPAWN,
|
||||
_StrategyType.DDP_FORK,
|
||||
]
|
||||
|
||||
def is_interactive_compatible(self) -> bool:
|
||||
"""Returns whether self is interactive compatible."""
|
||||
return self in _StrategyType.interactive_compatible_types()
|
||||
|
||||
|
||||
class _AcceleratorType(LightningEnum):
|
||||
"""Define Accelerator type by its nature."""
|
||||
|
||||
CPU = "CPU"
|
||||
CUDA = "CUDA"
|
||||
IPU = "IPU"
|
||||
TPU = "TPU"
|
||||
HPU = "HPU"
|
||||
MPS = "MPS"
|
||||
|
||||
|
||||
class _FaultTolerantMode(LightningEnum):
|
||||
|
||||
DISABLED = "disabled"
|
||||
|
|
|
@ -12,9 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
class MisconfigurationException(Exception):
|
||||
"""Exception used to inform users of misuse with PyTorch Lightning."""
|
||||
from lightning_lite.utilities.exceptions import MisconfigurationException # noqa: F401
|
||||
|
||||
|
||||
class DeadlockDetectedException(Exception):
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue