From d2459df2ff79184c596b2f5c865ffb17c3541307 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 7 Sep 2022 17:25:23 +0200 Subject: [PATCH] Standalone Lite: Remaining Utilities (#14492) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Jirka Borovec Co-authored-by: Carlos MocholĂ­ Co-authored-by: Laverne Henderson Co-authored-by: Felonious-Spellfire --- .github/workflows/ci-lite-test-full.yml | 7 + docs/source-pytorch/extensions/logging.rst | 2 +- examples/pl_basics/autoencoder.py | 2 +- .../computer_vision_fine_tuning.py | 2 +- pyproject.toml | 2 + requirements/lite/base.txt | 1 + src/lightning_lite/__init__.py | 17 + src/lightning_lite/utilities/__init__.py | 40 ++ src/lightning_lite/utilities/cloud_io.py | 2 +- src/lightning_lite/utilities/data.py | 411 ++++++++++++++ src/lightning_lite/utilities/device_parser.py | 316 +++++++++++ src/lightning_lite/utilities/distributed.py | 264 +++++++++ src/lightning_lite/utilities/enums.py | 95 ++++ src/lightning_lite/utilities/exceptions.py | 17 + src/lightning_lite/utilities/imports.py | 61 +++ src/lightning_lite/utilities/optimizer.py | 34 ++ src/lightning_lite/utilities/rank_zero.py | 60 +++ .../utilities/registry.py | 0 src/lightning_lite/utilities/seed.py | 127 +++++ src/lightning_lite/utilities/types.py | 63 ++- src/lightning_lite/utilities/warnings.py | 24 + src/lightning_lite/utilities/xla_device.py | 2 +- src/pytorch_lightning/__init__.py | 2 +- src/pytorch_lightning/accelerators/cpu.py | 4 +- src/pytorch_lightning/accelerators/cuda.py | 4 +- src/pytorch_lightning/accelerators/mps.py | 4 +- .../accelerators/registry.py | 2 +- src/pytorch_lightning/accelerators/tpu.py | 2 +- src/pytorch_lightning/callbacks/base.py | 2 +- .../callbacks/early_stopping.py | 3 +- .../callbacks/fault_tolerance.py | 2 +- .../callbacks/model_checkpoint.py | 3 +- .../callbacks/stochastic_weight_avg.py | 3 +- src/pytorch_lightning/cli.py | 2 +- src/pytorch_lightning/core/datamodule.py | 3 +- src/pytorch_lightning/core/lightning.py | 2 +- src/pytorch_lightning/core/module.py | 2 +- src/pytorch_lightning/core/optimizer.py | 3 +- src/pytorch_lightning/core/saving.py | 2 +- src/pytorch_lightning/lite/lite.py | 16 +- src/pytorch_lightning/loops/utilities.py | 4 +- src/pytorch_lightning/overrides/base.py | 2 +- src/pytorch_lightning/overrides/fairscale.py | 2 +- .../plugins/io/checkpoint_plugin.py | 2 +- .../plugins/io/hpu_plugin.py | 2 +- .../plugins/io/torch_plugin.py | 2 +- .../plugins/io/xla_plugin.py | 2 +- .../plugins/precision/apex_amp.py | 2 +- .../plugins/precision/deepspeed.py | 2 +- .../precision/fsdp_native_native_amp.py | 2 +- .../plugins/precision/hpu.py | 2 +- .../plugins/precision/ipu.py | 2 +- .../plugins/precision/precision_plugin.py | 2 +- src/pytorch_lightning/profiler/advanced.py | 2 +- src/pytorch_lightning/profiler/profiler.py | 2 +- src/pytorch_lightning/profiler/pytorch.py | 2 +- src/pytorch_lightning/profiler/simple.py | 2 +- src/pytorch_lightning/profiler/xla.py | 2 +- src/pytorch_lightning/profilers/pytorch.py | 2 +- src/pytorch_lightning/strategies/bagua.py | 6 +- src/pytorch_lightning/strategies/ddp.py | 24 +- src/pytorch_lightning/strategies/ddp_spawn.py | 22 +- src/pytorch_lightning/strategies/deepspeed.py | 19 +- src/pytorch_lightning/strategies/dp.py | 2 +- .../strategies/fully_sharded.py | 4 +- .../strategies/fully_sharded_native.py | 19 +- src/pytorch_lightning/strategies/hivemind.py | 6 +- src/pytorch_lightning/strategies/horovod.py | 6 +- .../strategies/hpu_parallel.py | 2 +- src/pytorch_lightning/strategies/ipu.py | 2 +- .../strategies/launchers/multiprocessing.py | 4 +- src/pytorch_lightning/strategies/parallel.py | 12 +- src/pytorch_lightning/strategies/sharded.py | 4 +- .../strategies/sharded_spawn.py | 2 +- .../strategies/single_device.py | 2 +- .../strategies/single_hpu.py | 3 +- src/pytorch_lightning/strategies/strategy.py | 6 +- .../strategies/strategy_registry.py | 2 +- src/pytorch_lightning/strategies/tpu_spawn.py | 9 +- src/pytorch_lightning/strategies/utils.py | 2 +- src/pytorch_lightning/trainer/__init__.py | 2 +- .../trainer/configuration_validator.py | 2 +- .../connectors/accelerator_connector.py | 11 +- .../connectors/checkpoint_connector.py | 2 +- .../trainer/connectors/data_connector.py | 10 +- .../connectors/logger_connector/result.py | 2 +- src/pytorch_lightning/trainer/data_loading.py | 2 +- src/pytorch_lightning/trainer/optimizers.py | 2 +- src/pytorch_lightning/trainer/supporters.py | 2 +- src/pytorch_lightning/trainer/trainer.py | 9 +- .../tuner/auto_gpu_select.py | 2 +- src/pytorch_lightning/utilities/__init__.py | 10 +- .../utilities/auto_restart.py | 2 +- src/pytorch_lightning/utilities/cloud_io.py | 2 +- src/pytorch_lightning/utilities/data.py | 229 +------- src/pytorch_lightning/utilities/deepspeed.py | 2 +- .../utilities/device_parser.py | 351 ++---------- .../utilities/distributed.py | 338 +++--------- src/pytorch_lightning/utilities/enums.py | 80 +-- src/pytorch_lightning/utilities/exceptions.py | 4 +- src/pytorch_lightning/utilities/fetching.py | 2 +- src/pytorch_lightning/utilities/meta.py | 2 +- src/pytorch_lightning/utilities/optimizer.py | 31 +- src/pytorch_lightning/utilities/rank_zero.py | 39 +- src/pytorch_lightning/utilities/seed.py | 155 ++---- src/pytorch_lightning/utilities/types.py | 64 +-- .../utilities/upgrade_checkpoint.py | 2 +- src/pytorch_lightning/utilities/warnings.py | 12 +- src/pytorch_lightning/utilities/xla_device.py | 2 +- tests/tests_lite/conftest.py | 87 +++ tests/tests_lite/helpers/runif.py | 27 +- tests/tests_lite/helpers/utils.py | 31 ++ tests/tests_lite/utilities/test_data.py | 509 ++++++++++++++++++ .../utilities/test_device_parser.py | 31 ++ .../tests_lite/utilities/test_distributed.py | 63 +++ tests/tests_lite/utilities/test_enums.py | 9 + tests/tests_lite/utilities/test_imports.py | 81 +++ .../utilities/test_optimizer.py | 2 +- .../utilities/test_rank_zero.py | 19 +- tests/tests_lite/utilities/test_seed.py | 84 +++ tests/tests_lite/utilities/test_warnings.py | 78 +++ .../utilities/test_xla_device_utils.py | 6 +- .../tests_pytorch/accelerators/test_common.py | 2 +- .../core/test_metric_result_integration.py | 2 +- tests/tests_pytorch/core/test_results.py | 2 +- .../deprecated_api/test_remove_1-10.py | 127 ++++- .../deprecated_api/test_remove_1-8.py | 4 +- .../deprecated_api/test_remove_2-0.py | 4 +- tests/tests_pytorch/lite/test_lite.py | 4 +- tests/tests_pytorch/models/test_gpu.py | 6 +- tests/tests_pytorch/models/test_tpu.py | 2 +- .../overrides/test_distributed.py | 2 +- .../tests_pytorch/plugins/test_amp_plugins.py | 12 +- .../plugins/test_checkpoint_io_plugin.py | 2 +- .../plugins/test_cluster_integration.py | 4 +- .../strategies/test_bagua_strategy.py | 2 +- tests/tests_pytorch/strategies/test_common.py | 2 +- tests/tests_pytorch/strategies/test_ddp.py | 8 +- ..._ddp_fully_sharded_with_full_state_dict.py | 4 +- .../strategies/test_deepspeed_strategy.py | 2 +- tests/tests_pytorch/test_cli.py | 4 +- .../connectors/test_accelerator_connector.py | 52 +- .../trainer/connectors/test_data_connector.py | 2 +- .../trainer/flags/test_env_vars.py | 4 +- .../trainer/flags/test_min_max_epochs.py | 2 +- .../properties/test_auto_gpu_select.py | 4 +- .../test_estimated_stepping_batches.py | 2 +- .../trainer/test_config_validator.py | 4 +- .../tests_pytorch/trainer/test_dataloaders.py | 3 +- .../tests_pytorch/trainer/test_supporters.py | 4 +- tests/tests_pytorch/trainer/test_trainer.py | 8 +- .../tests_pytorch/trainer/test_trainer_cli.py | 3 +- .../utilities/test_all_gather_grad.py | 5 +- .../utilities/test_auto_restart.py | 3 +- tests/tests_pytorch/utilities/test_data.py | 361 +------------ .../utilities/test_device_parser.py | 2 +- .../utilities/test_distributed.py | 57 +- tests/tests_pytorch/utilities/test_enums.py | 10 +- tests/tests_pytorch/utilities/test_seed.py | 80 +-- tests/tests_pytorch/utilities/test_types.py | 2 +- .../tests_pytorch/utilities/test_warnings.py | 31 -- 161 files changed, 3120 insertions(+), 1922 deletions(-) create mode 100644 src/lightning_lite/utilities/data.py create mode 100644 src/lightning_lite/utilities/device_parser.py create mode 100644 src/lightning_lite/utilities/distributed.py create mode 100644 src/lightning_lite/utilities/enums.py create mode 100644 src/lightning_lite/utilities/exceptions.py create mode 100644 src/lightning_lite/utilities/imports.py create mode 100644 src/lightning_lite/utilities/optimizer.py create mode 100644 src/lightning_lite/utilities/rank_zero.py rename src/{pytorch_lightning => lightning_lite}/utilities/registry.py (100%) create mode 100644 src/lightning_lite/utilities/seed.py create mode 100644 src/lightning_lite/utilities/warnings.py create mode 100644 tests/tests_lite/helpers/utils.py create mode 100644 tests/tests_lite/utilities/test_data.py create mode 100644 tests/tests_lite/utilities/test_device_parser.py create mode 100644 tests/tests_lite/utilities/test_distributed.py create mode 100644 tests/tests_lite/utilities/test_enums.py create mode 100644 tests/tests_lite/utilities/test_imports.py rename tests/{tests_pytorch => tests_lite}/utilities/test_optimizer.py (93%) rename tests/{tests_pytorch => tests_lite}/utilities/test_rank_zero.py (65%) create mode 100644 tests/tests_lite/utilities/test_seed.py create mode 100644 tests/tests_lite/utilities/test_warnings.py diff --git a/.github/workflows/ci-lite-test-full.yml b/.github/workflows/ci-lite-test-full.yml index 896086b697..2830952e24 100644 --- a/.github/workflows/ci-lite-test-full.yml +++ b/.github/workflows/ci-lite-test-full.yml @@ -88,6 +88,13 @@ jobs: pip list shell: bash + - name: Testing Warnings + # the stacklevel can only be set on >=3.7 + if: matrix.python-version != '3.7' + working-directory: tests/tests_lite + # needs to run outside of `pytest` + run: python utilities/test_warnings.py + - name: Testing Lite working-directory: tests/tests_lite # NOTE: do not include coverage report here, see: https://github.com/nedbat/coveragepy/issues/1003 diff --git a/docs/source-pytorch/extensions/logging.rst b/docs/source-pytorch/extensions/logging.rst index f7fb3cfd6f..109445779f 100644 --- a/docs/source-pytorch/extensions/logging.rst +++ b/docs/source-pytorch/extensions/logging.rst @@ -231,7 +231,7 @@ Use the :func:`~pytorch_lightning.loggers.logger.rank_zero_experiment` and :func .. testcode:: from pytorch_lightning.loggers.logger import Logger, rank_zero_experiment - from pytorch_lightning.utilities.distributed import rank_zero_only + from pytorch_lightning.utilities import rank_zero_only class MyLogger(Logger): diff --git a/examples/pl_basics/autoencoder.py b/examples/pl_basics/autoencoder.py index 0fd9ddae18..ae8c7b6611 100644 --- a/examples/pl_basics/autoencoder.py +++ b/examples/pl_basics/autoencoder.py @@ -26,8 +26,8 @@ from torch.utils.data import DataLoader, random_split from pytorch_lightning import callbacks, cli_lightning_logo, LightningDataModule, LightningModule, Trainer from pytorch_lightning.cli import LightningCLI from pytorch_lightning.demos.mnist_datamodule import MNIST +from pytorch_lightning.utilities import rank_zero_only from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE -from pytorch_lightning.utilities.rank_zero import rank_zero_only if _TORCHVISION_AVAILABLE: import torchvision diff --git a/examples/pl_domain_templates/computer_vision_fine_tuning.py b/examples/pl_domain_templates/computer_vision_fine_tuning.py index b33d63eb65..7a81df9839 100644 --- a/examples/pl_domain_templates/computer_vision_fine_tuning.py +++ b/examples/pl_domain_templates/computer_vision_fine_tuning.py @@ -57,7 +57,7 @@ from torchvision.datasets.utils import download_and_extract_archive from pytorch_lightning import cli_lightning_logo, LightningDataModule, LightningModule from pytorch_lightning.callbacks.finetuning import BaseFinetuning from pytorch_lightning.cli import LightningCLI -from pytorch_lightning.utilities.rank_zero import rank_zero_info +from pytorch_lightning.utilities import rank_zero_info log = logging.getLogger(__name__) DATA_URL = "https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip" diff --git a/pyproject.toml b/pyproject.toml index 5b62baf9ce..166447dd65 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,7 @@ exclude = '(_notebooks/.*)' [tool.mypy] files = [ "src/pytorch_lightning", + "src/lightning_lite", # TODO: Check typing in app source # "src/lightning_app", ] @@ -57,5 +58,6 @@ module = [ "pytorch_lightning.trainer.trainer", "pytorch_lightning.tuner.batch_size_scaling", "pytorch_lightning.utilities.data", + "lightning_lite.utilities.data", ] ignore_errors = "True" diff --git a/requirements/lite/base.txt b/requirements/lite/base.txt index 4dbc213afe..eb130bc254 100644 --- a/requirements/lite/base.txt +++ b/requirements/lite/base.txt @@ -1,6 +1,7 @@ # NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package # in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment +numpy>=1.17.2, <1.23.1 torch>=1.9.*, <1.13.0 fsspec[http]>=2021.05.0, !=2021.06.0, <2022.6.0 packaging>=17.0, <=21.3 diff --git a/src/lightning_lite/__init__.py b/src/lightning_lite/__init__.py index 5e0d0ad5cb..6c16dcbf6c 100644 --- a/src/lightning_lite/__init__.py +++ b/src/lightning_lite/__init__.py @@ -1,4 +1,21 @@ """Root package info.""" +import logging from lightning_lite.__about__ import * # noqa: F401, F403 from lightning_lite.__version__ import version as __version__ # noqa: F401 + +_root_logger = logging.getLogger() +_logger = logging.getLogger(__name__) +_logger.setLevel(logging.INFO) + +if not _root_logger.hasHandlers(): + _logger.addHandler(logging.StreamHandler()) + _logger.propagate = False + +from lightning_lite.lite import LightningLite # noqa: E402 +from lightning_lite.utilities.seed import seed_everything # noqa: E402 + +__all__ = ["LightningLite", "seed_everything"] + +# for compatibility with namespace packages +__import__("pkg_resources").declare_namespace(__name__) diff --git a/src/lightning_lite/utilities/__init__.py b/src/lightning_lite/utilities/__init__.py index e69de29bb2..edeab0cd5d 100644 --- a/src/lightning_lite/utilities/__init__.py +++ b/src/lightning_lite/utilities/__init__.py @@ -0,0 +1,40 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""General utilities.""" + +from lightning_lite.utilities.apply_func import move_data_to_device # noqa: F401 +from lightning_lite.utilities.distributed import AllGatherGrad # noqa: F401 +from lightning_lite.utilities.enums import _AcceleratorType, _StrategyType, AMPType, LightningEnum # noqa: F401 + +# TODO(lite): Avoid importing protected attributes in `__init__.py` files +from lightning_lite.utilities.imports import ( # noqa: F401 + _HIVEMIND_AVAILABLE, + _HOROVOD_AVAILABLE, + _HPU_AVAILABLE, + _IPU_AVAILABLE, + _IS_INTERACTIVE, + _IS_WINDOWS, + _POPTORCH_AVAILABLE, + _TORCH_GREATER_EQUAL_1_10, + _TORCH_GREATER_EQUAL_1_11, + _TORCH_GREATER_EQUAL_1_12, + _TPU_AVAILABLE, + _XLA_AVAILABLE, +) +from lightning_lite.utilities.rank_zero import ( # noqa: F401 + rank_zero_deprecation, + rank_zero_info, + rank_zero_only, + rank_zero_warn, +) diff --git a/src/lightning_lite/utilities/cloud_io.py b/src/lightning_lite/utilities/cloud_io.py index 99629bcda8..bdc20f7e3f 100644 --- a/src/lightning_lite/utilities/cloud_io.py +++ b/src/lightning_lite/utilities/cloud_io.py @@ -22,7 +22,7 @@ import torch from fsspec.core import url_to_fs from fsspec.implementations.local import AbstractFileSystem -from pytorch_lightning.utilities.types import _MAP_LOCATION_TYPE, _PATH +from lightning_lite.utilities.types import _MAP_LOCATION_TYPE, _PATH def load( diff --git a/src/lightning_lite/utilities/data.py b/src/lightning_lite/utilities/data.py new file mode 100644 index 0000000000..cdaf806a0c --- /dev/null +++ b/src/lightning_lite/utilities/data.py @@ -0,0 +1,411 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +import inspect +import os +from collections import OrderedDict +from contextlib import contextmanager +from functools import partial +from typing import Any, Callable, Dict, Generator, Iterable, Optional, Tuple, Type, Union + +from lightning_utilities.core.inheritance import get_all_subclasses +from torch.utils.data import BatchSampler, DataLoader, IterableDataset, Sampler + +from lightning_lite.utilities.enums import LightningEnum +from lightning_lite.utilities.exceptions import MisconfigurationException +from lightning_lite.utilities.rank_zero import rank_zero_warn +from lightning_lite.utilities.seed import pl_worker_init_function + + +class _WrapAttrTag(LightningEnum): + SET = "set" + DEL = "del" + + def __call__(self, *args): + if self == self.SET: + fn = setattr + else: + fn = delattr + return fn(*args) + + +def has_iterable_dataset(dataloader: DataLoader) -> bool: + return hasattr(dataloader, "dataset") and isinstance(dataloader.dataset, IterableDataset) + + +def has_len(dataloader: Union[DataLoader, Iterable]) -> bool: + """Checks if a given Dataloader has ``__len__`` method implemented i.e. if it is a finite dataloader or + infinite dataloader.""" + try: + # try getting the length + if len(dataloader) == 0: + rank_zero_warn( + f"`{dataloader.__class__.__name__}` returned 0 length. Please make sure this was your intention." + ) + has_len = True + except (TypeError, NotImplementedError): + has_len = False + + if has_len and has_iterable_dataset(dataloader): + rank_zero_warn( + "Your `IterableDataset` has `__len__` defined." + " In combination with multi-process data loading (when num_workers > 1)," + " `__len__` could be inaccurate if each worker is not configured independently" + " to avoid having duplicate data." + ) + return has_len + + +def _update_dataloader(dataloader: DataLoader, sampler: Union[Sampler, Iterable]) -> DataLoader: + dl_args, dl_kwargs = _get_dataloader_init_args_and_kwargs(dataloader, sampler) + dataloader = _reinstantiate_wrapped_cls(dataloader, *dl_args, **dl_kwargs) + return dataloader + + +def _get_dataloader_init_args_and_kwargs( + dataloader: DataLoader, + sampler: Optional[Sampler], + disallow_batch_sampler: bool = False, +) -> Tuple[Tuple[Any], Dict[str, Any]]: + if not isinstance(dataloader, DataLoader): + raise ValueError(f"The dataloader {dataloader} needs to subclass `torch.utils.data.DataLoader`") + + was_wrapped = hasattr(dataloader, "__pl_saved_args") + if was_wrapped: + dl_args = dataloader.__pl_saved_args + dl_kwargs = dataloader.__pl_saved_kwargs + arg_names = dataloader.__pl_saved_arg_names + original_dataset = dataloader.__dataset # we have this saved from _wrap_init + else: + # get the dataloader instance attributes + attrs = {k: v for k, v in vars(dataloader).items() if not k.startswith("_")} + # We cannot be 100% sure the class sets dataset argument. Let's set it to None to be safe + # and hope we can get it from the instance attributes + original_dataset = None + # not part of `vars` + attrs["multiprocessing_context"] = dataloader.multiprocessing_context + arg_names = () + + # get the dataloader instance `__init__` parameters + params = dict(inspect.signature(dataloader.__init__).parameters) + has_variadic_kwargs = any(p.kind is p.VAR_KEYWORD for p in params.values()) + if has_variadic_kwargs: + # if the signature takes **kwargs, assume they will be passed down with `super().__init__(**kwargs)` + + if was_wrapped: + # if the dataloader was wrapped in a hook, only take arguments with default values + # and assume user passes their kwargs correctly + params.update( + {k: v for k, v in inspect.signature(DataLoader.__init__).parameters.items() if v.default is not v.empty} + ) + else: + params.update(inspect.signature(DataLoader.__init__).parameters) + params.pop("self", None) + + if not was_wrapped: + # keep only the params whose default is different to the current attr value + non_defaults = {name for name, p in params.items() if name in attrs and p.default != attrs[name]} + + # add `dataset` as it might have been replaced with `*args` + non_defaults.add("dataset") + # kwargs to re-construct the dataloader + dl_kwargs = {k: v for k, v in attrs.items() if k in non_defaults} + dl_args = () + + dataset = dl_kwargs.get("dataset", original_dataset) + if isinstance(dataset, IterableDataset): + dl_kwargs["batch_sampler"] = None + dl_kwargs["sampler"] = None + else: + dl_kwargs.update(_dataloader_init_kwargs_resolve_sampler(dataloader, sampler, disallow_batch_sampler)) + + required_args = { + p.name + for p in params.values() + if p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD) + and p.default is p.empty + and p.name not in dl_kwargs + and p.name not in arg_names + } + # the dataloader has required args which we could not extract from the existing attributes + if required_args: + required_args = sorted(required_args) + dataloader_cls_name = dataloader.__class__.__name__ + missing_args_message = ", ".join(f"`self.{arg_name}`" for arg_name in required_args) + raise MisconfigurationException( + f"Trying to inject custom `Sampler` into the `{dataloader_cls_name}` instance. " + "This would fail as some of the `__init__` arguments are not available as instance attributes. " + f"The missing attributes are {required_args}. If you instantiate your `{dataloader_cls_name}` inside a " + "`*_dataloader` hook of your module, we will do this for you." + f" Otherwise, define {missing_args_message} inside your `__init__`." + ) + + if not has_variadic_kwargs: + # the dataloader signature does not allow keyword arguments that need to be passed + missing_kwargs = (set(dl_kwargs) | set(arg_names)) - params.keys() + if missing_kwargs: + missing_kwargs = sorted(missing_kwargs) + dataloader_cls_name = dataloader.__class__.__name__ + raise TypeError( + f"Trying to inject parameters into the `{dataloader_cls_name}` instance. " + "This would fail as it doesn't expose all its attributes in the `__init__` signature. " + f"The missing arguments are {missing_kwargs}. HINT: If you wrote the `{dataloader_cls_name}` class, " + "add the `__init__` arguments or allow passing `**kwargs`" + ) + + return dl_args, dl_kwargs + + +def _dataloader_init_kwargs_resolve_sampler( + dataloader: DataLoader, + sampler: Optional[Sampler], + disallow_batch_sampler: bool = False, +) -> Dict[str, Any]: + """This function is used to handle the sampler, batch_sampler arguments associated within a DataLoader for its + re-instantiation. + + If there are multiple devices in IPU mode, it is necessary to disallow BatchSampler that isn't instantiated + automatically, since `poptorch.DataLoader` will try to increase the batch_size + """ + batch_sampler = getattr(dataloader, "batch_sampler") + + if batch_sampler is not None: + if disallow_batch_sampler: + # Check that we don't have a PyTorch default batch sampler that was instantiated in DataLoader __init__ + if not ( + type(batch_sampler) is BatchSampler + and batch_sampler.sampler == sampler + and dataloader.batch_size == batch_sampler.batch_size + ): + raise MisconfigurationException( + "It is not possible to have a batch sampler in your dataloader, " + "when running on multiple IPU devices." + ) + elif type(batch_sampler) is not BatchSampler: + batch_sampler_cls = type(batch_sampler) + if hasattr(batch_sampler, "__pl_saved_args"): + args = batch_sampler.__pl_saved_args + kwargs = batch_sampler.__pl_saved_kwargs + default_kwargs = batch_sampler.__pl_saved_default_kwargs + arg_names = batch_sampler.__pl_saved_arg_names + + success, args, kwargs = _replace_value_in_saved_args( + "sampler", sampler, args, kwargs, default_kwargs, arg_names + ) + if not success: + raise TypeError( + "Trying to inject a modified sampler into the batch sampler; however, it seems the class " + f"`{batch_sampler_cls.__qualname__}` does not have an argument called `sampler.` To mitigate " + "this, expose an argument `sampler` in the `__init__` method of your custom class." + ) + + batch_sampler = _reinstantiate_wrapped_cls(batch_sampler, *args, **kwargs) + else: + try: + batch_sampler = batch_sampler_cls( + sampler, + batch_size=batch_sampler.batch_size, + drop_last=batch_sampler.drop_last, + ) + except TypeError as e: + import re + + match = re.match(r".*__init__\(\) (got multiple values)|(missing \d required)", str(e)) + if not match: + # an unexpected `TypeError`, continue failure + raise + + # There could either be too few or too many arguments. Customizing the message based on this doesn't + # make much sense since our MisconfigurationException is going to be raised from the original one. + raise TypeError( + "We tried to re-instantiate your custom batch sampler and failed. " + "To mitigate this, either follow the API of `BatchSampler` or instantiate " + "your custom batch sampler inside `*_dataloader` hooks of your module." + ) from e + + return { + "sampler": None, + "shuffle": False, + "batch_sampler": batch_sampler, + "batch_size": 1, + "drop_last": False, + } + + return {"sampler": sampler, "shuffle": False, "batch_sampler": None} + + +def _auto_add_worker_init_fn(dataloader: DataLoader, rank: int) -> None: + if int(os.environ.get("PL_SEED_WORKERS", 0)) and dataloader.worker_init_fn is None: + dataloader.worker_init_fn = partial(pl_worker_init_function, rank=rank) + + +def _reinstantiate_wrapped_cls(orig_object: Any, *args: Any, explicit_cls: Optional[Type] = None, **kwargs: Any) -> Any: + constructor = type(orig_object) if explicit_cls is None else explicit_cls + + try: + result = constructor(*args, **kwargs) + except TypeError as e: + # improve exception message due to an incorrect implementation of the `DataLoader` where multiple subclass + # `__init__` arguments map to one `DataLoader.__init__` argument + import re + + match = re.match(r".*__init__\(\) got multiple values .* '(\w+)'", str(e)) + if not match: + # an unexpected `TypeError`, continue failure + raise + argument = match.groups()[0] + message = ( + f"The {constructor.__name__} implementation has an error where more than one `__init__` argument" + f" can be passed to its parent's `{argument}=...` `__init__` argument. This is likely caused by allowing" + f" passing both a custom argument that will map to the `{argument}` argument as well as `**kwargs`." + f" `kwargs` should be filtered to make sure they don't contain the `{argument}` key." + " This argument was automatically passed to your object by PyTorch Lightning." + ) + raise MisconfigurationException(message) from e + + attrs_record = getattr(orig_object, "__pl_attrs_record", list()) + for args, fn in attrs_record: + fn(result, *args) + + return result + + +def _wrap_init_method(init: Callable, store_explicit_arg: Optional[str] = None) -> Callable: + """Wraps the ``__init__`` method of classes (currently :class:`~torch.utils.data.DataLoader` and + :class:`~torch.utils.data.BatchSampler`) in order to enable re-instantiation of custom subclasses.""" + + @functools.wraps(init) + def wrapper(obj: Any, *args: Any, **kwargs: Any) -> None: + # We need to inspect `init`, as inspecting `obj.__init__` + # can lead to inspecting the wrong function with multiple inheritance + old_inside_init = getattr(obj, "__pl_inside_init", False) + object.__setattr__(obj, "__pl_inside_init", True) + params = inspect.signature(init).parameters + + parameters_defaults = OrderedDict( + (param.name, param.default) + for param in params.values() + if param.name != "self" and param.kind not in (param.VAR_POSITIONAL, param.VAR_KEYWORD) + ) + + param_names = tuple(parameters_defaults)[: len(args)] + + default_kwargs = { + name: value + for name, value in parameters_defaults.items() + if name not in kwargs and name not in param_names and value != inspect.Parameter.empty + } + + if not hasattr(obj, "__pl_saved_args"): + object.__setattr__(obj, "__pl_saved_args", args) + object.__setattr__(obj, "__pl_saved_kwargs", kwargs) + object.__setattr__(obj, "__pl_saved_arg_names", param_names) + object.__setattr__(obj, "__pl_saved_default_kwargs", default_kwargs) + + # We want to use the latest possible value for explicit argument (i.e. ideally what gets passed to base class) + # so that we can be sure, that it will not get changed anymore. + # That is why we are setting this in every `__init__` + if store_explicit_arg is not None: + if store_explicit_arg in param_names: + object.__setattr__(obj, f"__{store_explicit_arg}", args[param_names.index(store_explicit_arg)]) + elif store_explicit_arg in kwargs: + object.__setattr__(obj, f"__{store_explicit_arg}", kwargs[store_explicit_arg]) + + init(obj, *args, **kwargs) + object.__setattr__(obj, "__pl_inside_init", old_inside_init) + + return wrapper + + +def _wrap_attr_method(method: Callable, tag: _WrapAttrTag) -> Callable: + """Wraps the ``__setattr__`` or ``__delattr__`` method of classes (currently :class:`~torch.utils.data.DataLoader` and + :class:`~torch.utils.data.BatchSampler`) in order to enable re-instantiation of custom subclasses.""" + + @functools.wraps(method) + def wrapper(obj: Any, *args: Any): + # First, let's find out if we're the first in inheritance chain calling the patched method. + name, *_ = args + prev_call_name, prev_call_method = getattr(obj, "__pl_current_call", (None, "method")) + first_call = not (prev_call_name == name and prev_call_method == tag) + + # Then mark the current called method + object.__setattr__(obj, "__pl_current_call", (name, tag)) + + # call original method + method(obj, *args) + if first_call and not getattr(obj, "__pl_inside_init", True): + # and save the value it was called with to the internal list, + # if we're outside of __init__ and the original call did not fail and we're the first call + attrs_record = getattr(obj, "__pl_attrs_record", list()) + attrs_record.append((args, tag)) + object.__setattr__(obj, "__pl_attrs_record", attrs_record) + object.__setattr__(obj, "__pl_current_call", (prev_call_name, prev_call_method)) + + return wrapper + + +@contextmanager +def _replace_dunder_methods(base_cls: Type, store_explicit_arg: Optional[str] = None) -> Generator[None, None, None]: + """This context manager is used to add support for re-instantiation of custom (subclasses) of `base_cls`. + + It patches the ``__init__``, ``__setattr__`` and ``__delattr__`` methods. + """ + classes = get_all_subclasses(base_cls) | {base_cls} + for cls in classes: + # Check that __init__ belongs to the class + # https://stackoverflow.com/a/5253424 + if "__init__" in cls.__dict__: + cls.__old__init__ = cls.__init__ + cls.__init__ = _wrap_init_method(cls.__init__, store_explicit_arg) + + # we want at least one setattr/delattr in the chain to be patched and it can happen, that none of the subclasses + # implement `__setattr__`/`__delattr__`. Therefore, we are always patching the `base_cls` + for patch_fn_name, tag in (("__setattr__", _WrapAttrTag.SET), ("__delattr__", _WrapAttrTag.DEL)): + if patch_fn_name in cls.__dict__ or cls is base_cls: + saved_name = f"__old{patch_fn_name}" + setattr(cls, saved_name, getattr(cls, patch_fn_name)) + setattr(cls, patch_fn_name, _wrap_attr_method(getattr(cls, patch_fn_name), tag)) + yield + for cls in classes: + for patched_name in ("__setattr__", "__delattr__", "__init__"): + # Check that __old__{init,setattr,delattr} belongs to the class + # https://stackoverflow.com/a/5253424 + if f"__old{patched_name}" in cls.__dict__: + setattr(cls, patched_name, getattr(cls, f"__old{patched_name}")) + delattr(cls, f"__old{patched_name}") + + +def _replace_value_in_saved_args( + replace_key: str, + replace_value: Any, + args: Tuple[Any, ...], + kwargs: Dict[str, Any], + default_kwargs: Dict[str, Any], + arg_names: Tuple[str, ...], +) -> Tuple[bool, Tuple[Any, ...], Dict[str, Any]]: + """Tries to replace an argument value in a saved list of args and kwargs. + + Returns a tuple indicating success of the operation and modified saved args and kwargs + """ + + if replace_key in arg_names: + replace_index = arg_names.index(replace_key) + args = args[:replace_index] + (replace_value,) + args[replace_index + 1 :] + return True, args, kwargs + elif replace_key in kwargs or replace_key in default_kwargs: + kwargs[replace_key] = replace_value + return True, args, kwargs + + return False, args, kwargs diff --git a/src/lightning_lite/utilities/device_parser.py b/src/lightning_lite/utilities/device_parser.py new file mode 100644 index 0000000000..78bf8a9a8c --- /dev/null +++ b/src/lightning_lite/utilities/device_parser.py @@ -0,0 +1,316 @@ +import multiprocessing +import os +from typing import Any, List, MutableSequence, Optional, Tuple, Union + +import torch + +# TODO(lite): Fix the imports +# from lightning_lite.plugins.environments import TorchElasticEnvironment +# from lightning_lite.strategies.launchers.multiprocessing import _is_forking_disabled +from lightning_lite.utilities.exceptions import MisconfigurationException +from lightning_lite.utilities.types import _DEVICE + + +def determine_root_gpu_device(gpus: List[_DEVICE]) -> Optional[_DEVICE]: + """ + Args: + gpus: Non-empty list of ints representing which GPUs to use + + Returns: + Designated root GPU device id + + Raises: + TypeError: + If ``gpus`` is not a list + AssertionError: + If GPU list is empty + """ + if gpus is None: + return None + + if not isinstance(gpus, list): + raise TypeError("GPUs should be a list") + + assert len(gpus) > 0, "GPUs should be a non-empty list" + + # set root gpu + root_gpu = gpus[0] + + return root_gpu + + +def parse_gpu_ids( + gpus: Optional[Union[int, str, List[int]]], + include_cuda: bool = False, + include_mps: bool = False, +) -> Optional[List[int]]: + """ + Parses the GPU IDs given in the format as accepted by the + :class:`~pytorch_lightning.trainer.Trainer`. + + Args: + gpus: An int -1 or string '-1' indicate that all available GPUs should be used. + A list of unique ints or a string containing a list of comma separated unique integers + indicates specific GPUs to use. + An int of 0 means that no GPUs should be used. + Any int N > 0 indicates that GPUs [0..N) should be used. + include_cuda: A boolean value indicating whether to include CUDA devices for GPU parsing. + include_mps: A boolean value indicating whether to include MPS devices for GPU parsing. + + Returns: + A list of GPUs to be used or ``None`` if no GPUs were requested + + Raises: + MisconfigurationException: + If no GPUs are available but the value of gpus variable indicates request for GPUs + + .. note:: + ``include_cuda`` and ``include_mps`` default to ``False`` so that you only + have to specify which device type to use and all other devices are not disabled. + """ + # Check that gpus param is None, Int, String or Sequence of Ints + _check_data_type(gpus) + + # Handle the case when no GPUs are requested + if gpus is None or (isinstance(gpus, int) and gpus == 0) or str(gpus).strip() in ("0", "[]"): + return None + + # We know the user requested GPUs therefore if some of the + # requested GPUs are not available an exception is thrown. + gpus = _normalize_parse_gpu_string_input(gpus) + gpus = _normalize_parse_gpu_input_to_list(gpus, include_cuda=include_cuda, include_mps=include_mps) + if not gpus: + raise MisconfigurationException("GPUs requested but none are available.") + + if ( + True # TorchElasticEnvironment.detect() # TODO(lite): Revert this once environments have moved + and len(gpus) != 1 + and len(_get_all_available_gpus(include_cuda=include_cuda, include_mps=include_mps)) == 1 + ): + # Omit sanity check on torchelastic because by default it shows one visible GPU per process + return gpus + + # Check that GPUs are unique. Duplicate GPUs are not supported by the backend. + _check_unique(gpus) + + return _sanitize_gpu_ids(gpus, include_cuda=include_cuda, include_mps=include_mps) + + +def parse_tpu_cores(tpu_cores: Optional[Union[int, str, List[int]]]) -> Optional[Union[int, List[int]]]: + """ + Parses the tpu_cores given in the format as accepted by the + :class:`~pytorch_lightning.trainer.Trainer`. + + Args: + tpu_cores: An int of 1 or string '1' indicates that 1 core with multi-processing should be used + An int 8 or string '8' indicates that all 8 cores with multi-processing should be used + A list of ints or a strings containing a list of comma separated integers + indicates the specific TPU core to use. + + Returns: + A list of tpu_cores to be used or ``None`` if no TPU cores were requested + + Raises: + MisconfigurationException: + If TPU cores aren't 1, 8 or [<1-8>] + """ + _check_data_type(tpu_cores) + + if isinstance(tpu_cores, str): + tpu_cores = _parse_tpu_cores_str(tpu_cores.strip()) + + if not _tpu_cores_valid(tpu_cores): + raise MisconfigurationException("`tpu_cores` can only be 1, 8 or [<1-8>]") + + return tpu_cores + + +def parse_cpu_cores(cpu_cores: Union[int, str, List[int]]) -> int: + """Parses the cpu_cores given in the format as accepted by the ``devices`` argument in the + :class:`~pytorch_lightning.trainer.Trainer`. + + Args: + cpu_cores: An int > 0. + + Returns: + An int representing the number of processes + + Raises: + MisconfigurationException: + If cpu_cores is not an int > 0 + """ + if isinstance(cpu_cores, str) and cpu_cores.strip().isdigit(): + cpu_cores = int(cpu_cores) + + if not isinstance(cpu_cores, int) or cpu_cores <= 0: + raise MisconfigurationException("`devices` selected with `CPUAccelerator` should be an int > 0.") + + return cpu_cores + + +def _normalize_parse_gpu_string_input(s: Union[int, str, List[int]]) -> Union[int, List[int]]: + if not isinstance(s, str): + return s + if s == "-1": + return -1 + if "," in s: + return [int(x.strip()) for x in s.split(",") if len(x) > 0] + return int(s.strip()) + + +def _sanitize_gpu_ids(gpus: List[int], include_cuda: bool = False, include_mps: bool = False) -> List[int]: + """Checks that each of the GPUs in the list is actually available. Raises a MisconfigurationException if any of + the GPUs is not available. + + Args: + gpus: List of ints corresponding to GPU indices + + Returns: + Unmodified gpus variable + + Raises: + MisconfigurationException: + If machine has fewer available GPUs than requested. + """ + if sum((include_cuda, include_mps)) == 0: + raise ValueError("At least one gpu type should be specified!") + all_available_gpus = _get_all_available_gpus(include_cuda=include_cuda, include_mps=include_mps) + for gpu in gpus: + if gpu not in all_available_gpus: + raise MisconfigurationException( + f"You requested gpu: {gpus}\n But your machine only has: {all_available_gpus}" + ) + return gpus + + +def _normalize_parse_gpu_input_to_list( + gpus: Union[int, List[int], Tuple[int, ...]], include_cuda: bool, include_mps: bool +) -> Optional[List[int]]: + assert gpus is not None + if isinstance(gpus, (MutableSequence, tuple)): + return list(gpus) + + # must be an int + if not gpus: # gpus==0 + return None + if gpus == -1: + return _get_all_available_gpus(include_cuda=include_cuda, include_mps=include_mps) + + return list(range(gpus)) + + +def _get_all_available_gpus(include_cuda: bool = False, include_mps: bool = False) -> List[int]: + """ + Returns: + A list of all available GPUs + """ + cuda_gpus = _get_all_available_cuda_gpus() if include_cuda else [] + mps_gpus = _get_all_available_mps_gpus() if include_mps else [] + return cuda_gpus + mps_gpus + + +def _get_all_available_mps_gpus() -> List[int]: + """ + Returns: + A list of all available MPS GPUs + """ + # lazy import to avoid circular dependencies + # from lightning_lite.accelerators.mps import _MPS_AVAILABLE + _MPS_AVAILABLE = False # TODO(lite): revert this once MPS utils have moved + return [0] if _MPS_AVAILABLE else [] + + +def _get_all_available_cuda_gpus() -> List[int]: + """ + Returns: + A list of all available CUDA GPUs + """ + return list(range(num_cuda_devices())) + + +def _check_unique(device_ids: List[int]) -> None: + """Checks that the device_ids are unique. + + Args: + device_ids: List of ints corresponding to GPUs indices + + Raises: + MisconfigurationException: + If ``device_ids`` of GPUs aren't unique + """ + if len(device_ids) != len(set(device_ids)): + raise MisconfigurationException("Device ID's (GPU) must be unique.") + + +def _check_data_type(device_ids: Any) -> None: + """Checks that the device_ids argument is one of the following: None, int, string, or sequence of integers. + + Args: + device_ids: gpus/tpu_cores parameter as passed to the Trainer + + Raises: + MisconfigurationException: + If ``device_ids`` of GPU/TPUs aren't ``int``, ``str``, sequence of ``int`` or ``None`` + """ + msg = "Device IDs (GPU/TPU) must be an int, a string, a sequence of ints or None, but you passed" + + if device_ids is None: + return + elif isinstance(device_ids, (MutableSequence, tuple)): + for id_ in device_ids: + if type(id_) is not int: + raise MisconfigurationException(f"{msg} a sequence of {type(id_).__name__}.") + elif type(device_ids) not in (int, str): + raise MisconfigurationException(f"{msg} {type(device_ids).__name__}.") + + +def _tpu_cores_valid(tpu_cores: Any) -> bool: + # allow 1 or 8 cores + if tpu_cores in (1, 8, None): + return True + + # allow picking 1 of 8 indexes + if isinstance(tpu_cores, (list, tuple, set)): + has_1_tpu_idx = len(tpu_cores) == 1 + is_valid_tpu_idx = 1 <= list(tpu_cores)[0] <= 8 + + is_valid_tpu_core_choice = has_1_tpu_idx and is_valid_tpu_idx + return is_valid_tpu_core_choice + + return False + + +def _parse_tpu_cores_str(tpu_cores: str) -> Union[int, List[int]]: + if tpu_cores in ("1", "8"): + return int(tpu_cores) + return [int(x.strip()) for x in tpu_cores.split(",") if len(x) > 0] + + +def num_cuda_devices() -> int: + """Returns the number of GPUs available. + + Unlike :func:`torch.cuda.device_count`, this function does its best not to create a CUDA context for fork support, + if the platform allows it. + """ + if "fork" not in torch.multiprocessing.get_all_start_methods() or _is_forking_disabled(): + return torch.cuda.device_count() + with multiprocessing.get_context("fork").Pool(1) as pool: + return pool.apply(torch.cuda.device_count) + + +def is_cuda_available() -> bool: + """Returns a bool indicating if CUDA is currently available. + + Unlike :func:`torch.cuda.is_available`, this function does its best not to create a CUDA context for fork support, + if the platform allows it. + """ + if "fork" not in torch.multiprocessing.get_all_start_methods() or _is_forking_disabled(): + return torch.cuda.is_available() + with multiprocessing.get_context("fork").Pool(1) as pool: + return pool.apply(torch.cuda.is_available) + + +# TODO(lite): move this back to launchers/multiprocessing.py once launchers have moved +def _is_forking_disabled() -> bool: + """Returns whether forking is disabled through the environment variable ``PL_DISABLE_FORK``.""" + return bool(int(os.environ.get("PL_DISABLE_FORK", "0"))) diff --git a/src/lightning_lite/utilities/distributed.py b/src/lightning_lite/utilities/distributed.py new file mode 100644 index 0000000000..77123c53ff --- /dev/null +++ b/src/lightning_lite/utilities/distributed.py @@ -0,0 +1,264 @@ +import logging +import os +from typing import Any, List, Optional, Tuple, Union + +import torch +from torch import Tensor +from torch.nn import functional as F + +from lightning_lite.utilities.imports import _HPU_AVAILABLE, _TPU_AVAILABLE +from lightning_lite.utilities.rank_zero import rank_zero_deprecation +from lightning_lite.utilities.rank_zero import rank_zero_info as new_rank_zero_info + +if _TPU_AVAILABLE: + import torch_xla.core.xla_model as xm + + +if torch.distributed.is_available(): + from torch.distributed import group, ReduceOp +else: + + class ReduceOp: # type: ignore # (see https://github.com/python/mypy/issues/1153) + SUM = None + + class group: # type: ignore + WORLD = None + + +log = logging.getLogger(__name__) + + +def gather_all_tensors(result: Tensor, group: Optional[Any] = None) -> List[Tensor]: + """Function to gather all tensors from several DDP processes onto a list that is broadcasted to all processes. + + Works on tensors that have the same number of dimensions, but where each dimension may differ. In this case + tensors are padded, gathered and then trimmed to secure equal workload for all processes. + + Args: + result: The value to sync + group: The process group to gather results from. Defaults to all processes (world) + + Return: + gathered_result: List with size equal to the process group where + gathered_result[i] corresponds to result tensor from process i + """ + if group is None: + group = torch.distributed.group.WORLD + + # Convert tensors to contiguous format + result = result.contiguous() + + world_size = torch.distributed.get_world_size(group) + torch.distributed.barrier(group=group) + + # If the tensor is scalar, things are easy + if result.ndim == 0: + return _simple_gather_all_tensors(result, group, world_size) + + # 1. Gather sizes of all tensors + local_size = torch.tensor(result.shape, device=result.device) + local_sizes = [torch.zeros_like(local_size) for _ in range(world_size)] + torch.distributed.all_gather(local_sizes, local_size, group=group) + max_size = torch.stack(local_sizes).max(dim=0).values + all_sizes_equal = all(all(ls == max_size) for ls in local_sizes) + + # 2. If shapes are all the same, then do a simple gather: + if all_sizes_equal: + return _simple_gather_all_tensors(result, group, world_size) + + # 3. If not, we need to pad each local tensor to maximum size, gather and then truncate + pad_dims = [] + pad_by = (max_size - local_size).detach().cpu() + for val in reversed(pad_by): + pad_dims.append(0) + pad_dims.append(val.item()) + result_padded = F.pad(result, pad_dims) + gathered_result = [torch.zeros_like(result_padded) for _ in range(world_size)] + torch.distributed.all_gather(gathered_result, result_padded, group) + for idx, item_size in enumerate(local_sizes): + slice_param = [slice(dim_size) for dim_size in item_size] + gathered_result[idx] = gathered_result[idx][slice_param] + return gathered_result + + +def _simple_gather_all_tensors(result: Tensor, group: Any, world_size: int) -> List[Tensor]: + gathered_result = [torch.zeros_like(result) for _ in range(world_size)] + torch.distributed.all_gather(gathered_result, result, group) + return gathered_result + + +def distributed_available() -> bool: + return torch.distributed.is_available() and torch.distributed.is_initialized() or tpu_distributed() + + +def sync_ddp_if_available( + result: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None +) -> Tensor: + """Function to reduce a tensor across worker processes during distributed training. + + Args: + result: The value to sync and reduce (typically tensor or number) + group: The process group to gather results from. Defaults to all processes (world) + reduce_op: The reduction operation. Defaults to sum. + Can also be a string of 'avg', 'mean' to calculate the mean during reduction. + + Return: + reduced value + """ + if distributed_available(): + return sync_ddp(result, group=group, reduce_op=reduce_op) + return result + + +def sync_ddp(result: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None) -> Tensor: + """Function to reduce the tensors from several DDP processes to one main process. + + Args: + result: The value to sync and reduce (typically tensor or number) + group: The process group to gather results from. Defaults to all processes (world) + reduce_op: The reduction operation. Defaults to sum. + Can also be a string of 'avg', 'mean' to calculate the mean during reduction. + + Return: + reduced value + """ + divide_by_world_size = False + + if group is None: + group = torch.distributed.group.WORLD + + op: Optional[ReduceOp] + if isinstance(reduce_op, str): + if reduce_op.lower() in ("avg", "mean"): + op = ReduceOp.SUM + divide_by_world_size = True + else: + op = getattr(ReduceOp, reduce_op.upper()) + else: + op = reduce_op + + # WA for HPU. HPU doesn't support Long types, forcefully set it to float + if _HPU_AVAILABLE: + is_hpu_backend = os.environ.get("HCCL_DISTRIBUTED_BACKEND") == "1" + if is_hpu_backend: + if (result.type() == "torch.LongTensor") or (result.type() == "torch.hpu.LongTensor"): + new_rank_zero_info("Long tensor unsupported on HPU, casting to float") + result = result.float() + + # Sync all processes before reduction + torch.distributed.barrier(group=group) + torch.distributed.all_reduce(result, op=op, group=group, async_op=False) + + if divide_by_world_size: + result = result / torch.distributed.get_world_size(group) + + return result + + +class AllGatherGrad(torch.autograd.Function): + @staticmethod + def forward( # type: ignore[override] + ctx: Any, + tensor: Tensor, + group: Optional["torch.distributed.ProcessGroup"] = group.WORLD, + ) -> Tensor: + ctx.group = group + + gathered_tensor = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())] + + torch.distributed.all_gather(gathered_tensor, tensor, group=group) + gathered_tensor = torch.stack(gathered_tensor, dim=0) + + return gathered_tensor + + @staticmethod + def backward(ctx: Any, *grad_output: Tensor) -> Tuple[Tensor, None]: + grad_output = torch.cat(grad_output) + + torch.distributed.all_reduce(grad_output, op=torch.distributed.ReduceOp.SUM, async_op=False, group=ctx.group) + + return grad_output[torch.distributed.get_rank()], None + + +def all_gather_ddp_if_available( + tensor: Tensor, group: Optional["torch.distributed.ProcessGroup"] = None, sync_grads: bool = False +) -> Tensor: + """Function to gather a tensor from several distributed processes. + + Args: + tensor: Tensor of shape (batch, ...) + group: The process group to gather results from. Defaults to all processes (world) + sync_grads: Flag that allows users to synchronize gradients for all_gather op + + Return: + A tensor of shape (world_size, batch, ...) + """ + group = group if group is not None else torch.distributed.group.WORLD + if distributed_available(): + if sync_grads: + return AllGatherGrad.apply(tensor, group) + with torch.no_grad(): + return AllGatherGrad.apply(tensor, group) + return tensor + + +def init_dist_connection( + # TODO(lite): Fix this type error + cluster_environment: "ClusterEnvironment", # type: ignore[name-defined] # noqa: F821 + torch_distributed_backend: str, + global_rank: Optional[int] = None, + world_size: Optional[int] = None, + **kwargs: Any, +) -> None: + """Utility function to initialize distributed connection by setting env variables and initializing the + distributed process group. + + Args: + cluster_environment: ``ClusterEnvironment`` instance + torch_distributed_backend: Backend to use (includes `nccl` and `gloo`) + global_rank: Rank of the current process + world_size: Number of processes in the group + kwargs: Kwargs for ``init_process_group`` + + Raises: + RuntimeError: + If ``torch.distributed`` is not available + """ + if not torch.distributed.is_available(): + raise RuntimeError("torch.distributed is not available. Cannot initialize distributed process group") + if torch.distributed.is_initialized(): + log.debug("torch.distributed is already initialized. Exiting early") + return + global_rank = global_rank if global_rank is not None else cluster_environment.global_rank() + world_size = world_size if world_size is not None else cluster_environment.world_size() + os.environ["MASTER_ADDR"] = cluster_environment.main_address + os.environ["MASTER_PORT"] = str(cluster_environment.main_port) + log.info(f"Initializing distributed: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}") + torch.distributed.init_process_group(torch_distributed_backend, rank=global_rank, world_size=world_size, **kwargs) + + # On rank=0 let everyone know training is starting + new_rank_zero_info( + f"{'-' * 100}\n" + f"distributed_backend={torch_distributed_backend}\n" + f"All distributed processes registered. Starting with {world_size} processes\n" + f"{'-' * 100}\n" + ) + + +def tpu_distributed() -> bool: + return _TPU_AVAILABLE and xm.xrt_world_size() > 1 + + +def get_default_process_group_backend_for_device(device: torch.device) -> str: + return "nccl" if device.type == "cuda" else "gloo" + + +def _get_process_group_backend_from_env() -> Optional[str]: + torch_backend = os.getenv("PL_TORCH_DISTRIBUTED_BACKEND") + if torch_backend is not None: + rank_zero_deprecation( + "Environment variable `PL_TORCH_DISTRIBUTED_BACKEND`" + " was deprecated in v1.6 and will be removed in v1.8." + " Specify `process_group_backend` directly on the strategy constructor." + ) + return torch_backend diff --git a/src/lightning_lite/utilities/enums.py b/src/lightning_lite/utilities/enums.py new file mode 100644 index 0000000000..567483b1e5 --- /dev/null +++ b/src/lightning_lite/utilities/enums.py @@ -0,0 +1,95 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Enumerated utilities.""" +from __future__ import annotations + +from typing import TYPE_CHECKING + +from lightning_utilities.core.enums import StrEnum + +if TYPE_CHECKING: + from enum import Enum + + # re-defined because `mypy` infers `StrEnum` as `Any` + class LightningEnum(StrEnum, Enum): + ... + +else: + LightningEnum = StrEnum + + +class AMPType(LightningEnum): + """Type of Automatic Mixed Precission used for training.""" + + APEX = "apex" + NATIVE = "native" + + +class PrecisionType(LightningEnum): + """Type of precision used.""" + + HALF = "16" + FLOAT = "32" + FULL = "64" + BFLOAT = "bf16" + MIXED = "mixed" + + @staticmethod + def supported_type(precision: str | int) -> bool: + return any(x == precision for x in PrecisionType) + + @staticmethod + def supported_types() -> list[str]: + return [x.value for x in PrecisionType] + + +class _StrategyType(LightningEnum): + """Define type of training strategy.""" + + DP = "dp" + DDP = "ddp" + DDP_SPAWN = "ddp_spawn" + DDP_FORK = "ddp_fork" + TPU_SPAWN = "tpu_spawn" + DEEPSPEED = "deepspeed" + HOROVOD = "horovod" + DDP_SHARDED = "ddp_sharded" + DDP_SHARDED_SPAWN = "ddp_sharded_spawn" + DDP_FULLY_SHARDED = "ddp_fully_sharded" + BAGUA = "bagua" + HPU_PARALLEL = "hpu_parallel" + + @staticmethod + def interactive_compatible_types() -> list[_StrategyType]: + """Returns a list containing interactive compatible _StrategyTypes.""" + return [ + _StrategyType.DP, + _StrategyType.TPU_SPAWN, + _StrategyType.DDP_FORK, + ] + + def is_interactive_compatible(self) -> bool: + """Returns whether self is interactive compatible.""" + return self in _StrategyType.interactive_compatible_types() + + +class _AcceleratorType(LightningEnum): + """Define Accelerator type by its nature.""" + + CPU = "CPU" + CUDA = "CUDA" + IPU = "IPU" + TPU = "TPU" + HPU = "HPU" + MPS = "MPS" diff --git a/src/lightning_lite/utilities/exceptions.py b/src/lightning_lite/utilities/exceptions.py new file mode 100644 index 0000000000..7f6c3dd9b3 --- /dev/null +++ b/src/lightning_lite/utilities/exceptions.py @@ -0,0 +1,17 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class MisconfigurationException(Exception): + """Exception used to inform users of misuse with Lightning.""" diff --git a/src/lightning_lite/utilities/imports.py b/src/lightning_lite/utilities/imports.py new file mode 100644 index 0000000000..34e7b5ac5f --- /dev/null +++ b/src/lightning_lite/utilities/imports.py @@ -0,0 +1,61 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""General utilities.""" +import operator +import platform +import sys + +from lightning_utilities.core.imports import compare_version, module_available, package_available + +_IS_WINDOWS = platform.system() == "Windows" +_IS_INTERACTIVE = hasattr(sys, "ps1") # https://stackoverflow.com/a/64523765 +_PYTHON_GREATER_EQUAL_3_8_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 8) +_PYTHON_GREATER_EQUAL_3_10_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 10) +_TORCH_GREATER_EQUAL_1_9_1 = compare_version("torch", operator.ge, "1.9.1") +_TORCH_GREATER_EQUAL_1_10 = compare_version("torch", operator.ge, "1.10.0") +_TORCH_LESSER_EQUAL_1_10_2 = compare_version("torch", operator.le, "1.10.2") +_TORCH_GREATER_EQUAL_1_11 = compare_version("torch", operator.ge, "1.11.0") +_TORCH_GREATER_EQUAL_1_12 = compare_version("torch", operator.ge, "1.12.0") +_TORCH_GREATER_EQUAL_1_13 = compare_version("torch", operator.ge, "1.13.0", use_base_version=True) + +_APEX_AVAILABLE = module_available("apex.amp") +_HABANA_FRAMEWORK_AVAILABLE = package_available("habana_frameworks") +_HIVEMIND_AVAILABLE = package_available("hivemind") +_HOROVOD_AVAILABLE = module_available("horovod.torch") +_OMEGACONF_AVAILABLE = package_available("omegaconf") +_POPTORCH_AVAILABLE = package_available("poptorch") +_PSUTIL_AVAILABLE = package_available("psutil") +_XLA_AVAILABLE: bool = package_available("torch_xla") + +# TODO(lite): import this from the fairscale files once they move to lite package +_FAIRSCALE_AVAILABLE = not _IS_WINDOWS and module_available("fairscale.nn") + + +from lightning_lite.utilities.xla_device import XLADeviceUtils # noqa: E402 + +_TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists() + +if _POPTORCH_AVAILABLE: + import poptorch + + _IPU_AVAILABLE = poptorch.ipuHardwareIsAvailable() +else: + _IPU_AVAILABLE = False + +if _HABANA_FRAMEWORK_AVAILABLE: + from habana_frameworks.torch.utils.library_loader import is_habana_avaialble + + _HPU_AVAILABLE = is_habana_avaialble() +else: + _HPU_AVAILABLE = False diff --git a/src/lightning_lite/utilities/optimizer.py b/src/lightning_lite/utilities/optimizer.py new file mode 100644 index 0000000000..c10c426bfe --- /dev/null +++ b/src/lightning_lite/utilities/optimizer.py @@ -0,0 +1,34 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Iterable + +from lightning_utilities.core.apply_func import apply_to_collection +from torch import Tensor +from torch.optim import Optimizer + +from lightning_lite.utilities.apply_func import move_data_to_device +from lightning_lite.utilities.types import _DEVICE + + +def optimizers_to_device(optimizers: Iterable[Optimizer], device: _DEVICE) -> None: + """Moves optimizer states for a sequence of optimizers to the device.""" + for opt in optimizers: + optimizer_to_device(opt, device) + + +def optimizer_to_device(optimizer: Optimizer, device: _DEVICE) -> None: + """Moves the state of a single optimizer to the device.""" + for p, v in optimizer.state.items(): + optimizer.state[p] = apply_to_collection(v, Tensor, move_data_to_device, device) diff --git a/src/lightning_lite/utilities/rank_zero.py b/src/lightning_lite/utilities/rank_zero.py new file mode 100644 index 0000000000..db364dfd8f --- /dev/null +++ b/src/lightning_lite/utilities/rank_zero.py @@ -0,0 +1,60 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities that can be used for calling functions on a particular rank.""" +import logging +import os +from typing import Optional + +import lightning_utilities.core.rank_zero as rank_zero_module + +# note: we want to keep these indirections so the `rank_zero_only.rank` is set on import +from lightning_utilities.core.rank_zero import ( # noqa: F401 + rank_zero_debug, + rank_zero_deprecation, + rank_zero_info, + rank_zero_only, + rank_zero_warn, +) + +import lightning_lite + +rank_zero_module.log = logging.getLogger(__name__) + + +def _get_rank( + strategy: Optional["lightning_lite.strategies.Strategy"] = None, # type: ignore[name-defined] +) -> Optional[int]: + if strategy is not None: + return strategy.global_rank + # SLURM_PROCID can be set even if SLURM is not managing the multiprocessing, + # therefore LOCAL_RANK needs to be checked first + rank_keys = ("RANK", "LOCAL_RANK", "SLURM_PROCID", "JSM_NAMESPACE_RANK") + for key in rank_keys: + rank = os.environ.get(key) + if rank is not None: + return int(rank) + # None to differentiate whether an environment variable was set at all + return None + + +# add the attribute to the function but don't overwrite in case Trainer has already set it +rank_zero_only.rank = getattr(rank_zero_only, "rank", _get_rank() or 0) + + +class LightningDeprecationWarning(DeprecationWarning): + """Deprecation warnings raised by Lightning.""" + + +rank_zero_module.rank_zero_deprecation_category = LightningDeprecationWarning diff --git a/src/pytorch_lightning/utilities/registry.py b/src/lightning_lite/utilities/registry.py similarity index 100% rename from src/pytorch_lightning/utilities/registry.py rename to src/lightning_lite/utilities/registry.py diff --git a/src/lightning_lite/utilities/seed.py b/src/lightning_lite/utilities/seed.py new file mode 100644 index 0000000000..a55b5e3dd8 --- /dev/null +++ b/src/lightning_lite/utilities/seed.py @@ -0,0 +1,127 @@ +import logging +import os +import random +from random import getstate as python_get_rng_state +from random import setstate as python_set_rng_state +from typing import Any, Dict, Optional + +import numpy as np +import torch +from lightning_utilities.core.rank_zero import rank_prefixed_message + +from lightning_lite.utilities.rank_zero import _get_rank, rank_zero_only, rank_zero_warn + +log = logging.getLogger(__name__) + +max_seed_value = np.iinfo(np.uint32).max +min_seed_value = np.iinfo(np.uint32).min + + +def seed_everything(seed: Optional[int] = None, workers: bool = False) -> int: + """Function that sets seed for pseudo-random number generators in: pytorch, numpy, python.random In addition, + sets the following environment variables: + + - `PL_GLOBAL_SEED`: will be passed to spawned subprocesses (e.g. ddp_spawn backend). + - `PL_SEED_WORKERS`: (optional) is set to 1 if ``workers=True``. + + Args: + seed: the integer value seed for global random state in Lightning. + If `None`, will read seed from `PL_GLOBAL_SEED` env variable + or select it randomly. + workers: if set to ``True``, will properly configure all dataloaders passed to the + Trainer with a ``worker_init_fn``. If the user already provides such a function + for their dataloaders, setting this argument will have no influence. See also: + :func:`~lightning_lite.utilities.seed.pl_worker_init_function`. + """ + if seed is None: + env_seed = os.environ.get("PL_GLOBAL_SEED") + if env_seed is None: + seed = _select_seed_randomly(min_seed_value, max_seed_value) + rank_zero_warn(f"No seed found, seed set to {seed}") + else: + try: + seed = int(env_seed) + except ValueError: + seed = _select_seed_randomly(min_seed_value, max_seed_value) + rank_zero_warn(f"Invalid seed found: {repr(env_seed)}, seed set to {seed}") + elif not isinstance(seed, int): + seed = int(seed) + + if not (min_seed_value <= seed <= max_seed_value): + rank_zero_warn(f"{seed} is not in bounds, numpy accepts from {min_seed_value} to {max_seed_value}") + seed = _select_seed_randomly(min_seed_value, max_seed_value) + + log.info(rank_prefixed_message(f"Global seed set to {seed}", _get_rank())) + os.environ["PL_GLOBAL_SEED"] = str(seed) + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + os.environ["PL_SEED_WORKERS"] = f"{int(workers)}" + + return seed + + +def _select_seed_randomly(min_seed_value: int = min_seed_value, max_seed_value: int = max_seed_value) -> int: + return random.randint(min_seed_value, max_seed_value) + + +def reset_seed() -> None: + """Reset the seed to the value that :func:`lightning_lite.utilities.seed.seed_everything` previously set. + + If :func:`lightning_lite.utilities.seed.seed_everything` is unused, this function will do nothing. + """ + seed = os.environ.get("PL_GLOBAL_SEED", None) + if seed is None: + return + workers = os.environ.get("PL_SEED_WORKERS", "0") + seed_everything(int(seed), workers=bool(int(workers))) + + +def pl_worker_init_function(worker_id: int, rank: Optional[int] = None) -> None: # pragma: no cover + """The worker_init_fn that Lightning automatically adds to your dataloader if you previously set the seed with + ``seed_everything(seed, workers=True)``. + + See also the PyTorch documentation on + `randomness in DataLoaders `_. + """ + # implementation notes: https://github.com/pytorch/pytorch/issues/5059#issuecomment-817392562 + global_rank = rank if rank is not None else rank_zero_only.rank + process_seed = torch.initial_seed() + # back out the base seed so we can use all the bits + base_seed = process_seed - worker_id + log.debug( + f"Initializing random number generators of process {global_rank} worker {worker_id} with base seed {base_seed}" + ) + ss = np.random.SeedSequence([base_seed, worker_id, global_rank]) + # use 128 bits (4 x 32-bit words) + np.random.seed(ss.generate_state(4)) + # Spawn distinct SeedSequences for the PyTorch PRNG and the stdlib random module + torch_ss, stdlib_ss = ss.spawn(2) + torch.manual_seed(torch_ss.generate_state(1, dtype=np.uint64)[0]) + # use 128 bits expressed as an integer + stdlib_seed = (stdlib_ss.generate_state(2, dtype=np.uint64).astype(object) * [1 << 64, 1]).sum() + random.seed(stdlib_seed) + + +def _collect_rng_states() -> Dict[str, Any]: + """Collect the global random state of :mod:`torch`, :mod:`torch.cuda`, :mod:`numpy` and Python.""" + return { + "torch": torch.get_rng_state(), + "torch.cuda": torch.cuda.get_rng_state_all(), + "numpy": np.random.get_state(), + "python": python_get_rng_state(), + } + + +def _set_rng_states(rng_state_dict: Dict[str, Any]) -> None: + """Set the global random state of :mod:`torch`, :mod:`torch.cuda`, :mod:`numpy` and Python in the current + process.""" + torch.set_rng_state(rng_state_dict["torch"]) + # torch.cuda rng_state is only included since v1.8. + if "torch.cuda" in rng_state_dict: + torch.cuda.set_rng_state_all(rng_state_dict["torch.cuda"]) + np.random.set_state(rng_state_dict["numpy"]) + version, state, gauss = rng_state_dict["python"] + python_set_rng_state((version, tuple(state), gauss)) diff --git a/src/lightning_lite/utilities/types.py b/src/lightning_lite/utilities/types.py index 900154e69c..950210925a 100644 --- a/src/lightning_lite/utilities/types.py +++ b/src/lightning_lite/utilities/types.py @@ -11,8 +11,69 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Union +from pathlib import Path +from typing import Any, Callable, Dict, Iterator, List, Optional, TypeVar, Union import torch +from torch import Tensor +from torch.optim import Optimizer +from typing_extensions import Protocol, runtime_checkable +_PATH = Union[str, Path] _DEVICE = Union[torch.device, str, int] +_MAP_LOCATION_TYPE = Optional[Union[_DEVICE, Callable[[_DEVICE], _DEVICE], Dict[_DEVICE, _DEVICE]]] +_PARAMETERS = Iterator[torch.nn.Parameter] + + +_DictKey = TypeVar("_DictKey") + + +@runtime_checkable +class _Stateful(Protocol[_DictKey]): + """This class is used to detect if an object is stateful using `isinstance(obj, _Stateful)`.""" + + def state_dict(self) -> Dict[_DictKey, Any]: + ... + + def load_state_dict(self, state_dict: Dict[_DictKey, Any]) -> None: + ... + + +# Inferred from `torch.optim.lr_scheduler.pyi` +# Missing attributes were added to improve typing +@runtime_checkable +class _LRScheduler(_Stateful[str], Protocol): + optimizer: Optimizer + base_lrs: List[float] + + def __init__(self, optimizer: Optimizer, *args: Any, **kwargs: Any) -> None: + ... + + def step(self, epoch: Optional[int] = None) -> None: + ... + + +# Inferred from `torch.optim.lr_scheduler.pyi` +# Missing attributes were added to improve typing +@runtime_checkable +class ReduceLROnPlateau(_Stateful[str], Protocol): + in_cooldown: bool + optimizer: Optimizer + + def __init__( + self, + optimizer: Optimizer, + mode: str = ..., + factor: float = ..., + patience: int = ..., + verbose: bool = ..., + threshold: float = ..., + threshold_mode: str = ..., + cooldown: int = ..., + min_lr: float = ..., + eps: float = ..., + ) -> None: + ... + + def step(self, metrics: Union[float, int, Tensor], epoch: Optional[int] = None) -> None: + ... diff --git a/src/lightning_lite/utilities/warnings.py b/src/lightning_lite/utilities/warnings.py new file mode 100644 index 0000000000..dfd298fd49 --- /dev/null +++ b/src/lightning_lite/utilities/warnings.py @@ -0,0 +1,24 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Warning-related utilities.""" +import warnings + +from lightning_lite.utilities.rank_zero import LightningDeprecationWarning + +# enable our warnings +warnings.simplefilter("default", category=LightningDeprecationWarning) + + +class PossibleUserWarning(UserWarning): + """Warnings that could be false positives.""" diff --git a/src/lightning_lite/utilities/xla_device.py b/src/lightning_lite/utilities/xla_device.py index 2feef71c56..cc0bfb7882 100644 --- a/src/lightning_lite/utilities/xla_device.py +++ b/src/lightning_lite/utilities/xla_device.py @@ -18,7 +18,7 @@ import traceback from multiprocessing import Process, Queue from typing import Any, Callable, Union -from pytorch_lightning.utilities.imports import _XLA_AVAILABLE +from lightning_lite.utilities.imports import _XLA_AVAILABLE if _XLA_AVAILABLE: import torch_xla.core.xla_model as xm diff --git a/src/pytorch_lightning/__init__.py b/src/pytorch_lightning/__init__.py index 5a009713e0..d1f7c29aae 100644 --- a/src/pytorch_lightning/__init__.py +++ b/src/pytorch_lightning/__init__.py @@ -31,10 +31,10 @@ if not _root_logger.hasHandlers(): _logger.addHandler(logging.StreamHandler()) _logger.propagate = False +from lightning_lite.utilities.seed import seed_everything # noqa: E402 from pytorch_lightning.callbacks import Callback # noqa: E402 from pytorch_lightning.core import LightningDataModule, LightningModule # noqa: E402 from pytorch_lightning.trainer import Trainer # noqa: E402 -from pytorch_lightning.utilities.seed import seed_everything # noqa: E402 __all__ = ["Trainer", "LightningDataModule", "LightningModule", "Callback", "seed_everything"] diff --git a/src/pytorch_lightning/accelerators/cpu.py b/src/pytorch_lightning/accelerators/cpu.py index d0981e7269..00eeac15ff 100644 --- a/src/pytorch_lightning/accelerators/cpu.py +++ b/src/pytorch_lightning/accelerators/cpu.py @@ -15,11 +15,11 @@ from typing import Any, Dict, List, Union import torch +from lightning_lite.utilities.device_parser import parse_cpu_cores +from lightning_lite.utilities.types import _DEVICE from pytorch_lightning.accelerators.accelerator import Accelerator -from pytorch_lightning.utilities.device_parser import parse_cpu_cores from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _PSUTIL_AVAILABLE -from pytorch_lightning.utilities.types import _DEVICE class CPUAccelerator(Accelerator): diff --git a/src/pytorch_lightning/accelerators/cuda.py b/src/pytorch_lightning/accelerators/cuda.py index 1c69015546..e5f939c69a 100644 --- a/src/pytorch_lightning/accelerators/cuda.py +++ b/src/pytorch_lightning/accelerators/cuda.py @@ -20,10 +20,10 @@ from typing import Any, Dict, List, Optional, Union import torch import pytorch_lightning as pl +from lightning_lite.utilities import device_parser +from lightning_lite.utilities.types import _DEVICE from pytorch_lightning.accelerators.accelerator import Accelerator -from pytorch_lightning.utilities import device_parser from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.types import _DEVICE _log = logging.getLogger(__name__) diff --git a/src/pytorch_lightning/accelerators/mps.py b/src/pytorch_lightning/accelerators/mps.py index 5ebcb37cd0..5610ba1549 100644 --- a/src/pytorch_lightning/accelerators/mps.py +++ b/src/pytorch_lightning/accelerators/mps.py @@ -16,11 +16,11 @@ from typing import Any, Dict, List, Optional, Union import torch +from lightning_lite.utilities import device_parser +from lightning_lite.utilities.types import _DEVICE from pytorch_lightning.accelerators.accelerator import Accelerator -from pytorch_lightning.utilities import device_parser from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _PSUTIL_AVAILABLE, _TORCH_GREATER_EQUAL_1_12 -from pytorch_lightning.utilities.types import _DEVICE # For using the `MPSAccelerator`, user's machine should have `torch>=1.12`, Metal programming framework and # the ARM-based Apple Silicon processors. diff --git a/src/pytorch_lightning/accelerators/registry.py b/src/pytorch_lightning/accelerators/registry.py index 992fa34b02..74a306df26 100644 --- a/src/pytorch_lightning/accelerators/registry.py +++ b/src/pytorch_lightning/accelerators/registry.py @@ -15,9 +15,9 @@ import importlib from inspect import getmembers, isclass from typing import Any, Callable, Dict, List, Optional +from lightning_lite.utilities.registry import _is_register_method_overridden from pytorch_lightning.accelerators.accelerator import Accelerator from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.registry import _is_register_method_overridden class _AcceleratorRegistry(dict): diff --git a/src/pytorch_lightning/accelerators/tpu.py b/src/pytorch_lightning/accelerators/tpu.py index fa8bd007cb..89170e4c92 100644 --- a/src/pytorch_lightning/accelerators/tpu.py +++ b/src/pytorch_lightning/accelerators/tpu.py @@ -15,8 +15,8 @@ from typing import Any, Dict, List, Optional, Union import torch +from lightning_lite.utilities import device_parser from pytorch_lightning.accelerators.accelerator import Accelerator -from pytorch_lightning.utilities import device_parser from pytorch_lightning.utilities.imports import _TPU_AVAILABLE, _XLA_AVAILABLE if _XLA_AVAILABLE: diff --git a/src/pytorch_lightning/callbacks/base.py b/src/pytorch_lightning/callbacks/base.py index d0d564110a..0504249ea7 100644 --- a/src/pytorch_lightning/callbacks/base.py +++ b/src/pytorch_lightning/callbacks/base.py @@ -14,7 +14,7 @@ from typing import Any from pytorch_lightning.callbacks.callback import Callback as NewCallback -from pytorch_lightning.utilities import rank_zero_deprecation +from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation class Callback(NewCallback): diff --git a/src/pytorch_lightning/callbacks/early_stopping.py b/src/pytorch_lightning/callbacks/early_stopping.py index 30ab05c76e..6c1a43e1d1 100644 --- a/src/pytorch_lightning/callbacks/early_stopping.py +++ b/src/pytorch_lightning/callbacks/early_stopping.py @@ -27,9 +27,10 @@ from lightning_utilities.core.rank_zero import rank_prefixed_message from torch import Tensor import pytorch_lightning as pl +from lightning_lite.utilities.rank_zero import _get_rank from pytorch_lightning.callbacks.callback import Callback from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.rank_zero import _get_rank, rank_zero_warn +from pytorch_lightning.utilities.rank_zero import rank_zero_warn log = logging.getLogger(__name__) diff --git a/src/pytorch_lightning/callbacks/fault_tolerance.py b/src/pytorch_lightning/callbacks/fault_tolerance.py index 9d04fc86b6..75347df01b 100644 --- a/src/pytorch_lightning/callbacks/fault_tolerance.py +++ b/src/pytorch_lightning/callbacks/fault_tolerance.py @@ -21,8 +21,8 @@ import os from typing import Any import pytorch_lightning as pl +from lightning_lite.utilities.types import _PATH from pytorch_lightning.callbacks import Checkpoint -from pytorch_lightning.utilities.types import _PATH class _FaultToleranceCheckpoint(Checkpoint): diff --git a/src/pytorch_lightning/callbacks/model_checkpoint.py b/src/pytorch_lightning/callbacks/model_checkpoint.py index a80c82447c..e484cfde5c 100644 --- a/src/pytorch_lightning/callbacks/model_checkpoint.py +++ b/src/pytorch_lightning/callbacks/model_checkpoint.py @@ -36,10 +36,11 @@ from torch import Tensor import pytorch_lightning as pl from lightning_lite.utilities.cloud_io import get_filesystem +from lightning_lite.utilities.types import _PATH from pytorch_lightning.callbacks import Checkpoint from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info, rank_zero_warn -from pytorch_lightning.utilities.types import _PATH, STEP_OUTPUT +from pytorch_lightning.utilities.types import STEP_OUTPUT log = logging.getLogger(__name__) warning_cache = WarningCache() diff --git a/src/pytorch_lightning/callbacks/stochastic_weight_avg.py b/src/pytorch_lightning/callbacks/stochastic_weight_avg.py index 51cbceb7f9..732c8831b2 100644 --- a/src/pytorch_lightning/callbacks/stochastic_weight_avg.py +++ b/src/pytorch_lightning/callbacks/stochastic_weight_avg.py @@ -23,12 +23,13 @@ from torch import nn, Tensor from torch.optim.swa_utils import SWALR import pytorch_lightning as pl +from lightning_lite.utilities.types import _LRScheduler from pytorch_lightning.callbacks.callback import Callback from pytorch_lightning.strategies import DDPFullyShardedStrategy, DeepSpeedStrategy from pytorch_lightning.strategies.fully_sharded_native import DDPFullyShardedNativeStrategy from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_warn -from pytorch_lightning.utilities.types import _LRScheduler, LRSchedulerConfig +from pytorch_lightning.utilities.types import LRSchedulerConfig _AVG_FN = Callable[[Tensor, Tensor, Tensor], Tensor] diff --git a/src/pytorch_lightning/cli.py b/src/pytorch_lightning/cli.py index ee53236508..82156c6b4a 100644 --- a/src/pytorch_lightning/cli.py +++ b/src/pytorch_lightning/cli.py @@ -287,7 +287,7 @@ class LightningCLI: this argument will not be configurable from a configuration file and will always be present for this particular CLI. Alternatively, configurable callbacks can be added as explained in :ref:`the CLI docs `. - seed_everything_default: Value for the :func:`~pytorch_lightning.utilities.seed.seed_everything` + seed_everything_default: Value for the :func:`~lightning_lite.utilities.seed.seed_everything` seed argument. Set to True to automatically choose a valid seed. Setting it to False will not call seed_everything. description: Description of the tool shown when running ``--help``. diff --git a/src/pytorch_lightning/core/datamodule.py b/src/pytorch_lightning/core/datamodule.py index e4adf9b1ca..6a5ea13013 100644 --- a/src/pytorch_lightning/core/datamodule.py +++ b/src/pytorch_lightning/core/datamodule.py @@ -19,6 +19,7 @@ from typing import Any, Dict, IO, List, Mapping, Optional, Sequence, Tuple, Unio from torch.utils.data import DataLoader, Dataset, IterableDataset import pytorch_lightning as pl +from lightning_lite.utilities.types import _PATH from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks from pytorch_lightning.core.mixins import HyperparametersMixin from pytorch_lightning.core.saving import _load_from_checkpoint @@ -28,7 +29,7 @@ from pytorch_lightning.utilities.argparse import ( get_init_arguments_and_types, parse_argparser, ) -from pytorch_lightning.utilities.types import _ADD_ARGPARSE_RETURN, _PATH, EVAL_DATALOADERS, TRAIN_DATALOADERS +from pytorch_lightning.utilities.types import _ADD_ARGPARSE_RETURN, EVAL_DATALOADERS, TRAIN_DATALOADERS class LightningDataModule(CheckpointHooks, DataHooks, HyperparametersMixin): diff --git a/src/pytorch_lightning/core/lightning.py b/src/pytorch_lightning/core/lightning.py index bf6fe19c7d..974cecb39e 100644 --- a/src/pytorch_lightning/core/lightning.py +++ b/src/pytorch_lightning/core/lightning.py @@ -14,7 +14,7 @@ from typing import Any from pytorch_lightning.core.module import LightningModule as NewLightningModule -from pytorch_lightning.utilities import rank_zero_deprecation +from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation class LightningModule(NewLightningModule): diff --git a/src/pytorch_lightning/core/module.py b/src/pytorch_lightning/core/module.py index 6776f8ab95..ab655adb4d 100644 --- a/src/pytorch_lightning/core/module.py +++ b/src/pytorch_lightning/core/module.py @@ -37,6 +37,7 @@ import pytorch_lightning as pl from lightning_lite.utilities.apply_func import convert_to_tensors from lightning_lite.utilities.cloud_io import get_filesystem from lightning_lite.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin +from lightning_lite.utilities.distributed import distributed_available, sync_ddp from pytorch_lightning.callbacks.callback import Callback from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks, ModelHooks from pytorch_lightning.core.mixins import HyperparametersMixin @@ -45,7 +46,6 @@ from pytorch_lightning.core.saving import ModelIO from pytorch_lightning.loggers import Logger, LoggerCollection from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import _FxValidator from pytorch_lightning.utilities import _IS_WINDOWS, _TORCH_GREATER_EQUAL_1_10, GradClipAlgorithmType -from pytorch_lightning.utilities.distributed import distributed_available, sync_ddp from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_11, _TORCH_GREATER_EQUAL_1_13 from pytorch_lightning.utilities.rank_zero import rank_zero_debug, rank_zero_deprecation, rank_zero_warn diff --git a/src/pytorch_lightning/core/optimizer.py b/src/pytorch_lightning/core/optimizer.py index b96cfabd83..e1a834f8c8 100644 --- a/src/pytorch_lightning/core/optimizer.py +++ b/src/pytorch_lightning/core/optimizer.py @@ -21,10 +21,11 @@ from torch import optim from torch.optim import Optimizer import pytorch_lightning as pl +from lightning_lite.utilities.types import _Stateful, ReduceLROnPlateau from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.rank_zero import rank_zero_warn -from pytorch_lightning.utilities.types import _Stateful, LRSchedulerConfig, LRSchedulerTypeTuple, ReduceLROnPlateau +from pytorch_lightning.utilities.types import LRSchedulerConfig, LRSchedulerTypeTuple def do_nothing_closure() -> None: diff --git a/src/pytorch_lightning/core/saving.py b/src/pytorch_lightning/core/saving.py index 7d999eebb4..1bec607fad 100644 --- a/src/pytorch_lightning/core/saving.py +++ b/src/pytorch_lightning/core/saving.py @@ -29,11 +29,11 @@ from lightning_utilities.core.apply_func import apply_to_collection import pytorch_lightning as pl from lightning_lite.utilities.cloud_io import get_filesystem from lightning_lite.utilities.cloud_io import load as pl_load +from lightning_lite.utilities.types import _MAP_LOCATION_TYPE, _PATH from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, AttributeDict from pytorch_lightning.utilities.migration import pl_legacy_patch from pytorch_lightning.utilities.parsing import parse_class_init_keys from pytorch_lightning.utilities.rank_zero import rank_zero_warn -from pytorch_lightning.utilities.types import _MAP_LOCATION_TYPE, _PATH log = logging.getLogger(__name__) PRIMITIVE_TYPES = (bool, int, float, str) diff --git a/src/pytorch_lightning/lite/lite.py b/src/pytorch_lightning/lite/lite.py index 0ec9cf5c2d..c301f71d44 100644 --- a/src/pytorch_lightning/lite/lite.py +++ b/src/pytorch_lightning/lite/lite.py @@ -25,7 +25,15 @@ from torch import Tensor from torch.optim import Optimizer from torch.utils.data import BatchSampler, DataLoader, DistributedSampler +from lightning_lite.utilities import _AcceleratorType, _StrategyType, move_data_to_device from lightning_lite.utilities.apply_func import convert_to_tensors +from lightning_lite.utilities.data import ( + _auto_add_worker_init_fn, + _replace_dunder_methods, + _update_dataloader, + has_iterable_dataset, +) +from lightning_lite.utilities.seed import seed_everything from pytorch_lightning.accelerators.accelerator import Accelerator from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer from pytorch_lightning.overrides.distributed import DistributedSamplerWrapper @@ -33,15 +41,7 @@ from pytorch_lightning.plugins import PLUGIN_INPUT from pytorch_lightning.strategies import DeepSpeedStrategy, Strategy, TPUSpawnStrategy from pytorch_lightning.strategies.strategy import TBroadcast from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector -from pytorch_lightning.utilities import _AcceleratorType, _StrategyType, move_data_to_device -from pytorch_lightning.utilities.data import ( - _auto_add_worker_init_fn, - _replace_dunder_methods, - _update_dataloader, - has_iterable_dataset, -) from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.seed import seed_everything class LightningLite(ABC): diff --git a/src/pytorch_lightning/loops/utilities.py b/src/pytorch_lightning/loops/utilities.py index 9b8ec84ba3..d5824c431c 100644 --- a/src/pytorch_lightning/loops/utilities.py +++ b/src/pytorch_lightning/loops/utilities.py @@ -23,16 +23,16 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader import pytorch_lightning as pl +from lightning_lite.utilities.warnings import PossibleUserWarning from pytorch_lightning.callbacks.timer import Timer from pytorch_lightning.loops import Loop from pytorch_lightning.strategies import ParallelStrategy, Strategy from pytorch_lightning.trainer.progress import BaseProgress -from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.memory import recursive_detach +from pytorch_lightning.utilities.rank_zero import rank_zero_warn from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature from pytorch_lightning.utilities.types import STEP_OUTPUT -from pytorch_lightning.utilities.warnings import PossibleUserWarning def check_finite_loss(loss: Optional[Tensor]) -> None: diff --git a/src/pytorch_lightning/overrides/base.py b/src/pytorch_lightning/overrides/base.py index bd2a904de6..10ab5c06b2 100644 --- a/src/pytorch_lightning/overrides/base.py +++ b/src/pytorch_lightning/overrides/base.py @@ -20,7 +20,7 @@ from torch.nn.parallel import DistributedDataParallel import pytorch_lightning as pl from lightning_lite.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin -from pytorch_lightning.utilities import rank_zero_deprecation +from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation class _LightningPrecisionModuleWrapperBase(_DeviceDtypeModuleMixin, torch.nn.Module): diff --git a/src/pytorch_lightning/overrides/fairscale.py b/src/pytorch_lightning/overrides/fairscale.py index 572efd277d..0a35f9ddd4 100644 --- a/src/pytorch_lightning/overrides/fairscale.py +++ b/src/pytorch_lightning/overrides/fairscale.py @@ -22,8 +22,8 @@ from pytorch_lightning.overrides.base import ( _LightningPrecisionModuleWrapperBase, unwrap_lightning_module, ) -from pytorch_lightning.utilities import rank_zero_deprecation from pytorch_lightning.utilities.imports import _IS_WINDOWS +from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation _FAIRSCALE_AVAILABLE = not _IS_WINDOWS and module_available("fairscale.nn") diff --git a/src/pytorch_lightning/plugins/io/checkpoint_plugin.py b/src/pytorch_lightning/plugins/io/checkpoint_plugin.py index 7dcc850424..04ace9945f 100644 --- a/src/pytorch_lightning/plugins/io/checkpoint_plugin.py +++ b/src/pytorch_lightning/plugins/io/checkpoint_plugin.py @@ -14,7 +14,7 @@ from abc import ABC, abstractmethod from typing import Any, Dict, Optional -from pytorch_lightning.utilities.types import _PATH +from lightning_lite.utilities.types import _PATH class CheckpointIO(ABC): diff --git a/src/pytorch_lightning/plugins/io/hpu_plugin.py b/src/pytorch_lightning/plugins/io/hpu_plugin.py index 59dfa93219..9fb564cda7 100644 --- a/src/pytorch_lightning/plugins/io/hpu_plugin.py +++ b/src/pytorch_lightning/plugins/io/hpu_plugin.py @@ -19,8 +19,8 @@ import torch from lightning_lite.utilities.apply_func import move_data_to_device from lightning_lite.utilities.cloud_io import atomic_save, get_filesystem +from lightning_lite.utilities.types import _PATH from pytorch_lightning.plugins.io.torch_plugin import TorchCheckpointIO -from pytorch_lightning.utilities.types import _PATH class HPUCheckpointIO(TorchCheckpointIO): diff --git a/src/pytorch_lightning/plugins/io/torch_plugin.py b/src/pytorch_lightning/plugins/io/torch_plugin.py index ccdc4874a1..723900864c 100644 --- a/src/pytorch_lightning/plugins/io/torch_plugin.py +++ b/src/pytorch_lightning/plugins/io/torch_plugin.py @@ -18,9 +18,9 @@ from typing import Any, Callable, Dict, Optional import pytorch_lightning as pl from lightning_lite.utilities.cloud_io import atomic_save, get_filesystem from lightning_lite.utilities.cloud_io import load as pl_load +from lightning_lite.utilities.types import _PATH from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.utilities.rank_zero import rank_zero_warn -from pytorch_lightning.utilities.types import _PATH log = logging.getLogger(__name__) diff --git a/src/pytorch_lightning/plugins/io/xla_plugin.py b/src/pytorch_lightning/plugins/io/xla_plugin.py index 791e1e0683..88d8c2bcb7 100644 --- a/src/pytorch_lightning/plugins/io/xla_plugin.py +++ b/src/pytorch_lightning/plugins/io/xla_plugin.py @@ -17,9 +17,9 @@ from typing import Any, Dict, Optional from lightning_utilities.core.apply_func import apply_to_collection from lightning_lite.utilities.cloud_io import get_filesystem +from lightning_lite.utilities.types import _PATH from pytorch_lightning.plugins.io.torch_plugin import TorchCheckpointIO from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, _TPU_AVAILABLE -from pytorch_lightning.utilities.types import _PATH if _TPU_AVAILABLE: import torch_xla.core.xla_model as xm diff --git a/src/pytorch_lightning/plugins/precision/apex_amp.py b/src/pytorch_lightning/plugins/precision/apex_amp.py index d85dceb53a..0416e216f6 100644 --- a/src/pytorch_lightning/plugins/precision/apex_amp.py +++ b/src/pytorch_lightning/plugins/precision/apex_amp.py @@ -18,10 +18,10 @@ from torch.nn import Module from torch.optim import LBFGS, Optimizer import pytorch_lightning as pl +from lightning_lite.utilities.types import _PARAMETERS from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin from pytorch_lightning.utilities import _APEX_AVAILABLE, AMPType from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.types import _PARAMETERS if _APEX_AVAILABLE: from apex import amp diff --git a/src/pytorch_lightning/plugins/precision/deepspeed.py b/src/pytorch_lightning/plugins/precision/deepspeed.py index 1a83e9538d..658e66cd1b 100644 --- a/src/pytorch_lightning/plugins/precision/deepspeed.py +++ b/src/pytorch_lightning/plugins/precision/deepspeed.py @@ -20,9 +20,9 @@ from torch.nn import Module from torch.optim import LBFGS, Optimizer import pytorch_lightning as pl +from lightning_lite.utilities.enums import AMPType, PrecisionType from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin from pytorch_lightning.utilities import GradClipAlgorithmType -from pytorch_lightning.utilities.enums import AMPType, PrecisionType from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _APEX_AVAILABLE from pytorch_lightning.utilities.model_helpers import is_overridden diff --git a/src/pytorch_lightning/plugins/precision/fsdp_native_native_amp.py b/src/pytorch_lightning/plugins/precision/fsdp_native_native_amp.py index a5b26d7dec..ce372a1f04 100644 --- a/src/pytorch_lightning/plugins/precision/fsdp_native_native_amp.py +++ b/src/pytorch_lightning/plugins/precision/fsdp_native_native_amp.py @@ -15,8 +15,8 @@ from typing import Any, Optional, Union import torch +from lightning_lite.utilities.enums import PrecisionType from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin -from pytorch_lightning.utilities.enums import PrecisionType from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_12 diff --git a/src/pytorch_lightning/plugins/precision/hpu.py b/src/pytorch_lightning/plugins/precision/hpu.py index 4f8db7dabb..170372ad4e 100644 --- a/src/pytorch_lightning/plugins/precision/hpu.py +++ b/src/pytorch_lightning/plugins/precision/hpu.py @@ -13,8 +13,8 @@ # limitations under the License. from typing import Optional, Union +from lightning_lite.utilities.enums import PrecisionType from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin -from pytorch_lightning.utilities.enums import PrecisionType from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _HPU_AVAILABLE diff --git a/src/pytorch_lightning/plugins/precision/ipu.py b/src/pytorch_lightning/plugins/precision/ipu.py index 34ad358793..2b01dd010f 100644 --- a/src/pytorch_lightning/plugins/precision/ipu.py +++ b/src/pytorch_lightning/plugins/precision/ipu.py @@ -18,9 +18,9 @@ from torch.nn import Module from torch.optim import LBFGS, Optimizer import pytorch_lightning as pl +from lightning_lite.utilities.enums import PrecisionType from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin from pytorch_lightning.utilities import GradClipAlgorithmType -from pytorch_lightning.utilities.enums import PrecisionType from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden diff --git a/src/pytorch_lightning/plugins/precision/precision_plugin.py b/src/pytorch_lightning/plugins/precision/precision_plugin.py index 285a0f31e3..063c8cabb7 100644 --- a/src/pytorch_lightning/plugins/precision/precision_plugin.py +++ b/src/pytorch_lightning/plugins/precision/precision_plugin.py @@ -21,9 +21,9 @@ from torch.nn import Module from torch.optim import Optimizer import pytorch_lightning as pl +from lightning_lite.utilities.types import _PARAMETERS from pytorch_lightning.core.hooks import CheckpointHooks from pytorch_lightning.utilities import grad_norm, GradClipAlgorithmType -from pytorch_lightning.utilities.types import _PARAMETERS class PrecisionPlugin(CheckpointHooks): diff --git a/src/pytorch_lightning/profiler/advanced.py b/src/pytorch_lightning/profiler/advanced.py index 1d2bbed5d9..d0456f7afa 100644 --- a/src/pytorch_lightning/profiler/advanced.py +++ b/src/pytorch_lightning/profiler/advanced.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from pytorch_lightning.profilers.advanced import AdvancedProfiler as NewAdvancedProfiler -from pytorch_lightning.utilities import rank_zero_deprecation +from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation class AdvancedProfiler(NewAdvancedProfiler): diff --git a/src/pytorch_lightning/profiler/profiler.py b/src/pytorch_lightning/profiler/profiler.py index 84bea3ecae..40d18e79a3 100644 --- a/src/pytorch_lightning/profiler/profiler.py +++ b/src/pytorch_lightning/profiler/profiler.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from pytorch_lightning.profilers.profiler import Profiler as NewProfiler -from pytorch_lightning.utilities import rank_zero_deprecation +from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation class Profiler(NewProfiler): diff --git a/src/pytorch_lightning/profiler/pytorch.py b/src/pytorch_lightning/profiler/pytorch.py index d443059912..488ce3b654 100644 --- a/src/pytorch_lightning/profiler/pytorch.py +++ b/src/pytorch_lightning/profiler/pytorch.py @@ -14,7 +14,7 @@ from pytorch_lightning.profilers.pytorch import PyTorchProfiler as NewPyTorchProfiler from pytorch_lightning.profilers.pytorch import RegisterRecordFunction as NewRegisterRecordFuncion from pytorch_lightning.profilers.pytorch import ScheduleWrapper as NewScheduleWrapper -from pytorch_lightning.utilities import rank_zero_deprecation +from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation class RegisterRecordFunction(NewRegisterRecordFuncion): diff --git a/src/pytorch_lightning/profiler/simple.py b/src/pytorch_lightning/profiler/simple.py index 61ef7da8ae..9438f516b2 100644 --- a/src/pytorch_lightning/profiler/simple.py +++ b/src/pytorch_lightning/profiler/simple.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from pytorch_lightning.profilers.simple import SimpleProfiler as NewSimpleProfiler -from pytorch_lightning.utilities import rank_zero_deprecation +from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation class SimpleProfiler(NewSimpleProfiler): diff --git a/src/pytorch_lightning/profiler/xla.py b/src/pytorch_lightning/profiler/xla.py index dde858e99e..0cdc019600 100644 --- a/src/pytorch_lightning/profiler/xla.py +++ b/src/pytorch_lightning/profiler/xla.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from pytorch_lightning.profilers.xla import XLAProfiler as NewXLAProfiler -from pytorch_lightning.utilities import rank_zero_deprecation +from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation class XLAProfiler(NewXLAProfiler): diff --git a/src/pytorch_lightning/profilers/pytorch.py b/src/pytorch_lightning/profilers/pytorch.py index f410230668..c7f34fdc79 100644 --- a/src/pytorch_lightning/profilers/pytorch.py +++ b/src/pytorch_lightning/profilers/pytorch.py @@ -24,8 +24,8 @@ from lightning_utilities.core.rank_zero import WarningCache from torch import nn, Tensor from torch.autograd.profiler import record_function +from lightning_lite.utilities.device_parser import is_cuda_available from pytorch_lightning.profilers.profiler import Profiler -from pytorch_lightning.utilities.device_parser import is_cuda_available from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _KINETO_AVAILABLE from pytorch_lightning.utilities.rank_zero import rank_zero_warn diff --git a/src/pytorch_lightning/strategies/bagua.py b/src/pytorch_lightning/strategies/bagua.py index fd5f7b1319..a54267a32b 100644 --- a/src/pytorch_lightning/strategies/bagua.py +++ b/src/pytorch_lightning/strategies/bagua.py @@ -8,6 +8,9 @@ from torch import Tensor from torch.nn import Module import pytorch_lightning as pl +from lightning_lite.utilities.distributed import ReduceOp +from lightning_lite.utilities.optimizer import optimizers_to_device +from lightning_lite.utilities.seed import reset_seed from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO @@ -15,10 +18,7 @@ from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.strategies.ddp import DDPStrategy from pytorch_lightning.strategies.strategy import TBroadcast from pytorch_lightning.trainer.states import TrainerFn -from pytorch_lightning.utilities.distributed import ReduceOp from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.optimizer import optimizers_to_device -from pytorch_lightning.utilities.seed import reset_seed _BAGUA_AVAILABLE = package_available("bagua") diff --git a/src/pytorch_lightning/strategies/ddp.py b/src/pytorch_lightning/strategies/ddp.py index 2cfdbab357..c0eaf47ff8 100644 --- a/src/pytorch_lightning/strategies/ddp.py +++ b/src/pytorch_lightning/strategies/ddp.py @@ -29,6 +29,15 @@ from torch.nn.parallel.distributed import DistributedDataParallel from torch.optim.optimizer import Optimizer import pytorch_lightning as pl +from lightning_lite.utilities.distributed import ( + _get_process_group_backend_from_env, + distributed_available, + get_default_process_group_backend_for_device, +) +from lightning_lite.utilities.distributed import group as _group +from lightning_lite.utilities.distributed import init_dist_connection, ReduceOp, sync_ddp_if_available +from lightning_lite.utilities.optimizer import optimizers_to_device +from lightning_lite.utilities.seed import reset_seed from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.overrides import LightningDistributedModule from pytorch_lightning.overrides.base import _LightningPrecisionModuleWrapperBase @@ -41,23 +50,10 @@ from pytorch_lightning.strategies.launchers.subprocess_script import _Subprocess from pytorch_lightning.strategies.parallel import ParallelStrategy from pytorch_lightning.strategies.strategy import TBroadcast from pytorch_lightning.trainer.states import TrainerFn -from pytorch_lightning.utilities.distributed import ( - _get_process_group_backend_from_env, - distributed_available, - get_default_process_group_backend_for_device, -) -from pytorch_lightning.utilities.distributed import group as _group -from pytorch_lightning.utilities.distributed import ( - init_dist_connection, - ReduceOp, - register_ddp_comm_hook, - sync_ddp_if_available, -) +from pytorch_lightning.utilities.distributed import register_ddp_comm_hook from pytorch_lightning.utilities.exceptions import DeadlockDetectedException from pytorch_lightning.utilities.imports import _IS_WINDOWS, _TORCH_GREATER_EQUAL_1_10, _TORCH_GREATER_EQUAL_1_11 -from pytorch_lightning.utilities.optimizer import optimizers_to_device from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_only, rank_zero_warn -from pytorch_lightning.utilities.seed import reset_seed from pytorch_lightning.utilities.types import PredictStep, STEP_OUTPUT, TestStep, ValidationStep if _FAIRSCALE_AVAILABLE: diff --git a/src/pytorch_lightning/strategies/ddp_spawn.py b/src/pytorch_lightning/strategies/ddp_spawn.py index 2eea8f11f1..35d9049813 100644 --- a/src/pytorch_lightning/strategies/ddp_spawn.py +++ b/src/pytorch_lightning/strategies/ddp_spawn.py @@ -24,6 +24,14 @@ from torch.nn.parallel.distributed import DistributedDataParallel from typing_extensions import Literal import pytorch_lightning as pl +from lightning_lite.utilities.distributed import ( + _get_process_group_backend_from_env, + distributed_available, + get_default_process_group_backend_for_device, +) +from lightning_lite.utilities.distributed import group as _group +from lightning_lite.utilities.distributed import init_dist_connection, ReduceOp, sync_ddp_if_available +from lightning_lite.utilities.optimizer import optimizers_to_device from pytorch_lightning.overrides import LightningDistributedModule from pytorch_lightning.overrides.base import _LightningPrecisionModuleWrapperBase from pytorch_lightning.overrides.distributed import prepare_for_backward @@ -34,20 +42,8 @@ from pytorch_lightning.strategies.launchers.multiprocessing import _MultiProcess from pytorch_lightning.strategies.parallel import ParallelStrategy from pytorch_lightning.strategies.strategy import TBroadcast from pytorch_lightning.trainer.states import TrainerFn -from pytorch_lightning.utilities.distributed import ( - _get_process_group_backend_from_env, - distributed_available, - get_default_process_group_backend_for_device, -) -from pytorch_lightning.utilities.distributed import group as _group -from pytorch_lightning.utilities.distributed import ( - init_dist_connection, - ReduceOp, - register_ddp_comm_hook, - sync_ddp_if_available, -) +from pytorch_lightning.utilities.distributed import register_ddp_comm_hook from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_11 -from pytorch_lightning.utilities.optimizer import optimizers_to_device from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_only from pytorch_lightning.utilities.types import PredictStep, STEP_OUTPUT, TestStep, ValidationStep diff --git a/src/pytorch_lightning/strategies/deepspeed.py b/src/pytorch_lightning/strategies/deepspeed.py index 1d1c687507..46634f0012 100644 --- a/src/pytorch_lightning/strategies/deepspeed.py +++ b/src/pytorch_lightning/strategies/deepspeed.py @@ -30,6 +30,15 @@ from torch.nn import Module from torch.optim import Optimizer import pytorch_lightning as pl +from lightning_lite.utilities.distributed import ( + _get_process_group_backend_from_env, + get_default_process_group_backend_for_device, + log, +) +from lightning_lite.utilities.enums import AMPType, PrecisionType +from lightning_lite.utilities.optimizer import optimizers_to_device +from lightning_lite.utilities.seed import reset_seed +from lightning_lite.utilities.types import _LRScheduler, _PATH, ReduceLROnPlateau from pytorch_lightning.accelerators.cuda import CUDAAccelerator from pytorch_lightning.core.optimizer import _init_optimizers_and_lr_schedulers from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase @@ -39,18 +48,10 @@ from pytorch_lightning.strategies.ddp import DDPStrategy from pytorch_lightning.strategies.utils import _fp_to_half from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import GradClipAlgorithmType -from pytorch_lightning.utilities.distributed import ( - _get_process_group_backend_from_env, - get_default_process_group_backend_for_device, - log, -) -from pytorch_lightning.utilities.enums import AMPType, PrecisionType from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden -from pytorch_lightning.utilities.optimizer import optimizers_to_device from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info, rank_zero_warn -from pytorch_lightning.utilities.seed import reset_seed -from pytorch_lightning.utilities.types import _LRScheduler, _PATH, LRSchedulerConfig, ReduceLROnPlateau, STEP_OUTPUT +from pytorch_lightning.utilities.types import LRSchedulerConfig, STEP_OUTPUT warning_cache = WarningCache() diff --git a/src/pytorch_lightning/strategies/dp.py b/src/pytorch_lightning/strategies/dp.py index a377171982..1724f0021d 100644 --- a/src/pytorch_lightning/strategies/dp.py +++ b/src/pytorch_lightning/strategies/dp.py @@ -19,13 +19,13 @@ from torch import Tensor from torch.nn import DataParallel, Module import pytorch_lightning as pl +from lightning_lite.utilities.distributed import ReduceOp from pytorch_lightning.overrides.base import _LightningPrecisionModuleWrapperBase from pytorch_lightning.overrides.data_parallel import LightningParallelModule from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.strategies.parallel import ParallelStrategy from pytorch_lightning.strategies.strategy import TBroadcast, TReduce -from pytorch_lightning.utilities.distributed import ReduceOp from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.types import STEP_OUTPUT diff --git a/src/pytorch_lightning/strategies/fully_sharded.py b/src/pytorch_lightning/strategies/fully_sharded.py index a364d7d19a..add78dc35e 100644 --- a/src/pytorch_lightning/strategies/fully_sharded.py +++ b/src/pytorch_lightning/strategies/fully_sharded.py @@ -18,6 +18,8 @@ from typing import Any, Dict, Generator, List, Optional import torch import pytorch_lightning as pl +from lightning_lite.utilities.enums import PrecisionType +from lightning_lite.utilities.optimizer import optimizers_to_device from pytorch_lightning.overrides.base import _LightningModuleWrapperBase from pytorch_lightning.overrides.fairscale import _FAIRSCALE_AVAILABLE from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment @@ -25,10 +27,8 @@ from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.strategies.ddp import DDPStrategy from pytorch_lightning.trainer.states import TrainerFn -from pytorch_lightning.utilities.enums import PrecisionType from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden -from pytorch_lightning.utilities.optimizer import optimizers_to_device from pytorch_lightning.utilities.rank_zero import rank_zero_info from pytorch_lightning.utilities.types import STEP_OUTPUT diff --git a/src/pytorch_lightning/strategies/fully_sharded_native.py b/src/pytorch_lightning/strategies/fully_sharded_native.py index b32f460ee1..ed7c237c9b 100644 --- a/src/pytorch_lightning/strategies/fully_sharded_native.py +++ b/src/pytorch_lightning/strategies/fully_sharded_native.py @@ -19,6 +19,14 @@ import torch from torch import Tensor import pytorch_lightning as pl +from lightning_lite.utilities.distributed import ( + _get_process_group_backend_from_env, + get_default_process_group_backend_for_device, +) +from lightning_lite.utilities.distributed import group as _group +from lightning_lite.utilities.distributed import init_dist_connection, ReduceOp, sync_ddp_if_available +from lightning_lite.utilities.optimizer import optimizers_to_device +from lightning_lite.utilities.seed import reset_seed from pytorch_lightning.overrides.base import _LightningModuleWrapperBase from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO @@ -28,19 +36,10 @@ from pytorch_lightning.strategies.launchers.subprocess_script import _Subprocess from pytorch_lightning.strategies.parallel import ParallelStrategy from pytorch_lightning.strategies.strategy import TBroadcast from pytorch_lightning.trainer.states import TrainerFn -from pytorch_lightning.utilities import rank_zero_only -from pytorch_lightning.utilities.distributed import ( - _get_process_group_backend_from_env, - get_default_process_group_backend_for_device, -) -from pytorch_lightning.utilities.distributed import group as _group -from pytorch_lightning.utilities.distributed import init_dist_connection, ReduceOp, sync_ddp_if_available from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_12 from pytorch_lightning.utilities.model_helpers import is_overridden -from pytorch_lightning.utilities.optimizer import optimizers_to_device -from pytorch_lightning.utilities.rank_zero import rank_zero_info -from pytorch_lightning.utilities.seed import reset_seed +from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_only from pytorch_lightning.utilities.types import ProcessGroup, STEP_OUTPUT _distributed_available = torch.distributed.is_available() diff --git a/src/pytorch_lightning/strategies/hivemind.py b/src/pytorch_lightning/strategies/hivemind.py index b258fe7f73..7cad027ac6 100644 --- a/src/pytorch_lightning/strategies/hivemind.py +++ b/src/pytorch_lightning/strategies/hivemind.py @@ -8,14 +8,14 @@ import torch from torch import Tensor import pytorch_lightning as pl +from lightning_lite.utilities.enums import PrecisionType +from lightning_lite.utilities.types import _LRScheduler, ReduceLROnPlateau from pytorch_lightning.strategies.strategy import Strategy, TBroadcast -from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.data import extract_batch_size -from pytorch_lightning.utilities.enums import PrecisionType from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _HIVEMIND_AVAILABLE from pytorch_lightning.utilities.model_helpers import is_overridden -from pytorch_lightning.utilities.types import _LRScheduler, ReduceLROnPlateau +from pytorch_lightning.utilities.rank_zero import rank_zero_warn if _HIVEMIND_AVAILABLE: import hivemind diff --git a/src/pytorch_lightning/strategies/horovod.py b/src/pytorch_lightning/strategies/horovod.py index 6329d1e409..27793306fb 100644 --- a/src/pytorch_lightning/strategies/horovod.py +++ b/src/pytorch_lightning/strategies/horovod.py @@ -20,14 +20,14 @@ from torch import Tensor from torch.optim import Optimizer import pytorch_lightning as pl +from lightning_lite.utilities.distributed import distributed_available +from lightning_lite.utilities.distributed import group as dist_group +from lightning_lite.utilities.distributed import ReduceOp from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.strategies.parallel import ParallelStrategy from pytorch_lightning.strategies.strategy import TBroadcast -from pytorch_lightning.utilities.distributed import distributed_available -from pytorch_lightning.utilities.distributed import group as dist_group -from pytorch_lightning.utilities.distributed import ReduceOp from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _HOROVOD_AVAILABLE from pytorch_lightning.utilities.rank_zero import rank_zero_only diff --git a/src/pytorch_lightning/strategies/hpu_parallel.py b/src/pytorch_lightning/strategies/hpu_parallel.py index 3e6f8e932e..e7c18d3471 100644 --- a/src/pytorch_lightning/strategies/hpu_parallel.py +++ b/src/pytorch_lightning/strategies/hpu_parallel.py @@ -18,6 +18,7 @@ from typing import Any, Callable, Dict, List, Optional import torch.distributed import pytorch_lightning as pl +from lightning_lite.utilities.distributed import group as _group from pytorch_lightning.overrides import LightningDistributedModule from pytorch_lightning.overrides.torch_distributed import broadcast_object_list from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment @@ -26,7 +27,6 @@ from pytorch_lightning.plugins.io.hpu_plugin import HPUCheckpointIO from pytorch_lightning.plugins.io.wrapper import _WrappingCheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.strategies.ddp import DDPStrategy -from pytorch_lightning.utilities.distributed import group as _group from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _HPU_AVAILABLE, _TORCH_LESSER_EQUAL_1_10_2 from pytorch_lightning.utilities.types import STEP_OUTPUT diff --git a/src/pytorch_lightning/strategies/ipu.py b/src/pytorch_lightning/strategies/ipu.py index 2d976e545d..69de604971 100644 --- a/src/pytorch_lightning/strategies/ipu.py +++ b/src/pytorch_lightning/strategies/ipu.py @@ -22,6 +22,7 @@ from torch.utils.data import DataLoader, Sampler import pytorch_lightning as pl from lightning_lite.utilities.cloud_io import get_filesystem +from lightning_lite.utilities.enums import PrecisionType from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO @@ -32,7 +33,6 @@ from pytorch_lightning.strategies.utils import _fp_to_half from pytorch_lightning.trainer.states import RunningStage, TrainerFn from pytorch_lightning.utilities import _IPU_AVAILABLE, _POPTORCH_AVAILABLE, rank_zero_warn from pytorch_lightning.utilities.data import _get_dataloader_init_args_and_kwargs, _reinstantiate_wrapped_cls -from pytorch_lightning.utilities.enums import PrecisionType from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation diff --git a/src/pytorch_lightning/strategies/launchers/multiprocessing.py b/src/pytorch_lightning/strategies/launchers/multiprocessing.py index 6bf81eb72d..31508067ab 100644 --- a/src/pytorch_lightning/strategies/launchers/multiprocessing.py +++ b/src/pytorch_lightning/strategies/launchers/multiprocessing.py @@ -27,13 +27,13 @@ from typing_extensions import Literal import pytorch_lightning as pl from lightning_lite.utilities.apply_func import move_data_to_device +from lightning_lite.utilities.seed import _collect_rng_states, _set_rng_states +from lightning_lite.utilities.types import _PATH from pytorch_lightning.strategies.launchers.base import _Launcher from pytorch_lightning.strategies.strategy import Strategy from pytorch_lightning.trainer.states import TrainerFn, TrainerState from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_11 from pytorch_lightning.utilities.rank_zero import rank_zero_debug -from pytorch_lightning.utilities.seed import _collect_rng_states, _set_rng_states -from pytorch_lightning.utilities.types import _PATH class _MultiProcessingLauncher(_Launcher): diff --git a/src/pytorch_lightning/strategies/parallel.py b/src/pytorch_lightning/strategies/parallel.py index 0790b5e75e..124d01f362 100644 --- a/src/pytorch_lightning/strategies/parallel.py +++ b/src/pytorch_lightning/strategies/parallel.py @@ -19,17 +19,17 @@ import torch from torch import Tensor import pytorch_lightning as pl -from pytorch_lightning.plugins import LayerSync -from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment -from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO -from pytorch_lightning.plugins.precision import PrecisionPlugin -from pytorch_lightning.strategies.strategy import Strategy -from pytorch_lightning.utilities.distributed import ( +from lightning_lite.utilities.distributed import ( _get_process_group_backend_from_env, all_gather_ddp_if_available, get_default_process_group_backend_for_device, ReduceOp, ) +from pytorch_lightning.plugins import LayerSync +from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment +from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO +from pytorch_lightning.plugins.precision import PrecisionPlugin +from pytorch_lightning.strategies.strategy import Strategy from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation diff --git a/src/pytorch_lightning/strategies/sharded.py b/src/pytorch_lightning/strategies/sharded.py index 22a1c22e96..df0d126385 100644 --- a/src/pytorch_lightning/strategies/sharded.py +++ b/src/pytorch_lightning/strategies/sharded.py @@ -19,14 +19,14 @@ from torch.nn import Module from torch.optim import Optimizer import pytorch_lightning as pl +from lightning_lite.utilities.enums import PrecisionType +from lightning_lite.utilities.optimizer import optimizers_to_device from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase from pytorch_lightning.overrides.fairscale import _FAIRSCALE_AVAILABLE from pytorch_lightning.strategies.ddp import DDPStrategy from pytorch_lightning.trainer.states import TrainerFn -from pytorch_lightning.utilities.enums import PrecisionType from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.optimizer import optimizers_to_device if _FAIRSCALE_AVAILABLE: from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel diff --git a/src/pytorch_lightning/strategies/sharded_spawn.py b/src/pytorch_lightning/strategies/sharded_spawn.py index b5cd9497a3..438f6d5eb6 100644 --- a/src/pytorch_lightning/strategies/sharded_spawn.py +++ b/src/pytorch_lightning/strategies/sharded_spawn.py @@ -19,12 +19,12 @@ from torch.nn import Module from torch.optim import Optimizer import pytorch_lightning as pl +from lightning_lite.utilities.optimizer import optimizers_to_device from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase from pytorch_lightning.overrides.fairscale import _FAIRSCALE_AVAILABLE from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.optimizer import optimizers_to_device if _FAIRSCALE_AVAILABLE: from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel diff --git a/src/pytorch_lightning/strategies/single_device.py b/src/pytorch_lightning/strategies/single_device.py index cb436fded8..a9d5d7ca87 100644 --- a/src/pytorch_lightning/strategies/single_device.py +++ b/src/pytorch_lightning/strategies/single_device.py @@ -19,10 +19,10 @@ import torch from torch import Tensor import pytorch_lightning as pl +from lightning_lite.utilities.types import _DEVICE from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.strategies.strategy import Strategy, TBroadcast -from pytorch_lightning.utilities.types import _DEVICE class SingleDeviceStrategy(Strategy): diff --git a/src/pytorch_lightning/strategies/single_hpu.py b/src/pytorch_lightning/strategies/single_hpu.py index 45eb8c58f2..5c29829fa6 100644 --- a/src/pytorch_lightning/strategies/single_hpu.py +++ b/src/pytorch_lightning/strategies/single_hpu.py @@ -15,6 +15,7 @@ from typing import Dict, Optional import pytorch_lightning as pl +from lightning_lite.utilities.types import _DEVICE from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.io.hpu_plugin import HPUCheckpointIO from pytorch_lightning.plugins.io.wrapper import _WrappingCheckpointIO @@ -22,7 +23,7 @@ from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.strategies.single_device import SingleDeviceStrategy from pytorch_lightning.utilities import _HPU_AVAILABLE from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.types import _DEVICE, STEP_OUTPUT +from pytorch_lightning.utilities.types import STEP_OUTPUT if _HPU_AVAILABLE: import habana_frameworks.torch.core as htcore diff --git a/src/pytorch_lightning/strategies/strategy.py b/src/pytorch_lightning/strategies/strategy.py index 0a10722166..bb63c60269 100644 --- a/src/pytorch_lightning/strategies/strategy.py +++ b/src/pytorch_lightning/strategies/strategy.py @@ -24,6 +24,9 @@ from torch.utils.data import DataLoader import pytorch_lightning as pl from lightning_lite.utilities.apply_func import move_data_to_device +from lightning_lite.utilities.distributed import ReduceOp +from lightning_lite.utilities.optimizer import optimizer_to_device, optimizers_to_device +from lightning_lite.utilities.types import _PATH from pytorch_lightning.core.optimizer import _init_optimizers_and_lr_schedulers, LightningOptimizer from pytorch_lightning.plugins import TorchCheckpointIO from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO @@ -31,10 +34,7 @@ from pytorch_lightning.plugins.io.wrapper import _WrappingCheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.strategies.launchers.base import _Launcher from pytorch_lightning.trainer.states import TrainerFn -from pytorch_lightning.utilities.distributed import ReduceOp -from pytorch_lightning.utilities.optimizer import optimizer_to_device, optimizers_to_device from pytorch_lightning.utilities.types import ( - _PATH, LRSchedulerConfig, PredictStep, STEP_OUTPUT, diff --git a/src/pytorch_lightning/strategies/strategy_registry.py b/src/pytorch_lightning/strategies/strategy_registry.py index 7dee7146d4..43089b735a 100644 --- a/src/pytorch_lightning/strategies/strategy_registry.py +++ b/src/pytorch_lightning/strategies/strategy_registry.py @@ -15,9 +15,9 @@ import importlib from inspect import getmembers, isclass from typing import Any, Callable, Dict, List, Optional +from lightning_lite.utilities.registry import _is_register_method_overridden from pytorch_lightning.strategies.strategy import Strategy from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.registry import _is_register_method_overridden class _StrategyRegistry(dict): diff --git a/src/pytorch_lightning/strategies/tpu_spawn.py b/src/pytorch_lightning/strategies/tpu_spawn.py index 748406479b..52dec94ac3 100644 --- a/src/pytorch_lightning/strategies/tpu_spawn.py +++ b/src/pytorch_lightning/strategies/tpu_spawn.py @@ -22,6 +22,10 @@ from torch.nn import Module from torch.utils.data import DataLoader import pytorch_lightning as pl +from lightning_lite.utilities.data import has_len +from lightning_lite.utilities.distributed import ReduceOp +from lightning_lite.utilities.optimizer import optimizers_to_device +from lightning_lite.utilities.types import _PATH from pytorch_lightning.overrides import LightningDistributedModule from pytorch_lightning.plugins.environments import XLAEnvironment from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO @@ -34,12 +38,9 @@ from pytorch_lightning.strategies.strategy import TBroadcast from pytorch_lightning.trainer.connectors.data_connector import DataConnector from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _TPU_AVAILABLE, find_shared_parameters, set_shared_parameters -from pytorch_lightning.utilities.data import has_len -from pytorch_lightning.utilities.distributed import ReduceOp from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.optimizer import optimizers_to_device from pytorch_lightning.utilities.rank_zero import rank_zero_only -from pytorch_lightning.utilities.types import _PATH, EVAL_DATALOADERS, STEP_OUTPUT, TRAIN_DATALOADERS +from pytorch_lightning.utilities.types import EVAL_DATALOADERS, STEP_OUTPUT, TRAIN_DATALOADERS if _TPU_AVAILABLE: import torch_xla.core.xla_env_vars as xenv diff --git a/src/pytorch_lightning/strategies/utils.py b/src/pytorch_lightning/strategies/utils.py index ec7a1bd6ff..3c3ebbe241 100644 --- a/src/pytorch_lightning/strategies/utils.py +++ b/src/pytorch_lightning/strategies/utils.py @@ -15,7 +15,7 @@ import os import torch -from pytorch_lightning.utilities.enums import PrecisionType +from lightning_lite.utilities.enums import PrecisionType from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation diff --git a/src/pytorch_lightning/trainer/__init__.py b/src/pytorch_lightning/trainer/__init__.py index 6226a75de4..b53effd6e7 100644 --- a/src/pytorch_lightning/trainer/__init__.py +++ b/src/pytorch_lightning/trainer/__init__.py @@ -13,7 +13,7 @@ # limitations under the License. """""" +from lightning_lite.utilities.seed import seed_everything from pytorch_lightning.trainer.trainer import Trainer -from pytorch_lightning.utilities.seed import seed_everything __all__ = ["Trainer", "seed_everything"] diff --git a/src/pytorch_lightning/trainer/configuration_validator.py b/src/pytorch_lightning/trainer/configuration_validator.py index 6ec2b15a11..f1d86995d1 100644 --- a/src/pytorch_lightning/trainer/configuration_validator.py +++ b/src/pytorch_lightning/trainer/configuration_validator.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import pytorch_lightning as pl +from lightning_lite.utilities.warnings import PossibleUserWarning from pytorch_lightning.accelerators.ipu import IPUAccelerator from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin from pytorch_lightning.strategies import DataParallelStrategy @@ -20,7 +21,6 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature -from pytorch_lightning.utilities.warnings import PossibleUserWarning def verify_loop_configurations(trainer: "pl.Trainer") -> None: diff --git a/src/pytorch_lightning/trainer/connectors/accelerator_connector.py b/src/pytorch_lightning/trainer/connectors/accelerator_connector.py index e9183f7f52..f3be6caa5b 100644 --- a/src/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/src/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -20,6 +20,7 @@ from typing import Dict, List, Optional, Union import torch from typing_extensions import Literal +from lightning_lite.utilities import _StrategyType, AMPType, device_parser, LightningEnum from pytorch_lightning.accelerators.accelerator import Accelerator from pytorch_lightning.accelerators.cpu import CPUAccelerator from pytorch_lightning.accelerators.cuda import CUDAAccelerator @@ -75,15 +76,6 @@ from pytorch_lightning.strategies import ( from pytorch_lightning.strategies.ddp_spawn import _DDP_FORK_ALIASES from pytorch_lightning.strategies.launchers.multiprocessing import _is_forking_disabled from pytorch_lightning.tuner.auto_gpu_select import pick_multiple_gpus -from pytorch_lightning.utilities import ( - _StrategyType, - AMPType, - device_parser, - LightningEnum, - rank_zero_deprecation, - rank_zero_info, - rank_zero_warn, -) from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import ( _HOROVOD_AVAILABLE, @@ -93,6 +85,7 @@ from pytorch_lightning.utilities.imports import ( _TORCH_GREATER_EQUAL_1_11, _TPU_AVAILABLE, ) +from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info, rank_zero_warn log = logging.getLogger(__name__) diff --git a/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 647d505eb3..300f3c1292 100644 --- a/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -24,6 +24,7 @@ from torchmetrics import Metric import pytorch_lightning as pl from lightning_lite.utilities.cloud_io import get_filesystem +from lightning_lite.utilities.types import _PATH from pytorch_lightning.plugins.precision import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE @@ -31,7 +32,6 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _fault_tolerant_training from pytorch_lightning.utilities.migration import pl_legacy_patch from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info -from pytorch_lightning.utilities.types import _PATH from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS if _OMEGACONF_AVAILABLE: diff --git a/src/pytorch_lightning/trainer/connectors/data_connector.py b/src/pytorch_lightning/trainer/connectors/data_connector.py index bfb26228e3..56ba809e10 100644 --- a/src/pytorch_lightning/trainer/connectors/data_connector.py +++ b/src/pytorch_lightning/trainer/connectors/data_connector.py @@ -23,20 +23,14 @@ from torch.utils.data import BatchSampler, DataLoader, Sampler, SequentialSample from torch.utils.data.distributed import DistributedSampler import pytorch_lightning as pl +from lightning_lite.utilities.data import _auto_add_worker_init_fn, _replace_dunder_methods, has_iterable_dataset from pytorch_lightning.accelerators.ipu import IPUAccelerator from pytorch_lightning.overrides.distributed import DistributedSamplerWrapper, UnrepeatedDistributedSamplerWrapper from pytorch_lightning.strategies import DDPSpawnStrategy from pytorch_lightning.trainer.states import RunningStage, TrainerFn from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator from pytorch_lightning.utilities.auto_restart import _validate_fault_tolerant_automatic -from pytorch_lightning.utilities.data import ( - _auto_add_worker_init_fn, - _is_dataloader_shuffled, - _replace_dunder_methods, - _update_dataloader, - has_iterable_dataset, - has_len_all_ranks, -) +from pytorch_lightning.utilities.data import _is_dataloader_shuffled, _update_dataloader, has_len_all_ranks from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _fault_tolerant_training from pytorch_lightning.utilities.model_helpers import is_overridden diff --git a/src/pytorch_lightning/trainer/connectors/logger_connector/result.py b/src/pytorch_lightning/trainer/connectors/logger_connector/result.py index 9f1be4ba4b..4d856223d5 100644 --- a/src/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/src/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -24,8 +24,8 @@ from typing_extensions import TypedDict from lightning_lite.utilities.apply_func import move_data_to_device from lightning_lite.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin +from lightning_lite.utilities.distributed import distributed_available from pytorch_lightning.utilities.data import extract_batch_size -from pytorch_lightning.utilities.distributed import distributed_available from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _fault_tolerant_training from pytorch_lightning.utilities.memory import recursive_detach diff --git a/src/pytorch_lightning/trainer/data_loading.py b/src/pytorch_lightning/trainer/data_loading.py index e3a2fd4785..3163e7660c 100644 --- a/src/pytorch_lightning/trainer/data_loading.py +++ b/src/pytorch_lightning/trainer/data_loading.py @@ -18,7 +18,7 @@ from torch.utils.data import DataLoader import pytorch_lightning as pl from pytorch_lightning.trainer.states import RunningStage -from pytorch_lightning.utilities import rank_zero_deprecation +from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation class TrainerDataLoadingMixin(ABC): diff --git a/src/pytorch_lightning/trainer/optimizers.py b/src/pytorch_lightning/trainer/optimizers.py index 8e25fb5ac6..fcd37c4e27 100644 --- a/src/pytorch_lightning/trainer/optimizers.py +++ b/src/pytorch_lightning/trainer/optimizers.py @@ -19,7 +19,7 @@ from torch.optim import Optimizer import pytorch_lightning as pl from pytorch_lightning.core.optimizer import _init_optimizers_and_lr_schedulers, LightningOptimizer -from pytorch_lightning.utilities import rank_zero_deprecation +from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation class TrainerOptimizersMixin(ABC): diff --git a/src/pytorch_lightning/trainer/supporters.py b/src/pytorch_lightning/trainer/supporters.py index 3be43e37fe..e183bdcc64 100644 --- a/src/pytorch_lightning/trainer/supporters.py +++ b/src/pytorch_lightning/trainer/supporters.py @@ -21,13 +21,13 @@ from torch.utils.data import Dataset from torch.utils.data.dataloader import _BaseDataLoaderIter, _MultiProcessingDataLoaderIter, DataLoader from torch.utils.data.dataset import IterableDataset +from lightning_lite.utilities.distributed import distributed_available from pytorch_lightning.utilities.auto_restart import ( _reload_dataloader_state_dict, MergedIteratorState, patch_dataloader_iterator, ) from pytorch_lightning.utilities.data import get_len -from pytorch_lightning.utilities.distributed import distributed_available from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _fault_tolerant_training diff --git a/src/pytorch_lightning/trainer/trainer.py b/src/pytorch_lightning/trainer/trainer.py index 72caec4117..1859eab265 100644 --- a/src/pytorch_lightning/trainer/trainer.py +++ b/src/pytorch_lightning/trainer/trainer.py @@ -39,6 +39,10 @@ from torch.utils.data import DataLoader import pytorch_lightning as pl from lightning_lite.utilities.cloud_io import get_filesystem +from lightning_lite.utilities.data import _auto_add_worker_init_fn +from lightning_lite.utilities.distributed import distributed_available +from lightning_lite.utilities.types import _PATH +from lightning_lite.utilities.warnings import PossibleUserWarning from pytorch_lightning.accelerators import ( Accelerator, CUDAAccelerator, @@ -102,8 +106,7 @@ from pytorch_lightning.utilities.argparse import ( parse_env_variables, ) from pytorch_lightning.utilities.auto_restart import _add_capture_metadata_collate -from pytorch_lightning.utilities.data import _auto_add_worker_init_fn, has_len_all_ranks -from pytorch_lightning.utilities.distributed import distributed_available +from pytorch_lightning.utilities.data import has_len_all_ranks from pytorch_lightning.utilities.exceptions import ExitGracefullyException, MisconfigurationException from pytorch_lightning.utilities.imports import _fault_tolerant_training from pytorch_lightning.utilities.model_helpers import is_overridden @@ -111,13 +114,11 @@ from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_ze from pytorch_lightning.utilities.seed import isolate_rng from pytorch_lightning.utilities.types import ( _EVALUATE_OUTPUT, - _PATH, _PREDICT_OUTPUT, EVAL_DATALOADERS, LRSchedulerConfig, TRAIN_DATALOADERS, ) -from pytorch_lightning.utilities.warnings import PossibleUserWarning log = logging.getLogger(__name__) # warnings to ignore in trainer diff --git a/src/pytorch_lightning/tuner/auto_gpu_select.py b/src/pytorch_lightning/tuner/auto_gpu_select.py index a42e55a613..5b165c9d94 100644 --- a/src/pytorch_lightning/tuner/auto_gpu_select.py +++ b/src/pytorch_lightning/tuner/auto_gpu_select.py @@ -15,7 +15,7 @@ from typing import List import torch -from pytorch_lightning.utilities import device_parser +from lightning_lite.utilities import device_parser from pytorch_lightning.utilities.exceptions import MisconfigurationException diff --git a/src/pytorch_lightning/utilities/__init__.py b/src/pytorch_lightning/utilities/__init__.py index a0baa3a85f..c29ed71bd8 100644 --- a/src/pytorch_lightning/utilities/__init__.py +++ b/src/pytorch_lightning/utilities/__init__.py @@ -15,15 +15,9 @@ import numpy +from lightning_lite.utilities import AllGatherGrad, AMPType, LightningEnum # noqa: F401 from lightning_lite.utilities.apply_func import move_data_to_device # noqa: F401 -from pytorch_lightning.utilities.distributed import AllGatherGrad # noqa: F401 -from pytorch_lightning.utilities.enums import ( # noqa: F401 - _AcceleratorType, - _StrategyType, - AMPType, - GradClipAlgorithmType, - LightningEnum, -) +from pytorch_lightning.utilities.enums import GradClipAlgorithmType # noqa: F401 from pytorch_lightning.utilities.grads import grad_norm # noqa: F401 from pytorch_lightning.utilities.imports import ( # noqa: F401 _APEX_AVAILABLE, diff --git a/src/pytorch_lightning/utilities/auto_restart.py b/src/pytorch_lightning/utilities/auto_restart.py index 0f6fadfb26..d9d8c5da38 100644 --- a/src/pytorch_lightning/utilities/auto_restart.py +++ b/src/pytorch_lightning/utilities/auto_restart.py @@ -29,11 +29,11 @@ from torch.utils.data.dataloader import ( from typing_extensions import TypedDict import pytorch_lightning as pl +from lightning_lite.utilities.types import _Stateful from pytorch_lightning.utilities.distributed import _collect_states_on_rank_zero from pytorch_lightning.utilities.enums import _FaultTolerantMode, AutoRestartBatchKeys from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.seed import _collect_rng_states, _set_rng_states -from pytorch_lightning.utilities.types import _Stateful class _IteratorStateDict(TypedDict): diff --git a/src/pytorch_lightning/utilities/cloud_io.py b/src/pytorch_lightning/utilities/cloud_io.py index 735b2e95ed..4993b8d3d0 100644 --- a/src/pytorch_lightning/utilities/cloud_io.py +++ b/src/pytorch_lightning/utilities/cloud_io.py @@ -18,7 +18,7 @@ from typing import Any from lightning_lite.utilities.cloud_io import atomic_save as new_atomic_save from lightning_lite.utilities.cloud_io import get_filesystem as new_get_filesystem from lightning_lite.utilities.cloud_io import load as new_load -from pytorch_lightning.utilities import rank_zero_deprecation # TODO(lite): change to lightning_lite.utilities +from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation def atomic_save(*args: Any, **kwargs: Any) -> Any: diff --git a/src/pytorch_lightning/utilities/data.py b/src/pytorch_lightning/utilities/data.py index 59068b1a15..cf07949461 100644 --- a/src/pytorch_lightning/utilities/data.py +++ b/src/pytorch_lightning/utilities/data.py @@ -11,18 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import functools import inspect -import os -from collections import OrderedDict -from contextlib import contextmanager from dataclasses import fields -from functools import partial -from typing import Any, Callable, Dict, Generator, Iterable, Mapping, Optional, Tuple, Type, Union +from typing import Any, Dict, Generator, Iterable, Mapping, Optional, Tuple, Union import torch from lightning_utilities.core.apply_func import is_dataclass_instance -from lightning_utilities.core.inheritance import get_all_subclasses from lightning_utilities.core.rank_zero import WarningCache from torch import Tensor from torch.utils.data import ( @@ -36,13 +30,16 @@ from torch.utils.data import ( ) import pytorch_lightning as pl +from lightning_lite.utilities import LightningEnum +from lightning_lite.utilities.data import _reinstantiate_wrapped_cls, _replace_value_in_saved_args +from lightning_lite.utilities.data import has_iterable_dataset as new_has_iterable_dataset +from lightning_lite.utilities.data import has_len as new_has_len from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities.auto_restart import CaptureIterableDataset, CaptureMapDataset, FastForwardSampler -from pytorch_lightning.utilities.enums import _FaultTolerantMode, LightningEnum +from pytorch_lightning.utilities.enums import _FaultTolerantMode from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.rank_zero import rank_zero_warn -from pytorch_lightning.utilities.seed import pl_worker_init_function +from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn BType = Union[Tensor, str, Mapping[Any, "BType"], Iterable["BType"]] @@ -110,33 +107,6 @@ def extract_batch_size(batch: BType) -> int: return batch_size -def has_iterable_dataset(dataloader: DataLoader) -> bool: - return hasattr(dataloader, "dataset") and isinstance(dataloader.dataset, IterableDataset) - - -def has_len(dataloader: Union[DataLoader, Iterable]) -> bool: - """Checks if a given Dataloader has ``__len__`` method implemented i.e. if it is a finite dataloader or - infinite dataloader.""" - try: - # try getting the length - if len(dataloader) == 0: - rank_zero_warn( - f"`{dataloader.__class__.__name__}` returned 0 length. Please make sure this was your intention." - ) - has_len = True - except (TypeError, NotImplementedError): - has_len = False - - if has_len and has_iterable_dataset(dataloader): - rank_zero_warn( - "Your `IterableDataset` has `__len__` defined." - " In combination with multi-process data loading (when num_workers > 1)," - " `__len__` could be inaccurate if each worker is not configured independently" - " to avoid having duplicate data." - ) - return has_len - - def has_len_all_ranks( dataloader: DataLoader, strategy: "pl.Strategy", @@ -171,7 +141,7 @@ def has_len_all_ranks( except (TypeError, NotImplementedError): has_len = False - if has_len and has_iterable_dataset(dataloader): + if has_len and new_has_iterable_dataset(dataloader): rank_zero_warn( "Your `IterableDataset` has `__len__` defined." " In combination with multi-process data loading (when num_workers > 1)," @@ -187,7 +157,7 @@ def get_len(dataloader: DataLoader) -> Union[int, float]: If ``__len__`` method is not implemented, return float('inf'). """ - if has_len(dataloader): + if new_has_len(dataloader): return len(dataloader) return float("inf") @@ -409,171 +379,6 @@ def _dataloader_init_kwargs_resolve_sampler( return {"sampler": sampler, "shuffle": False, "batch_sampler": None} -def _replace_value_in_saved_args( - replace_key: str, - replace_value: Any, - args: Tuple[Any, ...], - kwargs: Dict[str, Any], - default_kwargs: Dict[str, Any], - arg_names: Tuple[str, ...], -) -> Tuple[bool, Tuple[Any, ...], Dict[str, Any]]: - """Tries to replace an argument value in a saved list of args and kwargs. - - Returns a tuple indicating success of the operation and modified saved args and kwargs - """ - - if replace_key in arg_names: - replace_index = arg_names.index(replace_key) - args = args[:replace_index] + (replace_value,) + args[replace_index + 1 :] - return True, args, kwargs - elif replace_key in kwargs or replace_key in default_kwargs: - kwargs[replace_key] = replace_value - return True, args, kwargs - - return False, args, kwargs - - -def _auto_add_worker_init_fn(dataloader: DataLoader, rank: int) -> None: - if int(os.environ.get("PL_SEED_WORKERS", 0)) and dataloader.worker_init_fn is None: - dataloader.worker_init_fn = partial(pl_worker_init_function, rank=rank) - - -def _reinstantiate_wrapped_cls(orig_object: Any, *args: Any, explicit_cls: Optional[Type] = None, **kwargs: Any) -> Any: - constructor = type(orig_object) if explicit_cls is None else explicit_cls - - try: - result = constructor(*args, **kwargs) - except TypeError as e: - # improve exception message due to an incorrect implementation of the `DataLoader` where multiple subclass - # `__init__` arguments map to one `DataLoader.__init__` argument - import re - - match = re.match(r".*__init__\(\) got multiple values .* '(\w+)'", str(e)) - if not match: - # an unexpected `TypeError`, continue failure - raise - argument = match.groups()[0] - message = ( - f"The {constructor.__name__} implementation has an error where more than one `__init__` argument" - f" can be passed to its parent's `{argument}=...` `__init__` argument. This is likely caused by allowing" - f" passing both a custom argument that will map to the `{argument}` argument as well as `**kwargs`." - f" `kwargs` should be filtered to make sure they don't contain the `{argument}` key." - " This argument was automatically passed to your object by PyTorch Lightning." - ) - raise MisconfigurationException(message) from e - - attrs_record = getattr(orig_object, "__pl_attrs_record", list()) - for args, fn in attrs_record: - fn(result, *args) - - return result - - -def _wrap_init_method(init: Callable, store_explicit_arg: Optional[str] = None) -> Callable: - """Wraps the ``__init__`` method of classes (currently :class:`~torch.utils.data.DataLoader` and - :class:`~torch.utils.data.BatchSampler`) in order to enable re-instantiation of custom subclasses.""" - - @functools.wraps(init) - def wrapper(obj: Any, *args: Any, **kwargs: Any) -> None: - # We need to inspect `init`, as inspecting `obj.__init__` - # can lead to inspecting the wrong function with multiple inheritance - old_inside_init = getattr(obj, "__pl_inside_init", False) - object.__setattr__(obj, "__pl_inside_init", True) - params = inspect.signature(init).parameters - - parameters_defaults = OrderedDict( - (param.name, param.default) - for param in params.values() - if param.name != "self" and param.kind not in (param.VAR_POSITIONAL, param.VAR_KEYWORD) - ) - - param_names = tuple(parameters_defaults)[: len(args)] - - default_kwargs = { - name: value - for name, value in parameters_defaults.items() - if name not in kwargs and name not in param_names and value != inspect.Parameter.empty - } - - if not hasattr(obj, "__pl_saved_args"): - object.__setattr__(obj, "__pl_saved_args", args) - object.__setattr__(obj, "__pl_saved_kwargs", kwargs) - object.__setattr__(obj, "__pl_saved_arg_names", param_names) - object.__setattr__(obj, "__pl_saved_default_kwargs", default_kwargs) - - # We want to use the latest possible value for explicit argument (i.e. ideally what gets passed to base class) - # so that we can be sure, that it will not get changed anymore. - # That is why we are setting this in every `__init__` - if store_explicit_arg is not None: - if store_explicit_arg in param_names: - object.__setattr__(obj, f"__{store_explicit_arg}", args[param_names.index(store_explicit_arg)]) - elif store_explicit_arg in kwargs: - object.__setattr__(obj, f"__{store_explicit_arg}", kwargs[store_explicit_arg]) - - init(obj, *args, **kwargs) - object.__setattr__(obj, "__pl_inside_init", old_inside_init) - - return wrapper - - -def _wrap_attr_method(method: Callable, tag: _WrapAttrTag) -> Callable: - """Wraps the ``__setattr__`` or ``__delattr__`` method of classes (currently :class:`~torch.utils.data.DataLoader` and - :class:`~torch.utils.data.BatchSampler`) in order to enable re-instantiation of custom subclasses.""" - - @functools.wraps(method) - def wrapper(obj: Any, *args: Any): - # First, let's find out if we're the first in inheritance chain calling the patched method. - name, *_ = args - prev_call_name, prev_call_method = getattr(obj, "__pl_current_call", (None, "method")) - first_call = not (prev_call_name == name and prev_call_method == tag) - - # Then mark the current called method - object.__setattr__(obj, "__pl_current_call", (name, tag)) - - # call original method - method(obj, *args) - if first_call and not getattr(obj, "__pl_inside_init", True): - # and save the value it was called with to the internal list, - # if we're outside of __init__ and the original call did not fail and we're the first call - attrs_record = getattr(obj, "__pl_attrs_record", list()) - attrs_record.append((args, tag)) - object.__setattr__(obj, "__pl_attrs_record", attrs_record) - object.__setattr__(obj, "__pl_current_call", (prev_call_name, prev_call_method)) - - return wrapper - - -@contextmanager -def _replace_dunder_methods(base_cls: Type, store_explicit_arg: Optional[str] = None) -> Generator[None, None, None]: - """This context manager is used to add support for re-instantiation of custom (subclasses) of `base_cls`. - - It patches the ``__init__``, ``__setattr__`` and ``__delattr__`` methods. - """ - classes = get_all_subclasses(base_cls) | {base_cls} - for cls in classes: - # Check that __init__ belongs to the class - # https://stackoverflow.com/a/5253424 - if "__init__" in cls.__dict__: - cls.__old__init__ = cls.__init__ - cls.__init__ = _wrap_init_method(cls.__init__, store_explicit_arg) - - # we want at least one setattr/delattr in the chain to be patched and it can happen, that none of the subclasses - # implement `__setattr__`/`__delattr__`. Therefore, we are always patching the `base_cls` - for patch_fn_name, tag in (("__setattr__", _WrapAttrTag.SET), ("__delattr__", _WrapAttrTag.DEL)): - if patch_fn_name in cls.__dict__ or cls is base_cls: - saved_name = f"__old{patch_fn_name}" - setattr(cls, saved_name, getattr(cls, patch_fn_name)) - setattr(cls, patch_fn_name, _wrap_attr_method(getattr(cls, patch_fn_name), tag)) - yield - for cls in classes: - for patched_name in ("__setattr__", "__delattr__", "__init__"): - # Check that __old__{init,setattr,delattr} belongs to the class - # https://stackoverflow.com/a/5253424 - if f"__old{patched_name}" in cls.__dict__: - setattr(cls, patched_name, getattr(cls, f"__old{patched_name}")) - delattr(cls, f"__old{patched_name}") - - def _wrap_with_capture_dataset(dataset: Dataset) -> Dataset: if isinstance(dataset, IterableDataset): # wrap the `IterableDataset` into a `CaptureIterableDataset` to record sampler states. @@ -627,3 +432,19 @@ def _is_dataloader_shuffled(dataloader: object) -> bool: if isinstance(sampler, SequentialSampler): return False return isinstance(sampler, RandomSampler) + + +def has_iterable_dataset(*args: Any, **kwargs: Any) -> Any: + rank_zero_deprecation( + "`pytorch_lightning.utilities.data.has_iterable_dataset` has been deprecated in v1.8.0 and will be" + " removed in v1.10.0. Please use `lightning_lite.utilities.data.has_iterable_dataset` instead." + ) + return new_has_iterable_dataset(*args, **kwargs) + + +def has_len(*args: Any, **kwargs: Any) -> Any: + rank_zero_deprecation( + "`pytorch_lightning.utilities.data.has_len` has been deprecated in v1.8.0 and will be" + " removed in v1.10.0. Please use `lightning_lite.utilities.data.has_len` instead." + ) + return new_has_len(*args, **kwargs) diff --git a/src/pytorch_lightning/utilities/deepspeed.py b/src/pytorch_lightning/utilities/deepspeed.py index cfa4e6a2f4..3d5b77e774 100644 --- a/src/pytorch_lightning/utilities/deepspeed.py +++ b/src/pytorch_lightning/utilities/deepspeed.py @@ -19,8 +19,8 @@ import os import torch +from lightning_lite.utilities.types import _PATH from pytorch_lightning.strategies.deepspeed import _DEEPSPEED_AVAILABLE -from pytorch_lightning.utilities.types import _PATH if _DEEPSPEED_AVAILABLE: from deepspeed.utils.zero_to_fp32 import ( diff --git a/src/pytorch_lightning/utilities/device_parser.py b/src/pytorch_lightning/utilities/device_parser.py index 32f370b5b2..b1337c2554 100644 --- a/src/pytorch_lightning/utilities/device_parser.py +++ b/src/pytorch_lightning/utilities/device_parser.py @@ -11,291 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import multiprocessing -from typing import Any, List, MutableSequence, Optional, Tuple, Union +from typing import Any, List, Optional, Union -import torch -import torch.cuda - -from pytorch_lightning.plugins.environments import TorchElasticEnvironment -from pytorch_lightning.strategies.launchers.multiprocessing import _is_forking_disabled +from lightning_lite.utilities.device_parser import determine_root_gpu_device as new_determine_root_gpu_device +from lightning_lite.utilities.device_parser import is_cuda_available as new_is_cuda_available +from lightning_lite.utilities.device_parser import num_cuda_devices as new_num_cuda_devices +from lightning_lite.utilities.device_parser import parse_cpu_cores as new_parse_cpu_cores +from lightning_lite.utilities.device_parser import parse_gpu_ids as new_parse_gpu_ids +from lightning_lite.utilities.device_parser import parse_tpu_cores as new_parse_tpu_cores from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.types import _DEVICE - - -def determine_root_gpu_device(gpus: List[_DEVICE]) -> Optional[_DEVICE]: - """ - Args: - gpus: non-empty list of ints representing which gpus to use - - Returns: - designated root GPU device id - - Raises: - TypeError: - If ``gpus`` is not a list - AssertionError: - If GPU list is empty - """ - if gpus is None: - return None - - if not isinstance(gpus, list): - raise TypeError("gpus should be a list") - - assert len(gpus) > 0, "gpus should be a non empty list" - - # set root gpu - root_gpu = gpus[0] - - return root_gpu - - -def parse_gpu_ids( - gpus: Optional[Union[int, str, List[int]]], - include_cuda: bool = False, - include_mps: bool = False, -) -> Optional[List[int]]: - """ - Parses the GPU ids given in the format as accepted by the - :class:`~pytorch_lightning.trainer.Trainer`. - - Args: - gpus: An int -1 or string '-1' indicate that all available GPUs should be used. - A list of unique ints or a string containing list of comma separated unique integers - indicates specific GPUs to use. - An int 0 means that no GPUs should be used. - Any int N > 0 indicates that GPUs [0..N) should be used. - include_cuda: A boolean indicating whether to include cuda devices for gpu parsing. - include_mps: A boolean indicating whether to include mps devices for gpu parsing. - - Returns: - a list of gpus to be used or ``None`` if no GPUs were requested - - Raises: - MisconfigurationException: - If no GPUs are available but the value of gpus variable indicates request for GPUs - - .. note:: - ``include_cuda`` and ``include_mps`` default to ``False`` so that you only - have to specify which device type to use and not disabling all the others. - """ - # Check that gpus param is None, Int, String or Sequence of Ints - _check_data_type(gpus) - - # Handle the case when no gpus are requested - if gpus is None or (isinstance(gpus, int) and gpus == 0) or str(gpus).strip() in ("0", "[]"): - return None - - # We know user requested GPUs therefore if some of the - # requested GPUs are not available an exception is thrown. - gpus = _normalize_parse_gpu_string_input(gpus) - gpus = _normalize_parse_gpu_input_to_list(gpus, include_cuda=include_cuda, include_mps=include_mps) - if not gpus: - raise MisconfigurationException("GPUs requested but none are available.") - - if ( - TorchElasticEnvironment.detect() - and len(gpus) != 1 - and len(_get_all_available_gpus(include_cuda=include_cuda, include_mps=include_mps)) == 1 - ): - # omit sanity check on torchelastic as by default shows one visible GPU per process - return gpus - - # Check that gpus are unique. Duplicate gpus are not supported by the backend. - _check_unique(gpus) - - return _sanitize_gpu_ids(gpus, include_cuda=include_cuda, include_mps=include_mps) - - -def parse_tpu_cores(tpu_cores: Optional[Union[int, str, List[int]]]) -> Optional[Union[int, List[int]]]: - """ - Parses the tpu_cores given in the format as accepted by the - :class:`~pytorch_lightning.trainer.Trainer`. - - Args: - tpu_cores: An int 1 or string '1' indicate that 1 core with multi-processing should be used - An int 8 or string '8' indicate that all 8 cores with multi-processing should be used - A list of int or a string containing list of comma separated integer - indicates specific TPU core to use. - - Returns: - a list of tpu_cores to be used or ``None`` if no TPU cores were requested - - Raises: - MisconfigurationException: - If TPU cores aren't 1, 8 or [<1-8>] - """ - _check_data_type(tpu_cores) - - if isinstance(tpu_cores, str): - tpu_cores = _parse_tpu_cores_str(tpu_cores.strip()) - - if not _tpu_cores_valid(tpu_cores): - raise MisconfigurationException("`tpu_cores` can only be 1, 8 or [<1-8>]") - - return tpu_cores - - -def parse_cpu_cores(cpu_cores: Union[int, str, List[int]]) -> int: - """Parses the cpu_cores given in the format as accepted by the ``devices`` argument in the - :class:`~pytorch_lightning.trainer.Trainer`. - - Args: - cpu_cores: An int > 0. - - Returns: - an int representing the number of processes - - Raises: - MisconfigurationException: - If cpu_cores is not an int > 0 - """ - if isinstance(cpu_cores, str) and cpu_cores.strip().isdigit(): - cpu_cores = int(cpu_cores) - - if not isinstance(cpu_cores, int) or cpu_cores <= 0: - raise MisconfigurationException("`devices` selected with `CPUAccelerator` should be an int > 0.") - - return cpu_cores - - -def _normalize_parse_gpu_string_input(s: Union[int, str, List[int]]) -> Union[int, List[int]]: - if not isinstance(s, str): - return s - if s == "-1": - return -1 - if "," in s: - return [int(x.strip()) for x in s.split(",") if len(x) > 0] - return int(s.strip()) - - -def _sanitize_gpu_ids(gpus: List[int], include_cuda: bool = False, include_mps: bool = False) -> List[int]: - """Checks that each of the GPUs in the list is actually available. Raises a MisconfigurationException if any of - the GPUs is not available. - - Args: - gpus: list of ints corresponding to GPU indices - - Returns: - unmodified gpus variable - - Raises: - MisconfigurationException: - If machine has fewer available GPUs than requested. - """ - if sum((include_cuda, include_mps)) == 0: - raise ValueError("At least one gpu type should be specified!") - all_available_gpus = _get_all_available_gpus(include_cuda=include_cuda, include_mps=include_mps) - for gpu in gpus: - if gpu not in all_available_gpus: - raise MisconfigurationException( - f"You requested gpu: {gpus}\n But your machine only has: {all_available_gpus}" - ) - return gpus - - -def _normalize_parse_gpu_input_to_list( - gpus: Union[int, List[int], Tuple[int, ...]], include_cuda: bool, include_mps: bool -) -> Optional[List[int]]: - assert gpus is not None - if isinstance(gpus, (MutableSequence, tuple)): - return list(gpus) - - # must be an int - if not gpus: # gpus==0 - return None - if gpus == -1: - return _get_all_available_gpus(include_cuda=include_cuda, include_mps=include_mps) - - return list(range(gpus)) - - -def _get_all_available_gpus(include_cuda: bool = False, include_mps: bool = False) -> List[int]: - """ - Returns: - a list of all available gpus - """ - cuda_gpus = _get_all_available_cuda_gpus() if include_cuda else [] - mps_gpus = _get_all_available_mps_gpus() if include_mps else [] - return cuda_gpus + mps_gpus - - -def _get_all_available_mps_gpus() -> List[int]: - """ - Returns: - a list of all available MPS gpus - """ - # lazy import to avoid circular dependencies - from pytorch_lightning.accelerators.mps import _MPS_AVAILABLE - - return [0] if _MPS_AVAILABLE else [] - - -def _get_all_available_cuda_gpus() -> List[int]: - """ - Returns: - a list of all available CUDA gpus - """ - return list(range(num_cuda_devices())) - - -def _check_unique(device_ids: List[int]) -> None: - """Checks that the device_ids are unique. - - Args: - device_ids: list of ints corresponding to gpus indices - - Raises: - MisconfigurationException: - If ``device_ids`` of GPUs aren't unique - """ - if len(device_ids) != len(set(device_ids)): - raise MisconfigurationException("Device ID's (GPU) must be unique.") - - -def _check_data_type(device_ids: Any) -> None: - """Checks that the device_ids argument is one of None, int, string, or sequence of integers. - - Args: - device_ids: gpus/tpu_cores parameter as passed to the Trainer - - Raises: - MisconfigurationException: - If ``device_ids`` of GPU/TPUs aren't ``int``, ``str``, sequence of ``int`` or ``None`` - """ - msg = "Device IDs (GPU/TPU) must be an int, a string, a sequence of ints or None, but you passed" - - if device_ids is None: - return - elif isinstance(device_ids, (MutableSequence, tuple)): - for id_ in device_ids: - if type(id_) is not int: - raise MisconfigurationException(f"{msg} a sequence of {type(id_).__name__}.") - elif type(device_ids) not in (int, str): - raise MisconfigurationException(f"{msg} {type(device_ids).__name__}.") - - -def _tpu_cores_valid(tpu_cores: Any) -> bool: - # allow 1 or 8 cores - if tpu_cores in (1, 8, None): - return True - - # allow picking 1 of 8 indexes - if isinstance(tpu_cores, (list, tuple, set)): - has_1_tpu_idx = len(tpu_cores) == 1 - is_valid_tpu_idx = 1 <= list(tpu_cores)[0] <= 8 - - is_valid_tpu_core_choice = has_1_tpu_idx and is_valid_tpu_idx - return is_valid_tpu_core_choice - - return False - - -def _parse_tpu_cores_str(tpu_cores: str) -> Union[int, List[int]]: - if tpu_cores in ("1", "8"): - return int(tpu_cores) - return [int(x.strip()) for x in tpu_cores.split(",") if len(x) > 0] +from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation def parse_hpus(devices: Optional[Union[int, str, List[int]]]) -> Optional[int]: @@ -319,25 +44,49 @@ def parse_hpus(devices: Optional[Union[int, str, List[int]]]) -> Optional[int]: return int(devices) if isinstance(devices, str) else devices -def num_cuda_devices() -> int: - """Returns the number of GPUs available. - - Unlike :func:`torch.cuda.device_count`, this function will do its best not to create a CUDA context for fork - support, if the platform allows it. - """ - if "fork" not in torch.multiprocessing.get_all_start_methods() or _is_forking_disabled(): - return torch.cuda.device_count() - with multiprocessing.get_context("fork").Pool(1) as pool: - return pool.apply(torch.cuda.device_count) +def determine_root_gpu_device(*args: Any, **kwargs: Any) -> Any: + rank_zero_deprecation( + "`pytorch_lightning.utilities.device_parser.determine_root_gpu_device` has been deprecated in v1.8.0 and will" + " be removed in v1.10.0. Please use `lightning_lite.utilities.device_parser.determine_root_gpu_device` instead." + ) + return new_determine_root_gpu_device(*args, **kwargs) def is_cuda_available() -> bool: - """Returns a bool indicating if CUDA is currently available. + rank_zero_deprecation( + "`pytorch_lightning.utilities.device_parser.is_cuda_available` has been deprecated in v1.8.0 and will" + " be removed in v1.10.0. Please use `lightning_lite.utilities.device_parser.is_cuda_available` instead." + ) + return new_is_cuda_available() - Unlike :func:`torch.cuda.is_available`, this function will do its best not to create a CUDA context for fork - support, if the platform allows it. - """ - if "fork" not in torch.multiprocessing.get_all_start_methods() or _is_forking_disabled(): - return torch.cuda.is_available() - with multiprocessing.get_context("fork").Pool(1) as pool: - return pool.apply(torch.cuda.is_available) + +def num_cuda_devices() -> int: + rank_zero_deprecation( + "`pytorch_lightning.utilities.device_parser.num_cuda_devices` has been deprecated in v1.8.0 and will" + " be removed in v1.10.0. Please use `lightning_lite.utilities.device_parser.num_cuda_devices` instead." + ) + return new_num_cuda_devices() + + +def parse_cpu_cores(*args: Any, **kwargs: Any) -> Any: + rank_zero_deprecation( + "`pytorch_lightning.utilities.device_parser.parse_cpu_cores` has been deprecated in v1.8.0 and will" + " be removed in v1.10.0. Please use `lightning_lite.utilities.device_parser.parse_cpu_cores` instead." + ) + return new_parse_cpu_cores(*args, **kwargs) + + +def parse_gpu_ids(*args: Any, **kwargs: Any) -> Any: + rank_zero_deprecation( + "`pytorch_lightning.utilities.device_parser.parse_gpu_ids` has been deprecated in v1.8.0 and will" + " be removed in v1.10.0. Please use `lightning_lite.utilities.device_parser.parse_gpu_ids` instead." + ) + return new_parse_gpu_ids(*args, **kwargs) + + +def parse_tpu_cores(*args: Any, **kwargs: Any) -> Any: + rank_zero_deprecation( + "`pytorch_lightning.utilities.device_parser.parse_tpu_cores` has been deprecated in v1.8.0 and will" + " be removed in v1.10.0. Please use `lightning_lite.utilities.device_parser.parse_tpu_cores` instead." + ) + return new_parse_tpu_cores(*args, **kwargs) diff --git a/src/pytorch_lightning/utilities/distributed.py b/src/pytorch_lightning/utilities/distributed.py index 7b33cb38b6..6f01a1a5b4 100644 --- a/src/pytorch_lightning/utilities/distributed.py +++ b/src/pytorch_lightning/utilities/distributed.py @@ -12,211 +12,23 @@ # limitations under the License. """Utilities that can be used with distributed training.""" -import logging -import os -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, Optional import torch -import torch.nn.functional as F -from torch import Tensor from torch.nn.parallel.distributed import DistributedDataParallel -import pytorch_lightning as pl -from pytorch_lightning.utilities.imports import _HPU_AVAILABLE, _TPU_AVAILABLE -from pytorch_lightning.utilities.rank_zero import rank_zero_only # noqa: F401 +from lightning_lite.utilities.distributed import all_gather_ddp_if_available as new_all_gather_ddp_if_available +from lightning_lite.utilities.distributed import distributed_available as new_distributed_available +from lightning_lite.utilities.distributed import gather_all_tensors as new_gather_all_tensors +from lightning_lite.utilities.distributed import ( + get_default_process_group_backend_for_device as new_get_default_process_group_backend_for_device, +) +from lightning_lite.utilities.distributed import init_dist_connection as new_init_dist_connection +from lightning_lite.utilities.distributed import sync_ddp as new_sync_ddp +from lightning_lite.utilities.distributed import sync_ddp_if_available as new_sync_ddp_if_available +from lightning_lite.utilities.distributed import tpu_distributed as new_tpu_distributed from pytorch_lightning.utilities.rank_zero import rank_zero_debug, rank_zero_deprecation, rank_zero_info -if _TPU_AVAILABLE: - import torch_xla.core.xla_model as xm - -if torch.distributed.is_available(): - from torch.distributed import group, ReduceOp - -else: - - class ReduceOp: # type: ignore # (see https://github.com/python/mypy/issues/1153) - SUM = None - - class group: # type: ignore - WORLD = None - - -log = logging.getLogger(__name__) - - -def gather_all_tensors(result: Tensor, group: Optional[Any] = None) -> List[Tensor]: - """Function to gather all tensors from several ddp processes onto a list that is broadcasted to all processes. - - Works on tensors that have the same number of dimensions, but where each dimension may differ. In this case - tensors are padded, gathered and then trimmed to secure equal workload for all processes. - - Args: - result: the value to sync - group: the process group to gather results from. Defaults to all processes (world) - - Return: - gathered_result: list with size equal to the process group where - gathered_result[i] corresponds to result tensor from process i - """ - if group is None: - group = torch.distributed.group.WORLD - - # convert tensors to contiguous format - result = result.contiguous() - - world_size = torch.distributed.get_world_size(group) - torch.distributed.barrier(group=group) - - # if the tensor is scalar, things are easy - if result.ndim == 0: - return _simple_gather_all_tensors(result, group, world_size) - - # 1. Gather sizes of all tensors - local_size = torch.tensor(result.shape, device=result.device) - local_sizes = [torch.zeros_like(local_size) for _ in range(world_size)] - torch.distributed.all_gather(local_sizes, local_size, group=group) - max_size = torch.stack(local_sizes).max(dim=0).values - all_sizes_equal = all(all(ls == max_size) for ls in local_sizes) - - # 2. If shapes are all the same, then do a simple gather: - if all_sizes_equal: - return _simple_gather_all_tensors(result, group, world_size) - - # 3. If not, we need to pad each local tensor to maximum size, gather and then truncate - pad_dims = [] - pad_by = (max_size - local_size).detach().cpu() - for val in reversed(pad_by): - pad_dims.append(0) - pad_dims.append(val.item()) - result_padded = F.pad(result, pad_dims) - gathered_result = [torch.zeros_like(result_padded) for _ in range(world_size)] - torch.distributed.all_gather(gathered_result, result_padded, group) - for idx, item_size in enumerate(local_sizes): - slice_param = [slice(dim_size) for dim_size in item_size] - gathered_result[idx] = gathered_result[idx][slice_param] - return gathered_result - - -def _simple_gather_all_tensors(result: Tensor, group: Any, world_size: int) -> List[Tensor]: - gathered_result = [torch.zeros_like(result) for _ in range(world_size)] - torch.distributed.all_gather(gathered_result, result, group) - return gathered_result - - -def distributed_available() -> bool: - return torch.distributed.is_available() and torch.distributed.is_initialized() or tpu_distributed() - - -def sync_ddp_if_available( - result: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None -) -> Tensor: - """Function to reduce a tensor across worker processes during distributed training. - - Args: - result: the value to sync and reduce (typically tensor or number) - group: the process group to gather results from. Defaults to all processes (world) - reduce_op: the reduction operation. Defaults to sum. - Can also be a string of 'avg', 'mean' to calculate the mean during reduction. - - Return: - reduced value - """ - if distributed_available(): - return sync_ddp(result, group=group, reduce_op=reduce_op) - return result - - -def sync_ddp(result: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None) -> Tensor: - """Function to reduce the tensors from several ddp processes to one main process. - - Args: - result: the value to sync and reduce (typically tensor or number) - group: the process group to gather results from. Defaults to all processes (world) - reduce_op: the reduction operation. Defaults to sum. - Can also be a string of 'avg', 'mean' to calculate the mean during reduction. - - Return: - reduced value - """ - divide_by_world_size = False - - if group is None: - group = torch.distributed.group.WORLD - - op: Optional[ReduceOp] - if isinstance(reduce_op, str): - if reduce_op.lower() in ("avg", "mean"): - op = ReduceOp.SUM - divide_by_world_size = True - else: - op = getattr(ReduceOp, reduce_op.upper()) - else: - op = reduce_op - - # WA for HPU. HPU doesn't support Long types, forcefully set it to float - if _HPU_AVAILABLE: - is_hpu_backend = os.environ.get("HCCL_DISTRIBUTED_BACKEND") == "1" - if is_hpu_backend: - if (result.type() == "torch.LongTensor") or (result.type() == "torch.hpu.LongTensor"): - rank_zero_info("Long tensor unsupported on HPU, casting to float") - result = result.float() - - # sync all processes before reduction - torch.distributed.barrier(group=group) - torch.distributed.all_reduce(result, op=op, group=group, async_op=False) - - if divide_by_world_size: - result = result / torch.distributed.get_world_size(group) - - return result - - -class AllGatherGrad(torch.autograd.Function): - @staticmethod - def forward( # type: ignore[override] - ctx: Any, - tensor: Tensor, - group: Optional["torch.distributed.ProcessGroup"] = group.WORLD, - ) -> Tensor: - ctx.group = group - - gathered_tensor = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())] - - torch.distributed.all_gather(gathered_tensor, tensor, group=group) - gathered_tensor = torch.stack(gathered_tensor, dim=0) - - return gathered_tensor - - @staticmethod - def backward(ctx: Any, *grad_output: Tensor) -> Tuple[Tensor, None]: - grad_output = torch.cat(grad_output) - - torch.distributed.all_reduce(grad_output, op=torch.distributed.ReduceOp.SUM, async_op=False, group=ctx.group) - - return grad_output[torch.distributed.get_rank()], None - - -def all_gather_ddp_if_available( - tensor: Tensor, group: Optional["torch.distributed.ProcessGroup"] = None, sync_grads: bool = False -) -> Tensor: - """Function to gather a tensor from several distributed processes. - - Args: - tensor: tensor of shape (batch, ...) - group: the process group to gather results from. Defaults to all processes (world) - sync_grads: flag that allows users to synchronize gradients for all_gather op - - Return: - A tensor of shape (world_size, batch, ...) - """ - group = group if group is not None else torch.distributed.group.WORLD - if distributed_available(): - if sync_grads: - return AllGatherGrad.apply(tensor, group) - with torch.no_grad(): - return AllGatherGrad.apply(tensor, group) - return tensor - def register_ddp_comm_hook( model: DistributedDataParallel, @@ -319,67 +131,6 @@ def register_ddp_comm_hook( model.register_comm_hook(state=ddp_comm_state, hook=ddp_comm_hook) # type: ignore[operator] -def tpu_distributed() -> bool: - return _TPU_AVAILABLE and xm.xrt_world_size() > 1 - - -def get_default_process_group_backend_for_device(device: torch.device) -> str: - return "nccl" if device.type == "cuda" else "gloo" - - -def _get_process_group_backend_from_env() -> Optional[str]: - torch_backend = os.getenv("PL_TORCH_DISTRIBUTED_BACKEND") - if torch_backend is not None: - rank_zero_deprecation( - "Environment variable `PL_TORCH_DISTRIBUTED_BACKEND`" - " was deprecated in v1.6 and will be removed in v1.8." - " Specify `process_group_backend` directly on the strategy constructor." - ) - return torch_backend - - -def init_dist_connection( - cluster_environment: "pl.plugins.environments.ClusterEnvironment", - torch_distributed_backend: str, - global_rank: Optional[int] = None, - world_size: Optional[int] = None, - **kwargs: Any, -) -> None: - """Utility function to initialize distributed connection by setting env variables and initializing the - distributed process group. - - Args: - cluster_environment: ``ClusterEnvironment`` instance - torch_distributed_backend: backend to use (includes `nccl` and `gloo`) - global_rank: rank of the current process - world_size: number of processes in the group - kwargs: kwargs for ``init_process_group`` - - Raises: - RuntimeError: - If ``torch.distributed`` is not available - """ - if not torch.distributed.is_available(): - raise RuntimeError("torch.distributed is not available. Cannot initialize distributed process group") - if torch.distributed.is_initialized(): - log.debug("torch.distributed is already initialized. Exiting early") - return - global_rank = global_rank if global_rank is not None else cluster_environment.global_rank() - world_size = world_size if world_size is not None else cluster_environment.world_size() - os.environ["MASTER_ADDR"] = cluster_environment.main_address - os.environ["MASTER_PORT"] = str(cluster_environment.main_port) - log.info(f"Initializing distributed: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}") - torch.distributed.init_process_group(torch_distributed_backend, rank=global_rank, world_size=world_size, **kwargs) - - # on rank=0 let everyone know training is starting - rank_zero_info( - f"{'-' * 100}\n" - f"distributed_backend={torch_distributed_backend}\n" - f"All distributed processes registered. Starting with {world_size} processes\n" - f"{'-' * 100}\n" - ) - - def _broadcast_object_list(obj: Any, rank: int) -> Any: objects = [obj if torch.distributed.get_rank() == rank else None] torch.distributed.broadcast_object_list(objects, src=rank) @@ -397,6 +148,71 @@ def _collect_states_on_rank_zero(state: Dict[str, Any]) -> Dict[int, Any]: states: On global rank 0, a dictionary where the primary keys are the process rank and the values their associated states. Otherwise, returns None. """ - if not distributed_available(): + if not new_distributed_available(): return {0: state} return {rank: _broadcast_object_list(state, rank) for rank in range(torch.distributed.get_world_size())} + + +def all_gather_ddp_if_available(*args: Any, **kwargs: Any) -> Any: + rank_zero_deprecation( + "`pytorch_lightning.utilities.distributed.all_gather_ddp_if_available` has been deprecated in v1.8.0 and will" + " be removed in v1.10.0. Please use `lightning_lite.utilities.distributed.all_gather_ddp_if_available` instead." + ) + return new_all_gather_ddp_if_available(*args, **kwargs) + + +def distributed_available() -> Any: + rank_zero_deprecation( + "`pytorch_lightning.utilities.distributed.distributed_available` has been deprecated in v1.8.0 and will" + " be removed in v1.10.0. Please use `lightning_lite.utilities.distributed.distributed_available` instead." + ) + return new_distributed_available() + + +def gather_all_tensors(*args: Any, **kwargs: Any) -> Any: + rank_zero_deprecation( + "`pytorch_lightning.utilities.distributed.gather_all_tensors` has been deprecated in v1.8.0 and will" + " be removed in v1.10.0. Please use `lightning_lite.utilities.distributed.gather_all_tensors` instead." + ) + return new_gather_all_tensors(*args, **kwargs) + + +def get_default_process_group_backend_for_device(*args: Any, **kwargs: Any) -> Any: + rank_zero_deprecation( + "`pytorch_lightning.utilities.distributed.get_default_process_group_backend_for_device` has been deprecated" + " in v1.8.0 and will be removed in v1.10.0. Please use" + " `lightning_lite.utilities.distributed.get_default_process_group_backend_for_device` instead." + ) + return new_get_default_process_group_backend_for_device(*args, **kwargs) + + +def init_dist_connection(*args: Any, **kwargs: Any) -> Any: + rank_zero_deprecation( + "`pytorch_lightning.utilities.distributed.init_dist_connection` has been deprecated in v1.8.0 and will" + " be removed in v1.10.0. Please use `lightning_lite.utilities.distributed.init_dist_connection` instead." + ) + return new_init_dist_connection(*args, **kwargs) + + +def sync_ddp(*args: Any, **kwargs: Any) -> Any: + rank_zero_deprecation( + "`pytorch_lightning.utilities.distributed.sync_ddp` has been deprecated in v1.8.0 and will" + " be removed in v1.10.0. Please use `lightning_lite.utilities.distributed.sync_ddp` instead." + ) + return new_sync_ddp(*args, **kwargs) + + +def sync_ddp_if_available(*args: Any, **kwargs: Any) -> Any: + rank_zero_deprecation( + "`pytorch_lightning.utilities.distributed.sync_ddp_if_available` has been deprecated in v1.8.0 and will" + " be removed in v1.10.0. Please use `lightning_lite.utilities.distributed.sync_ddp_if_available` instead." + ) + return new_sync_ddp_if_available(*args, **kwargs) + + +def tpu_distributed() -> bool: + rank_zero_deprecation( + "`pytorch_lightning.utilities.distributed.tpu_distributed` has been deprecated in v1.8.0 and will" + " be removed in v1.10.0. Please use `lightning_lite.utilities.distributed.tpu_distributed` instead." + ) + return new_tpu_distributed() diff --git a/src/pytorch_lightning/utilities/enums.py b/src/pytorch_lightning/utilities/enums.py index 03d9b8782e..8a5fe0e35d 100644 --- a/src/pytorch_lightning/utilities/enums.py +++ b/src/pytorch_lightning/utilities/enums.py @@ -15,47 +15,10 @@ from __future__ import annotations import os -from typing import TYPE_CHECKING - -from lightning_utilities.core.enums import StrEnum +from lightning_lite.utilities.enums import AMPType, LightningEnum, PrecisionType # noqa: F401 from pytorch_lightning.utilities.exceptions import MisconfigurationException -if TYPE_CHECKING: - from enum import Enum - - # re-defined because `mypy` infers `StrEnum` as `Any` - class LightningEnum(StrEnum, Enum): - ... - -else: - LightningEnum = StrEnum - - -class AMPType(LightningEnum): - """Type of Automatic Mixed Precission used for training.""" - - APEX = "apex" - NATIVE = "native" - - -class PrecisionType(LightningEnum): - """Type of precision used.""" - - HALF = "16" - FLOAT = "32" - FULL = "64" - BFLOAT = "bf16" - MIXED = "mixed" - - @staticmethod - def supported_type(precision: str | int) -> bool: - return any(x == precision for x in PrecisionType) - - @staticmethod - def supported_types() -> list[str]: - return [x.value for x in PrecisionType] - class GradClipAlgorithmType(LightningEnum): """Define gradient_clip_algorithm types - training-tricks. @@ -85,47 +48,6 @@ class AutoRestartBatchKeys(LightningEnum): PL_RESTART_META = "__pl_restart_meta" -class _StrategyType(LightningEnum): - """Define type of training strategy.""" - - DP = "dp" - DDP = "ddp" - DDP_SPAWN = "ddp_spawn" - DDP_FORK = "ddp_fork" - TPU_SPAWN = "tpu_spawn" - DEEPSPEED = "deepspeed" - HOROVOD = "horovod" - DDP_SHARDED = "ddp_sharded" - DDP_SHARDED_SPAWN = "ddp_sharded_spawn" - DDP_FULLY_SHARDED = "ddp_fully_sharded" - BAGUA = "bagua" - HPU_PARALLEL = "hpu_parallel" - - @staticmethod - def interactive_compatible_types() -> list[_StrategyType]: - """Returns a list containing interactive compatible _StrategyTypes.""" - return [ - _StrategyType.DP, - _StrategyType.TPU_SPAWN, - _StrategyType.DDP_FORK, - ] - - def is_interactive_compatible(self) -> bool: - """Returns whether self is interactive compatible.""" - return self in _StrategyType.interactive_compatible_types() - - -class _AcceleratorType(LightningEnum): - """Define Accelerator type by its nature.""" - - CPU = "CPU" - CUDA = "CUDA" - IPU = "IPU" - TPU = "TPU" - HPU = "HPU" - MPS = "MPS" - - class _FaultTolerantMode(LightningEnum): DISABLED = "disabled" diff --git a/src/pytorch_lightning/utilities/exceptions.py b/src/pytorch_lightning/utilities/exceptions.py index ece4629819..7a3e20034e 100644 --- a/src/pytorch_lightning/utilities/exceptions.py +++ b/src/pytorch_lightning/utilities/exceptions.py @@ -12,9 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. - -class MisconfigurationException(Exception): - """Exception used to inform users of misuse with PyTorch Lightning.""" +from lightning_lite.utilities.exceptions import MisconfigurationException # noqa: F401 class DeadlockDetectedException(Exception): diff --git a/src/pytorch_lightning/utilities/fetching.py b/src/pytorch_lightning/utilities/fetching.py index ba44e2132a..5dd068af53 100644 --- a/src/pytorch_lightning/utilities/fetching.py +++ b/src/pytorch_lightning/utilities/fetching.py @@ -20,6 +20,7 @@ import torch from lightning_utilities.core.apply_func import apply_to_collection, apply_to_collections from torch.utils.data.dataloader import DataLoader +from lightning_lite.utilities.data import has_len from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator from pytorch_lightning.utilities.auto_restart import ( _add_capture_metadata_collate, @@ -29,7 +30,6 @@ from pytorch_lightning.utilities.auto_restart import ( MergedIteratorState, patch_dataloader_iterator, ) -from pytorch_lightning.utilities.data import has_len from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _fault_tolerant_training diff --git a/src/pytorch_lightning/utilities/meta.py b/src/pytorch_lightning/utilities/meta.py index b1359df852..6670dc7a63 100644 --- a/src/pytorch_lightning/utilities/meta.py +++ b/src/pytorch_lightning/utilities/meta.py @@ -18,7 +18,7 @@ from lightning_utilities.core.imports import module_available from torch import Tensor from torch.nn import Module, Parameter -from pytorch_lightning.utilities import rank_zero_deprecation +from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation def is_meta_init() -> bool: diff --git a/src/pytorch_lightning/utilities/optimizer.py b/src/pytorch_lightning/utilities/optimizer.py index b13baf2552..9b5fe9273f 100644 --- a/src/pytorch_lightning/utilities/optimizer.py +++ b/src/pytorch_lightning/utilities/optimizer.py @@ -12,23 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Iterable +from typing import Any -from lightning_utilities.core.apply_func import apply_to_collection -from torch import Tensor -from torch.optim import Optimizer - -from lightning_lite.utilities.apply_func import move_data_to_device -from pytorch_lightning.utilities.types import _DEVICE +from lightning_lite.utilities.optimizer import optimizer_to_device as new_optimizer_to_device +from lightning_lite.utilities.optimizer import optimizers_to_device as new_optimizers_to_device +from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation -def optimizers_to_device(optimizers: Iterable[Optimizer], device: _DEVICE) -> None: - """Moves optimizer states for a sequence of optimizers to the device.""" - for opt in optimizers: - optimizer_to_device(opt, device) +def optimizers_to_device(*args: Any, **kwargs: Any) -> None: + rank_zero_deprecation( + "`pytorch_lightning.utilities.optimizer.optimizers_to_device` has been deprecated in v1.8.0 and will be" + " removed in v1.10.0. Please use `lightning_lite.utilities.optimizer.optimizers_to_device` instead." + ) + return new_optimizers_to_device(*args, **kwargs) -def optimizer_to_device(optimizer: Optimizer, device: _DEVICE) -> None: - """Moves the state of a single optimizer to the device.""" - for p, v in optimizer.state.items(): - optimizer.state[p] = apply_to_collection(v, Tensor, move_data_to_device, device) +def optimizer_to_device(*args: Any, **kwargs: Any) -> None: + rank_zero_deprecation( + "`pytorch_lightning.utilities.optimizer.optimizer_to_device` has been deprecated in v1.8.0 and will be" + " removed in v1.10.0. Please use `lightning_lite.utilities.optimizer.optimizer_to_device` instead." + ) + return new_optimizer_to_device(*args, **kwargs) diff --git a/src/pytorch_lightning/utilities/rank_zero.py b/src/pytorch_lightning/utilities/rank_zero.py index 156c7c98c5..70550e43a4 100644 --- a/src/pytorch_lightning/utilities/rank_zero.py +++ b/src/pytorch_lightning/utilities/rank_zero.py @@ -11,48 +11,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Utilities that can be used for calling functions on a particular rank.""" import logging -import os -from typing import Optional -import lightning_utilities.core.rank_zero as rank_zero_module - -# note: we want to keep these indirections so the `rank_zero_only.rank` is set (on import) for PL users -from lightning_utilities.core.rank_zero import ( # noqa: F401 +# note: we want to keep these indirections so the `rank_zero_module.log` is set (on import) for PL users +# backwards compatibility +from lightning_lite.utilities.rank_zero import LightningDeprecationWarning # noqa: F401 +from lightning_lite.utilities.rank_zero import ( # noqa: F401 rank_zero_debug, rank_zero_deprecation, rank_zero_info, + rank_zero_module, rank_zero_only, rank_zero_warn, ) -import pytorch_lightning as pl - rank_zero_module.log = logging.getLogger(__name__) - - -def _get_rank(trainer: Optional["pl.Trainer"] = None) -> Optional[int]: - if trainer is not None: - return trainer.global_rank - # SLURM_PROCID can be set even if SLURM is not managing the multiprocessing, - # therefore LOCAL_RANK needs to be checked first - rank_keys = ("RANK", "LOCAL_RANK", "SLURM_PROCID", "JSM_NAMESPACE_RANK") - for key in rank_keys: - rank = os.environ.get(key) - if rank is not None: - return int(rank) - # None to differentiate whether an environment variable was set at all - return None - - -# add the attribute to the function but don't overwrite in case Trainer has already set it -rank_zero_only.rank = getattr(rank_zero_only, "rank", _get_rank() or 0) - - -class LightningDeprecationWarning(DeprecationWarning): - """Deprecation warnings raised by PyTorch Lightning.""" - - -rank_zero_module.rank_zero_deprecation_category = LightningDeprecationWarning diff --git a/src/pytorch_lightning/utilities/seed.py b/src/pytorch_lightning/utilities/seed.py index 5c33214cf4..221ed9d114 100644 --- a/src/pytorch_lightning/utilities/seed.py +++ b/src/pytorch_lightning/utilities/seed.py @@ -12,135 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. """Utilities to help with reproducibility of models.""" - -import logging -import os -import random from contextlib import contextmanager -from random import getstate as python_get_rng_state -from random import setstate as python_set_rng_state -from typing import Any, Dict, Generator, Optional +from typing import Any, Generator -import numpy as np import torch -from lightning_utilities.core.rank_zero import rank_prefixed_message -from pytorch_lightning.utilities.rank_zero import _get_rank, rank_zero_only, rank_zero_warn - -log = logging.getLogger(__name__) - -max_seed_value = np.iinfo(np.uint32).max -min_seed_value = np.iinfo(np.uint32).min - - -def seed_everything(seed: Optional[int] = None, workers: bool = False) -> int: - """Function that sets seed for pseudo-random number generators in: pytorch, numpy, python.random In addition, - sets the following environment variables: - - - `PL_GLOBAL_SEED`: will be passed to spawned subprocesses (e.g. ddp_spawn backend). - - `PL_SEED_WORKERS`: (optional) is set to 1 if ``workers=True``. - - Args: - seed: the integer value seed for global random state in Lightning. - If `None`, will read seed from `PL_GLOBAL_SEED` env variable - or select it randomly. - workers: if set to ``True``, will properly configure all dataloaders passed to the - Trainer with a ``worker_init_fn``. If the user already provides such a function - for their dataloaders, setting this argument will have no influence. See also: - :func:`~pytorch_lightning.utilities.seed.pl_worker_init_function`. - """ - if seed is None: - env_seed = os.environ.get("PL_GLOBAL_SEED") - if env_seed is None: - seed = _select_seed_randomly(min_seed_value, max_seed_value) - rank_zero_warn(f"No seed found, seed set to {seed}") - else: - try: - seed = int(env_seed) - except ValueError: - seed = _select_seed_randomly(min_seed_value, max_seed_value) - rank_zero_warn(f"Invalid seed found: {repr(env_seed)}, seed set to {seed}") - elif not isinstance(seed, int): - seed = int(seed) - - if not (min_seed_value <= seed <= max_seed_value): - rank_zero_warn(f"{seed} is not in bounds, numpy accepts from {min_seed_value} to {max_seed_value}") - seed = _select_seed_randomly(min_seed_value, max_seed_value) - - log.info(rank_prefixed_message(f"Global seed set to {seed}", _get_rank())) - os.environ["PL_GLOBAL_SEED"] = str(seed) - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - os.environ["PL_SEED_WORKERS"] = f"{int(workers)}" - - return seed - - -def _select_seed_randomly(min_seed_value: int = min_seed_value, max_seed_value: int = max_seed_value) -> int: - return random.randint(min_seed_value, max_seed_value) - - -def reset_seed() -> None: - """Reset the seed to the value that :func:`pytorch_lightning.utilities.seed.seed_everything` previously set. - - If :func:`pytorch_lightning.utilities.seed.seed_everything` is unused, this function will do nothing. - """ - seed = os.environ.get("PL_GLOBAL_SEED", None) - if seed is None: - return - workers = os.environ.get("PL_SEED_WORKERS", "0") - seed_everything(int(seed), workers=bool(int(workers))) - - -def pl_worker_init_function(worker_id: int, rank: Optional[int] = None) -> None: # pragma: no cover - """The worker_init_fn that Lightning automatically adds to your dataloader if you previously set the seed with - ``seed_everything(seed, workers=True)``. - - See also the PyTorch documentation on - `randomness in DataLoaders `_. - """ - # implementation notes: https://github.com/pytorch/pytorch/issues/5059#issuecomment-817392562 - global_rank = rank if rank is not None else rank_zero_only.rank - process_seed = torch.initial_seed() - # back out the base seed so we can use all the bits - base_seed = process_seed - worker_id - log.debug( - f"Initializing random number generators of process {global_rank} worker {worker_id} with base seed {base_seed}" - ) - ss = np.random.SeedSequence([base_seed, worker_id, global_rank]) - # use 128 bits (4 x 32-bit words) - np.random.seed(ss.generate_state(4)) - # Spawn distinct SeedSequences for the PyTorch PRNG and the stdlib random module - torch_ss, stdlib_ss = ss.spawn(2) - torch.manual_seed(torch_ss.generate_state(1, dtype=np.uint64)[0]) - # use 128 bits expressed as an integer - stdlib_seed = (stdlib_ss.generate_state(2, dtype=np.uint64).astype(object) * [1 << 64, 1]).sum() - random.seed(stdlib_seed) - - -def _collect_rng_states() -> Dict[str, Any]: - """Collect the global random state of :mod:`torch`, :mod:`torch.cuda`, :mod:`numpy` and Python.""" - return { - "torch": torch.get_rng_state(), - "torch.cuda": torch.cuda.get_rng_state_all(), - "numpy": np.random.get_state(), - "python": python_get_rng_state(), - } - - -def _set_rng_states(rng_state_dict: Dict[str, Any]) -> None: - """Set the global random state of :mod:`torch`, :mod:`torch.cuda`, :mod:`numpy` and Python in the current - process.""" - torch.set_rng_state(rng_state_dict["torch"]) - # torch.cuda rng_state is only included since v1.8. - if "torch.cuda" in rng_state_dict: - torch.cuda.set_rng_state_all(rng_state_dict["torch.cuda"]) - np.random.set_state(rng_state_dict["numpy"]) - version, state, gauss = rng_state_dict["python"] - python_set_rng_state((version, tuple(state), gauss)) +from lightning_lite.utilities.seed import _collect_rng_states, _set_rng_states +from lightning_lite.utilities.seed import pl_worker_init_function as new_pl_worker_init_function +from lightning_lite.utilities.seed import reset_seed as new_reset_seed +from lightning_lite.utilities.seed import seed_everything as new_seed_everything +from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation @contextmanager @@ -161,3 +42,27 @@ def isolate_rng() -> Generator[None, None, None]: states = _collect_rng_states() yield _set_rng_states(states) + + +def seed_everything(*args: Any, **kwargs: Any) -> Any: + rank_zero_deprecation( + "`pytorch_lightning.utilities.seed.seed_everything` has been deprecated in v1.8.0 and will be" + " removed in v1.10.0. Please use `lightning_lite.utilities.seed.seed_everything` instead." + ) + return new_seed_everything(*args, **kwargs) + + +def reset_seed() -> None: + rank_zero_deprecation( + "`pytorch_lightning.utilities.seed.reset_seed` has been deprecated in v1.8.0 and will be" + " removed in v1.10.0. Please use `lightning_lite.utilities.seed.reset_seed` instead." + ) + return new_reset_seed() + + +def pl_worker_init_function(*args: Any, **kwargs: Any) -> None: + rank_zero_deprecation( + "`pytorch_lightning.utilities.seed.pl_worker_init_function` has been deprecated in v1.8.0 and will be" + " removed in v1.10.0. Please use `lightning_lite.utilities.seed.pl_worker_init_function` instead." + ) + return new_pl_worker_init_function(*args, **kwargs) diff --git a/src/pytorch_lightning/utilities/types.py b/src/pytorch_lightning/utilities/types.py index 39b5074809..06dea2eebb 100644 --- a/src/pytorch_lightning/utilities/types.py +++ b/src/pytorch_lightning/utilities/types.py @@ -19,16 +19,16 @@ Convention: from argparse import _ArgumentGroup, ArgumentParser from contextlib import contextmanager from dataclasses import dataclass -from pathlib import Path -from typing import Any, Callable, Dict, Generator, Iterator, List, Mapping, Optional, Sequence, Type, TypeVar, Union +from typing import Any, Dict, Generator, List, Mapping, Optional, Sequence, Type, Union import torch from torch import Tensor -from torch.optim import Optimizer from torch.utils.data import DataLoader from torchmetrics import Metric from typing_extensions import Protocol, runtime_checkable +from lightning_lite.utilities.types import _LRScheduler, ReduceLROnPlateau + if torch.distributed.is_available(): from torch._C._distributed_c10d import ProcessGroup else: @@ -41,8 +41,6 @@ STEP_OUTPUT = Union[Tensor, Dict[str, Any]] EPOCH_OUTPUT = List[STEP_OUTPUT] _EVALUATE_OUTPUT = List[Dict[str, float]] # 1 dict per DataLoader _PREDICT_OUTPUT = Union[List[Any], List[List[Any]]] -_PARAMETERS = Iterator[torch.nn.Parameter] -_PATH = Union[str, Path] TRAIN_DATALOADERS = Union[ DataLoader, Sequence[DataLoader], @@ -53,8 +51,6 @@ TRAIN_DATALOADERS = Union[ Dict[str, Sequence[DataLoader]], ] EVAL_DATALOADERS = Union[DataLoader, Sequence[DataLoader]] -_DEVICE = Union[torch.device, str, int] -_MAP_LOCATION_TYPE = Optional[Union[_DEVICE, Callable[[_DEVICE], _DEVICE], Dict[_DEVICE, _DEVICE]]] _ADD_ARGPARSE_RETURN = Union[_ArgumentGroup, ArgumentParser] @@ -94,60 +90,6 @@ class PredictStep(Protocol): ... -_DictKey = TypeVar("_DictKey") - - -@runtime_checkable -class _Stateful(Protocol[_DictKey]): - """This class is used to detect if an object is stateful using `isinstance(obj, _Stateful)`.""" - - def state_dict(self) -> Dict[_DictKey, Any]: - ... - - def load_state_dict(self, state_dict: Dict[_DictKey, Any]) -> None: - ... - - -# Inferred from `torch.optim.lr_scheduler.pyi` -# Missing attributes were added to improve typing -@runtime_checkable -class _LRScheduler(_Stateful[str], Protocol): - optimizer: Optimizer - base_lrs: List[float] - - def __init__(self, optimizer: Optimizer, *args: Any, **kwargs: Any) -> None: - ... - - def step(self, epoch: Optional[int] = None) -> None: - ... - - -# Inferred from `torch.optim.lr_scheduler.pyi` -# Missing attributes were added to improve typing -@runtime_checkable -class ReduceLROnPlateau(_Stateful[str], Protocol): - in_cooldown: bool - optimizer: Optimizer - - def __init__( - self, - optimizer: Optimizer, - mode: str = ..., - factor: float = ..., - patience: int = ..., - verbose: bool = ..., - threshold: float = ..., - threshold_mode: str = ..., - cooldown: int = ..., - min_lr: float = ..., - eps: float = ..., - ) -> None: - ... - - def step(self, metrics: Union[float, int, Tensor], epoch: Optional[int] = None) -> None: - ... - - # Inferred from `torch.nn.parallel.distributed.pyi` # Missing attributes were added to improve typing @runtime_checkable diff --git a/src/pytorch_lightning/utilities/upgrade_checkpoint.py b/src/pytorch_lightning/utilities/upgrade_checkpoint.py index 804714d82c..6f4dd5ca93 100644 --- a/src/pytorch_lightning/utilities/upgrade_checkpoint.py +++ b/src/pytorch_lightning/utilities/upgrade_checkpoint.py @@ -17,9 +17,9 @@ from shutil import copyfile import torch +from lightning_lite.utilities.types import _PATH from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint from pytorch_lightning.utilities.migration import pl_legacy_patch -from pytorch_lightning.utilities.types import _PATH KEYS_MAPPING = { "checkpoint_callback_best_model_score": (ModelCheckpoint, "best_model_score"), diff --git a/src/pytorch_lightning/utilities/warnings.py b/src/pytorch_lightning/utilities/warnings.py index ae608bdbcc..57b56ba068 100644 --- a/src/pytorch_lightning/utilities/warnings.py +++ b/src/pytorch_lightning/utilities/warnings.py @@ -12,13 +12,5 @@ # See the License for the specific language governing permissions and # limitations under the License. """Warning-related utilities.""" -import warnings - -from pytorch_lightning.utilities.rank_zero import LightningDeprecationWarning - -# enable our warnings -warnings.simplefilter("default", category=LightningDeprecationWarning) - - -class PossibleUserWarning(UserWarning): - """Warnings that could be false positives.""" +# backwards compatibility +from lightning_lite.utilities.warnings import PossibleUserWarning # noqa: F401 diff --git a/src/pytorch_lightning/utilities/xla_device.py b/src/pytorch_lightning/utilities/xla_device.py index 1d6347c6e6..a515058a63 100644 --- a/src/pytorch_lightning/utilities/xla_device.py +++ b/src/pytorch_lightning/utilities/xla_device.py @@ -18,7 +18,7 @@ from typing import Any, Callable from lightning_lite.utilities.xla_device import inner_f as new_inner_f from lightning_lite.utilities.xla_device import pl_multi_process as new_pl_multi_process from lightning_lite.utilities.xla_device import XLADeviceUtils as NewXLADeviceUtils -from pytorch_lightning.utilities import rank_zero_deprecation # TODO(lite): update to lightning_lite.utilities +from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation def inner_f(queue: Queue, func: Callable, *args: Any, **kwargs: Any) -> None: # pragma: no cover diff --git a/tests/tests_lite/conftest.py b/tests/tests_lite/conftest.py index fab4ff7e17..209d6869a1 100644 --- a/tests/tests_lite/conftest.py +++ b/tests/tests_lite/conftest.py @@ -1,7 +1,94 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import os from typing import List import pytest +import torch.distributed + + +@pytest.fixture(scope="function", autouse=True) +def preserve_global_rank_variable(): + """Ensures that the rank_zero_only.rank global variable gets reset in each test.""" + from lightning_lite.utilities.rank_zero import rank_zero_only + + rank = getattr(rank_zero_only, "rank", None) + yield + if rank is not None: + setattr(rank_zero_only, "rank", rank) + + +@pytest.fixture(scope="function", autouse=True) +def restore_env_variables(): + """Ensures that environment variables set during the test do not leak out.""" + env_backup = os.environ.copy() + yield + leaked_vars = os.environ.keys() - env_backup.keys() + # restore environment as it was before running the test + os.environ.clear() + os.environ.update(env_backup) + # these are currently known leakers - ideally these would not be allowed + # TODO(lite): this list can be trimmed, maybe PL's too after moving tests + allowlist = { + "CUDA_DEVICE_ORDER", + "LOCAL_RANK", + "NODE_RANK", + "WORLD_SIZE", + "MASTER_ADDR", + "MASTER_PORT", + "PL_GLOBAL_SEED", + "PL_SEED_WORKERS", + "HOROVOD_FUSION_THRESHOLD", + "RANK", # set by DeepSpeed + "POPLAR_ENGINE_OPTIONS", # set by IPUStrategy + # set by XLA + "TF2_BEHAVIOR", + "XRT_MESH_SERVICE_ADDRESS", + "XRT_TORCH_DIST_ROOT", + "XRT_MULTI_PROCESSING_DEVICE", + "XRT_SHARD_WORLD_SIZE", + "XRT_LOCAL_WORKER", + "XRT_HOST_WORLD_SIZE", + "XRT_SHARD_ORDINAL", + "XRT_SHARD_LOCAL_ORDINAL", + "TF_CPP_MIN_LOG_LEVEL", + } + leaked_vars.difference_update(allowlist) + assert not leaked_vars, f"test is leaking environment variable(s): {set(leaked_vars)}" + + +@pytest.fixture(scope="function", autouse=True) +def teardown_process_group(): + """Ensures that the distributed process group gets closed before the next test runs.""" + yield + if torch.distributed.is_available() and torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() + + +@pytest.fixture +def caplog(caplog): + """Workaround for https://github.com/pytest-dev/pytest/issues/3697. + + Setting ``filterwarnings`` with pytest breaks ``caplog`` when ``not logger.propagate``. + """ + import logging + + lightning_logger = logging.getLogger("lightning_lite") + propagate = lightning_logger.propagate + lightning_logger.propagate = True + yield caplog + lightning_logger.propagate = propagate def pytest_collection_modifyitems(items: List[pytest.Function], config: pytest.Config) -> None: diff --git a/tests/tests_lite/helpers/runif.py b/tests/tests_lite/helpers/runif.py index fcdca0f9a6..7dd9aaf729 100644 --- a/tests/tests_lite/helpers/runif.py +++ b/tests/tests_lite/helpers/runif.py @@ -20,8 +20,9 @@ import torch from packaging.version import Version from pkg_resources import get_distribution +from lightning_lite.utilities.imports import _FAIRSCALE_AVAILABLE, _PSUTIL_AVAILABLE, _TPU_AVAILABLE + -# TODO(lite): Add all RunIf conditions once the relevant utilities have moved to lite source dir class RunIf: """RunIf wrapper for simple marking specific cases, fully compatible with pytest.mark:: @@ -38,8 +39,11 @@ class RunIf: min_torch: Optional[str] = None, max_torch: Optional[str] = None, min_python: Optional[str] = None, + tpu: bool = False, skip_windows: bool = False, standalone: bool = False, + fairscale: bool = False, + psutil: bool = False, **kwargs, ): """ @@ -49,9 +53,12 @@ class RunIf: min_torch: Require that PyTorch is greater or equal than this version. max_torch: Require that PyTorch is less than this version. min_python: Require that Python is greater or equal than this version. + tpu: Require that TPU is available. skip_windows: Skip for Windows platform. standalone: Mark the test as standalone, our CI will run it in a separate process. This requires that the ``PL_RUN_STANDALONE_TESTS=1`` environment variable is set. + fairscale: Require that facebookresearch/fairscale is installed. + psutil: Require that psutil is installed. **kwargs: Any :class:`pytest.mark.skipif` keyword arguments. """ conditions = [] @@ -82,6 +89,12 @@ class RunIf: conditions.append(sys.platform == "win32") reasons.append("unimplemented on Windows") + if tpu: + conditions.append(not _TPU_AVAILABLE) + reasons.append("TPU") + # used in conftest.py::pytest_collection_modifyitems + kwargs["tpu"] = True + if standalone: env_flag = os.getenv("PL_RUN_STANDALONE_TESTS", "0") conditions.append(env_flag != "1") @@ -89,6 +102,18 @@ class RunIf: # used in conftest.py::pytest_collection_modifyitems kwargs["standalone"] = True + if fairscale: + if skip_windows: + raise ValueError( + "`skip_windows` is not necessary when `fairscale` is set as it does not support Windows." + ) + conditions.append(not _FAIRSCALE_AVAILABLE) + reasons.append("Fairscale") + + if psutil: + conditions.append(not _PSUTIL_AVAILABLE) + reasons.append("psutil") + reasons = [rs for cond, rs in zip(conditions, reasons) if cond] return pytest.mark.skipif( *args, condition=any(conditions), reason=f"Requires: [{' + '.join(reasons)}]", **kwargs diff --git a/tests/tests_lite/helpers/utils.py b/tests/tests_lite/helpers/utils.py new file mode 100644 index 0000000000..2a8294d4d8 --- /dev/null +++ b/tests/tests_lite/helpers/utils.py @@ -0,0 +1,31 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +import numpy as np + +from lightning_lite.utilities.seed import seed_everything + +# generate a list of random seeds for each test +RANDOM_PORTS = list(np.random.randint(12000, 19000, 1000)) + + +def reset_seed(seed=0): + seed_everything(seed) + + +def set_random_main_port(): + reset_seed() + port = RANDOM_PORTS.pop() + os.environ["MASTER_PORT"] = str(port) diff --git a/tests/tests_lite/utilities/test_data.py b/tests/tests_lite/utilities/test_data.py new file mode 100644 index 0000000000..8946ab5001 --- /dev/null +++ b/tests/tests_lite/utilities/test_data.py @@ -0,0 +1,509 @@ +import random + +import pytest +import torch +from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler + +from lightning_lite.utilities.data import ( + _dataloader_init_kwargs_resolve_sampler, + _get_dataloader_init_args_and_kwargs, + _replace_dunder_methods, + _replace_value_in_saved_args, + _update_dataloader, + _WrapAttrTag, + has_iterable_dataset, + has_len, +) +from lightning_lite.utilities.exceptions import MisconfigurationException + +# TODO(lite): provide boring classes in Lite +from pytorch_lightning.demos.boring_classes import RandomDataset, RandomIterableDataset + + +def test_has_iterable_dataset(): + assert has_iterable_dataset(DataLoader(RandomIterableDataset(1, 1))) + + assert not has_iterable_dataset(DataLoader(RandomDataset(1, 1))) + + class MockDatasetWithoutIterableDataset(RandomDataset): + def __iter__(self): + yield 1 + return self + + assert not has_iterable_dataset(DataLoader(MockDatasetWithoutIterableDataset(1, 1))) + + +def test_has_len(): + assert has_len(DataLoader(RandomDataset(1, 1))) + + with pytest.warns(UserWarning, match="`DataLoader` returned 0 length."): + assert has_len(DataLoader(RandomDataset(0, 0))) + + assert not has_len(DataLoader(RandomIterableDataset(1, 1))) + + +def test_replace_dunder_methods_multiple_loaders_without_init(): + """In case of a class, that inherits from a class that we are patching, but doesn't define its own `__init__` + method (the one we are wrapping), it can happen, that `hasattr(cls, "__old__init__")` is True because of parent + class, but it is impossible to delete, because that method is owned by parent class. Furthermore, the error + occured only sometimes because it depends on the order in which we are iterating over a set of classes we are + patching. + + This test simulates the behavior by generating sufficient number of dummy classes, which do not define `__init__` + and are children of `DataLoader`. We are testing that a) context manager `_replace_dunder_method` exits cleanly, and + b) the mechanism checking for presence of `__old__init__` works as expected. + """ + classes = [DataLoader] + for i in range(100): + classes.append(type(f"DataLoader_{i}", (random.choice(classes),), {})) + + before = {cls: cls.__init__ for cls in classes} + + with _replace_dunder_methods(DataLoader, "dataset"): + for cls in classes[1:]: # First one is `DataLoader` + assert "__old__init__" not in cls.__dict__ + assert hasattr(cls, "__old__init__") + + assert "__old__init__" in DataLoader.__dict__ + assert hasattr(DataLoader, "__old__init__") + + for cls in classes: + assert before[cls] == cls.__init__ + + +class MyBaseDataLoader(DataLoader): + pass + + +class DataLoaderSubclass1(DataLoader): + def __init__(self, attribute1, *args, **kwargs): + self.at1 = attribute1 + super().__init__(*args, **kwargs) + + +class DataLoaderSubclass2(DataLoaderSubclass1): + def __init__(self, attribute2, *args, **kwargs): + self.at2 = attribute2 + super().__init__(attribute2 + "-2", *args, **kwargs) + + +class MyDataLoader(MyBaseDataLoader): + def __init__(self, data: torch.Tensor, *args, **kwargs): + self.data = data + super().__init__(range(data.size(0)), *args, **kwargs) + + +test3_data = torch.randn((10, 20)) + + +class PoptorchDataLoader(DataLoader): + def __init__(self, options, *args, **kwargs): + super().__init__(*args, **kwargs) + self._options = options + + @property + def options(self): + return self._options + + +class IncompleteDataLoader(DataLoader): + def __init__(self, dataset, batch_size, **kwargs): + batch_size = max(batch_size - 5, 0) + super().__init__(dataset, batch_size=batch_size, **kwargs) + + +class WeirdDataLoader1(DataLoader): + def __init__(self, arg1, arg2, **kwargs): + self.arg1 = arg1 + super().__init__(arg2, **kwargs) + + +class WeirdDataLoader2(DataLoader): + def __init__(self, data_part1, data_part2, **kwargs): + data = list(data_part1) + list(data_part2) + super().__init__(data, **kwargs) + + +class NoneDataLoader(DataLoader): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + +class ChangingDataLoader(DataLoader): + def __init__(self, dataset, **kwargs): + super().__init__(list(dataset) + list(range(5, 10)), **kwargs) + + +@pytest.mark.parametrize( + ["cls", "args", "kwargs", "arg_names", "dataset", "checked_values"], + [ + pytest.param( + DataLoaderSubclass1, + ("attribute1",), + dict(dataset=range(4), batch_size=2), + ("attribute1",), + range(4), + dict(batch_size=2, at1="attribute1"), + id="test1", + ), + pytest.param( + DataLoaderSubclass2, + ("attribute2",), + dict(dataset=range(4), batch_size=2), + ("attribute2",), + range(4), + dict(batch_size=2, at1="attribute2-2", at2="attribute2"), + id="test2", + ), + pytest.param( + MyDataLoader, + (test3_data,), + dict(batch_size=2), + ("data",), + range(10), + dict(batch_size=2, data=test3_data), + id="test3", + ), + pytest.param(PoptorchDataLoader, (123, [1]), dict(), ("options",), [1], dict(options=123), id="test4"), + pytest.param( + IncompleteDataLoader, + (range(10),), + dict(batch_size=10), + ("dataset",), + range(10), + dict(batch_size=5), + id="test5", + ), + pytest.param( + WeirdDataLoader1, + (10, range(10)), + dict(batch_size=10), + ("arg1", "arg2"), + range(10), + dict(arg1=10, batch_size=10), + id="test6", + ), + pytest.param( + WeirdDataLoader2, + (range(10), range(10, 20)), + dict(batch_size=10), + ("data_part1", "data_part2"), + list(range(20)), + dict(batch_size=10), + id="test7", + ), + pytest.param(NoneDataLoader, (None,), dict(), (), None, dict(), id="test8"), + pytest.param(ChangingDataLoader, (range(5),), dict(), ("dataset",), list(range(10)), dict(), id="test9"), + ], +) +def test_replace_dunder_methods_dataloader(cls, args, kwargs, arg_names, dataset, checked_values): + with _replace_dunder_methods(DataLoader, "dataset"): + dataloader = cls(*args, **kwargs) + + assert dataloader.__pl_saved_args == args + assert dataloader.__pl_saved_kwargs == kwargs + assert dataloader.__pl_saved_arg_names == arg_names + assert dataloader.__pl_saved_default_kwargs == {} + assert dataloader.__dataset == dataset + + assert dataloader.dataset == dataset + + for key, value in checked_values.items(): + dataloader_value = getattr(dataloader, key) + if isinstance(dataloader_value, torch.Tensor): + assert dataloader_value is value + else: + assert dataloader_value == value + + dataloader = _update_dataloader(dataloader, dataloader.sampler) + + assert isinstance(dataloader, cls) + assert not hasattr(dataloader, "__pl_saved_kwargs") + assert not hasattr(dataloader, "__pl_saved_arg_names") + assert not hasattr(dataloader, "__pl_saved_args") + assert not hasattr(dataloader, "__pl_saved_default_kwargs") + assert not hasattr(dataloader, "__dataset") + + assert dataloader.dataset == dataset + + for key, value in checked_values.items(): + dataloader_value = getattr(dataloader, key) + if isinstance(dataloader_value, torch.Tensor): + assert dataloader_value is value + else: + assert dataloader_value == value + + +def test_replace_dunder_methods_extra_kwargs(): + class LoaderSubclass(DataLoader): + def __init__(self, dataset, *args, batch_size=10, **kwargs): + super().__init__(dataset, *args, batch_size=batch_size, **kwargs) + + with _replace_dunder_methods(DataLoader, "dataset"): + dataloader = LoaderSubclass(range(10)) + + assert dataloader.__pl_saved_args == (range(10),) + assert dataloader.__pl_saved_kwargs == {} + assert dataloader.__pl_saved_arg_names == ("dataset",) + assert dataloader.__pl_saved_default_kwargs == {"batch_size": 10} + assert dataloader.__dataset == range(10) + + +def test_replace_dunder_methods_attrs(): + """This test checks, that all the calls from setting and deleting attributes within `_replace_dunder_methods` + are correctly preserved even after reinstantiation. + + It also includes a custom `__setattr__` + """ + + class Loader(DataLoader): + def __setattr__(self, attr, val): + if attr == "custom_arg": + val = val + 2 + super().__setattr__(attr, val) + + with _replace_dunder_methods(DataLoader, "dataset"): + dataloader = Loader(range(10)) + dataloader.custom_arg = 5 + dataloader.my_arg = 10 + dataloader.another_arg = 100 + del dataloader.dataset + try: + del dataloader.abc_arg + except AttributeError: + pass + + assert dataloader.__pl_saved_args == (range(10),) + assert dataloader.__pl_saved_kwargs == {} + assert dataloader.__pl_saved_arg_names == ("dataset",) + assert dataloader.__dataset == range(10) + assert dataloader.custom_arg == 7 + assert dataloader.my_arg == 10 + assert dataloader.another_arg == 100 + assert not hasattr(dataloader, "dataset") + assert dataloader.__pl_attrs_record == [ + (("custom_arg", 5), _WrapAttrTag.SET), + (("my_arg", 10), _WrapAttrTag.SET), + (("another_arg", 100), _WrapAttrTag.SET), + (("dataset",), _WrapAttrTag.DEL), + ] + + dataloader = _update_dataloader(dataloader, dataloader.sampler) + assert dataloader.custom_arg == 7 + assert dataloader.my_arg == 10 + assert dataloader.another_arg == 100 + assert not hasattr(dataloader, "dataset") + + +def test_replace_dunder_methods_restore_methods(): + """This tests checks whether are all dunder methods restored to their original versions.""" + + class Init(DataLoader): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + class SetAttr(DataLoader): + def __setattr__(self, *args): + return super().__setattr__(*args) + + class DelAttr(DataLoader): + def __delattr__(self, *args): + return super().__delattr__(*args) + + class InitAndSetAttr(Init, SetAttr): + pass + + class InitAndDelAttr(Init, DelAttr): + pass + + class SetAttrAndDelAttr(SetAttr, DelAttr): + pass + + class AllDunder(Init, SetAttr, DelAttr): + pass + + before = dict() + for cls in (Init, SetAttr, DelAttr, InitAndSetAttr, InitAndDelAttr, SetAttrAndDelAttr, AllDunder): + before[cls] = {"init": cls.__init__, "setattr": cls.__setattr__, "delattr": cls.__delattr__} + + with _replace_dunder_methods(DataLoader, "dataset"): + pass + + for cls in (Init, SetAttr, DelAttr, InitAndSetAttr, InitAndDelAttr, SetAttrAndDelAttr, AllDunder): + assert before[cls] == {"init": cls.__init__, "setattr": cls.__setattr__, "delattr": cls.__delattr__} + + +@pytest.mark.parametrize( + [ + "args", + "kwargs", + "default_kwargs", + "arg_names", + "replace_key", + "replace_value", + "expected_status", + "expected_args", + "expected_kwargs", + ], + [ + pytest.param((), {}, {}, [], "a", 1, False, (), {}, id="empty"), + pytest.param((1,), {}, {}, ["a"], "a", 2, True, (2,), {}, id="simple1"), + pytest.param((1, 2, 3), {}, {}, ["a", "b", "c"], "b", False, True, (1, False, 3), {}, id="simple2"), + pytest.param((1, 2, 3), {"a": 1}, {}, ["b", "c", "d"], "a", 2, True, (1, 2, 3), {"a": 2}, id="simple_kwargs"), + pytest.param( + (1, 2, 3), + {"a": 1}, + {"e": 5}, + ["b", "c", "d"], + "e", + 2, + True, + (1, 2, 3), + {"a": 1, "e": 2}, + id="default_kwargs", + ), + ], +) +def test_replace_value_in_args( + args, kwargs, default_kwargs, arg_names, replace_key, replace_value, expected_status, expected_args, expected_kwargs +): + assert _replace_value_in_saved_args(replace_key, replace_value, args, kwargs, default_kwargs, arg_names) == ( + expected_status, + expected_args, + expected_kwargs, + ) + + +def test_update_dataloader_typerror_custom_exception(): + class BadStandaloneGoodHookImpl(DataLoader): + def __init__(self, foo, *args, **kwargs): + self.foo = foo + # positional conflict with `dataset` + super().__init__(foo, *args, **kwargs) + + dataloader = BadStandaloneGoodHookImpl([1, 2, 3]) + with pytest.raises(MisconfigurationException, match="implementation has an error.*`dataset`"): + _update_dataloader(dataloader, dataloader.sampler) + + with _replace_dunder_methods(DataLoader, "dataset"): + dataloader = BadStandaloneGoodHookImpl([1, 2, 3]) + new_dataloader = _update_dataloader(dataloader, dataloader.sampler) + assert isinstance(new_dataloader, BadStandaloneGoodHookImpl) + + class BadImpl(DataLoader): + def __init__(self, randomize, *args, **kwargs): + self.randomize = randomize + # keyword conflict with `shuffle` + super().__init__(*args, shuffle=randomize, **kwargs) + + dataloader = BadImpl(False, []) + with pytest.raises(MisconfigurationException, match="implementation has an error.*`shuffle`"): + _update_dataloader(dataloader, dataloader.sampler) + + class GoodImpl(DataLoader): + def __init__(self, randomize, *args, **kwargs): + # fixed implementation, kwargs are filtered + self.randomize = randomize or kwargs.pop("shuffle", False) + super().__init__(*args, shuffle=randomize, **kwargs) + + dataloader = GoodImpl(False, []) + new_dataloader = _update_dataloader(dataloader, dataloader.sampler) + assert isinstance(new_dataloader, GoodImpl) + + +def test_custom_batch_sampler(): + """This test asserts, that custom `BatchSampler`, with all the arguments, that are required in order to + properly reinstantiate the class, is invoked properly. + + It also asserts, that during the reinstantiation, the wrapper of `__init__` method is not present anymore, therefore + not setting `__pl_saved_{args,arg_names,kwargs}` attributes. + """ + + class MyBatchSampler(BatchSampler): + # Custom Batch sampler with extra argument and default value + def __init__(self, sampler, extra_arg, drop_last=True): + self.extra_arg = extra_arg + super().__init__(sampler, 10, drop_last) + + sampler = RandomSampler(range(10)) + with _replace_dunder_methods(BatchSampler): + # instantiate within `_replace_dunder_method` context manager, simulating `*_dataloader` hooks + batch_sampler = MyBatchSampler(sampler, "random_str") + + dataloader = DataLoader(range(10), batch_sampler=batch_sampler) + + # assert that passed information got saved + assert dataloader.batch_sampler.__pl_saved_args == (sampler, "random_str") + assert dataloader.batch_sampler.__pl_saved_kwargs == {} + assert dataloader.batch_sampler.__pl_saved_arg_names == ("sampler", "extra_arg") + assert dataloader.batch_sampler.__pl_saved_default_kwargs == {"drop_last": True} + + # updating dataloader, what happens on access of the dataloaders. + # This should not fail, and would fail before support for custom args. + dataloader = _update_dataloader(dataloader, dataloader.sampler) + + # Assert the `__init__` method is not replaced anymore and everything is instantiated to correct types + batch_sampler = dataloader.batch_sampler + + assert isinstance(batch_sampler, MyBatchSampler) + + assert batch_sampler.extra_arg == "random_str" + assert not hasattr(batch_sampler, "__pl_saved_kwargs") + assert not hasattr(batch_sampler, "__pl_saved_arg_names") + assert not hasattr(batch_sampler, "__pl_saved_args") + assert not hasattr(batch_sampler, "__pl_saved_default_kwargs") + + +def test_custom_batch_sampler_no_sampler(): + """Tests whether appropriate error is raised when the custom `BatchSampler` does not support sampler + argument.""" + + class MyBatchSampler(BatchSampler): + # Custom batch sampler, without sampler argument. + def __init__(self, extra_arg): + self.extra_arg = extra_arg + super().__init__(RandomSampler(range(10)), 10, False) + + with _replace_dunder_methods(BatchSampler): + # instantiate within `_replace_dunder_method` context manager, simulating `*_dataloader` hooks + batch_sampler = MyBatchSampler("random_str") + dataloader = DataLoader(range(10), batch_sampler=batch_sampler) + + # assert that passed information got saved + assert dataloader.batch_sampler.__pl_saved_args == ("random_str",) + assert dataloader.batch_sampler.__pl_saved_kwargs == {} + assert dataloader.batch_sampler.__pl_saved_arg_names == ("extra_arg",) + assert dataloader.batch_sampler.__pl_saved_default_kwargs == {} + + # Assert that error is raised + with pytest.raises(TypeError, match="sampler into the batch sampler"): + dataloader = _update_dataloader(dataloader, dataloader.sampler) + + +def test_dataloader_disallow_batch_sampler(): + dataset = RandomDataset(5, 100) + dataloader = DataLoader(dataset, batch_size=10) + + # This should not raise + _dataloader_init_kwargs_resolve_sampler(dataloader, dataloader.sampler, disallow_batch_sampler=True) + + dataset = RandomDataset(5, 100) + sampler = SequentialSampler(dataset) + batch_sampler = BatchSampler(sampler, batch_size=10, drop_last=False) + dataloader = DataLoader(dataset, batch_sampler=batch_sampler) + + # this should raise - using batch sampler, that was not automatically instantiated by DataLoader + with pytest.raises(MisconfigurationException, match="when running on multiple IPU devices"): + _dataloader_init_kwargs_resolve_sampler(dataloader, dataloader.sampler, disallow_batch_sampler=True) + + +def test_dataloader_kwargs_replacement_with_iterable_dataset(): + """Test that DataLoader kwargs are not replaced when using Iterable Dataset.""" + dataset = RandomIterableDataset(7, 100) + dataloader = DataLoader(dataset, batch_size=32) + _, dl_kwargs = _get_dataloader_init_args_and_kwargs(dataloader, dataloader.sampler) + assert dl_kwargs["sampler"] is None + assert dl_kwargs["batch_sampler"] is None + assert dl_kwargs["batch_size"] is dataloader.batch_size + assert dl_kwargs["dataset"] is dataloader.dataset + assert dl_kwargs["collate_fn"] is dataloader.collate_fn diff --git a/tests/tests_lite/utilities/test_device_parser.py b/tests/tests_lite/utilities/test_device_parser.py new file mode 100644 index 0000000000..bb6e1665ef --- /dev/null +++ b/tests/tests_lite/utilities/test_device_parser.py @@ -0,0 +1,31 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from unittest import mock + +import pytest +import torch + +import lightning_lite.utilities.device_parser + + +@pytest.mark.skipif( + "fork" in torch.multiprocessing.get_all_start_methods(), reason="Requires platform without forking support" +) +@mock.patch("torch.cuda.is_available", return_value=True) +@mock.patch("torch.cuda.device_count", return_value=2) +def test_num_cuda_devices_without_forking(*_): + """This merely tests that on platforms without fork support our helper functions fall back to the default + implementation for determining cuda availability.""" + assert lightning_lite.utilities.device_parser.is_cuda_available() + assert lightning_lite.utilities.device_parser.num_cuda_devices() == 2 diff --git a/tests/tests_lite/utilities/test_distributed.py b/tests/tests_lite/utilities/test_distributed.py new file mode 100644 index 0000000000..b09c0487bd --- /dev/null +++ b/tests/tests_lite/utilities/test_distributed.py @@ -0,0 +1,63 @@ +import os + +import pytest +import tests_lite.helpers.utils as tutils +import torch +from tests_lite.helpers.runif import RunIf +from torch import multiprocessing as mp + +from lightning_lite.utilities.distributed import gather_all_tensors + + +def _test_all_gather_uneven_tensors(rank, world_size, backend): + os.environ["MASTER_ADDR"] = "localhost" + + if backend == "nccl": + device = torch.device("cuda", rank) + torch.cuda.set_device(device) + else: + device = torch.device("cpu") + + # initialize the process group + torch.distributed.init_process_group(backend, rank=rank, world_size=world_size) + + tensor = torch.ones(rank, device=device) + result = gather_all_tensors(tensor) + assert len(result) == world_size + for idx in range(world_size): + assert len(result[idx]) == idx + assert (result[idx] == torch.ones_like(result[idx])).all() + + +def _test_all_gather_uneven_tensors_multidim(rank, world_size, backend): + os.environ["MASTER_ADDR"] = "localhost" + + if backend == "nccl": + device = torch.device("cuda", rank) + torch.cuda.set_device(device) + else: + device = torch.device("cpu") + + # initialize the process group + torch.distributed.init_process_group(backend, rank=rank, world_size=world_size) + tensor = torch.ones(rank + 1, 2 - rank, device=device) + result = gather_all_tensors(tensor) + assert len(result) == world_size + for idx in range(world_size): + val = result[idx] + assert val.shape == (idx + 1, 2 - idx) + assert (val == torch.ones_like(val)).all() + + +@RunIf(min_torch="1.10", skip_windows=True) +@pytest.mark.parametrize( + "process", + [ + _test_all_gather_uneven_tensors_multidim, + _test_all_gather_uneven_tensors, + ], +) +@pytest.mark.parametrize("backend", [pytest.param("nccl", marks=RunIf(min_cuda_gpus=2)), "gloo"]) +def test_gather_all_tensors(backend, process): + tutils.set_random_main_port() + mp.spawn(process, args=(2, backend), nprocs=2) diff --git a/tests/tests_lite/utilities/test_enums.py b/tests/tests_lite/utilities/test_enums.py new file mode 100644 index 0000000000..38a556e5dc --- /dev/null +++ b/tests/tests_lite/utilities/test_enums.py @@ -0,0 +1,9 @@ +from lightning_lite.utilities.enums import PrecisionType + + +def test_precision_supported_types(): + assert PrecisionType.supported_types() == ["16", "32", "64", "bf16", "mixed"] + assert PrecisionType.supported_type(16) + assert PrecisionType.supported_type("16") + assert not PrecisionType.supported_type(1) + assert not PrecisionType.supported_type("invalid") diff --git a/tests/tests_lite/utilities/test_imports.py b/tests/tests_lite/utilities/test_imports.py new file mode 100644 index 0000000000..3a8444ef72 --- /dev/null +++ b/tests/tests_lite/utilities/test_imports.py @@ -0,0 +1,81 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from lightning_lite.utilities.imports import ( + _APEX_AVAILABLE, + _FAIRSCALE_AVAILABLE, + _HOROVOD_AVAILABLE, + _OMEGACONF_AVAILABLE, + _POPTORCH_AVAILABLE, +) + + +def test_imports(): + try: + import apex # noqa + except ModuleNotFoundError: + assert not _APEX_AVAILABLE + else: + assert _APEX_AVAILABLE + + # TODO(lite): re-enable these once deepspeed strategy has moved + # try: + # import deepspeed + # except ModuleNotFoundError: + # assert not _DEEPSPEED_AVAILABLE + # else: + # assert _DEEPSPEED_AVAILABLE + + try: + import fairscale.nn # noqa + except ModuleNotFoundError: + assert not _FAIRSCALE_AVAILABLE + else: + assert _FAIRSCALE_AVAILABLE + + try: + import horovod.torch # noqa + except ModuleNotFoundError: + assert not _HOROVOD_AVAILABLE + else: + assert _HOROVOD_AVAILABLE + + try: + import omegaconf # noqa + except ModuleNotFoundError: + assert not _OMEGACONF_AVAILABLE + else: + assert _OMEGACONF_AVAILABLE + + try: + import poptorch # noqa + except ModuleNotFoundError: + assert not _POPTORCH_AVAILABLE + else: + assert _POPTORCH_AVAILABLE diff --git a/tests/tests_pytorch/utilities/test_optimizer.py b/tests/tests_lite/utilities/test_optimizer.py similarity index 93% rename from tests/tests_pytorch/utilities/test_optimizer.py rename to tests/tests_lite/utilities/test_optimizer.py index 6d4c0ec54e..09a37a6403 100644 --- a/tests/tests_pytorch/utilities/test_optimizer.py +++ b/tests/tests_lite/utilities/test_optimizer.py @@ -2,7 +2,7 @@ import collections import torch -from pytorch_lightning.utilities.optimizer import optimizer_to_device +from lightning_lite.utilities.optimizer import optimizer_to_device def test_optimizer_to_device(): diff --git a/tests/tests_pytorch/utilities/test_rank_zero.py b/tests/tests_lite/utilities/test_rank_zero.py similarity index 65% rename from tests/tests_pytorch/utilities/test_rank_zero.py rename to tests/tests_lite/utilities/test_rank_zero.py index c4c15b28e5..edf85d7b34 100644 --- a/tests/tests_pytorch/utilities/test_rank_zero.py +++ b/tests/tests_lite/utilities/test_rank_zero.py @@ -1,23 +1,10 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. import os import sys from unittest import mock import pytest -from pytorch_lightning.utilities.rank_zero import _get_rank +from lightning_lite.utilities.rank_zero import _get_rank @pytest.mark.parametrize( @@ -39,8 +26,8 @@ def test_rank_zero_known_environment_variables(env_vars, expected): with mock.patch.dict(os.environ, env_vars): # force module reload to re-trigger the rank_zero_only.rank global computation sys.modules.pop("lightning_utilities.core.rank_zero", None) - sys.modules.pop("pytorch_lightning.utilities.rank_zero", None) - from pytorch_lightning.utilities.rank_zero import rank_zero_only + sys.modules.pop("lightning_lite.utilities.rank_zero", None) + from lightning_lite.utilities.rank_zero import rank_zero_only @rank_zero_only def foo(): diff --git a/tests/tests_lite/utilities/test_seed.py b/tests/tests_lite/utilities/test_seed.py new file mode 100644 index 0000000000..b03aa6d049 --- /dev/null +++ b/tests/tests_lite/utilities/test_seed.py @@ -0,0 +1,84 @@ +import os +from unittest import mock + +import pytest +import torch + +import lightning_lite.utilities +from lightning_lite.utilities import seed as seed_utils +from lightning_lite.utilities.seed import _collect_rng_states, _set_rng_states + + +@mock.patch.dict(os.environ, {}, clear=True) +def test_seed_stays_same_with_multiple_seed_everything_calls(): + """Ensure that after the initial seed everything, the seed stays the same for the same run.""" + with pytest.warns(UserWarning, match="No seed found"): + lightning_lite.utilities.seed.seed_everything() + initial_seed = os.environ.get("PL_GLOBAL_SEED") + + with pytest.warns(None) as record: + lightning_lite.utilities.seed.seed_everything() + assert not record # does not warn + seed = os.environ.get("PL_GLOBAL_SEED") + + assert initial_seed == seed + + +@mock.patch.dict(os.environ, {"PL_GLOBAL_SEED": "2020"}, clear=True) +def test_correct_seed_with_environment_variable(): + """Ensure that the PL_GLOBAL_SEED environment is read.""" + assert lightning_lite.utilities.seed.seed_everything() == 2020 + + +@mock.patch.dict(os.environ, {"PL_GLOBAL_SEED": "invalid"}, clear=True) +@mock.patch.object(seed_utils, attribute="_select_seed_randomly", new=lambda *_: 123) +def test_invalid_seed(): + """Ensure that we still fix the seed even if an invalid seed is given.""" + with pytest.warns(UserWarning, match="Invalid seed found"): + seed = lightning_lite.utilities.seed.seed_everything() + assert seed == 123 + + +@mock.patch.dict(os.environ, {}, clear=True) +@mock.patch.object(seed_utils, attribute="_select_seed_randomly", new=lambda *_: 123) +@pytest.mark.parametrize("seed", (10e9, -10e9)) +def test_out_of_bounds_seed(seed): + """Ensure that we still fix the seed even if an out-of-bounds seed is given.""" + with pytest.warns(UserWarning, match="is not in bounds"): + actual = lightning_lite.utilities.seed.seed_everything(seed) + assert actual == 123 + + +def test_reset_seed_no_op(): + """Test that the reset_seed function is a no-op when seed_everything() was not used.""" + assert "PL_GLOBAL_SEED" not in os.environ + seed_before = torch.initial_seed() + lightning_lite.utilities.seed.reset_seed() + assert torch.initial_seed() == seed_before + assert "PL_GLOBAL_SEED" not in os.environ + + +@pytest.mark.parametrize("workers", (True, False)) +def test_reset_seed_everything(workers): + """Test that we can reset the seed to the initial value set by seed_everything()""" + assert "PL_GLOBAL_SEED" not in os.environ + assert "PL_SEED_WORKERS" not in os.environ + + lightning_lite.utilities.seed.seed_everything(123, workers) + before = torch.rand(1) + assert os.environ["PL_GLOBAL_SEED"] == "123" + assert os.environ["PL_SEED_WORKERS"] == str(int(workers)) + + lightning_lite.utilities.seed.reset_seed() + after = torch.rand(1) + assert os.environ["PL_GLOBAL_SEED"] == "123" + assert os.environ["PL_SEED_WORKERS"] == str(int(workers)) + assert torch.allclose(before, after) + + +def test_backward_compatibility_rng_states_dict(): + """Test that an older rng_states_dict without the "torch.cuda" key does not crash.""" + states = _collect_rng_states() + assert "torch.cuda" in states + states.pop("torch.cuda") + _set_rng_states(states) diff --git a/tests/tests_lite/utilities/test_warnings.py b/tests/tests_lite/utilities/test_warnings.py new file mode 100644 index 0000000000..e951ff53ea --- /dev/null +++ b/tests/tests_lite/utilities/test_warnings.py @@ -0,0 +1,78 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Test that the warnings actually appear and they have the correct `stacklevel` + +Needs to be run outside of `pytest` as it captures all the warnings. +""" +from contextlib import redirect_stderr +from io import StringIO + +from lightning_utilities.core.rank_zero import _warn, WarningCache + +from lightning_lite.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn + +if __name__ == "__main__": + stderr = StringIO() + # recording + with redirect_stderr(stderr): + _warn("test1") + _warn("test2", category=DeprecationWarning) + + rank_zero_warn("test3") + rank_zero_warn("test4", category=DeprecationWarning) + + rank_zero_deprecation("test5") + + cache = WarningCache() + cache.warn("test6") + cache.deprecation("test7") + + output = stderr.getvalue() + assert "test_warnings.py:29: UserWarning: test1" in output + assert "test_warnings.py:30: DeprecationWarning: test2" in output + + assert "test_warnings.py:32: UserWarning: test3" in output + assert "test_warnings.py:33: DeprecationWarning: test4" in output + + assert "test_warnings.py:35: LightningDeprecationWarning: test5" in output + + assert "test_warnings.py:38: UserWarning: test6" in output + assert "test_warnings.py:39: LightningDeprecationWarning: test7" in output + + # check that logging is properly configured + import logging + + root_logger = logging.getLogger() + lightning_logger = logging.getLogger("lightning_lite") + # should have a `StreamHandler` + assert lightning_logger.hasHandlers() and len(lightning_logger.handlers) == 1 + # set our own stream for testing + handler = lightning_logger.handlers[0] + assert isinstance(handler, logging.StreamHandler) + stderr = StringIO() + # necessary with `propagate = False` + lightning_logger.handlers[0].stream = stderr + + # necessary with `propagate = True` + with redirect_stderr(stderr): + # Lightning should not configure the root `logging` logger by default + logging.info("test1") + root_logger.info("test1") + # but our logger instance + lightning_logger.info("test2") + # level is set to INFO + lightning_logger.debug("test3") + + output = stderr.getvalue() + assert output == "test2\n", repr(output) diff --git a/tests/tests_lite/utilities/test_xla_device_utils.py b/tests/tests_lite/utilities/test_xla_device_utils.py index d8f6003c6a..87c92b772c 100644 --- a/tests/tests_lite/utilities/test_xla_device_utils.py +++ b/tests/tests_lite/utilities/test_xla_device_utils.py @@ -15,10 +15,10 @@ import time from unittest.mock import patch import pytest +from tests_lite.helpers.runif import RunIf import lightning_lite.utilities.xla_device as xla_utils -from pytorch_lightning.utilities import _XLA_AVAILABLE -from tests_pytorch.helpers.runif import RunIf +from lightning_lite.utilities.imports import _XLA_AVAILABLE @pytest.mark.skipif(_XLA_AVAILABLE, reason="test requires torch_xla to be absent") @@ -38,7 +38,7 @@ def sleep_fn(sleep_time: float) -> bool: return True -@patch("pytorch_lightning.utilities.xla_device.TPU_CHECK_TIMEOUT", 3) +@patch("lightning_lite.utilities.xla_device.TPU_CHECK_TIMEOUT", 3) @pytest.mark.skipif(not _XLA_AVAILABLE, reason="test requires torch_xla to be present") def test_result_returns_within_timeout_seconds(): """Check that pl_multi_process returns within 3 seconds.""" diff --git a/tests/tests_pytorch/accelerators/test_common.py b/tests/tests_pytorch/accelerators/test_common.py index 8c4ac8f3fd..05fb76f1cc 100644 --- a/tests/tests_pytorch/accelerators/test_common.py +++ b/tests/tests_pytorch/accelerators/test_common.py @@ -18,7 +18,7 @@ from pytorch_lightning.accelerators import Accelerator, CPUAccelerator, CUDAAcce from pytorch_lightning.strategies import DDPStrategy -@mock.patch("pytorch_lightning.utilities.device_parser.num_cuda_devices", return_value=2) +@mock.patch("lightning_lite.utilities.device_parser.num_cuda_devices", return_value=2) def test_auto_device_count(_): assert CPUAccelerator.auto_device_count() == 1 assert CUDAAccelerator.auto_device_count() == 2 diff --git a/tests/tests_pytorch/core/test_metric_result_integration.py b/tests/tests_pytorch/core/test_metric_result_integration.py index 9672bb75b5..8eb4abca00 100644 --- a/tests/tests_pytorch/core/test_metric_result_integration.py +++ b/tests/tests_pytorch/core/test_metric_result_integration.py @@ -27,6 +27,7 @@ from torchmetrics import Metric, MetricCollection import pytorch_lightning as pl import tests_pytorch.helpers.utils as tutils +from lightning_lite.utilities.warnings import PossibleUserWarning from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.demos.boring_classes import BoringModel @@ -36,7 +37,6 @@ from pytorch_lightning.trainer.connectors.logger_connector.result import ( _ResultMetric, _Sync, ) -from pytorch_lightning.utilities.warnings import PossibleUserWarning from tests_pytorch.helpers.runif import RunIf from tests_pytorch.helpers.utils import no_warning_call diff --git a/tests/tests_pytorch/core/test_results.py b/tests/tests_pytorch/core/test_results.py index dc4c2ac065..543437c28f 100644 --- a/tests/tests_pytorch/core/test_results.py +++ b/tests/tests_pytorch/core/test_results.py @@ -16,8 +16,8 @@ import torch.distributed as dist import torch.multiprocessing as mp import tests_pytorch.helpers.utils as tutils +from lightning_lite.utilities.distributed import sync_ddp_if_available from pytorch_lightning.trainer.connectors.logger_connector.result import _Sync -from pytorch_lightning.utilities.distributed import sync_ddp_if_available from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/deprecated_api/test_remove_1-10.py b/tests/tests_pytorch/deprecated_api/test_remove_1-10.py index 2193085255..a48c6a7884 100644 --- a/tests/tests_pytorch/deprecated_api/test_remove_1-10.py +++ b/tests/tests_pytorch/deprecated_api/test_remove_1-10.py @@ -17,13 +17,15 @@ from unittest import mock import numpy import pytest import torch +from torch.utils.data import DataLoader from pytorch_lightning import Trainer from pytorch_lightning.core.mixins.device_dtype_mixin import DeviceDtypeModuleMixin -from pytorch_lightning.demos.boring_classes import BoringModel +from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset from pytorch_lightning.overrides import LightningDistributedModule, LightningParallelModule from pytorch_lightning.overrides.base import unwrap_lightning_module from pytorch_lightning.overrides.fairscale import LightningShardedDataParallel, unwrap_lightning_module_sharded +from pytorch_lightning.plugins.environments import LightningEnvironment from pytorch_lightning.strategies.bagua import LightningBaguaModule from pytorch_lightning.strategies.deepspeed import LightningDeepSpeedModule from pytorch_lightning.strategies.ipu import LightningIPUModule @@ -38,6 +40,27 @@ from pytorch_lightning.utilities.apply_func import ( TransferableDataType, ) from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem, load +from pytorch_lightning.utilities.data import has_iterable_dataset, has_len +from pytorch_lightning.utilities.device_parser import ( + determine_root_gpu_device, + is_cuda_available, + num_cuda_devices, + parse_cpu_cores, + parse_gpu_ids, + parse_tpu_cores, +) +from pytorch_lightning.utilities.distributed import ( + all_gather_ddp_if_available, + distributed_available, + gather_all_tensors, + get_default_process_group_backend_for_device, + init_dist_connection, + sync_ddp, + sync_ddp_if_available, + tpu_distributed, +) +from pytorch_lightning.utilities.optimizer import optimizer_to_device, optimizers_to_device +from pytorch_lightning.utilities.seed import pl_worker_init_function, reset_seed, seed_everything from pytorch_lightning.utilities.xla_device import inner_f, pl_multi_process, XLADeviceUtils from tests_pytorch.helpers.runif import RunIf from tests_pytorch.helpers.utils import no_warning_call @@ -112,17 +135,6 @@ def test_v1_10_deprecated_xla_device_utilities(): XLADeviceUtils.tpu_device_exists() -def test_v1_10_deprecated_cloud_io_utilities(tmpdir): - with pytest.deprecated_call(match="cloud_io.atomic_save` has been deprecated in v1.8.0"): - atomic_save({}, tmpdir / "atomic_save.ckpt") - - with pytest.deprecated_call(match="cloud_io.get_filesystem` has been deprecated in v1.8.0"): - get_filesystem(tmpdir) - - with pytest.deprecated_call(match="cloud_io.load` has been deprecated in v1.8.0"): - load(str(tmpdir / "atomic_save.ckpt")) - - def test_v1_10_deprecated_apply_func_utilities(): with pytest.deprecated_call(match="apply_func.apply_to_collection` has been deprecated in v1.8.0"): apply_to_collection([], dtype=object, function=(lambda x: x)) @@ -147,3 +159,94 @@ def test_v1_10_deprecated_apply_func_utilities(): with pytest.deprecated_call(match="apply_func.TransferableDataType` has been deprecated in v1.8.0"): MyModule() + + +def test_v1_10_deprecated_cloud_io_utilities(tmpdir): + with pytest.deprecated_call(match="cloud_io.atomic_save` has been deprecated in v1.8.0"): + atomic_save({}, tmpdir / "atomic_save.ckpt") + + with pytest.deprecated_call(match="cloud_io.get_filesystem` has been deprecated in v1.8.0"): + get_filesystem(tmpdir) + + with pytest.deprecated_call(match="cloud_io.load` has been deprecated in v1.8.0"): + load(str(tmpdir / "atomic_save.ckpt")) + + +def test_v1_10_deprecated_data_utilities(): + with pytest.deprecated_call(match="data.has_iterable_dataset` has been deprecated in v1.8.0"): + has_iterable_dataset(DataLoader(RandomDataset(2, 4))) + + with pytest.deprecated_call(match="data.has_len` has been deprecated in v1.8.0"): + has_len(DataLoader(RandomDataset(2, 4))) + + +def test_v1_10_deprecated_device_parser_utilities(): + with pytest.deprecated_call(match="device_parser.determine_root_gpu_device` has been deprecated in v1.8.0"): + determine_root_gpu_device(None) + + with pytest.deprecated_call(match="device_parser.is_cuda_available` has been deprecated in v1.8.0"): + is_cuda_available() + + with pytest.deprecated_call(match="device_parser.num_cuda_devices` has been deprecated in v1.8.0"): + num_cuda_devices() + + with pytest.deprecated_call(match="device_parser.parse_cpu_cores` has been deprecated in v1.8.0"): + parse_cpu_cores(1) + + with pytest.deprecated_call(match="device_parser.parse_gpu_ids` has been deprecated in v1.8.0"): + parse_gpu_ids(None) + + with pytest.deprecated_call(match="device_parser.parse_tpu_cores` has been deprecated in v1.8.0"): + parse_tpu_cores(None) + + +def test_v1_10_deprecated_distributed_utilities(): + with pytest.deprecated_call(match="distributed.all_gather_ddp_if_available` has been deprecated in v1.8.0"): + all_gather_ddp_if_available(torch.tensor(1)) + + with pytest.deprecated_call(match="distributed.distributed_available` has been deprecated in v1.8.0"): + distributed_available() + + with mock.patch("torch.distributed.get_world_size", return_value=2), mock.patch( + "torch.distributed.barrier" + ), mock.patch("torch.distributed.all_gather"): + with pytest.deprecated_call(match="distributed.gather_all_tensors` has been deprecated in v1.8.0"): + gather_all_tensors(torch.tensor(1)) + + with pytest.deprecated_call( + match="distributed.get_default_process_group_backend_for_device` has been deprecated in v1.8.0" + ): + get_default_process_group_backend_for_device(torch.device("cpu")) + + with mock.patch("torch.distributed.is_initialized", return_value=True): + with pytest.deprecated_call(match="distributed.init_dist_connection` has been deprecated in v1.8.0"): + init_dist_connection(LightningEnvironment(), "gloo") + + with pytest.deprecated_call(match="distributed.sync_ddp_if_available` has been deprecated in v1.8.0"): + sync_ddp_if_available(torch.tensor(1)) + + with mock.patch("torch.distributed.barrier"), mock.patch("torch.distributed.all_reduce"): + with pytest.deprecated_call(match="distributed.sync_ddp` has been deprecated in v1.8.0"): + sync_ddp(torch.tensor(1)) + + with pytest.deprecated_call(match="distributed.tpu_distributed` has been deprecated in v1.8.0"): + tpu_distributed() + + +def test_v1_10_deprecated_optimizer_utilities(): + with pytest.deprecated_call(match="optimizer.optimizers_to_device` has been deprecated in v1.8.0"): + optimizers_to_device([torch.optim.Adam(torch.nn.Linear(1, 1).parameters())], "cpu") + + with pytest.deprecated_call(match="optimizer.optimizer_to_device` has been deprecated in v1.8.0"): + optimizer_to_device(torch.optim.Adam(torch.nn.Linear(1, 1).parameters()), "cpu") + + +def test_v1_10_deprecated_seed_utilities(): + with pytest.deprecated_call(match="seed.seed_everything` has been deprecated in v1.8.0"): + seed_everything(1) + + with pytest.deprecated_call(match="seed.reset_seed` has been deprecated in v1.8.0"): + reset_seed() + + with pytest.deprecated_call(match="seed.pl_worker_init_function` has been deprecated in v1.8.0"): + pl_worker_init_function(0) diff --git a/tests/tests_pytorch/deprecated_api/test_remove_1-8.py b/tests/tests_pytorch/deprecated_api/test_remove_1-8.py index 489ef38f0c..b9e36df94d 100644 --- a/tests/tests_pytorch/deprecated_api/test_remove_1-8.py +++ b/tests/tests_pytorch/deprecated_api/test_remove_1-8.py @@ -22,6 +22,7 @@ import pytest import torch import pytorch_lightning +from lightning_lite.utilities import device_parser from pytorch_lightning import Callback, Trainer from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.demos.boring_classes import BoringDataModule, BoringModel @@ -33,7 +34,6 @@ from pytorch_lightning.strategies import ParallelStrategy from pytorch_lightning.strategies.ipu import LightningIPUModule from pytorch_lightning.trainer.configuration_validator import _check_datamodule_checkpoint_hooks from pytorch_lightning.trainer.states import RunningStage -from pytorch_lightning.utilities import device_parser from pytorch_lightning.utilities.rank_zero import rank_zero_only from tests_pytorch.helpers.runif import RunIf @@ -547,7 +547,7 @@ def test_v1_8_0_lightning_module_use_amp(): @mock.patch.dict(os.environ, {"PL_TORCH_DISTRIBUTED_BACKEND": "foo"}) def test_v1_8_0_torch_distributed_backend_env(): - from pytorch_lightning.utilities.distributed import _get_process_group_backend_from_env + from lightning_lite.utilities.distributed import _get_process_group_backend_from_env with pytest.deprecated_call( match="Environment variable `PL_TORCH_DISTRIBUTED_BACKEND`" diff --git a/tests/tests_pytorch/deprecated_api/test_remove_2-0.py b/tests/tests_pytorch/deprecated_api/test_remove_2-0.py index b39c6dafc1..bd359cc323 100644 --- a/tests/tests_pytorch/deprecated_api/test_remove_2-0.py +++ b/tests/tests_pytorch/deprecated_api/test_remove_2-0.py @@ -28,8 +28,8 @@ def test_v2_0_0_deprecated_num_processes(): _ = Trainer(num_processes=2) -@mock.patch("pytorch_lightning.utilities.device_parser.is_cuda_available", return_value=True) -@mock.patch("pytorch_lightning.utilities.device_parser.num_cuda_devices", return_value=2) +@mock.patch("lightning_lite.utilities.device_parser.is_cuda_available", return_value=True) +@mock.patch("lightning_lite.utilities.device_parser.num_cuda_devices", return_value=2) def test_v2_0_0_deprecated_gpus(*_): with pytest.deprecated_call(match=r"is deprecated in v1.7 and will be removed in v2.0."): _ = Trainer(gpus=0) diff --git a/tests/tests_pytorch/lite/test_lite.py b/tests/tests_pytorch/lite/test_lite.py index d45046f249..e7b5c61a67 100644 --- a/tests/tests_pytorch/lite/test_lite.py +++ b/tests/tests_pytorch/lite/test_lite.py @@ -23,13 +23,13 @@ import torch.nn.functional from torch import nn from torch.utils.data import DataLoader, DistributedSampler, Sampler +from lightning_lite.utilities import _StrategyType +from lightning_lite.utilities.seed import pl_worker_init_function from pytorch_lightning.lite import LightningLite from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer from pytorch_lightning.plugins import PrecisionPlugin from pytorch_lightning.strategies import DeepSpeedStrategy, Strategy -from pytorch_lightning.utilities import _StrategyType from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.seed import pl_worker_init_function from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/models/test_gpu.py b/tests/tests_pytorch/models/test_gpu.py index 5eded60d20..1f15f2a596 100644 --- a/tests/tests_pytorch/models/test_gpu.py +++ b/tests/tests_pytorch/models/test_gpu.py @@ -21,11 +21,11 @@ import torch import tests_pytorch.helpers.pipelines as tpipes import tests_pytorch.helpers.utils as tutils +from lightning_lite.utilities import device_parser from pytorch_lightning import Trainer from pytorch_lightning.accelerators import CPUAccelerator, CUDAAccelerator from pytorch_lightning.demos.boring_classes import BoringModel from pytorch_lightning.plugins.environments import TorchElasticEnvironment -from pytorch_lightning.utilities import device_parser from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests_pytorch.helpers.datamodules import ClassifDataModule from tests_pytorch.helpers.runif import RunIf @@ -181,8 +181,8 @@ def test_parse_gpu_returns_none_when_no_devices_are_available(mocked_device_coun "TORCHELASTIC_RUN_ID": "1", }, ) -@mock.patch("pytorch_lightning.utilities.device_parser.num_cuda_devices", return_value=1) -@mock.patch("pytorch_lightning.utilities.device_parser.is_cuda_available", return_value=True) +@mock.patch("lightning_lite.utilities.device_parser.num_cuda_devices", return_value=1) +@mock.patch("lightning_lite.utilities.device_parser.is_cuda_available", return_value=True) @pytest.mark.parametrize("gpus", [[0, 1, 2], 2, "0", [0, 2]]) def test_torchelastic_gpu_parsing(mocked_device_count, mocked_is_available, gpus): """Ensure when using torchelastic and nproc_per_node is set to the default of 1 per GPU device That we omit diff --git a/tests/tests_pytorch/models/test_tpu.py b/tests/tests_pytorch/models/test_tpu.py index a41ba7429c..1265f6c90f 100644 --- a/tests/tests_pytorch/models/test_tpu.py +++ b/tests/tests_pytorch/models/test_tpu.py @@ -21,6 +21,7 @@ from torch.utils.data import DataLoader import tests_pytorch.helpers.pipelines as tpipes import tests_pytorch.helpers.utils as tutils +from lightning_lite.utilities.distributed import ReduceOp from pytorch_lightning import Trainer from pytorch_lightning.accelerators import TPUAccelerator from pytorch_lightning.callbacks import EarlyStopping @@ -29,7 +30,6 @@ from pytorch_lightning.strategies import TPUSpawnStrategy from pytorch_lightning.strategies.launchers.xla import _save_spawn from pytorch_lightning.trainer.connectors.logger_connector.result import _Sync from pytorch_lightning.utilities import _TPU_AVAILABLE -from pytorch_lightning.utilities.distributed import ReduceOp from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/overrides/test_distributed.py b/tests/tests_pytorch/overrides/test_distributed.py index 77c3eb40bf..bae31e8fc0 100644 --- a/tests/tests_pytorch/overrides/test_distributed.py +++ b/tests/tests_pytorch/overrides/test_distributed.py @@ -16,9 +16,9 @@ from typing import Iterable import pytest from torch.utils.data import BatchSampler, SequentialSampler +from lightning_lite.utilities.data import has_len from pytorch_lightning import seed_everything from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper, UnrepeatedDistributedSampler -from pytorch_lightning.utilities.data import has_len @pytest.mark.parametrize("shuffle", [False, True]) diff --git a/tests/tests_pytorch/plugins/test_amp_plugins.py b/tests/tests_pytorch/plugins/test_amp_plugins.py index 974964e5b9..a7efe0ec75 100644 --- a/tests/tests_pytorch/plugins/test_amp_plugins.py +++ b/tests/tests_pytorch/plugins/test_amp_plugins.py @@ -51,8 +51,8 @@ class MyApexPlugin(ApexMixedPrecisionPlugin): "SLURM_LOCALID": "0", }, ) -@mock.patch("pytorch_lightning.utilities.device_parser.is_cuda_available", return_value=True) -@mock.patch("pytorch_lightning.utilities.device_parser.num_cuda_devices", return_value=2) +@mock.patch("lightning_lite.utilities.device_parser.is_cuda_available", return_value=True) +@mock.patch("lightning_lite.utilities.device_parser.num_cuda_devices", return_value=2) @pytest.mark.parametrize("strategy,devices", [("ddp", 2), ("ddp_spawn", 2)]) @pytest.mark.parametrize( "amp,custom_plugin,plugin_cls", @@ -278,16 +278,16 @@ def test_precision_selection_raises(monkeypatch): with pytest.raises(MisconfigurationException, match=r"amp_type='apex', precision='bf16'\)` but it's not supported"): Trainer(amp_backend="apex", precision="bf16") - with mock.patch("pytorch_lightning.utilities.device_parser.num_cuda_devices", return_value=1), pytest.raises( + with mock.patch("lightning_lite.utilities.device_parser.num_cuda_devices", return_value=1), pytest.raises( MisconfigurationException, match="Sharded plugins are not supported with apex" ): - with mock.patch("pytorch_lightning.utilities.device_parser.is_cuda_available", return_value=True): + with mock.patch("lightning_lite.utilities.device_parser.is_cuda_available", return_value=True): Trainer(amp_backend="apex", precision=16, accelerator="gpu", devices=1, strategy="ddp_fully_sharded") import pytorch_lightning.plugins.precision.apex_amp as apex monkeypatch.setattr(apex, "_APEX_AVAILABLE", False) - with mock.patch("pytorch_lightning.utilities.device_parser.num_cuda_devices", return_value=1), mock.patch( - "pytorch_lightning.utilities.device_parser.is_cuda_available", return_value=True + with mock.patch("lightning_lite.utilities.device_parser.num_cuda_devices", return_value=1), mock.patch( + "lightning_lite.utilities.device_parser.is_cuda_available", return_value=True ), pytest.raises(MisconfigurationException, match="asked for Apex AMP but `apex` is not installed"): Trainer(amp_backend="apex", precision=16, accelerator="gpu", devices=1) diff --git a/tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py b/tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py index ae618ffa33..21a94d33bb 100644 --- a/tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py +++ b/tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py @@ -18,6 +18,7 @@ from unittest.mock import MagicMock, Mock import torch +from lightning_lite.utilities.types import _PATH from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.demos.boring_classes import BoringModel @@ -25,7 +26,6 @@ from pytorch_lightning.plugins.io.async_plugin import AsyncCheckpointIO from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.io.torch_plugin import TorchCheckpointIO from pytorch_lightning.strategies import SingleDeviceStrategy -from pytorch_lightning.utilities.types import _PATH class CustomCheckpointIO(CheckpointIO): diff --git a/tests/tests_pytorch/plugins/test_cluster_integration.py b/tests/tests_pytorch/plugins/test_cluster_integration.py index b9f39336d1..be8f87d643 100644 --- a/tests/tests_pytorch/plugins/test_cluster_integration.py +++ b/tests/tests_pytorch/plugins/test_cluster_integration.py @@ -85,8 +85,8 @@ def test_ranks_available_manual_strategy_selection(mock_gpu_acc_available, strat dict(strategy="ddp_spawn", accelerator="gpu", devices=[1, 2]), ], ) -@mock.patch("pytorch_lightning.utilities.device_parser.is_cuda_available", return_value=True) -@mock.patch("pytorch_lightning.utilities.device_parser.num_cuda_devices", return_value=4) +@mock.patch("lightning_lite.utilities.device_parser.is_cuda_available", return_value=True) +@mock.patch("lightning_lite.utilities.device_parser.num_cuda_devices", return_value=4) def test_ranks_available_automatic_strategy_selection(mock0, mock1, trainer_kwargs): """Test that the rank information is readily available after Trainer initialization.""" num_nodes = 2 diff --git a/tests/tests_pytorch/strategies/test_bagua_strategy.py b/tests/tests_pytorch/strategies/test_bagua_strategy.py index 79ec701964..3e9aba79dd 100644 --- a/tests/tests_pytorch/strategies/test_bagua_strategy.py +++ b/tests/tests_pytorch/strategies/test_bagua_strategy.py @@ -118,6 +118,6 @@ def test_bagua_not_available(monkeypatch): import pytorch_lightning.strategies.bagua as imports monkeypatch.setattr(imports, "_BAGUA_AVAILABLE", False) - with mock.patch("pytorch_lightning.utilities.device_parser.num_cuda_devices", return_value=1): + with mock.patch("lightning_lite.utilities.device_parser.num_cuda_devices", return_value=1): with pytest.raises(MisconfigurationException, match="you must have `Bagua` installed"): Trainer(strategy="bagua", accelerator="gpu", devices=1) diff --git a/tests/tests_pytorch/strategies/test_common.py b/tests/tests_pytorch/strategies/test_common.py index 479b222e25..d696ce8118 100644 --- a/tests/tests_pytorch/strategies/test_common.py +++ b/tests/tests_pytorch/strategies/test_common.py @@ -15,11 +15,11 @@ import pytest import torch import tests_pytorch.helpers.utils as tutils +from lightning_lite.utilities.seed import seed_everything from pytorch_lightning import Trainer from pytorch_lightning.demos.boring_classes import BoringModel from pytorch_lightning.strategies import DDPStrategy from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_12 -from pytorch_lightning.utilities.seed import seed_everything from tests_pytorch.helpers.datamodules import ClassifDataModule from tests_pytorch.helpers.runif import RunIf from tests_pytorch.strategies.test_dp import CustomClassificationModelDP diff --git a/tests/tests_pytorch/strategies/test_ddp.py b/tests/tests_pytorch/strategies/test_ddp.py index dbde198b6e..19317bfe30 100644 --- a/tests/tests_pytorch/strategies/test_ddp.py +++ b/tests/tests_pytorch/strategies/test_ddp.py @@ -60,12 +60,12 @@ def test_multi_gpu_model_ddp_fit_test(tmpdir): @RunIf(skip_windows=True) @pytest.mark.skipif(torch.cuda.is_available(), reason="test doesn't requires GPU machine") -@mock.patch("pytorch_lightning.utilities.device_parser.is_cuda_available", return_value=True) +@mock.patch("lightning_lite.utilities.device_parser.is_cuda_available", return_value=True) def test_torch_distributed_backend_env_variables(tmpdir): """This test set `undefined` as torch backend and should raise an `Backend.UNDEFINED` ValueError.""" _environ = {"PL_TORCH_DISTRIBUTED_BACKEND": "undefined", "CUDA_VISIBLE_DEVICES": "0,1", "WORLD_SIZE": "2"} with patch.dict(os.environ, _environ), patch( - "pytorch_lightning.utilities.device_parser.num_cuda_devices", return_value=2 + "lightning_lite.utilities.device_parser.num_cuda_devices", return_value=2 ): with pytest.deprecated_call(match="Environment variable `PL_TORCH_DISTRIBUTED_BACKEND` was deprecated in v1.6"): with pytest.raises(ValueError, match="Invalid backend: 'undefined'"): @@ -83,8 +83,8 @@ def test_torch_distributed_backend_env_variables(tmpdir): @RunIf(skip_windows=True) @mock.patch("torch.cuda.set_device") -@mock.patch("pytorch_lightning.utilities.device_parser.is_cuda_available", return_value=True) -@mock.patch("pytorch_lightning.utilities.device_parser.num_cuda_devices", return_value=1) +@mock.patch("lightning_lite.utilities.device_parser.is_cuda_available", return_value=True) +@mock.patch("lightning_lite.utilities.device_parser.num_cuda_devices", return_value=1) @mock.patch("pytorch_lightning.accelerators.gpu.CUDAAccelerator.is_available", return_value=True) @mock.patch.dict(os.environ, {"PL_TORCH_DISTRIBUTED_BACKEND": "gloo"}, clear=True) def test_ddp_torch_dist_is_available_in_setup( diff --git a/tests/tests_pytorch/strategies/test_ddp_fully_sharded_with_full_state_dict.py b/tests/tests_pytorch/strategies/test_ddp_fully_sharded_with_full_state_dict.py index 88a07a78ef..bb3b63ea57 100644 --- a/tests/tests_pytorch/strategies/test_ddp_fully_sharded_with_full_state_dict.py +++ b/tests/tests_pytorch/strategies/test_ddp_fully_sharded_with_full_state_dict.py @@ -29,8 +29,8 @@ def test_invalid_on_cpu(tmpdir): @mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0"}) -@mock.patch("pytorch_lightning.utilities.device_parser.num_cuda_devices", return_value=1) -@mock.patch("pytorch_lightning.utilities.device_parser.is_cuda_available", return_value=True) +@mock.patch("lightning_lite.utilities.device_parser.num_cuda_devices", return_value=1) +@mock.patch("lightning_lite.utilities.device_parser.is_cuda_available", return_value=True) @RunIf(fairscale=True) def test_fsdp_with_sharded_amp(device_count_mock, mock_cuda_available, tmpdir): """Test to ensure that plugin native amp plugin is correctly chosen when using sharded.""" diff --git a/tests/tests_pytorch/strategies/test_deepspeed_strategy.py b/tests/tests_pytorch/strategies/test_deepspeed_strategy.py index 857abaa8df..70af274e2f 100644 --- a/tests/tests_pytorch/strategies/test_deepspeed_strategy.py +++ b/tests/tests_pytorch/strategies/test_deepspeed_strategy.py @@ -169,7 +169,7 @@ def test_deepspeed_strategy_env(tmpdir, monkeypatch, deepspeed_config): @RunIf(deepspeed=True) -@mock.patch("pytorch_lightning.utilities.device_parser.num_cuda_devices", return_value=1) +@mock.patch("lightning_lite.utilities.device_parser.num_cuda_devices", return_value=1) @pytest.mark.parametrize( "amp_backend", ["native", pytest.param("apex", marks=RunIf(amp_apex=True))], diff --git a/tests/tests_pytorch/test_cli.py b/tests/tests_pytorch/test_cli.py index e37f799888..46fc7e9b62 100644 --- a/tests/tests_pytorch/test_cli.py +++ b/tests/tests_pytorch/test_cli.py @@ -197,8 +197,8 @@ def test_parse_args_parsing_complex_types(cli_args, expected, instantiate): ) def test_parse_args_parsing_gpus(monkeypatch, cli_args, expected_gpu): """Test parsing of gpus and instantiation of Trainer.""" - monkeypatch.setattr("pytorch_lightning.utilities.device_parser.num_cuda_devices", lambda: 2) - monkeypatch.setattr("pytorch_lightning.utilities.device_parser.is_cuda_available", lambda: True) + monkeypatch.setattr("lightning_lite.utilities.device_parser.num_cuda_devices", lambda: 2) + monkeypatch.setattr("lightning_lite.utilities.device_parser.is_cuda_available", lambda: True) cli_args = cli_args.split(" ") if cli_args else [] with mock.patch("sys.argv", ["any.py"] + cli_args): parser = LightningArgumentParser(add_help=False, parse_as_dict=False) diff --git a/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py b/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py index 434607cab7..6625f191c3 100644 --- a/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py @@ -98,7 +98,7 @@ def _test_strategy_choice_ddp_and_cpu(tmpdir, ddp_strategy_class): "SLURM_LOCALID": "0", }, ) -@mock.patch("pytorch_lightning.utilities.device_parser.num_cuda_devices", return_value=0) +@mock.patch("lightning_lite.utilities.device_parser.num_cuda_devices", return_value=0) def test_custom_cluster_environment_in_slurm_environment(_, tmpdir): """Test that we choose the custom cluster even when SLURM or TE flags are around.""" @@ -135,7 +135,7 @@ def test_custom_cluster_environment_in_slurm_environment(_, tmpdir): "SLURM_LOCALID": "0", }, ) -@mock.patch("pytorch_lightning.utilities.device_parser.num_cuda_devices", return_value=0) +@mock.patch("lightning_lite.utilities.device_parser.num_cuda_devices", return_value=0) @mock.patch("pytorch_lightning.strategies.DDPStrategy.setup_distributed", autospec=True) def test_custom_accelerator(device_count_mock, setup_distributed_mock): class Accel(Accelerator): @@ -194,7 +194,7 @@ def test_custom_accelerator(device_count_mock, setup_distributed_mock): "SLURM_LOCALID": "0", }, ) -@mock.patch("pytorch_lightning.utilities.device_parser.num_cuda_devices", return_value=0) +@mock.patch("lightning_lite.utilities.device_parser.num_cuda_devices", return_value=0) @mock.patch("pytorch_lightning.strategies.DDPStrategy.setup_distributed", autospec=True) def test_dist_backend_accelerator_mapping(*_): trainer = Trainer(fast_dev_run=True, strategy="ddp_spawn", accelerator="cpu", devices=2) @@ -203,7 +203,7 @@ def test_dist_backend_accelerator_mapping(*_): assert trainer.strategy.local_rank == 0 -@mock.patch("pytorch_lightning.utilities.device_parser.num_cuda_devices", return_value=2) +@mock.patch("lightning_lite.utilities.device_parser.num_cuda_devices", return_value=2) def test_ipython_incompatible_backend_error(_, monkeypatch): monkeypatch.setattr(pytorch_lightning.utilities, "_IS_INTERACTIVE", True) with pytest.raises(MisconfigurationException, match=r"strategy='ddp'\)`.*is not compatible"): @@ -220,7 +220,7 @@ def test_ipython_incompatible_backend_error(_, monkeypatch): Trainer(strategy="dp") -@mock.patch("pytorch_lightning.utilities.device_parser.num_cuda_devices", return_value=2) +@mock.patch("lightning_lite.utilities.device_parser.num_cuda_devices", return_value=2) def test_ipython_compatible_dp_strategy_gpu(_, monkeypatch): monkeypatch.setattr(pytorch_lightning.utilities, "_IS_INTERACTIVE", True) trainer = Trainer(strategy="dp", accelerator="gpu") @@ -253,8 +253,8 @@ def test_ipython_compatible_strategy_ddp_fork(monkeypatch): ], ) @pytest.mark.parametrize("devices", [1, 2]) -@mock.patch("pytorch_lightning.utilities.device_parser.is_cuda_available", return_value=True) -@mock.patch("pytorch_lightning.utilities.device_parser.num_cuda_devices", return_value=2) +@mock.patch("lightning_lite.utilities.device_parser.is_cuda_available", return_value=True) +@mock.patch("lightning_lite.utilities.device_parser.num_cuda_devices", return_value=2) def test_accelerator_choice_multi_node_gpu( mock_is_available, mock_device_count, tmpdir, strategy, strategy_class, devices ): @@ -284,8 +284,8 @@ def test_accelerator_cpu(_): Trainer(accelerator="cpu", gpus=1) -@mock.patch("pytorch_lightning.utilities.device_parser.num_cuda_devices", return_value=2) -@mock.patch("pytorch_lightning.utilities.device_parser.is_cuda_available", return_value=True) +@mock.patch("lightning_lite.utilities.device_parser.num_cuda_devices", return_value=2) +@mock.patch("lightning_lite.utilities.device_parser.is_cuda_available", return_value=True) @pytest.mark.parametrize("device_count", (["0"], [0, "1"], ["GPU"], [["0", "1"], [0, 1]], [False])) def test_accelererator_invalid_type_devices(mock_is_available, mock_device_count, device_count): with pytest.raises( @@ -449,8 +449,8 @@ def test_strategy_choice_ddp_fork_cpu(): @mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1"}) -@mock.patch("pytorch_lightning.utilities.device_parser.num_cuda_devices", return_value=2) -@mock.patch("pytorch_lightning.utilities.device_parser.is_cuda_available", return_value=True) +@mock.patch("lightning_lite.utilities.device_parser.num_cuda_devices", return_value=2) +@mock.patch("lightning_lite.utilities.device_parser.is_cuda_available", return_value=True) def test_strategy_choice_ddp(*_): trainer = Trainer(fast_dev_run=True, strategy="ddp", accelerator="gpu", devices=1) assert isinstance(trainer.accelerator, CUDAAccelerator) @@ -459,8 +459,8 @@ def test_strategy_choice_ddp(*_): @mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1"}) -@mock.patch("pytorch_lightning.utilities.device_parser.num_cuda_devices", return_value=2) -@mock.patch("pytorch_lightning.utilities.device_parser.is_cuda_available", return_value=True) +@mock.patch("lightning_lite.utilities.device_parser.num_cuda_devices", return_value=2) +@mock.patch("lightning_lite.utilities.device_parser.is_cuda_available", return_value=True) def test_strategy_choice_ddp_spawn(cuda_available_mock, device_count_mock): trainer = Trainer(fast_dev_run=True, strategy="ddp_spawn", accelerator="gpu", devices=1) assert isinstance(trainer.accelerator, CUDAAccelerator) @@ -505,10 +505,10 @@ def test_strategy_choice_ddp_slurm(setup_distributed_mock, strategy): }, ) @mock.patch("torch.cuda.set_device") -@mock.patch("pytorch_lightning.utilities.device_parser.num_cuda_devices", return_value=2) -@mock.patch("pytorch_lightning.utilities.device_parser.is_cuda_available", return_value=True) +@mock.patch("lightning_lite.utilities.device_parser.num_cuda_devices", return_value=2) +@mock.patch("lightning_lite.utilities.device_parser.is_cuda_available", return_value=True) @mock.patch("pytorch_lightning.strategies.DDPStrategy.setup_distributed", autospec=True) -@mock.patch("pytorch_lightning.utilities.device_parser.is_cuda_available", return_value=True) +@mock.patch("lightning_lite.utilities.device_parser.is_cuda_available", return_value=True) def test_strategy_choice_ddp_te(*_): trainer = Trainer(fast_dev_run=True, strategy="ddp", accelerator="gpu", devices=2) assert isinstance(trainer.accelerator, CUDAAccelerator) @@ -529,7 +529,7 @@ def test_strategy_choice_ddp_te(*_): "TORCHELASTIC_RUN_ID": "1", }, ) -@mock.patch("pytorch_lightning.utilities.device_parser.num_cuda_devices", return_value=0) +@mock.patch("lightning_lite.utilities.device_parser.num_cuda_devices", return_value=0) @mock.patch("pytorch_lightning.strategies.DDPStrategy.setup_distributed", autospec=True) def test_strategy_choice_ddp_cpu_te(*_): trainer = Trainer(fast_dev_run=True, strategy="ddp_spawn", accelerator="cpu", devices=2) @@ -552,8 +552,8 @@ def test_strategy_choice_ddp_cpu_te(*_): }, ) @mock.patch("torch.cuda.set_device") -@mock.patch("pytorch_lightning.utilities.device_parser.num_cuda_devices", return_value=1) -@mock.patch("pytorch_lightning.utilities.device_parser.is_cuda_available", return_value=True) +@mock.patch("lightning_lite.utilities.device_parser.num_cuda_devices", return_value=1) +@mock.patch("lightning_lite.utilities.device_parser.is_cuda_available", return_value=True) @mock.patch("pytorch_lightning.strategies.DDPStrategy.setup_distributed", autospec=True) def test_strategy_choice_ddp_kubeflow(*_): trainer = Trainer(fast_dev_run=True, strategy="ddp", accelerator="gpu", devices=1) @@ -574,7 +574,7 @@ def test_strategy_choice_ddp_kubeflow(*_): "RANK": "1", }, ) -@mock.patch("pytorch_lightning.utilities.device_parser.num_cuda_devices", return_value=0) +@mock.patch("lightning_lite.utilities.device_parser.num_cuda_devices", return_value=0) @mock.patch("pytorch_lightning.strategies.DDPStrategy.setup_distributed", autospec=True) def test_strategy_choice_ddp_cpu_kubeflow(*_): trainer = Trainer(fast_dev_run=True, strategy="ddp_spawn", accelerator="cpu", devices=2) @@ -596,7 +596,7 @@ def test_strategy_choice_ddp_cpu_kubeflow(*_): "SLURM_LOCALID": "0", }, ) -@mock.patch("pytorch_lightning.utilities.device_parser.num_cuda_devices", return_value=0) +@mock.patch("lightning_lite.utilities.device_parser.num_cuda_devices", return_value=0) @mock.patch("pytorch_lightning.strategies.DDPStrategy.setup_distributed", autospec=True) @pytest.mark.parametrize("strategy", ["ddp", DDPStrategy()]) def test_strategy_choice_ddp_cpu_slurm(device_count_mock, setup_distributed_mock, strategy): @@ -646,7 +646,7 @@ def test_unsupported_ipu_choice(mock_ipu_acc_avail, monkeypatch): Trainer(accelerator="ipu", precision=64) -@mock.patch("pytorch_lightning.utilities.device_parser.num_cuda_devices", return_value=0) +@mock.patch("lightning_lite.utilities.device_parser.num_cuda_devices", return_value=0) @mock.patch("pytorch_lightning.utilities.imports._TPU_AVAILABLE", return_value=False) @mock.patch("pytorch_lightning.utilities.imports._IPU_AVAILABLE", return_value=False) @mock.patch("pytorch_lightning.utilities.imports._HPU_AVAILABLE", return_value=False) @@ -655,8 +655,8 @@ def test_devices_auto_choice_cpu(*_): assert trainer.num_devices == 1 -@mock.patch("pytorch_lightning.utilities.device_parser.is_cuda_available", return_value=True) -@mock.patch("pytorch_lightning.utilities.device_parser.num_cuda_devices", return_value=2) +@mock.patch("lightning_lite.utilities.device_parser.is_cuda_available", return_value=True) +@mock.patch("lightning_lite.utilities.device_parser.num_cuda_devices", return_value=2) @RunIf(mps=False) def test_devices_auto_choice_gpu(is_gpu_available_mock, device_count_mock): trainer = Trainer(accelerator="auto", devices="auto") @@ -769,7 +769,7 @@ def test_gpu_accelerator_backend_choice(expected_accelerator_flag, expected_acce assert isinstance(trainer.accelerator, expected_accelerator_class) -@mock.patch("pytorch_lightning.utilities.device_parser.num_cuda_devices", return_value=1) +@mock.patch("lightning_lite.utilities.device_parser.num_cuda_devices", return_value=1) def test_gpu_accelerator_backend_choice_cuda(_): trainer = Trainer(accelerator="gpu") @@ -777,6 +777,8 @@ def test_gpu_accelerator_backend_choice_cuda(_): assert isinstance(trainer.accelerator, CUDAAccelerator) +# TODO(lite): remove skip once MPS utils have moved +@pytest.mark.skip(reason="Utils in Lite rely on MPS accelerator file, but refactor is not yet finished") @mock.patch("pytorch_lightning.accelerators.mps._MPS_AVAILABLE", return_value=True) @mock.patch("torch.device", return_value="mps") # necessary because torch doesn't allow creation of mps devices def test_gpu_accelerator_backend_choice_mps(*_): diff --git a/tests/tests_pytorch/trainer/connectors/test_data_connector.py b/tests/tests_pytorch/trainer/connectors/test_data_connector.py index 379a3248a1..703ce8f053 100644 --- a/tests/tests_pytorch/trainer/connectors/test_data_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_data_connector.py @@ -21,6 +21,7 @@ import pytest from torch import Tensor from torch.utils.data import BatchSampler, DataLoader, DistributedSampler, Sampler, SequentialSampler +from lightning_lite.utilities.warnings import PossibleUserWarning from pytorch_lightning import Trainer from pytorch_lightning.demos.boring_classes import BoringDataModule, BoringModel, RandomDataset from pytorch_lightning.overrides.distributed import DistributedSamplerWrapper @@ -30,7 +31,6 @@ from pytorch_lightning.trainer.states import RunningStage, TrainerFn from pytorch_lightning.trainer.supporters import CombinedLoader from pytorch_lightning.utilities.data import _update_dataloader from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.warnings import PossibleUserWarning from tests_pytorch.helpers.runif import RunIf from tests_pytorch.helpers.utils import no_warning_call diff --git a/tests/tests_pytorch/trainer/flags/test_env_vars.py b/tests/tests_pytorch/trainer/flags/test_env_vars.py index cfac06c8d7..a6415d5e90 100644 --- a/tests/tests_pytorch/trainer/flags/test_env_vars.py +++ b/tests/tests_pytorch/trainer/flags/test_env_vars.py @@ -49,8 +49,8 @@ def test_passing_env_variables_defaults(): @mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1", "PL_TRAINER_DEVICES": "2"}) -@mock.patch("pytorch_lightning.utilities.device_parser.num_cuda_devices", return_value=2) -@mock.patch("pytorch_lightning.utilities.device_parser.is_cuda_available", return_value=True) +@mock.patch("lightning_lite.utilities.device_parser.num_cuda_devices", return_value=2) +@mock.patch("lightning_lite.utilities.device_parser.is_cuda_available", return_value=True) def test_passing_env_variables_devices(cuda_available_mock, device_count_mock): """Testing overwriting trainer arguments.""" trainer = Trainer() diff --git a/tests/tests_pytorch/trainer/flags/test_min_max_epochs.py b/tests/tests_pytorch/trainer/flags/test_min_max_epochs.py index 25f2dfdab2..ed3c9952b8 100644 --- a/tests/tests_pytorch/trainer/flags/test_min_max_epochs.py +++ b/tests/tests_pytorch/trainer/flags/test_min_max_epochs.py @@ -1,8 +1,8 @@ import pytest +from lightning_lite.utilities.warnings import PossibleUserWarning from pytorch_lightning import Trainer from pytorch_lightning.demos.boring_classes import BoringModel -from pytorch_lightning.utilities.warnings import PossibleUserWarning from tests_pytorch.helpers.utils import no_warning_call diff --git a/tests/tests_pytorch/trainer/properties/test_auto_gpu_select.py b/tests/tests_pytorch/trainer/properties/test_auto_gpu_select.py index aa9f15bc43..05ee9d2ab3 100644 --- a/tests/tests_pytorch/trainer/properties/test_auto_gpu_select.py +++ b/tests/tests_pytorch/trainer/properties/test_auto_gpu_select.py @@ -42,13 +42,13 @@ def test_pick_multiple_gpus(nb, expected_gpu_idxs, expected_error): assert expected_gpu_idxs == pick_multiple_gpus(nb) -@mock.patch("pytorch_lightning.utilities.device_parser.num_cuda_devices", return_value=1) +@mock.patch("lightning_lite.utilities.device_parser.num_cuda_devices", return_value=1) def test_pick_multiple_gpus_more_than_available(*_): with pytest.raises(MisconfigurationException, match="You requested 3 GPUs but your machine only has 1 GPUs"): pick_multiple_gpus(3) -@mock.patch("pytorch_lightning.utilities.device_parser.num_cuda_devices", return_value=2) +@mock.patch("lightning_lite.utilities.device_parser.num_cuda_devices", return_value=2) @mock.patch("pytorch_lightning.trainer.connectors.accelerator_connector.pick_multiple_gpus", return_value=[1]) def test_auto_select_gpus(*_): diff --git a/tests/tests_pytorch/trainer/properties/test_estimated_stepping_batches.py b/tests/tests_pytorch/trainer/properties/test_estimated_stepping_batches.py index 72c07ec079..0cd31008ea 100644 --- a/tests/tests_pytorch/trainer/properties/test_estimated_stepping_batches.py +++ b/tests/tests_pytorch/trainer/properties/test_estimated_stepping_batches.py @@ -20,11 +20,11 @@ import pytest import torch from torch.utils.data import DataLoader +from lightning_lite.utilities import device_parser from pytorch_lightning import Trainer from pytorch_lightning.callbacks.gradient_accumulation_scheduler import GradientAccumulationScheduler from pytorch_lightning.demos.boring_classes import BoringModel, RandomIterableDataset from pytorch_lightning.strategies.ipu import IPUStrategy -from pytorch_lightning.utilities import device_parser from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/trainer/test_config_validator.py b/tests/tests_pytorch/trainer/test_config_validator.py index f6508c181e..7cc742eea8 100644 --- a/tests/tests_pytorch/trainer/test_config_validator.py +++ b/tests/tests_pytorch/trainer/test_config_validator.py @@ -15,11 +15,11 @@ import pytest import torch import pytorch_lightning as pl +from lightning_lite.utilities import device_parser +from lightning_lite.utilities.warnings import PossibleUserWarning from pytorch_lightning import LightningDataModule, LightningModule, Trainer from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset -from pytorch_lightning.utilities import device_parser from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.warnings import PossibleUserWarning def test_wrong_train_setting(tmpdir): diff --git a/tests/tests_pytorch/trainer/test_dataloaders.py b/tests/tests_pytorch/trainer/test_dataloaders.py index 317a35af3d..08e81e5915 100644 --- a/tests/tests_pytorch/trainer/test_dataloaders.py +++ b/tests/tests_pytorch/trainer/test_dataloaders.py @@ -23,6 +23,7 @@ from torch.utils.data.dataset import Dataset, IterableDataset from torch.utils.data.distributed import DistributedSampler from torch.utils.data.sampler import SequentialSampler +from lightning_lite.utilities.data import _auto_add_worker_init_fn, has_iterable_dataset from pytorch_lightning import Callback, seed_everything, Trainer from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.demos.boring_classes import ( @@ -32,7 +33,7 @@ from pytorch_lightning.demos.boring_classes import ( RandomIterableDatasetWithLen, ) from pytorch_lightning.trainer.states import RunningStage -from pytorch_lightning.utilities.data import _auto_add_worker_init_fn, has_iterable_dataset, has_len_all_ranks +from pytorch_lightning.utilities.data import has_len_all_ranks from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests_pytorch.helpers.dataloaders import CustomInfDataloader, CustomNotImplementedErrorDataloader from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/trainer/test_supporters.py b/tests/tests_pytorch/trainer/test_supporters.py index fec8466748..d9beabda43 100644 --- a/tests/tests_pytorch/trainer/test_supporters.py +++ b/tests/tests_pytorch/trainer/test_supporters.py @@ -314,8 +314,8 @@ def test_nested_calc_num_data(input_data, compute_func, expected_length): @mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1"}) -@mock.patch("pytorch_lightning.utilities.device_parser.num_cuda_devices", return_value=2) -@mock.patch("pytorch_lightning.utilities.device_parser.is_cuda_available", return_value=True) +@mock.patch("lightning_lite.utilities.device_parser.num_cuda_devices", return_value=2) +@mock.patch("lightning_lite.utilities.device_parser.is_cuda_available", return_value=True) @pytest.mark.parametrize("use_fault_tolerant", [False, True]) @pytest.mark.parametrize("replace_sampler_ddp", [False, True]) def test_combined_data_loader_validation_test( diff --git a/tests/tests_pytorch/trainer/test_trainer.py b/tests/tests_pytorch/trainer/test_trainer.py index d1b1ef6cf9..da6aedebbe 100644 --- a/tests/tests_pytorch/trainer/test_trainer.py +++ b/tests/tests_pytorch/trainer/test_trainer.py @@ -35,7 +35,9 @@ from torch.utils.data import DataLoader, IterableDataset import pytorch_lightning import tests_pytorch.helpers.utils as tutils +from lightning_lite.utilities import device_parser from lightning_lite.utilities.cloud_io import load as pl_load +from lightning_lite.utilities.seed import seed_everything from pytorch_lightning import Callback, LightningDataModule, LightningModule, Trainer from pytorch_lightning.accelerators import CPUAccelerator, CUDAAccelerator from pytorch_lightning.callbacks import EarlyStopping, GradientAccumulationScheduler, ModelCheckpoint, Timer @@ -60,10 +62,8 @@ from pytorch_lightning.strategies import ( SingleDeviceStrategy, ) from pytorch_lightning.trainer.states import RunningStage, TrainerFn -from pytorch_lightning.utilities import device_parser from pytorch_lightning.utilities.exceptions import DeadlockDetectedException, MisconfigurationException from pytorch_lightning.utilities.imports import _OMEGACONF_AVAILABLE, _TORCH_GREATER_EQUAL_1_12 -from pytorch_lightning.utilities.seed import seed_everything from tests_pytorch.helpers.datamodules import ClassifDataModule from tests_pytorch.helpers.runif import RunIf from tests_pytorch.helpers.simple_models import ClassificationModel @@ -1258,8 +1258,8 @@ def test_trainer_subclassing(): "trainer_params", [{"max_epochs": 1, "accelerator": "gpu", "devices": 1}, {"max_epochs": 1, "accelerator": "gpu", "devices": [0]}], ) -@mock.patch("pytorch_lightning.utilities.device_parser.is_cuda_available", return_value=True) -@mock.patch("pytorch_lightning.utilities.device_parser.num_cuda_devices", return_value=1) +@mock.patch("lightning_lite.utilities.device_parser.is_cuda_available", return_value=True) +@mock.patch("lightning_lite.utilities.device_parser.num_cuda_devices", return_value=1) def test_trainer_omegaconf(_, __, trainer_params): config = OmegaConf.create(trainer_params) Trainer(**config) diff --git a/tests/tests_pytorch/trainer/test_trainer_cli.py b/tests/tests_pytorch/trainer/test_trainer_cli.py index 468650e234..6613f0b1bc 100644 --- a/tests/tests_pytorch/trainer/test_trainer_cli.py +++ b/tests/tests_pytorch/trainer/test_trainer_cli.py @@ -19,8 +19,9 @@ from unittest import mock import pytest import tests_pytorch.helpers.utils as tutils +from lightning_lite.utilities import device_parser from pytorch_lightning import Trainer -from pytorch_lightning.utilities import argparse, device_parser +from pytorch_lightning.utilities import argparse @mock.patch("argparse.ArgumentParser.parse_args") diff --git a/tests/tests_pytorch/utilities/test_all_gather_grad.py b/tests/tests_pytorch/utilities/test_all_gather_grad.py index 49d86aca9c..7e00bc74a5 100644 --- a/tests/tests_pytorch/utilities/test_all_gather_grad.py +++ b/tests/tests_pytorch/utilities/test_all_gather_grad.py @@ -17,9 +17,10 @@ import sys import numpy as np import torch -from pytorch_lightning import seed_everything, Trainer +from lightning_lite.utilities import AllGatherGrad +from lightning_lite.utilities.seed import seed_everything +from pytorch_lightning import Trainer from pytorch_lightning.demos.boring_classes import BoringModel -from pytorch_lightning.utilities import AllGatherGrad from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/utilities/test_auto_restart.py b/tests/tests_pytorch/utilities/test_auto_restart.py index a3bf115313..b399ba8b35 100644 --- a/tests/tests_pytorch/utilities/test_auto_restart.py +++ b/tests/tests_pytorch/utilities/test_auto_restart.py @@ -35,7 +35,8 @@ from torch.utils.data.dataloader import DataLoader, default_collate from torch.utils.data.dataset import Dataset, IterableDataset import tests_pytorch.helpers.utils as tutils -from pytorch_lightning import Callback, LightningModule, seed_everything, Trainer +from lightning_lite.utilities.seed import seed_everything +from pytorch_lightning import Callback, LightningModule, Trainer from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset from pytorch_lightning.trainer.states import RunningStage, TrainerState from pytorch_lightning.trainer.supporters import CombinedLoader diff --git a/tests/tests_pytorch/utilities/test_data.py b/tests/tests_pytorch/utilities/test_data.py index 9e3d04ae65..28743324c2 100644 --- a/tests/tests_pytorch/utilities/test_data.py +++ b/tests/tests_pytorch/utilities/test_data.py @@ -1,4 +1,3 @@ -import random from dataclasses import dataclass import pytest @@ -6,6 +5,7 @@ import torch from torch import Tensor from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler +from lightning_lite.utilities.data import _replace_dunder_methods from pytorch_lightning import Trainer from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset, RandomIterableDataset from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper @@ -13,14 +13,9 @@ from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities.data import ( _dataloader_init_kwargs_resolve_sampler, _get_dataloader_init_args_and_kwargs, - _replace_dunder_methods, - _replace_value_in_saved_args, _update_dataloader, - _WrapAttrTag, extract_batch_size, get_len, - has_iterable_dataset, - has_len, has_len_all_ranks, warning_cache, ) @@ -96,28 +91,6 @@ def test_extract_batch_size(): _check_error_raised(data) -def test_has_iterable_dataset(): - assert has_iterable_dataset(DataLoader(RandomIterableDataset(1, 1))) - - assert not has_iterable_dataset(DataLoader(RandomDataset(1, 1))) - - class MockDatasetWithoutIterableDataset(RandomDataset): - def __iter__(self): - yield 1 - return self - - assert not has_iterable_dataset(DataLoader(MockDatasetWithoutIterableDataset(1, 1))) - - -def test_has_len(): - assert has_len(DataLoader(RandomDataset(1, 1))) - - with pytest.warns(UserWarning, match="`DataLoader` returned 0 length."): - assert has_len(DataLoader(RandomDataset(0, 0))) - - assert not has_len(DataLoader(RandomIterableDataset(1, 1))) - - def test_get_len(): assert get_len(DataLoader(RandomDataset(1, 1))) == 1 @@ -174,297 +147,6 @@ def test_update_dataloader_typerror_custom_exception(): assert isinstance(new_dataloader, GoodImpl) -def test_replace_dunder_methods_multiple_loaders_without_init(): - """In case of a class, that inherits from a class that we are patching, but doesn't define its own `__init__` - method (the one we are wrapping), it can happen, that `hasattr(cls, "__old__init__")` is True because of parent - class, but it is impossible to delete, because that method is owned by parent class. Furthermore, the error - occured only sometimes because it depends on the order in which we are iterating over a set of classes we are - patching. - - This test simulates the behavior by generating sufficient number of dummy classes, which do not define `__init__` - and are children of `DataLoader`. We are testing that a) context manager `_replace_dunder_method` exits cleanly, and - b) the mechanism checking for presence of `__old__init__` works as expected. - """ - classes = [DataLoader] - for i in range(100): - classes.append(type(f"DataLoader_{i}", (random.choice(classes),), {})) - - before = {cls: cls.__init__ for cls in classes} - - with _replace_dunder_methods(DataLoader, "dataset"): - for cls in classes[1:]: # First one is `DataLoader` - assert "__old__init__" not in cls.__dict__ - assert hasattr(cls, "__old__init__") - - assert "__old__init__" in DataLoader.__dict__ - assert hasattr(DataLoader, "__old__init__") - - for cls in classes: - assert before[cls] == cls.__init__ - - -class DataLoaderSubclass1(DataLoader): - def __init__(self, attribute1, *args, **kwargs): - self.at1 = attribute1 - super().__init__(*args, **kwargs) - - -class DataLoaderSubclass2(DataLoaderSubclass1): - def __init__(self, attribute2, *args, **kwargs): - self.at2 = attribute2 - super().__init__(attribute2 + "-2", *args, **kwargs) - - -class MyBaseDataLoader(DataLoader): - pass - - -class MyDataLoader(MyBaseDataLoader): - def __init__(self, data: torch.Tensor, *args, **kwargs): - self.data = data - super().__init__(range(data.size(0)), *args, **kwargs) - - -test3_data = torch.randn((10, 20)) - - -class PoptorchDataLoader(DataLoader): - def __init__(self, options, *args, **kwargs): - super().__init__(*args, **kwargs) - self._options = options - - @property - def options(self): - return self._options - - -class IncompleteDataLoader(DataLoader): - def __init__(self, dataset, batch_size, **kwargs): - batch_size = max(batch_size - 5, 0) - super().__init__(dataset, batch_size=batch_size, **kwargs) - - -class WeirdDataLoader1(DataLoader): - def __init__(self, arg1, arg2, **kwargs): - self.arg1 = arg1 - super().__init__(arg2, **kwargs) - - -class WeirdDataLoader2(DataLoader): - def __init__(self, data_part1, data_part2, **kwargs): - data = list(data_part1) + list(data_part2) - super().__init__(data, **kwargs) - - -class NoneDataLoader(DataLoader): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - -class ChangingDataLoader(DataLoader): - def __init__(self, dataset, **kwargs): - super().__init__(list(dataset) + list(range(5, 10)), **kwargs) - - -@pytest.mark.parametrize( - ["cls", "args", "kwargs", "arg_names", "dataset", "checked_values"], - [ - pytest.param( - DataLoaderSubclass1, - ("attribute1",), - dict(dataset=range(4), batch_size=2), - ("attribute1",), - range(4), - dict(batch_size=2, at1="attribute1"), - id="test1", - ), - pytest.param( - DataLoaderSubclass2, - ("attribute2",), - dict(dataset=range(4), batch_size=2), - ("attribute2",), - range(4), - dict(batch_size=2, at1="attribute2-2", at2="attribute2"), - id="test2", - ), - pytest.param( - MyDataLoader, - (test3_data,), - dict(batch_size=2), - ("data",), - range(10), - dict(batch_size=2, data=test3_data), - id="test3", - ), - pytest.param(PoptorchDataLoader, (123, [1]), dict(), ("options",), [1], dict(options=123), id="test4"), - pytest.param( - IncompleteDataLoader, - (range(10),), - dict(batch_size=10), - ("dataset",), - range(10), - dict(batch_size=5), - id="test5", - ), - pytest.param( - WeirdDataLoader1, - (10, range(10)), - dict(batch_size=10), - ("arg1", "arg2"), - range(10), - dict(arg1=10, batch_size=10), - id="test6", - ), - pytest.param( - WeirdDataLoader2, - (range(10), range(10, 20)), - dict(batch_size=10), - ("data_part1", "data_part2"), - list(range(20)), - dict(batch_size=10), - id="test7", - ), - pytest.param(NoneDataLoader, (None,), dict(), (), None, dict(), id="test8"), - pytest.param(ChangingDataLoader, (range(5),), dict(), ("dataset",), list(range(10)), dict(), id="test9"), - ], -) -def test_replace_dunder_methods_dataloader(cls, args, kwargs, arg_names, dataset, checked_values): - with _replace_dunder_methods(DataLoader, "dataset"): - dataloader = cls(*args, **kwargs) - - assert dataloader.__pl_saved_args == args - assert dataloader.__pl_saved_kwargs == kwargs - assert dataloader.__pl_saved_arg_names == arg_names - assert dataloader.__pl_saved_default_kwargs == {} - assert dataloader.__dataset == dataset - - assert dataloader.dataset == dataset - - for key, value in checked_values.items(): - dataloader_value = getattr(dataloader, key) - if isinstance(dataloader_value, torch.Tensor): - assert dataloader_value is value - else: - assert dataloader_value == value - - dataloader = _update_dataloader(dataloader, dataloader.sampler) - - assert isinstance(dataloader, cls) - assert not hasattr(dataloader, "__pl_saved_kwargs") - assert not hasattr(dataloader, "__pl_saved_arg_names") - assert not hasattr(dataloader, "__pl_saved_args") - assert not hasattr(dataloader, "__pl_saved_default_kwargs") - assert not hasattr(dataloader, "__dataset") - - assert dataloader.dataset == dataset - - for key, value in checked_values.items(): - dataloader_value = getattr(dataloader, key) - if isinstance(dataloader_value, torch.Tensor): - assert dataloader_value is value - else: - assert dataloader_value == value - - -def test_replace_dunder_methods_extra_kwargs(): - class LoaderSubclass(DataLoader): - def __init__(self, dataset, *args, batch_size=10, **kwargs): - super().__init__(dataset, *args, batch_size=batch_size, **kwargs) - - with _replace_dunder_methods(DataLoader, "dataset"): - dataloader = LoaderSubclass(range(10)) - - assert dataloader.__pl_saved_args == (range(10),) - assert dataloader.__pl_saved_kwargs == {} - assert dataloader.__pl_saved_arg_names == ("dataset",) - assert dataloader.__pl_saved_default_kwargs == {"batch_size": 10} - assert dataloader.__dataset == range(10) - - -def test_replace_dunder_methods_attrs(): - """This test checks, that all the calls from setting and deleting attributes within `_replace_dunder_methods` - are correctly preserved even after reinstantiation. - - It also includes a custom `__setattr__` - """ - - class Loader(DataLoader): - def __setattr__(self, attr, val): - if attr == "custom_arg": - val = val + 2 - super().__setattr__(attr, val) - - with _replace_dunder_methods(DataLoader, "dataset"): - dataloader = Loader(range(10)) - dataloader.custom_arg = 5 - dataloader.my_arg = 10 - dataloader.another_arg = 100 - del dataloader.dataset - try: - del dataloader.abc_arg - except AttributeError: - pass - - assert dataloader.__pl_saved_args == (range(10),) - assert dataloader.__pl_saved_kwargs == {} - assert dataloader.__pl_saved_arg_names == ("dataset",) - assert dataloader.__dataset == range(10) - assert dataloader.custom_arg == 7 - assert dataloader.my_arg == 10 - assert dataloader.another_arg == 100 - assert not hasattr(dataloader, "dataset") - assert dataloader.__pl_attrs_record == [ - (("custom_arg", 5), _WrapAttrTag.SET), - (("my_arg", 10), _WrapAttrTag.SET), - (("another_arg", 100), _WrapAttrTag.SET), - (("dataset",), _WrapAttrTag.DEL), - ] - - dataloader = _update_dataloader(dataloader, dataloader.sampler) - assert dataloader.custom_arg == 7 - assert dataloader.my_arg == 10 - assert dataloader.another_arg == 100 - assert not hasattr(dataloader, "dataset") - - -def test_replace_dunder_methods_restore_methods(): - """This tests checks whether are all dunder methods restored to their original versions.""" - - class Init(DataLoader): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - class SetAttr(DataLoader): - def __setattr__(self, *args): - return super().__setattr__(*args) - - class DelAttr(DataLoader): - def __delattr__(self, *args): - return super().__delattr__(*args) - - class InitAndSetAttr(Init, SetAttr): - pass - - class InitAndDelAttr(Init, DelAttr): - pass - - class SetAttrAndDelAttr(SetAttr, DelAttr): - pass - - class AllDunder(Init, SetAttr, DelAttr): - pass - - before = dict() - for cls in (Init, SetAttr, DelAttr, InitAndSetAttr, InitAndDelAttr, SetAttrAndDelAttr, AllDunder): - before[cls] = {"init": cls.__init__, "setattr": cls.__setattr__, "delattr": cls.__delattr__} - - with _replace_dunder_methods(DataLoader, "dataset"): - pass - - for cls in (Init, SetAttr, DelAttr, InitAndSetAttr, InitAndDelAttr, SetAttrAndDelAttr, AllDunder): - assert before[cls] == {"init": cls.__init__, "setattr": cls.__setattr__, "delattr": cls.__delattr__} - - @pytest.mark.parametrize("predicting", [True, False]) def test_custom_batch_sampler(predicting): """This test asserts, that custom `BatchSampler`, with all the arguments, that are required in order to @@ -570,47 +252,6 @@ def test_custom_batch_sampler_no_sampler(): dataloader = _update_dataloader(dataloader, dataloader.sampler, mode=RunningStage.PREDICTING) -@pytest.mark.parametrize( - [ - "args", - "kwargs", - "default_kwargs", - "arg_names", - "replace_key", - "replace_value", - "expected_status", - "expected_args", - "expected_kwargs", - ], - [ - pytest.param((), {}, {}, [], "a", 1, False, (), {}, id="empty"), - pytest.param((1,), {}, {}, ["a"], "a", 2, True, (2,), {}, id="simple1"), - pytest.param((1, 2, 3), {}, {}, ["a", "b", "c"], "b", False, True, (1, False, 3), {}, id="simple2"), - pytest.param((1, 2, 3), {"a": 1}, {}, ["b", "c", "d"], "a", 2, True, (1, 2, 3), {"a": 2}, id="simple_kwargs"), - pytest.param( - (1, 2, 3), - {"a": 1}, - {"e": 5}, - ["b", "c", "d"], - "e", - 2, - True, - (1, 2, 3), - {"a": 1, "e": 2}, - id="default_kwargs", - ), - ], -) -def test_replace_value_in_args( - args, kwargs, default_kwargs, arg_names, replace_key, replace_value, expected_status, expected_args, expected_kwargs -): - assert _replace_value_in_saved_args(replace_key, replace_value, args, kwargs, default_kwargs, arg_names) == ( - expected_status, - expected_args, - expected_kwargs, - ) - - def test_dataloader_disallow_batch_sampler(): dataset = RandomDataset(5, 100) dataloader = DataLoader(dataset, batch_size=10) diff --git a/tests/tests_pytorch/utilities/test_device_parser.py b/tests/tests_pytorch/utilities/test_device_parser.py index d496db487f..a4a84892a6 100644 --- a/tests/tests_pytorch/utilities/test_device_parser.py +++ b/tests/tests_pytorch/utilities/test_device_parser.py @@ -16,7 +16,7 @@ from unittest import mock import pytest import torch -from pytorch_lightning.utilities import device_parser +from lightning_lite.utilities import device_parser @pytest.mark.skipif( diff --git a/tests/tests_pytorch/utilities/test_distributed.py b/tests/tests_pytorch/utilities/test_distributed.py index c3c90b5da6..2e2c88dd7a 100644 --- a/tests/tests_pytorch/utilities/test_distributed.py +++ b/tests/tests_pytorch/utilities/test_distributed.py @@ -13,13 +13,12 @@ # limitations under the License. import os -import pytest import torch import torch.distributed import torch.multiprocessing as mp import tests_pytorch.helpers.utils as tutils -from pytorch_lightning.utilities.distributed import _collect_states_on_rank_zero, gather_all_tensors +from pytorch_lightning.utilities.distributed import _collect_states_on_rank_zero from tests_pytorch.helpers.runif import RunIf @@ -44,57 +43,3 @@ def test_collect_states(): """ tutils.set_random_main_port() mp.spawn(_test_collect_states, args=(2,), nprocs=2) - - -def _test_all_gather_uneven_tensors(rank, world_size, backend): - os.environ["MASTER_ADDR"] = "localhost" - - if backend == "nccl": - device = torch.device("cuda", rank) - torch.cuda.set_device(device) - else: - device = torch.device("cpu") - - # initialize the process group - torch.distributed.init_process_group(backend, rank=rank, world_size=world_size) - - tensor = torch.ones(rank, device=device) - result = gather_all_tensors(tensor) - assert len(result) == world_size - for idx in range(world_size): - assert len(result[idx]) == idx - assert (result[idx] == torch.ones_like(result[idx])).all() - - -def _test_all_gather_uneven_tensors_multidim(rank, world_size, backend): - os.environ["MASTER_ADDR"] = "localhost" - - if backend == "nccl": - device = torch.device("cuda", rank) - torch.cuda.set_device(device) - else: - device = torch.device("cpu") - - # initialize the process group - torch.distributed.init_process_group(backend, rank=rank, world_size=world_size) - tensor = torch.ones(rank + 1, 2 - rank, device=device) - result = gather_all_tensors(tensor) - assert len(result) == world_size - for idx in range(world_size): - val = result[idx] - assert val.shape == (idx + 1, 2 - idx) - assert (val == torch.ones_like(val)).all() - - -@RunIf(min_torch="1.10", skip_windows=True) -@pytest.mark.parametrize( - "process", - [ - _test_all_gather_uneven_tensors_multidim, - _test_all_gather_uneven_tensors, - ], -) -@pytest.mark.parametrize("backend", [pytest.param("nccl", marks=RunIf(min_cuda_gpus=2)), "gloo"]) -def test_gather_all_tensors(backend, process): - tutils.set_random_main_port() - mp.spawn(process, args=(2, backend), nprocs=2) diff --git a/tests/tests_pytorch/utilities/test_enums.py b/tests/tests_pytorch/utilities/test_enums.py index 1519e17721..83b6c7b116 100644 --- a/tests/tests_pytorch/utilities/test_enums.py +++ b/tests/tests_pytorch/utilities/test_enums.py @@ -11,15 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from pytorch_lightning.utilities.enums import GradClipAlgorithmType, PrecisionType - - -def test_precision_supported_types(): - assert PrecisionType.supported_types() == ["16", "32", "64", "bf16", "mixed"] - assert PrecisionType.supported_type(16) - assert PrecisionType.supported_type("16") - assert not PrecisionType.supported_type(1) - assert not PrecisionType.supported_type("invalid") +from pytorch_lightning.utilities.enums import GradClipAlgorithmType def test_gradient_clip_algorithms(): diff --git a/tests/tests_pytorch/utilities/test_seed.py b/tests/tests_pytorch/utilities/test_seed.py index 502febcaa9..ac76725616 100644 --- a/tests/tests_pytorch/utilities/test_seed.py +++ b/tests/tests_pytorch/utilities/test_seed.py @@ -1,83 +1,13 @@ -import os import random -from unittest import mock import numpy as np import pytest import torch -import pytorch_lightning.utilities.seed as seed_utils -from pytorch_lightning.utilities.seed import _collect_rng_states, _set_rng_states, isolate_rng +from pytorch_lightning.utilities.seed import isolate_rng from tests_pytorch.helpers.runif import RunIf -@mock.patch.dict(os.environ, {}, clear=True) -def test_seed_stays_same_with_multiple_seed_everything_calls(): - """Ensure that after the initial seed everything, the seed stays the same for the same run.""" - with pytest.warns(UserWarning, match="No seed found"): - seed_utils.seed_everything() - initial_seed = os.environ.get("PL_GLOBAL_SEED") - - with pytest.warns(None) as record: - seed_utils.seed_everything() - assert not record # does not warn - seed = os.environ.get("PL_GLOBAL_SEED") - - assert initial_seed == seed - - -@mock.patch.dict(os.environ, {"PL_GLOBAL_SEED": "2020"}, clear=True) -def test_correct_seed_with_environment_variable(): - """Ensure that the PL_GLOBAL_SEED environment is read.""" - assert seed_utils.seed_everything() == 2020 - - -@mock.patch.dict(os.environ, {"PL_GLOBAL_SEED": "invalid"}, clear=True) -@mock.patch.object(seed_utils, attribute="_select_seed_randomly", new=lambda *_: 123) -def test_invalid_seed(): - """Ensure that we still fix the seed even if an invalid seed is given.""" - with pytest.warns(UserWarning, match="Invalid seed found"): - seed = seed_utils.seed_everything() - assert seed == 123 - - -@mock.patch.dict(os.environ, {}, clear=True) -@mock.patch.object(seed_utils, attribute="_select_seed_randomly", new=lambda *_: 123) -@pytest.mark.parametrize("seed", (10e9, -10e9)) -def test_out_of_bounds_seed(seed): - """Ensure that we still fix the seed even if an out-of-bounds seed is given.""" - with pytest.warns(UserWarning, match="is not in bounds"): - actual = seed_utils.seed_everything(seed) - assert actual == 123 - - -def test_reset_seed_no_op(): - """Test that the reset_seed function is a no-op when seed_everything() was not used.""" - assert "PL_GLOBAL_SEED" not in os.environ - seed_before = torch.initial_seed() - seed_utils.reset_seed() - assert torch.initial_seed() == seed_before - assert "PL_GLOBAL_SEED" not in os.environ - - -@pytest.mark.parametrize("workers", (True, False)) -def test_reset_seed_everything(workers): - """Test that we can reset the seed to the initial value set by seed_everything()""" - assert "PL_GLOBAL_SEED" not in os.environ - assert "PL_SEED_WORKERS" not in os.environ - - seed_utils.seed_everything(123, workers) - before = torch.rand(1) - assert os.environ["PL_GLOBAL_SEED"] == "123" - assert os.environ["PL_SEED_WORKERS"] == str(int(workers)) - - seed_utils.reset_seed() - after = torch.rand(1) - assert os.environ["PL_GLOBAL_SEED"] == "123" - assert os.environ["PL_SEED_WORKERS"] == str(int(workers)) - assert torch.allclose(before, after) - - @pytest.mark.parametrize("with_torch_cuda", [False, pytest.param(True, marks=RunIf(min_cuda_gpus=1))]) def test_isolate_rng(with_torch_cuda): """Test that the isolate_rng context manager isolates the random state from the outer scope.""" @@ -105,11 +35,3 @@ def test_isolate_rng(with_torch_cuda): with isolate_rng(): generated = [random.random() for _ in range(3)] assert random.random() == generated[0] - - -def test_backward_compatibility_rng_states_dict(): - """Test that an older rng_states_dict without the "torch.cuda" key does not crash.""" - states = _collect_rng_states() - assert "torch.cuda" in states - states.pop("torch.cuda") - _set_rng_states(states) diff --git a/tests/tests_pytorch/utilities/test_types.py b/tests/tests_pytorch/utilities/test_types.py index 5b523a43dc..0782d3bc2e 100644 --- a/tests/tests_pytorch/utilities/test_types.py +++ b/tests/tests_pytorch/utilities/test_types.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from pytorch_lightning.utilities.types import _Stateful +from lightning_lite.utilities.types import _Stateful def test_stateful_protocol(): diff --git a/tests/tests_pytorch/utilities/test_warnings.py b/tests/tests_pytorch/utilities/test_warnings.py index 223cd4e59f..e95a342327 100644 --- a/tests/tests_pytorch/utilities/test_warnings.py +++ b/tests/tests_pytorch/utilities/test_warnings.py @@ -18,38 +18,7 @@ Needs to be run outside of `pytest` as it captures all the warnings. from contextlib import redirect_stderr from io import StringIO -from lightning_utilities.core.rank_zero import _warn, WarningCache - -from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn - if __name__ == "__main__": - stderr = StringIO() - # recording - with redirect_stderr(stderr): - _warn("test1") - _warn("test2", category=DeprecationWarning) - - rank_zero_warn("test3") - rank_zero_warn("test4", category=DeprecationWarning) - - rank_zero_deprecation("test5") - - cache = WarningCache() - cache.warn("test6") - cache.deprecation("test7") - - output = stderr.getvalue() - assert "test_warnings.py:29: UserWarning: test1" in output - assert "test_warnings.py:30: DeprecationWarning: test2" in output - - assert "test_warnings.py:32: UserWarning: test3" in output - assert "test_warnings.py:33: DeprecationWarning: test4" in output - - assert "test_warnings.py:35: LightningDeprecationWarning: test5" in output - - assert "test_warnings.py:38: UserWarning: test6" in output - assert "test_warnings.py:39: LightningDeprecationWarning: test7" in output - # check that logging is properly configured import logging