diff --git a/.actions/assistant.py b/.actions/assistant.py index 15a20e63c6..4bd7a97726 100644 --- a/.actions/assistant.py +++ b/.actions/assistant.py @@ -18,10 +18,11 @@ import re import shutil import tempfile import urllib.request +from collections.abc import Iterable, Iterator, Sequence from itertools import chain from os.path import dirname, isfile from pathlib import Path -from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Tuple +from typing import Any, Dict, List, Optional, Tuple from packaging.requirements import Requirement from packaging.version import Version diff --git a/examples/fabric/build_your_own_trainer/trainer.py b/examples/fabric/build_your_own_trainer/trainer.py index 7af01ede05..a7089a8244 100644 --- a/examples/fabric/build_your_own_trainer/trainer.py +++ b/examples/fabric/build_your_own_trainer/trainer.py @@ -1,7 +1,7 @@ import os -from collections.abc import Mapping +from collections.abc import Iterable, Mapping from functools import partial -from typing import Any, Iterable, List, Literal, Optional, Tuple, Union, cast +from typing import Any, List, Literal, Optional, Tuple, Union, cast import lightning as L import torch diff --git a/examples/pytorch/domain_templates/reinforce_learn_Qnet.py b/examples/pytorch/domain_templates/reinforce_learn_Qnet.py index 497cb658c2..4769bb066b 100644 --- a/examples/pytorch/domain_templates/reinforce_learn_Qnet.py +++ b/examples/pytorch/domain_templates/reinforce_learn_Qnet.py @@ -35,7 +35,8 @@ Second-Edition/blob/master/Chapter06/02_dqn_pong.py import argparse import random from collections import OrderedDict, deque, namedtuple -from typing import Iterator, List, Tuple +from collections.abc import Iterator +from typing import List, Tuple import gym import torch diff --git a/examples/pytorch/domain_templates/reinforce_learn_ppo.py b/examples/pytorch/domain_templates/reinforce_learn_ppo.py index bc3f8c1b9b..f83e7b3df1 100644 --- a/examples/pytorch/domain_templates/reinforce_learn_ppo.py +++ b/examples/pytorch/domain_templates/reinforce_learn_ppo.py @@ -30,7 +30,8 @@ References """ import argparse -from typing import Callable, Iterator, List, Tuple +from collections.abc import Iterator +from typing import Callable, List, Tuple import gym import torch diff --git a/setup.py b/setup.py index bfc329bb8f..92f0265eaf 100755 --- a/setup.py +++ b/setup.py @@ -45,9 +45,10 @@ import glob import logging import os import tempfile +from collections.abc import Generator, Mapping from importlib.util import module_from_spec, spec_from_file_location from types import ModuleType -from typing import Generator, Mapping, Optional +from typing import Optional import setuptools import setuptools.command.egg_info diff --git a/src/lightning/fabric/fabric.py b/src/lightning/fabric/fabric.py index 0ff5b04b30..5c5ec33527 100644 --- a/src/lightning/fabric/fabric.py +++ b/src/lightning/fabric/fabric.py @@ -13,6 +13,7 @@ # limitations under the License. import inspect import os +from collections.abc import Generator, Mapping, Sequence from contextlib import contextmanager, nullcontext from functools import partial from pathlib import Path @@ -21,11 +22,8 @@ from typing import ( Callable, ContextManager, Dict, - Generator, List, - Mapping, Optional, - Sequence, Tuple, Union, cast, diff --git a/src/lightning/fabric/loggers/tensorboard.py b/src/lightning/fabric/loggers/tensorboard.py index 685c832088..7e769dffa9 100644 --- a/src/lightning/fabric/loggers/tensorboard.py +++ b/src/lightning/fabric/loggers/tensorboard.py @@ -14,7 +14,8 @@ import os from argparse import Namespace -from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Union +from collections.abc import Mapping +from typing import TYPE_CHECKING, Any, Dict, Optional, Union from lightning_utilities.core.imports import RequirementCache from torch import Tensor diff --git a/src/lightning/fabric/plugins/precision/bitsandbytes.py b/src/lightning/fabric/plugins/precision/bitsandbytes.py index 3944154528..41711c53a5 100644 --- a/src/lightning/fabric/plugins/precision/bitsandbytes.py +++ b/src/lightning/fabric/plugins/precision/bitsandbytes.py @@ -16,10 +16,11 @@ import logging import math import os import warnings +from collections import OrderedDict from contextlib import ExitStack from functools import partial from types import ModuleType -from typing import Any, Callable, ContextManager, Literal, Optional, OrderedDict, Set, Tuple, Type, cast +from typing import Any, Callable, ContextManager, Literal, Optional, Set, Tuple, Type, cast import torch from lightning_utilities import apply_to_collection diff --git a/src/lightning/fabric/plugins/precision/transformer_engine.py b/src/lightning/fabric/plugins/precision/transformer_engine.py index cb5296b21f..ddc30c2e1b 100644 --- a/src/lightning/fabric/plugins/precision/transformer_engine.py +++ b/src/lightning/fabric/plugins/precision/transformer_engine.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from collections.abc import Mapping from contextlib import ExitStack -from typing import TYPE_CHECKING, Any, ContextManager, Literal, Mapping, Optional, Union +from typing import TYPE_CHECKING, Any, ContextManager, Literal, Optional, Union import torch from lightning_utilities import apply_to_collection diff --git a/src/lightning/fabric/plugins/precision/utils.py b/src/lightning/fabric/plugins/precision/utils.py index 887dbc937a..3f939b59a8 100644 --- a/src/lightning/fabric/plugins/precision/utils.py +++ b/src/lightning/fabric/plugins/precision/utils.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Mapping, Type, Union +from collections.abc import Mapping +from typing import Any, Type, Union import torch from torch import Tensor diff --git a/src/lightning/fabric/strategies/deepspeed.py b/src/lightning/fabric/strategies/deepspeed.py index e71b8e2db3..d9bd5880ae 100644 --- a/src/lightning/fabric/strategies/deepspeed.py +++ b/src/lightning/fabric/strategies/deepspeed.py @@ -16,10 +16,11 @@ import json import logging import os import platform +from collections.abc import Mapping from contextlib import ExitStack from itertools import chain from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, ContextManager, Dict, List, Mapping, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, ContextManager, Dict, List, Optional, Tuple, Union import torch from lightning_utilities.core.imports import RequirementCache diff --git a/src/lightning/fabric/strategies/fsdp.py b/src/lightning/fabric/strategies/fsdp.py index e7fdd29f62..6b48f0493e 100644 --- a/src/lightning/fabric/strategies/fsdp.py +++ b/src/lightning/fabric/strategies/fsdp.py @@ -13,6 +13,7 @@ # limitations under the License. import shutil import warnings +from collections.abc import Generator from contextlib import ExitStack, nullcontext from datetime import timedelta from functools import partial @@ -23,7 +24,6 @@ from typing import ( Callable, ContextManager, Dict, - Generator, List, Literal, Optional, diff --git a/src/lightning/fabric/strategies/launchers/subprocess_script.py b/src/lightning/fabric/strategies/launchers/subprocess_script.py index 63ae8b0bee..1e17cfcf5b 100644 --- a/src/lightning/fabric/strategies/launchers/subprocess_script.py +++ b/src/lightning/fabric/strategies/launchers/subprocess_script.py @@ -18,7 +18,8 @@ import subprocess import sys import threading import time -from typing import Any, Callable, List, Optional, Sequence, Tuple +from collections.abc import Sequence +from typing import Any, Callable, List, Optional, Tuple from lightning_utilities.core.imports import RequirementCache from typing_extensions import override diff --git a/src/lightning/fabric/strategies/model_parallel.py b/src/lightning/fabric/strategies/model_parallel.py index 86b93d35e6..28c0ac3a48 100644 --- a/src/lightning/fabric/strategies/model_parallel.py +++ b/src/lightning/fabric/strategies/model_parallel.py @@ -13,10 +13,11 @@ # limitations under the License. import itertools import shutil +from collections.abc import Generator from contextlib import ExitStack from datetime import timedelta from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, ContextManager, Dict, Generator, Literal, Optional, TypeVar, Union +from typing import TYPE_CHECKING, Any, Callable, ContextManager, Dict, Literal, Optional, TypeVar, Union import torch from lightning_utilities.core.rank_zero import rank_zero_only as utils_rank_zero_only diff --git a/src/lightning/fabric/strategies/strategy.py b/src/lightning/fabric/strategies/strategy.py index 6bfed6a270..ff0794d63e 100644 --- a/src/lightning/fabric/strategies/strategy.py +++ b/src/lightning/fabric/strategies/strategy.py @@ -13,8 +13,9 @@ # limitations under the License. import logging from abc import ABC, abstractmethod +from collections.abc import Iterable from contextlib import ExitStack -from typing import Any, Callable, ContextManager, Dict, Iterable, List, Optional, Tuple, TypeVar, Union +from typing import Any, Callable, ContextManager, Dict, List, Optional, Tuple, TypeVar, Union import torch from torch import Tensor diff --git a/src/lightning/fabric/utilities/data.py b/src/lightning/fabric/utilities/data.py index 1ec0edce38..a68bba3c0f 100644 --- a/src/lightning/fabric/utilities/data.py +++ b/src/lightning/fabric/utilities/data.py @@ -16,9 +16,10 @@ import functools import inspect import os from collections import OrderedDict +from collections.abc import Generator, Iterable, Sized from contextlib import contextmanager from functools import partial -from typing import Any, Callable, Dict, Generator, Iterable, Optional, Sized, Tuple, Type, Union +from typing import Any, Callable, Dict, Optional, Tuple, Type, Union from lightning_utilities.core.inheritance import get_all_subclasses from torch.utils.data import BatchSampler, DataLoader, IterableDataset, Sampler diff --git a/src/lightning/fabric/utilities/device_parser.py b/src/lightning/fabric/utilities/device_parser.py index 16965d944c..d387e5b139 100644 --- a/src/lightning/fabric/utilities/device_parser.py +++ b/src/lightning/fabric/utilities/device_parser.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, MutableSequence, Optional, Tuple, Union +from collections.abc import MutableSequence +from typing import List, Optional, Tuple, Union import torch diff --git a/src/lightning/fabric/utilities/distributed.py b/src/lightning/fabric/utilities/distributed.py index 0e6c52dfb0..00d23c338e 100644 --- a/src/lightning/fabric/utilities/distributed.py +++ b/src/lightning/fabric/utilities/distributed.py @@ -4,10 +4,11 @@ import logging import os import signal import time +from collections.abc import Iterable, Iterator, Sized from contextlib import nullcontext from datetime import timedelta from pathlib import Path -from typing import TYPE_CHECKING, Any, Iterable, Iterator, List, Optional, Sized, Union +from typing import TYPE_CHECKING, Any, List, Optional, Union import torch import torch.nn.functional as F diff --git a/src/lightning/fabric/utilities/init.py b/src/lightning/fabric/utilities/init.py index c92dfd8c2e..ef01f04987 100644 --- a/src/lightning/fabric/utilities/init.py +++ b/src/lightning/fabric/utilities/init.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import itertools -from typing import Any, Callable, Dict, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Callable, Dict, Optional, Union import torch from torch.nn import Module, Parameter diff --git a/src/lightning/fabric/utilities/load.py b/src/lightning/fabric/utilities/load.py index a1c3b6933b..6d888f1c9c 100644 --- a/src/lightning/fabric/utilities/load.py +++ b/src/lightning/fabric/utilities/load.py @@ -13,10 +13,12 @@ import os import pickle import warnings +from collections import OrderedDict +from collections.abc import Sequence from functools import partial from io import BytesIO from pathlib import Path -from typing import IO, TYPE_CHECKING, Any, Callable, Dict, Optional, OrderedDict, Sequence, Set, Union +from typing import IO, TYPE_CHECKING, Any, Callable, Dict, Optional, Set, Union import torch from lightning_utilities.core.apply_func import apply_to_collection diff --git a/src/lightning/fabric/utilities/logger.py b/src/lightning/fabric/utilities/logger.py index 07b76ad9b0..948e9aa800 100644 --- a/src/lightning/fabric/utilities/logger.py +++ b/src/lightning/fabric/utilities/logger.py @@ -15,8 +15,9 @@ import inspect import json from argparse import Namespace +from collections.abc import Mapping, MutableMapping from dataclasses import asdict, is_dataclass -from typing import Any, Dict, Mapping, MutableMapping, Optional, Union +from typing import Any, Dict, Optional, Union from torch import Tensor diff --git a/src/lightning/fabric/utilities/optimizer.py b/src/lightning/fabric/utilities/optimizer.py index 2c57ec9d1f..df83f9b1ca 100644 --- a/src/lightning/fabric/utilities/optimizer.py +++ b/src/lightning/fabric/utilities/optimizer.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections.abc import MutableMapping -from typing import Iterable +from collections.abc import Iterable, MutableMapping from torch import Tensor from torch.optim import Optimizer diff --git a/src/lightning/fabric/utilities/types.py b/src/lightning/fabric/utilities/types.py index 2e18dc89b0..8108caf397 100644 --- a/src/lightning/fabric/utilities/types.py +++ b/src/lightning/fabric/utilities/types.py @@ -11,13 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Iterator from pathlib import Path from typing import ( Any, Callable, DefaultDict, Dict, - Iterator, List, Optional, Protocol, diff --git a/src/lightning/fabric/wrappers.py b/src/lightning/fabric/wrappers.py index c57f1974a6..537437444e 100644 --- a/src/lightning/fabric/wrappers.py +++ b/src/lightning/fabric/wrappers.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import inspect +from collections.abc import Generator, Iterator, Mapping from copy import deepcopy from functools import partial, wraps from types import MethodType @@ -19,10 +20,7 @@ from typing import ( Any, Callable, Dict, - Generator, - Iterator, List, - Mapping, Optional, Tuple, TypeVar, diff --git a/src/lightning/pytorch/callbacks/finetuning.py b/src/lightning/pytorch/callbacks/finetuning.py index 46a90986c0..d67b78eaa9 100644 --- a/src/lightning/pytorch/callbacks/finetuning.py +++ b/src/lightning/pytorch/callbacks/finetuning.py @@ -19,7 +19,8 @@ Freeze and unfreeze models for finetuning purposes. """ import logging -from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Union +from collections.abc import Generator, Iterable +from typing import Any, Callable, Dict, List, Optional, Union import torch from torch.nn import Module, ModuleDict diff --git a/src/lightning/pytorch/callbacks/prediction_writer.py b/src/lightning/pytorch/callbacks/prediction_writer.py index 7f782fb81c..ce6342c7aa 100644 --- a/src/lightning/pytorch/callbacks/prediction_writer.py +++ b/src/lightning/pytorch/callbacks/prediction_writer.py @@ -18,7 +18,8 @@ BasePredictionWriter Aids in saving predictions """ -from typing import Any, Literal, Optional, Sequence +from collections.abc import Sequence +from typing import Any, Literal, Optional from typing_extensions import override diff --git a/src/lightning/pytorch/callbacks/progress/rich_progress.py b/src/lightning/pytorch/callbacks/progress/rich_progress.py index 896de71267..86205d0e65 100644 --- a/src/lightning/pytorch/callbacks/progress/rich_progress.py +++ b/src/lightning/pytorch/callbacks/progress/rich_progress.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import math +from collections.abc import Generator from dataclasses import dataclass from datetime import timedelta -from typing import Any, Dict, Generator, Optional, Union, cast +from typing import Any, Dict, Optional, Union, cast from lightning_utilities.core.imports import RequirementCache from typing_extensions import override diff --git a/src/lightning/pytorch/callbacks/pruning.py b/src/lightning/pytorch/callbacks/pruning.py index e83a9de063..1e2b8ba912 100644 --- a/src/lightning/pytorch/callbacks/pruning.py +++ b/src/lightning/pytorch/callbacks/pruning.py @@ -18,9 +18,10 @@ ModelPruning import inspect import logging +from collections.abc import Sequence from copy import deepcopy from functools import partial -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch.nn.utils.prune as pytorch_prune from lightning_utilities.core.apply_func import apply_to_collection diff --git a/src/lightning/pytorch/callbacks/spike.py b/src/lightning/pytorch/callbacks/spike.py index 725d6f6433..b006acd44d 100644 --- a/src/lightning/pytorch/callbacks/spike.py +++ b/src/lightning/pytorch/callbacks/spike.py @@ -1,5 +1,6 @@ import os -from typing import Any, Mapping, Union +from collections.abc import Mapping +from typing import Any, Union import torch diff --git a/src/lightning/pytorch/cli.py b/src/lightning/pytorch/cli.py index 26af335f7b..62245647b7 100644 --- a/src/lightning/pytorch/cli.py +++ b/src/lightning/pytorch/cli.py @@ -14,9 +14,10 @@ import inspect import os import sys +from collections.abc import Iterable from functools import partial, update_wrapper from types import MethodType -from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, TypeVar, Union +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union import torch import yaml diff --git a/src/lightning/pytorch/core/datamodule.py b/src/lightning/pytorch/core/datamodule.py index 6cb8f79f09..b92dd1cce4 100644 --- a/src/lightning/pytorch/core/datamodule.py +++ b/src/lightning/pytorch/core/datamodule.py @@ -14,7 +14,8 @@ """LightningDataModule for loading DataLoaders with ease.""" import inspect -from typing import IO, Any, Dict, Iterable, Optional, Union, cast +from collections.abc import Iterable +from typing import IO, Any, Dict, Optional, Union, cast from lightning_utilities import apply_to_collection from torch.utils.data import DataLoader, Dataset, IterableDataset diff --git a/src/lightning/pytorch/core/mixins/hparams_mixin.py b/src/lightning/pytorch/core/mixins/hparams_mixin.py index 94ece0039d..ec509062b1 100644 --- a/src/lightning/pytorch/core/mixins/hparams_mixin.py +++ b/src/lightning/pytorch/core/mixins/hparams_mixin.py @@ -15,9 +15,10 @@ import copy import inspect import types from argparse import Namespace +from collections.abc import Iterator, MutableMapping, Sequence from contextlib import contextmanager from contextvars import ContextVar -from typing import Any, Iterator, List, MutableMapping, Optional, Sequence, Union +from typing import Any, List, Optional, Union from lightning.fabric.utilities.data import AttributeDict from lightning.pytorch.utilities.parsing import save_hyperparameters diff --git a/src/lightning/pytorch/core/module.py b/src/lightning/pytorch/core/module.py index d8374ef7ea..61eeedf947 100644 --- a/src/lightning/pytorch/core/module.py +++ b/src/lightning/pytorch/core/module.py @@ -16,6 +16,7 @@ import logging import numbers import weakref +from collections.abc import Generator, Mapping, Sequence from contextlib import contextmanager from io import BytesIO from pathlib import Path @@ -25,12 +26,9 @@ from typing import ( Any, Callable, Dict, - Generator, List, Literal, - Mapping, Optional, - Sequence, Tuple, Union, cast, diff --git a/src/lightning/pytorch/core/optimizer.py b/src/lightning/pytorch/core/optimizer.py index 777dca0b51..ae3b0a8403 100644 --- a/src/lightning/pytorch/core/optimizer.py +++ b/src/lightning/pytorch/core/optimizer.py @@ -11,9 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Generator from contextlib import contextmanager from dataclasses import fields -from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union, overload +from typing import Any, Callable, Dict, List, Optional, Tuple, Union, overload from weakref import proxy import torch diff --git a/src/lightning/pytorch/demos/boring_classes.py b/src/lightning/pytorch/demos/boring_classes.py index fd26602281..9f7f86890d 100644 --- a/src/lightning/pytorch/demos/boring_classes.py +++ b/src/lightning/pytorch/demos/boring_classes.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Iterator, List, Optional, Tuple +from collections.abc import Iterator +from typing import Any, Dict, List, Optional, Tuple import torch import torch.nn as nn diff --git a/src/lightning/pytorch/demos/lstm.py b/src/lightning/pytorch/demos/lstm.py index 672b61ad0e..f17f90359c 100644 --- a/src/lightning/pytorch/demos/lstm.py +++ b/src/lightning/pytorch/demos/lstm.py @@ -5,7 +5,8 @@ https://github.com/pytorch/examples/blob/main/word_language_model """ -from typing import Iterator, List, Optional, Sized, Tuple +from collections.abc import Iterator, Sized +from typing import List, Optional, Tuple import torch import torch.nn as nn diff --git a/src/lightning/pytorch/demos/mnist_datamodule.py b/src/lightning/pytorch/demos/mnist_datamodule.py index 992527ab67..304e03d0d3 100644 --- a/src/lightning/pytorch/demos/mnist_datamodule.py +++ b/src/lightning/pytorch/demos/mnist_datamodule.py @@ -16,7 +16,8 @@ import os import random import time import urllib -from typing import Any, Callable, Optional, Sized, Tuple, Union +from collections.abc import Sized +from typing import Any, Callable, Optional, Tuple, Union from urllib.error import HTTPError from warnings import warn diff --git a/src/lightning/pytorch/loggers/comet.py b/src/lightning/pytorch/loggers/comet.py index 277af5c85f..25d51d2ec0 100644 --- a/src/lightning/pytorch/loggers/comet.py +++ b/src/lightning/pytorch/loggers/comet.py @@ -19,7 +19,8 @@ Comet Logger import logging import os from argparse import Namespace -from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Union +from collections.abc import Mapping +from typing import TYPE_CHECKING, Any, Dict, Optional, Union from lightning_utilities.core.imports import RequirementCache from torch import Tensor diff --git a/src/lightning/pytorch/loggers/logger.py b/src/lightning/pytorch/loggers/logger.py index 40e8ed8c4a..bb7d93c90b 100644 --- a/src/lightning/pytorch/loggers/logger.py +++ b/src/lightning/pytorch/loggers/logger.py @@ -18,7 +18,8 @@ import operator import statistics from abc import ABC from collections import defaultdict -from typing import Any, Callable, Dict, Mapping, Optional, Sequence +from collections.abc import Mapping, Sequence +from typing import Any, Callable, Dict, Optional from typing_extensions import override diff --git a/src/lightning/pytorch/loggers/mlflow.py b/src/lightning/pytorch/loggers/mlflow.py index 1b15014cd0..b194ae6a64 100644 --- a/src/lightning/pytorch/loggers/mlflow.py +++ b/src/lightning/pytorch/loggers/mlflow.py @@ -21,9 +21,10 @@ import os import re import tempfile from argparse import Namespace +from collections.abc import Mapping from pathlib import Path from time import time -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Mapping, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Union import yaml from lightning_utilities.core.imports import RequirementCache diff --git a/src/lightning/pytorch/loggers/neptune.py b/src/lightning/pytorch/loggers/neptune.py index 691dbe0ba2..852f5f9116 100644 --- a/src/lightning/pytorch/loggers/neptune.py +++ b/src/lightning/pytorch/loggers/neptune.py @@ -20,8 +20,9 @@ import contextlib import logging import os from argparse import Namespace +from collections.abc import Generator from functools import wraps -from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Optional, Set, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Union from lightning_utilities.core.imports import RequirementCache from torch import Tensor diff --git a/src/lightning/pytorch/loggers/wandb.py b/src/lightning/pytorch/loggers/wandb.py index 20f8d02a7a..b38a69d938 100644 --- a/src/lightning/pytorch/loggers/wandb.py +++ b/src/lightning/pytorch/loggers/wandb.py @@ -18,8 +18,9 @@ Weights and Biases Logger import os from argparse import Namespace +from collections.abc import Mapping from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union import torch.nn as nn from lightning_utilities.core.imports import RequirementCache diff --git a/src/lightning/pytorch/loops/evaluation_loop.py b/src/lightning/pytorch/loops/evaluation_loop.py index 0ab3901cf0..cb4f4bbe23 100644 --- a/src/lightning/pytorch/loops/evaluation_loop.py +++ b/src/lightning/pytorch/loops/evaluation_loop.py @@ -15,7 +15,8 @@ import os import shutil import sys from collections import ChainMap, OrderedDict, defaultdict -from typing import Any, DefaultDict, Iterable, Iterator, List, Optional, Tuple, Union +from collections.abc import Iterable, Iterator +from typing import Any, DefaultDict, List, Optional, Tuple, Union from lightning_utilities.core.apply_func import apply_to_collection from torch import Tensor diff --git a/src/lightning/pytorch/loops/fetchers.py b/src/lightning/pytorch/loops/fetchers.py index e699321a4d..50ab916388 100644 --- a/src/lightning/pytorch/loops/fetchers.py +++ b/src/lightning/pytorch/loops/fetchers.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Iterator, List, Optional +from collections.abc import Iterator +from typing import Any, List, Optional from typing_extensions import override diff --git a/src/lightning/pytorch/loops/optimization/automatic.py b/src/lightning/pytorch/loops/optimization/automatic.py index 2ce6acab11..3238c486e1 100644 --- a/src/lightning/pytorch/loops/optimization/automatic.py +++ b/src/lightning/pytorch/loops/optimization/automatic.py @@ -11,9 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from collections import OrderedDict +from collections.abc import Mapping from dataclasses import dataclass, field from functools import partial -from typing import Any, Callable, Dict, Mapping, Optional, OrderedDict +from typing import Any, Callable, Dict, Optional import torch from torch import Tensor diff --git a/src/lightning/pytorch/loops/prediction_loop.py b/src/lightning/pytorch/loops/prediction_loop.py index 9002e6280f..878c9acc2f 100644 --- a/src/lightning/pytorch/loops/prediction_loop.py +++ b/src/lightning/pytorch/loops/prediction_loop.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections import OrderedDict -from typing import Any, Iterator, List, Optional, Union +from collections.abc import Iterator +from typing import Any, List, Optional, Union import torch from lightning_utilities import WarningCache diff --git a/src/lightning/pytorch/loops/utilities.py b/src/lightning/pytorch/loops/utilities.py index 99ea5c4254..74eae25b87 100644 --- a/src/lightning/pytorch/loops/utilities.py +++ b/src/lightning/pytorch/loops/utilities.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import inspect +from collections.abc import Generator from contextlib import contextmanager -from typing import Any, Callable, ContextManager, Generator, Optional, Tuple, Type +from typing import Any, Callable, ContextManager, Optional, Tuple, Type import torch import torch.distributed as dist diff --git a/src/lightning/pytorch/overrides/distributed.py b/src/lightning/pytorch/overrides/distributed.py index e4b6528553..6763b39d77 100644 --- a/src/lightning/pytorch/overrides/distributed.py +++ b/src/lightning/pytorch/overrides/distributed.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import itertools -from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Sized, Union, cast +from collections.abc import Iterable, Iterator, Sized +from typing import Any, Callable, Dict, List, Optional, Union, cast import torch from torch import Tensor diff --git a/src/lightning/pytorch/plugins/precision/amp.py b/src/lightning/pytorch/plugins/precision/amp.py index e63ccd6912..427dfc3acb 100644 --- a/src/lightning/pytorch/plugins/precision/amp.py +++ b/src/lightning/pytorch/plugins/precision/amp.py @@ -9,8 +9,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Generator from contextlib import contextmanager -from typing import Any, Callable, Dict, Generator, Literal, Optional, Union +from typing import Any, Callable, Dict, Literal, Optional, Union import torch from torch import Tensor diff --git a/src/lightning/pytorch/plugins/precision/double.py b/src/lightning/pytorch/plugins/precision/double.py index 20f493bb7b..5d0af8b992 100644 --- a/src/lightning/pytorch/plugins/precision/double.py +++ b/src/lightning/pytorch/plugins/precision/double.py @@ -11,8 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Generator from contextlib import contextmanager -from typing import Any, ContextManager, Generator, Literal +from typing import Any, ContextManager, Literal import torch import torch.nn as nn diff --git a/src/lightning/pytorch/plugins/precision/half.py b/src/lightning/pytorch/plugins/precision/half.py index 22dc29b580..2ad30de2b8 100644 --- a/src/lightning/pytorch/plugins/precision/half.py +++ b/src/lightning/pytorch/plugins/precision/half.py @@ -11,8 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Generator from contextlib import contextmanager -from typing import Any, ContextManager, Generator, Literal +from typing import Any, ContextManager, Literal import torch from lightning_utilities import apply_to_collection diff --git a/src/lightning/pytorch/plugins/precision/precision.py b/src/lightning/pytorch/plugins/precision/precision.py index 51bdddb18f..e391ac1f68 100644 --- a/src/lightning/pytorch/plugins/precision/precision.py +++ b/src/lightning/pytorch/plugins/precision/precision.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import contextlib +from collections.abc import Generator from functools import partial -from typing import Any, Callable, Generator, List, Optional, Tuple, Union +from typing import Any, Callable, List, Optional, Tuple, Union import torch from torch import Tensor diff --git a/src/lightning/pytorch/profilers/advanced.py b/src/lightning/pytorch/profilers/advanced.py index 467b47124e..461d43dc8e 100644 --- a/src/lightning/pytorch/profilers/advanced.py +++ b/src/lightning/pytorch/profilers/advanced.py @@ -89,9 +89,10 @@ class AdvancedProfiler(Profiler): dst_fs = get_filesystem(dst_filepath) dst_fs.mkdirs(self.dirpath, exist_ok=True) # temporarily save to local since pstats can only dump into a local file - with tempfile.TemporaryDirectory( - prefix="test", suffix=str(rank_zero_only.rank), dir=os.getcwd() - ) as tmp_dir, dst_fs.open(dst_filepath, "wb") as dst_file: + with ( + tempfile.TemporaryDirectory(prefix="test", suffix=str(rank_zero_only.rank), dir=os.getcwd()) as tmp_dir, + dst_fs.open(dst_filepath, "wb") as dst_file, + ): src_filepath = os.path.join(tmp_dir, "tmp.prof") profile.dump_stats(src_filepath) src_fs = get_filesystem(src_filepath) diff --git a/src/lightning/pytorch/profilers/profiler.py b/src/lightning/pytorch/profilers/profiler.py index fb44832157..9ab15f4ff9 100644 --- a/src/lightning/pytorch/profilers/profiler.py +++ b/src/lightning/pytorch/profilers/profiler.py @@ -16,9 +16,10 @@ import logging import os from abc import ABC, abstractmethod +from collections.abc import Generator from contextlib import contextmanager from pathlib import Path -from typing import Any, Callable, Dict, Generator, Optional, TextIO, Union +from typing import Any, Callable, Dict, Optional, TextIO, Union from lightning.fabric.utilities.cloud_io import get_filesystem diff --git a/src/lightning/pytorch/strategies/deepspeed.py b/src/lightning/pytorch/strategies/deepspeed.py index 1eaa5bab75..8add75e659 100644 --- a/src/lightning/pytorch/strategies/deepspeed.py +++ b/src/lightning/pytorch/strategies/deepspeed.py @@ -17,9 +17,10 @@ import logging import os import platform from collections import OrderedDict +from collections.abc import Generator, Mapping from contextlib import contextmanager from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, Generator, List, Mapping, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union import torch from torch.nn import Module diff --git a/src/lightning/pytorch/strategies/fsdp.py b/src/lightning/pytorch/strategies/fsdp.py index ab6e579c30..0a89916e49 100644 --- a/src/lightning/pytorch/strategies/fsdp.py +++ b/src/lightning/pytorch/strategies/fsdp.py @@ -13,6 +13,7 @@ # limitations under the License. import logging import shutil +from collections.abc import Generator, Mapping from contextlib import contextmanager, nullcontext from datetime import timedelta from pathlib import Path @@ -21,10 +22,8 @@ from typing import ( Any, Callable, Dict, - Generator, List, Literal, - Mapping, Optional, Set, Tuple, diff --git a/src/lightning/pytorch/strategies/model_parallel.py b/src/lightning/pytorch/strategies/model_parallel.py index fb45166378..d00ac25101 100644 --- a/src/lightning/pytorch/strategies/model_parallel.py +++ b/src/lightning/pytorch/strategies/model_parallel.py @@ -12,10 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. import shutil +from collections.abc import Generator, Mapping from contextlib import contextmanager, nullcontext from datetime import timedelta from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal, Mapping, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union import torch from lightning_utilities.core.rank_zero import rank_zero_only as utils_rank_zero_only diff --git a/src/lightning/pytorch/strategies/parallel.py b/src/lightning/pytorch/strategies/parallel.py index 5658438cd3..337153f827 100644 --- a/src/lightning/pytorch/strategies/parallel.py +++ b/src/lightning/pytorch/strategies/parallel.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC, abstractmethod +from collections.abc import Generator from contextlib import contextmanager -from typing import Any, Dict, Generator, List, Optional +from typing import Any, Dict, List, Optional import torch from torch import Tensor diff --git a/src/lightning/pytorch/strategies/strategy.py b/src/lightning/pytorch/strategies/strategy.py index 314007f497..3201ee549e 100644 --- a/src/lightning/pytorch/strategies/strategy.py +++ b/src/lightning/pytorch/strategies/strategy.py @@ -13,8 +13,9 @@ # limitations under the License. import logging from abc import ABC, abstractmethod +from collections.abc import Generator, Mapping from contextlib import contextmanager -from typing import Any, Callable, Dict, Generator, List, Mapping, Optional, Tuple, TypeVar, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union import torch from torch import Tensor diff --git a/src/lightning/pytorch/trainer/call.py b/src/lightning/pytorch/trainer/call.py index 4c3bc5ef41..85ccabd332 100644 --- a/src/lightning/pytorch/trainer/call.py +++ b/src/lightning/pytorch/trainer/call.py @@ -115,7 +115,11 @@ def _call_configure_model(trainer: "pl.Trainer") -> None: # we don't normally check for this before calling the hook. it is done here to avoid instantiating the context # managers if is_overridden("configure_model", trainer.lightning_module): - with trainer.strategy.tensor_init_context(), trainer.strategy.model_sharded_context(), trainer.precision_plugin.module_init_context(): # noqa: E501 + with ( + trainer.strategy.tensor_init_context(), + trainer.strategy.model_sharded_context(), + trainer.precision_plugin.module_init_context(), + ): # noqa: E501 _call_lightning_module_hook(trainer, "configure_model") diff --git a/src/lightning/pytorch/trainer/connectors/callback_connector.py b/src/lightning/pytorch/trainer/connectors/callback_connector.py index 2f2b619290..5f0dd55df7 100644 --- a/src/lightning/pytorch/trainer/connectors/callback_connector.py +++ b/src/lightning/pytorch/trainer/connectors/callback_connector.py @@ -14,8 +14,9 @@ import logging import os +from collections.abc import Sequence from datetime import timedelta -from typing import Dict, List, Optional, Sequence, Union +from typing import Dict, List, Optional, Union import lightning.pytorch as pl from lightning.fabric.utilities.registry import _load_external_callbacks diff --git a/src/lightning/pytorch/trainer/connectors/data_connector.py b/src/lightning/pytorch/trainer/connectors/data_connector.py index 1e84a2ebd0..8f9d8977f6 100644 --- a/src/lightning/pytorch/trainer/connectors/data_connector.py +++ b/src/lightning/pytorch/trainer/connectors/data_connector.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +from collections.abc import Iterable from dataclasses import dataclass, field -from typing import Any, Iterable, Optional, Tuple, Union +from typing import Any, Optional, Tuple, Union import torch.multiprocessing as mp from torch.utils.data import BatchSampler, DataLoader, RandomSampler, Sampler, SequentialSampler diff --git a/src/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py b/src/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py index c4ab11632b..ffc99a9772 100644 --- a/src/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py +++ b/src/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Iterable, Optional, Union +from collections.abc import Iterable +from typing import Any, Optional, Union from lightning_utilities.core.apply_func import apply_to_collection from torch import Tensor diff --git a/src/lightning/pytorch/trainer/connectors/logger_connector/result.py b/src/lightning/pytorch/trainer/connectors/logger_connector/result.py index 62cc7844d3..3a0b8f07a4 100644 --- a/src/lightning/pytorch/trainer/connectors/logger_connector/result.py +++ b/src/lightning/pytorch/trainer/connectors/logger_connector/result.py @@ -11,9 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Generator from dataclasses import dataclass from functools import partial, wraps -from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union, cast +from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast import torch from lightning_utilities.core.apply_func import apply_to_collection diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index 406f686efe..72493a5b36 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -23,9 +23,10 @@ import logging import math import os +from collections.abc import Generator, Iterable from contextlib import contextmanager from datetime import timedelta -from typing import Any, Dict, Generator, Iterable, List, Optional, Union +from typing import Any, Dict, List, Optional, Union from weakref import proxy import torch diff --git a/src/lightning/pytorch/utilities/combined_loader.py b/src/lightning/pytorch/utilities/combined_loader.py index 9b0ceb0288..8730b3c914 100644 --- a/src/lightning/pytorch/utilities/combined_loader.py +++ b/src/lightning/pytorch/utilities/combined_loader.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import contextlib -from collections.abc import Iterable -from typing import Any, Callable, Dict, Iterator, List, Literal, Optional, Tuple, Type, Union +from collections.abc import Iterable, Iterator +from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, Union from torch.utils.data.dataloader import _BaseDataLoaderIter, _MultiProcessingDataLoaderIter from typing_extensions import Self, TypedDict, override diff --git a/src/lightning/pytorch/utilities/data.py b/src/lightning/pytorch/utilities/data.py index 41c5ea86e5..eae20d7358 100644 --- a/src/lightning/pytorch/utilities/data.py +++ b/src/lightning/pytorch/utilities/data.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import inspect +from collections.abc import Generator, Iterable, Mapping, Sized from dataclasses import fields -from typing import Any, Dict, Generator, Iterable, Mapping, Optional, Sized, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union import torch from lightning_utilities.core.apply_func import is_dataclass_instance diff --git a/src/lightning/pytorch/utilities/parsing.py b/src/lightning/pytorch/utilities/parsing.py index 0f4460a3d5..b280c6e16e 100644 --- a/src/lightning/pytorch/utilities/parsing.py +++ b/src/lightning/pytorch/utilities/parsing.py @@ -17,8 +17,9 @@ import copy import inspect import pickle import types +from collections.abc import MutableMapping, Sequence from dataclasses import fields, is_dataclass -from typing import Any, Dict, List, Literal, MutableMapping, Optional, Sequence, Tuple, Type, Union +from typing import Any, Dict, List, Literal, Optional, Tuple, Type, Union from torch import nn diff --git a/src/lightning/pytorch/utilities/seed.py b/src/lightning/pytorch/utilities/seed.py index 4ba9e7f0f9..7250ba5936 100644 --- a/src/lightning/pytorch/utilities/seed.py +++ b/src/lightning/pytorch/utilities/seed.py @@ -13,8 +13,8 @@ # limitations under the License. """Utilities to help with reproducibility of models.""" +from collections.abc import Generator from contextlib import contextmanager -from typing import Generator from lightning.fabric.utilities.seed import _collect_rng_states, _set_rng_states diff --git a/src/lightning/pytorch/utilities/types.py b/src/lightning/pytorch/utilities/types.py index c1b971e924..20b2bd0e5e 100644 --- a/src/lightning/pytorch/utilities/types.py +++ b/src/lightning/pytorch/utilities/types.py @@ -17,17 +17,14 @@ Convention: - Types used in public hooks (as those in the `LightningModule` and `Callback`) should be public (no leading `_`) """ +from collections.abc import Generator, Iterator, Mapping, Sequence from contextlib import contextmanager from dataclasses import dataclass from typing import ( Any, - Generator, - Iterator, List, - Mapping, Optional, Protocol, - Sequence, Tuple, Type, TypedDict, diff --git a/tests/tests_fabric/accelerators/test_cuda.py b/tests/tests_fabric/accelerators/test_cuda.py index e323ada908..0aed3675d9 100644 --- a/tests/tests_fabric/accelerators/test_cuda.py +++ b/tests/tests_fabric/accelerators/test_cuda.py @@ -121,27 +121,32 @@ def test_tf32_message(_, __, ___, caplog, monkeypatch): def test_find_usable_cuda_devices_error_handling(): """Test error handling for edge cases when using `find_usable_cuda_devices`.""" # Asking for GPUs if no GPUs visible - with mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=0), pytest.raises( - ValueError, match="You requested to find 2 devices but there are no visible CUDA" + with ( + mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=0), + pytest.raises(ValueError, match="You requested to find 2 devices but there are no visible CUDA"), ): find_usable_cuda_devices(2) # Asking for more GPUs than are visible - with mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=1), pytest.raises( - ValueError, match="this machine only has 1 GPUs" + with ( + mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=1), + pytest.raises(ValueError, match="this machine only has 1 GPUs"), ): find_usable_cuda_devices(2) # All GPUs are unusable tensor_mock = Mock(side_effect=RuntimeError) # simulate device placement fails - with mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=2), mock.patch( - "lightning.fabric.accelerators.cuda.torch.tensor", tensor_mock - ), pytest.raises(RuntimeError, match=escape("The devices [0, 1] are occupied by other processes")): + with ( + mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=2), + mock.patch("lightning.fabric.accelerators.cuda.torch.tensor", tensor_mock), + pytest.raises(RuntimeError, match=escape("The devices [0, 1] are occupied by other processes")), + ): find_usable_cuda_devices(2) # Request for as many GPUs as there are, no error should be raised - with mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=5), mock.patch( - "lightning.fabric.accelerators.cuda.torch.tensor" + with ( + mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=5), + mock.patch("lightning.fabric.accelerators.cuda.torch.tensor"), ): assert find_usable_cuda_devices(-1) == [0, 1, 2, 3, 4] diff --git a/tests/tests_fabric/helpers/datasets.py b/tests/tests_fabric/helpers/datasets.py index 211e1f36a9..ee14b21dc5 100644 --- a/tests/tests_fabric/helpers/datasets.py +++ b/tests/tests_fabric/helpers/datasets.py @@ -1,4 +1,4 @@ -from typing import Iterator +from collections.abc import Iterator import torch from torch import Tensor diff --git a/tests/tests_fabric/plugins/collectives/test_torch_collective.py b/tests/tests_fabric/plugins/collectives/test_torch_collective.py index c8deb9d335..b4c223e770 100644 --- a/tests/tests_fabric/plugins/collectives/test_torch_collective.py +++ b/tests/tests_fabric/plugins/collectives/test_torch_collective.py @@ -29,13 +29,16 @@ PASSED_OBJECT = mock.Mock() @contextlib.contextmanager def check_destroy_group(): - with mock.patch( - "lightning.fabric.plugins.collectives.torch_collective.TorchCollective.new_group", - wraps=TorchCollective.new_group, - ) as mock_new, mock.patch( - "lightning.fabric.plugins.collectives.torch_collective.TorchCollective.destroy_group", - wraps=TorchCollective.destroy_group, - ) as mock_destroy: + with ( + mock.patch( + "lightning.fabric.plugins.collectives.torch_collective.TorchCollective.new_group", + wraps=TorchCollective.new_group, + ) as mock_new, + mock.patch( + "lightning.fabric.plugins.collectives.torch_collective.TorchCollective.destroy_group", + wraps=TorchCollective.destroy_group, + ) as mock_destroy, + ): yield # 0 to account for tests that mock distributed # -1 to account for destroying the default process group @@ -155,9 +158,10 @@ def test_repeated_create_and_destroy(): with pytest.raises(RuntimeError, match="TorchCollective` already owns a group"): collective.create_group() - with mock.patch.dict("torch.distributed.distributed_c10d._pg_map", {collective.group: ("", None)}), mock.patch( - "torch.distributed.destroy_process_group" - ) as destroy_mock: + with ( + mock.patch.dict("torch.distributed.distributed_c10d._pg_map", {collective.group: ("", None)}), + mock.patch("torch.distributed.destroy_process_group") as destroy_mock, + ): collective.teardown() # this would be called twice if `init_process_group` wasn't patched. once for the group and once for the default # group @@ -300,9 +304,11 @@ def test_collective_manages_default_group(): assert TorchCollective.manages_default_group - with mock.patch.object(collective, "_group") as mock_group, mock.patch.dict( - "torch.distributed.distributed_c10d._pg_map", {mock_group: ("", None)} - ), mock.patch("torch.distributed.destroy_process_group") as destroy_mock: + with ( + mock.patch.object(collective, "_group") as mock_group, + mock.patch.dict("torch.distributed.distributed_c10d._pg_map", {mock_group: ("", None)}), + mock.patch("torch.distributed.destroy_process_group") as destroy_mock, + ): collective.teardown() destroy_mock.assert_called_once_with(mock_group) diff --git a/tests/tests_fabric/plugins/environments/test_lsf.py b/tests/tests_fabric/plugins/environments/test_lsf.py index b444f6fc4d..4e60d968dc 100644 --- a/tests/tests_fabric/plugins/environments/test_lsf.py +++ b/tests/tests_fabric/plugins/environments/test_lsf.py @@ -41,8 +41,9 @@ def test_empty_lsb_djob_rankfile(): def test_missing_lsb_job_id(tmp_path): """Test an error when the job id cannot be found.""" - with mock.patch.dict(os.environ, {"LSB_DJOB_RANKFILE": _make_rankfile(tmp_path)}), pytest.raises( - ValueError, match="Could not find job id in environment variable LSB_JOBID" + with ( + mock.patch.dict(os.environ, {"LSB_DJOB_RANKFILE": _make_rankfile(tmp_path)}), + pytest.raises(ValueError, match="Could not find job id in environment variable LSB_JOBID"), ): LSFEnvironment() diff --git a/tests/tests_fabric/plugins/environments/test_slurm.py b/tests/tests_fabric/plugins/environments/test_slurm.py index 73457ede41..f237478a53 100644 --- a/tests/tests_fabric/plugins/environments/test_slurm.py +++ b/tests/tests_fabric/plugins/environments/test_slurm.py @@ -155,8 +155,9 @@ def test_srun_variable_validation(): """Test that we raise useful errors when `srun` variables are misconfigured.""" with mock.patch.dict(os.environ, {"SLURM_NTASKS": "1"}): SLURMEnvironment() - with mock.patch.dict(os.environ, {"SLURM_NTASKS": "2"}), pytest.raises( - RuntimeError, match="You set `--ntasks=2` in your SLURM" + with ( + mock.patch.dict(os.environ, {"SLURM_NTASKS": "2"}), + pytest.raises(RuntimeError, match="You set `--ntasks=2` in your SLURM"), ): SLURMEnvironment() diff --git a/tests/tests_fabric/strategies/launchers/test_multiprocessing.py b/tests/tests_fabric/strategies/launchers/test_multiprocessing.py index b63d443098..6c595fba7a 100644 --- a/tests/tests_fabric/strategies/launchers/test_multiprocessing.py +++ b/tests/tests_fabric/strategies/launchers/test_multiprocessing.py @@ -100,8 +100,11 @@ def test_check_for_bad_cuda_fork(mp_mock, _, start_method): def test_check_for_missing_main_guard(): launcher = _MultiProcessingLauncher(strategy=Mock(), start_method="spawn") - with mock.patch( - "lightning.fabric.strategies.launchers.multiprocessing.mp.current_process", - return_value=Mock(_inheriting=True), # pretend that main is importing itself - ), pytest.raises(RuntimeError, match="requires that your script guards the main"): + with ( + mock.patch( + "lightning.fabric.strategies.launchers.multiprocessing.mp.current_process", + return_value=Mock(_inheriting=True), # pretend that main is importing itself + ), + pytest.raises(RuntimeError, match="requires that your script guards the main"), + ): launcher.launch(function=Mock()) diff --git a/tests/tests_fabric/strategies/test_ddp.py b/tests/tests_fabric/strategies/test_ddp.py index 56d9875dfe..b98d5f8226 100644 --- a/tests/tests_fabric/strategies/test_ddp.py +++ b/tests/tests_fabric/strategies/test_ddp.py @@ -58,9 +58,12 @@ def test_ddp_no_backward_sync(): strategy = DDPStrategy() assert isinstance(strategy._backward_sync_control, _DDPBackwardSyncControl) - with pytest.raises( - TypeError, match="is only possible if the module passed to .* is wrapped in `DistributedDataParallel`" - ), strategy._backward_sync_control.no_backward_sync(Mock(), True): + with ( + pytest.raises( + TypeError, match="is only possible if the module passed to .* is wrapped in `DistributedDataParallel`" + ), + strategy._backward_sync_control.no_backward_sync(Mock(), True), + ): pass module = MagicMock(spec=DistributedDataParallel) diff --git a/tests/tests_fabric/strategies/test_deepspeed_integration.py b/tests/tests_fabric/strategies/test_deepspeed_integration.py index 3be535effa..4811599ed0 100644 --- a/tests/tests_fabric/strategies/test_deepspeed_integration.py +++ b/tests/tests_fabric/strategies/test_deepspeed_integration.py @@ -404,9 +404,11 @@ def test_deepspeed_init_module_with_stages_1_2(stage, empty_init): fabric = Fabric(accelerator="cuda", devices=2, strategy=strategy, precision="bf16-true") fabric.launch() - with mock.patch("deepspeed.zero.Init") as zero_init_mock, mock.patch( - "torch.Tensor.uniform_" - ) as init_mock, fabric.init_module(empty_init=empty_init): + with ( + mock.patch("deepspeed.zero.Init") as zero_init_mock, + mock.patch("torch.Tensor.uniform_") as init_mock, + fabric.init_module(empty_init=empty_init), + ): model = BoringModel() zero_init_mock.assert_called_with(enabled=False, remote_device=None, config_dict_or_path=ANY) diff --git a/tests/tests_fabric/strategies/test_fsdp.py b/tests/tests_fabric/strategies/test_fsdp.py index 0c46e7ac17..cb6542cdb6 100644 --- a/tests/tests_fabric/strategies/test_fsdp.py +++ b/tests/tests_fabric/strategies/test_fsdp.py @@ -133,9 +133,12 @@ def test_no_backward_sync(): strategy = FSDPStrategy() assert isinstance(strategy._backward_sync_control, _FSDPBackwardSyncControl) - with pytest.raises( - TypeError, match="is only possible if the module passed to .* is wrapped in `FullyShardedDataParallel`" - ), strategy._backward_sync_control.no_backward_sync(Mock(), True): + with ( + pytest.raises( + TypeError, match="is only possible if the module passed to .* is wrapped in `FullyShardedDataParallel`" + ), + strategy._backward_sync_control.no_backward_sync(Mock(), True), + ): pass module = MagicMock(spec=FullyShardedDataParallel) @@ -172,9 +175,12 @@ def test_activation_checkpointing(): assert isinstance(strategy._activation_checkpointing_kwargs["auto_wrap_policy"], ModuleWrapPolicy) strategy._parallel_devices = [torch.device("cuda", 0)] - with mock.patch("torch.distributed.fsdp.FullyShardedDataParallel", new=MagicMock), mock.patch( - "torch.distributed.algorithms._checkpoint.checkpoint_wrapper.apply_activation_checkpointing" - ) as apply_mock: + with ( + mock.patch("torch.distributed.fsdp.FullyShardedDataParallel", new=MagicMock), + mock.patch( + "torch.distributed.algorithms._checkpoint.checkpoint_wrapper.apply_activation_checkpointing" + ) as apply_mock, + ): wrapped = strategy.setup_module(Model()) apply_mock.assert_called_with(wrapped, checkpoint_wrapper_fn=ANY, **strategy._activation_checkpointing_kwargs) diff --git a/tests/tests_fabric/strategies/test_xla_fsdp.py b/tests/tests_fabric/strategies/test_xla_fsdp.py index e2864b684c..879a55cf77 100644 --- a/tests/tests_fabric/strategies/test_xla_fsdp.py +++ b/tests/tests_fabric/strategies/test_xla_fsdp.py @@ -48,9 +48,12 @@ def test_xla_fsdp_no_backward_sync(): strategy = XLAFSDPStrategy() assert isinstance(strategy._backward_sync_control, _XLAFSDPBackwardSyncControl) - with pytest.raises( - TypeError, match="is only possible if the module passed to .* is wrapped in `XlaFullyShardedDataParallel`" - ), strategy._backward_sync_control.no_backward_sync(object(), True): + with ( + pytest.raises( + TypeError, match="is only possible if the module passed to .* is wrapped in `XlaFullyShardedDataParallel`" + ), + strategy._backward_sync_control.no_backward_sync(object(), True), + ): pass module = MagicMock(spec=XlaFullyShardedDataParallel) diff --git a/tests/tests_fabric/test_connector.py b/tests/tests_fabric/test_connector.py index 08d6dbb45e..6cd93f096a 100644 --- a/tests/tests_fabric/test_connector.py +++ b/tests/tests_fabric/test_connector.py @@ -960,28 +960,33 @@ def test_arguments_from_environment_collision(): with mock.patch.dict(os.environ, {"LT_ACCELERATOR": "cpu"}): _Connector(accelerator="cuda") - with mock.patch.dict(os.environ, {"LT_ACCELERATOR": "cpu", "LT_CLI_USED": "1"}), pytest.raises( - ValueError, match="`Fabric\\(accelerator='cuda', ...\\)` but .* `--accelerator=cpu`" + with ( + mock.patch.dict(os.environ, {"LT_ACCELERATOR": "cpu", "LT_CLI_USED": "1"}), + pytest.raises(ValueError, match="`Fabric\\(accelerator='cuda', ...\\)` but .* `--accelerator=cpu`"), ): _Connector(accelerator="cuda") - with mock.patch.dict(os.environ, {"LT_STRATEGY": "ddp", "LT_CLI_USED": "1"}), pytest.raises( - ValueError, match="`Fabric\\(strategy='ddp_spawn', ...\\)` but .* `--strategy=ddp`" + with ( + mock.patch.dict(os.environ, {"LT_STRATEGY": "ddp", "LT_CLI_USED": "1"}), + pytest.raises(ValueError, match="`Fabric\\(strategy='ddp_spawn', ...\\)` but .* `--strategy=ddp`"), ): _Connector(strategy="ddp_spawn") - with mock.patch.dict(os.environ, {"LT_DEVICES": "2", "LT_CLI_USED": "1"}), pytest.raises( - ValueError, match="`Fabric\\(devices=3, ...\\)` but .* `--devices=2`" + with ( + mock.patch.dict(os.environ, {"LT_DEVICES": "2", "LT_CLI_USED": "1"}), + pytest.raises(ValueError, match="`Fabric\\(devices=3, ...\\)` but .* `--devices=2`"), ): _Connector(devices=3) - with mock.patch.dict(os.environ, {"LT_NUM_NODES": "3", "LT_CLI_USED": "1"}), pytest.raises( - ValueError, match="`Fabric\\(num_nodes=2, ...\\)` but .* `--num_nodes=3`" + with ( + mock.patch.dict(os.environ, {"LT_NUM_NODES": "3", "LT_CLI_USED": "1"}), + pytest.raises(ValueError, match="`Fabric\\(num_nodes=2, ...\\)` but .* `--num_nodes=3`"), ): _Connector(num_nodes=2) - with mock.patch.dict(os.environ, {"LT_PRECISION": "16-mixed", "LT_CLI_USED": "1"}), pytest.raises( - ValueError, match="`Fabric\\(precision='64-true', ...\\)` but .* `--precision=16-mixed`" + with ( + mock.patch.dict(os.environ, {"LT_PRECISION": "16-mixed", "LT_CLI_USED": "1"}), + pytest.raises(ValueError, match="`Fabric\\(precision='64-true', ...\\)` but .* `--precision=16-mixed`"), ): _Connector(precision="64-true") diff --git a/tests/tests_fabric/test_fabric.py b/tests/tests_fabric/test_fabric.py index 70d04d5431..7bb6b29ece 100644 --- a/tests/tests_fabric/test_fabric.py +++ b/tests/tests_fabric/test_fabric.py @@ -746,9 +746,10 @@ def test_no_backward_sync(): # pretend that the strategy does not support skipping backward sync fabric._strategy = Mock(spec=ParallelStrategy, _backward_sync_control=None) - with pytest.warns( - PossibleUserWarning, match="The `ParallelStrategy` does not support skipping the" - ), fabric.no_backward_sync(model): + with ( + pytest.warns(PossibleUserWarning, match="The `ParallelStrategy` does not support skipping the"), + fabric.no_backward_sync(model), + ): pass # for single-device strategies, it becomes a no-op without warning diff --git a/tests/tests_fabric/utilities/test_consolidate_checkpoint.py b/tests/tests_fabric/utilities/test_consolidate_checkpoint.py index 216b77e6b9..2584aab8bd 100644 --- a/tests/tests_fabric/utilities/test_consolidate_checkpoint.py +++ b/tests/tests_fabric/utilities/test_consolidate_checkpoint.py @@ -41,8 +41,9 @@ def test_parse_cli_args(args, expected): def test_process_cli_args(tmp_path, caplog, monkeypatch): # PyTorch version < 2.3 monkeypatch.setattr(lightning.fabric.utilities.consolidate_checkpoint, "_TORCH_GREATER_EQUAL_2_3", False) - with caplog.at_level(logging.ERROR, logger="lightning.fabric.utilities.consolidate_checkpoint"), pytest.raises( - SystemExit + with ( + caplog.at_level(logging.ERROR, logger="lightning.fabric.utilities.consolidate_checkpoint"), + pytest.raises(SystemExit), ): _process_cli_args(Namespace()) assert "requires PyTorch >= 2.3." in caplog.text @@ -51,8 +52,9 @@ def test_process_cli_args(tmp_path, caplog, monkeypatch): # Checkpoint does not exist checkpoint_folder = Path("does/not/exist") - with caplog.at_level(logging.ERROR, logger="lightning.fabric.utilities.consolidate_checkpoint"), pytest.raises( - SystemExit + with ( + caplog.at_level(logging.ERROR, logger="lightning.fabric.utilities.consolidate_checkpoint"), + pytest.raises(SystemExit), ): _process_cli_args(Namespace(checkpoint_folder=checkpoint_folder)) assert f"checkpoint folder does not exist: {checkpoint_folder}" in caplog.text @@ -61,8 +63,9 @@ def test_process_cli_args(tmp_path, caplog, monkeypatch): # Checkpoint exists but is not a folder file = tmp_path / "checkpoint_file" file.touch() - with caplog.at_level(logging.ERROR, logger="lightning.fabric.utilities.consolidate_checkpoint"), pytest.raises( - SystemExit + with ( + caplog.at_level(logging.ERROR, logger="lightning.fabric.utilities.consolidate_checkpoint"), + pytest.raises(SystemExit), ): _process_cli_args(Namespace(checkpoint_folder=file)) assert "checkpoint path must be a folder" in caplog.text @@ -71,8 +74,9 @@ def test_process_cli_args(tmp_path, caplog, monkeypatch): # Checkpoint exists but is not an FSDP checkpoint folder = tmp_path / "checkpoint_folder" folder.mkdir() - with caplog.at_level(logging.ERROR, logger="lightning.fabric.utilities.consolidate_checkpoint"), pytest.raises( - SystemExit + with ( + caplog.at_level(logging.ERROR, logger="lightning.fabric.utilities.consolidate_checkpoint"), + pytest.raises(SystemExit), ): _process_cli_args(Namespace(checkpoint_folder=folder)) assert "Only FSDP-sharded checkpoints saved with Lightning are supported" in caplog.text @@ -89,8 +93,9 @@ def test_process_cli_args(tmp_path, caplog, monkeypatch): # Checkpoint is a FSDP folder, output file already exists file = tmp_path / "ouput_file" file.touch() - with caplog.at_level(logging.ERROR, logger="lightning.fabric.utilities.consolidate_checkpoint"), pytest.raises( - SystemExit + with ( + caplog.at_level(logging.ERROR, logger="lightning.fabric.utilities.consolidate_checkpoint"), + pytest.raises(SystemExit), ): _process_cli_args(Namespace(checkpoint_folder=folder, output_file=file)) assert "path for the converted checkpoint already exists" in caplog.text diff --git a/tests/tests_fabric/utilities/test_distributed.py b/tests/tests_fabric/utilities/test_distributed.py index cc6c23bddb..f5a78a1529 100644 --- a/tests/tests_fabric/utilities/test_distributed.py +++ b/tests/tests_fabric/utilities/test_distributed.py @@ -215,9 +215,10 @@ def test_infinite_barrier(): # distributed available barrier = _InfiniteBarrier() - with mock.patch( - "lightning.fabric.utilities.distributed._distributed_is_initialized", return_value=True - ), mock.patch("lightning.fabric.utilities.distributed.torch.distributed") as dist_mock: + with ( + mock.patch("lightning.fabric.utilities.distributed._distributed_is_initialized", return_value=True), + mock.patch("lightning.fabric.utilities.distributed.torch.distributed") as dist_mock, + ): barrier.__enter__() dist_mock.new_group.assert_called_once() assert barrier.barrier == barrier.group.monitored_barrier diff --git a/tests/tests_fabric/utilities/test_throughput.py b/tests/tests_fabric/utilities/test_throughput.py index eefadb285a..d410d0766d 100644 --- a/tests/tests_fabric/utilities/test_throughput.py +++ b/tests/tests_fabric/utilities/test_throughput.py @@ -39,8 +39,9 @@ def test_get_available_flops(xla_available): with pytest.warns(match="not found for 'CocoNut"), mock.patch("torch.cuda.get_device_name", return_value="CocoNut"): assert get_available_flops(torch.device("cuda"), torch.bfloat16) is None - with pytest.warns(match="t4' does not support torch.bfloat"), mock.patch( - "torch.cuda.get_device_name", return_value="t4" + with ( + pytest.warns(match="t4' does not support torch.bfloat"), + mock.patch("torch.cuda.get_device_name", return_value="t4"), ): assert get_available_flops(torch.device("cuda"), torch.bfloat16) is None diff --git a/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py b/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py index b8d3d6d36c..89c1effe83 100644 --- a/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py +++ b/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py @@ -141,9 +141,12 @@ def test_rich_progress_bar_keyboard_interrupt(tmp_path): model = TestModel() - with mock.patch( - "lightning.pytorch.callbacks.progress.rich_progress.Progress.stop", autospec=True - ) as mock_progress_stop, pytest.raises(SystemExit): + with ( + mock.patch( + "lightning.pytorch.callbacks.progress.rich_progress.Progress.stop", autospec=True + ) as mock_progress_stop, + pytest.raises(SystemExit), + ): progress_bar = RichProgressBar() trainer = Trainer( default_root_dir=tmp_path, diff --git a/tests/tests_pytorch/callbacks/test_throughput_monitor.py b/tests/tests_pytorch/callbacks/test_throughput_monitor.py index a74efba758..4867134a85 100644 --- a/tests/tests_pytorch/callbacks/test_throughput_monitor.py +++ b/tests/tests_pytorch/callbacks/test_throughput_monitor.py @@ -43,8 +43,9 @@ def test_throughput_monitor_fit(tmp_path): ) # these timing results are meant to precisely match the `test_throughput_monitor` test in fabric timings = [0.0] + [0.5 + i for i in range(1, 6)] - with mock.patch("lightning.pytorch.callbacks.throughput_monitor.get_available_flops", return_value=100), mock.patch( - "time.perf_counter", side_effect=timings + with ( + mock.patch("lightning.pytorch.callbacks.throughput_monitor.get_available_flops", return_value=100), + mock.patch("time.perf_counter", side_effect=timings), ): trainer.fit(model) @@ -179,8 +180,9 @@ def test_throughput_monitor_fit_gradient_accumulation(log_every_n_steps, tmp_pat enable_progress_bar=False, ) timings = [0.0] + [0.5 + i for i in range(1, 11)] - with mock.patch("lightning.pytorch.callbacks.throughput_monitor.get_available_flops", return_value=100), mock.patch( - "time.perf_counter", side_effect=timings + with ( + mock.patch("lightning.pytorch.callbacks.throughput_monitor.get_available_flops", return_value=100), + mock.patch("time.perf_counter", side_effect=timings), ): trainer.fit(model) diff --git a/tests/tests_pytorch/checkpointing/test_trainer_checkpoint.py b/tests/tests_pytorch/checkpointing/test_trainer_checkpoint.py index 3ae7d6be49..c07400eaf8 100644 --- a/tests/tests_pytorch/checkpointing/test_trainer_checkpoint.py +++ b/tests/tests_pytorch/checkpointing/test_trainer_checkpoint.py @@ -92,9 +92,10 @@ def test_trainer_save_checkpoint_storage_options(tmp_path, xla_available): io_mock.assert_called_with(ANY, instance_path, storage_options=None) checkpoint_mock = Mock() - with mock.patch.object(trainer.strategy, "save_checkpoint") as save_mock, mock.patch.object( - trainer._checkpoint_connector, "dump_checkpoint", return_value=checkpoint_mock - ) as dump_mock: + with ( + mock.patch.object(trainer.strategy, "save_checkpoint") as save_mock, + mock.patch.object(trainer._checkpoint_connector, "dump_checkpoint", return_value=checkpoint_mock) as dump_mock, + ): trainer.save_checkpoint(instance_path, True) dump_mock.assert_called_with(True) save_mock.assert_called_with(checkpoint_mock, instance_path, storage_options=None) diff --git a/tests/tests_pytorch/core/test_lightning_optimizer.py b/tests/tests_pytorch/core/test_lightning_optimizer.py index 8ab6eca907..b25b7ae648 100644 --- a/tests/tests_pytorch/core/test_lightning_optimizer.py +++ b/tests/tests_pytorch/core/test_lightning_optimizer.py @@ -110,9 +110,10 @@ def test_lightning_optimizer_manual_optimization_and_accumulated_gradients(tmp_p default_root_dir=tmp_path, limit_train_batches=8, limit_val_batches=1, max_epochs=1, enable_model_summary=False ) - with patch.multiple(torch.optim.SGD, zero_grad=DEFAULT, step=DEFAULT) as sgd, patch.multiple( - torch.optim.Adam, zero_grad=DEFAULT, step=DEFAULT - ) as adam: + with ( + patch.multiple(torch.optim.SGD, zero_grad=DEFAULT, step=DEFAULT) as sgd, + patch.multiple(torch.optim.Adam, zero_grad=DEFAULT, step=DEFAULT) as adam, + ): trainer.fit(model) assert sgd["step"].call_count == 4 diff --git a/tests/tests_pytorch/core/test_metric_result_integration.py b/tests/tests_pytorch/core/test_metric_result_integration.py index 004d979fd1..dcb3f71c74 100644 --- a/tests/tests_pytorch/core/test_metric_result_integration.py +++ b/tests/tests_pytorch/core/test_metric_result_integration.py @@ -625,8 +625,9 @@ def test_logger_sync_dist(distributed_env, log_val): else nullcontext() ) - with warning_ctx( - PossibleUserWarning, match=r"recommended to use `self.log\('bar', ..., sync_dist=True\)`" - ), patch_ctx: + with ( + warning_ctx(PossibleUserWarning, match=r"recommended to use `self.log\('bar', ..., sync_dist=True\)`"), + patch_ctx, + ): value = _ResultCollection._get_cache(result_metric, on_step=False) assert value == 0.5 diff --git a/tests/tests_pytorch/helpers/datasets.py b/tests/tests_pytorch/helpers/datasets.py index 9b1d4ec735..0519680e7d 100644 --- a/tests/tests_pytorch/helpers/datasets.py +++ b/tests/tests_pytorch/helpers/datasets.py @@ -16,7 +16,8 @@ import os import random import time import urllib.request -from typing import Optional, Sequence, Tuple +from collections.abc import Sequence +from typing import Optional, Tuple import torch from torch import Tensor diff --git a/tests/tests_pytorch/loops/optimization/test_automatic_loop.py b/tests/tests_pytorch/loops/optimization/test_automatic_loop.py index 0ea6290586..a5f51783ca 100644 --- a/tests/tests_pytorch/loops/optimization/test_automatic_loop.py +++ b/tests/tests_pytorch/loops/optimization/test_automatic_loop.py @@ -11,8 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Iterator, Mapping from contextlib import nullcontext -from typing import Dict, Generic, Iterator, Mapping, TypeVar +from typing import Dict, Generic, TypeVar import pytest import torch diff --git a/tests/tests_pytorch/loops/test_fetchers.py b/tests/tests_pytorch/loops/test_fetchers.py index 763a6ded14..75b25e3d98 100644 --- a/tests/tests_pytorch/loops/test_fetchers.py +++ b/tests/tests_pytorch/loops/test_fetchers.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections import Counter -from typing import Any, Iterator +from collections.abc import Iterator +from typing import Any import pytest import torch diff --git a/tests/tests_pytorch/loops/test_loops.py b/tests/tests_pytorch/loops/test_loops.py index ff317cd2e1..44a3d0f60c 100644 --- a/tests/tests_pytorch/loops/test_loops.py +++ b/tests/tests_pytorch/loops/test_loops.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +from collections.abc import Iterator from copy import deepcopy from dataclasses import dataclass -from typing import Dict, Iterator +from typing import Dict from unittest.mock import ANY, Mock import pytest diff --git a/tests/tests_pytorch/models/test_restore.py b/tests/tests_pytorch/models/test_restore.py index fe7e3fbbab..64f70b176a 100644 --- a/tests/tests_pytorch/models/test_restore.py +++ b/tests/tests_pytorch/models/test_restore.py @@ -15,8 +15,9 @@ import glob import logging as log import os import pickle +from collections.abc import Mapping from copy import deepcopy -from typing import Generic, Mapping, TypeVar +from typing import Generic, TypeVar import cloudpickle import pytest diff --git a/tests/tests_pytorch/overrides/test_distributed.py b/tests/tests_pytorch/overrides/test_distributed.py index 29eb6d6d6d..3e2fba54bc 100644 --- a/tests/tests_pytorch/overrides/test_distributed.py +++ b/tests/tests_pytorch/overrides/test_distributed.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Iterable +from collections.abc import Iterable import pytest import torch diff --git a/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py b/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py index b0462c0105..394d827058 100644 --- a/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py +++ b/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py @@ -209,10 +209,13 @@ def test_memory_sharing_disabled(tmp_path): def test_check_for_missing_main_guard(): launcher = _MultiProcessingLauncher(strategy=Mock(), start_method="spawn") - with mock.patch( - "lightning.pytorch.strategies.launchers.multiprocessing.mp.current_process", - return_value=Mock(_inheriting=True), # pretend that main is importing itself - ), pytest.raises(RuntimeError, match="requires that your script guards the main"): + with ( + mock.patch( + "lightning.pytorch.strategies.launchers.multiprocessing.mp.current_process", + return_value=Mock(_inheriting=True), # pretend that main is importing itself + ), + pytest.raises(RuntimeError, match="requires that your script guards the main"), + ): launcher.launch(function=Mock()) diff --git a/tests/tests_pytorch/strategies/test_custom_strategy.py b/tests/tests_pytorch/strategies/test_custom_strategy.py index 7f7d018f63..347dacbd9a 100644 --- a/tests/tests_pytorch/strategies/test_custom_strategy.py +++ b/tests/tests_pytorch/strategies/test_custom_strategy.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from typing import Any, Mapping +from collections.abc import Mapping +from typing import Any import pytest import torch diff --git a/tests/tests_pytorch/strategies/test_fsdp.py b/tests/tests_pytorch/strategies/test_fsdp.py index aec01b83e9..2aee68f7ae 100644 --- a/tests/tests_pytorch/strategies/test_fsdp.py +++ b/tests/tests_pytorch/strategies/test_fsdp.py @@ -444,9 +444,12 @@ def test_activation_checkpointing(): strategy._parallel_devices = [torch.device("cuda", 0)] strategy._lightning_module = model strategy._process_group = Mock() - with mock.patch("torch.distributed.fsdp.FullyShardedDataParallel", new=MagicMock), mock.patch( - "torch.distributed.algorithms._checkpoint.checkpoint_wrapper.apply_activation_checkpointing" - ) as apply_mock: + with ( + mock.patch("torch.distributed.fsdp.FullyShardedDataParallel", new=MagicMock), + mock.patch( + "torch.distributed.algorithms._checkpoint.checkpoint_wrapper.apply_activation_checkpointing" + ) as apply_mock, + ): wrapped = strategy._setup_model(model) apply_mock.assert_called_with(wrapped, checkpoint_wrapper_fn=ANY, **strategy._activation_checkpointing_kwargs) diff --git a/tests/tests_pytorch/test_cli.py b/tests/tests_pytorch/test_cli.py index 56b58d4d15..688dfff962 100644 --- a/tests/tests_pytorch/test_cli.py +++ b/tests/tests_pytorch/test_cli.py @@ -409,8 +409,9 @@ def test_lightning_cli_config_and_subclass_mode(cleandir): with open(config_path, "w") as f: f.write(yaml.dump(input_config)) - with mock.patch("sys.argv", ["any.py", "--config", config_path]), mock_subclasses( - LightningDataModule, DataDirDataModule + with ( + mock.patch("sys.argv", ["any.py", "--config", config_path]), + mock_subclasses(LightningDataModule, DataDirDataModule), ): cli = LightningCLI( BoringModel, @@ -461,9 +462,12 @@ def test_lightning_cli_help(): cli_args = ["any.py", "fit", "--data.help=DataDirDataModule"] out = StringIO() - with mock.patch("sys.argv", cli_args), redirect_stdout(out), mock_subclasses( - LightningDataModule, DataDirDataModule - ), pytest.raises(SystemExit): + with ( + mock.patch("sys.argv", cli_args), + redirect_stdout(out), + mock_subclasses(LightningDataModule, DataDirDataModule), + pytest.raises(SystemExit), + ): any_model_any_data_cli() assert ("--data.data_dir" in out.getvalue()) or ("--data.init_args.data_dir" in out.getvalue()) @@ -609,8 +613,9 @@ class EarlyExitTestModel(BoringModel): def test_cli_distributed_save_config_callback(cleandir, logger, strategy): from torch.multiprocessing import ProcessRaisedException - with mock.patch("sys.argv", ["any.py", "fit"]), pytest.raises( - (MisconfigurationException, ProcessRaisedException), match=r"Error on fit start" + with ( + mock.patch("sys.argv", ["any.py", "fit"]), + pytest.raises((MisconfigurationException, ProcessRaisedException), match=r"Error on fit start"), ): LightningCLI( EarlyExitTestModel, @@ -710,12 +715,14 @@ def test_cli_no_need_configure_optimizers(cleandir): from lightning.pytorch.trainer.configuration_validator import __verify_train_val_loop_configuration - with mock.patch("sys.argv", ["any.py", "fit", "--optimizer=Adam"]), mock.patch( - "lightning.pytorch.Trainer._run_stage" - ) as run, mock.patch( - "lightning.pytorch.trainer.configuration_validator.__verify_train_val_loop_configuration", - wraps=__verify_train_val_loop_configuration, - ) as verify: + with ( + mock.patch("sys.argv", ["any.py", "fit", "--optimizer=Adam"]), + mock.patch("lightning.pytorch.Trainer._run_stage") as run, + mock.patch( + "lightning.pytorch.trainer.configuration_validator.__verify_train_val_loop_configuration", + wraps=__verify_train_val_loop_configuration, + ) as verify, + ): cli = LightningCLI(BoringModel) run.assert_called_once() verify.assert_called_once_with(cli.trainer, cli.model) @@ -1074,15 +1081,18 @@ class TestModel(BoringModel): @_xfail_python_ge_3_11_9 def test_lightning_cli_model_short_arguments(): - with mock.patch("sys.argv", ["any.py", "fit", "--model=BoringModel"]), mock.patch( - "lightning.pytorch.Trainer._fit_impl" - ) as run, mock_subclasses(LightningModule, BoringModel, TestModel): + with ( + mock.patch("sys.argv", ["any.py", "fit", "--model=BoringModel"]), + mock.patch("lightning.pytorch.Trainer._fit_impl") as run, + mock_subclasses(LightningModule, BoringModel, TestModel), + ): cli = LightningCLI(trainer_defaults={"fast_dev_run": 1}) assert isinstance(cli.model, BoringModel) run.assert_called_once_with(cli.model, ANY, ANY, ANY, ANY) - with mock.patch("sys.argv", ["any.py", "--model=TestModel", "--model.foo", "123"]), mock_subclasses( - LightningModule, BoringModel, TestModel + with ( + mock.patch("sys.argv", ["any.py", "--model=TestModel", "--model.foo", "123"]), + mock_subclasses(LightningModule, BoringModel, TestModel), ): cli = LightningCLI(run=False) assert isinstance(cli.model, TestModel) @@ -1100,15 +1110,18 @@ class MyDataModule(BoringDataModule): @_xfail_python_ge_3_11_9 def test_lightning_cli_datamodule_short_arguments(): # with set model - with mock.patch("sys.argv", ["any.py", "fit", "--data=BoringDataModule"]), mock.patch( - "lightning.pytorch.Trainer._fit_impl" - ) as run, mock_subclasses(LightningDataModule, BoringDataModule): + with ( + mock.patch("sys.argv", ["any.py", "fit", "--data=BoringDataModule"]), + mock.patch("lightning.pytorch.Trainer._fit_impl") as run, + mock_subclasses(LightningDataModule, BoringDataModule), + ): cli = LightningCLI(BoringModel, trainer_defaults={"fast_dev_run": 1}) assert isinstance(cli.datamodule, BoringDataModule) run.assert_called_once_with(ANY, ANY, ANY, cli.datamodule, ANY) - with mock.patch("sys.argv", ["any.py", "--data=MyDataModule", "--data.foo", "123"]), mock_subclasses( - LightningDataModule, MyDataModule + with ( + mock.patch("sys.argv", ["any.py", "--data=MyDataModule", "--data.foo", "123"]), + mock_subclasses(LightningDataModule, MyDataModule), ): cli = LightningCLI(BoringModel, run=False) assert isinstance(cli.datamodule, MyDataModule) @@ -1116,17 +1129,22 @@ def test_lightning_cli_datamodule_short_arguments(): assert cli.datamodule.bar == 5 # with configurable model - with mock.patch("sys.argv", ["any.py", "fit", "--model", "BoringModel", "--data=BoringDataModule"]), mock.patch( - "lightning.pytorch.Trainer._fit_impl" - ) as run, mock_subclasses(LightningModule, BoringModel), mock_subclasses(LightningDataModule, BoringDataModule): + with ( + mock.patch("sys.argv", ["any.py", "fit", "--model", "BoringModel", "--data=BoringDataModule"]), + mock.patch("lightning.pytorch.Trainer._fit_impl") as run, + mock_subclasses(LightningModule, BoringModel), + mock_subclasses(LightningDataModule, BoringDataModule), + ): cli = LightningCLI(trainer_defaults={"fast_dev_run": 1}) assert isinstance(cli.model, BoringModel) assert isinstance(cli.datamodule, BoringDataModule) run.assert_called_once_with(cli.model, ANY, ANY, cli.datamodule, ANY) - with mock.patch("sys.argv", ["any.py", "--model", "BoringModel", "--data=MyDataModule"]), mock_subclasses( - LightningModule, BoringModel - ), mock_subclasses(LightningDataModule, MyDataModule): + with ( + mock.patch("sys.argv", ["any.py", "--model", "BoringModel", "--data=MyDataModule"]), + mock_subclasses(LightningModule, BoringModel), + mock_subclasses(LightningDataModule, MyDataModule), + ): cli = LightningCLI(run=False) assert isinstance(cli.model, BoringModel) assert isinstance(cli.datamodule, MyDataModule) @@ -1293,9 +1311,10 @@ def test_optimizers_and_lr_schedulers_add_arguments_to_parser_implemented_reload def test_lightning_cli_config_with_subcommand(): config = {"test": {"trainer": {"limit_test_batches": 1}, "verbose": True, "ckpt_path": "foobar"}} - with mock.patch("sys.argv", ["any.py", f"--config={config}"]), mock.patch( - "lightning.pytorch.Trainer.test", autospec=True - ) as test_mock: + with ( + mock.patch("sys.argv", ["any.py", f"--config={config}"]), + mock.patch("lightning.pytorch.Trainer.test", autospec=True) as test_mock, + ): cli = LightningCLI(BoringModel) test_mock.assert_called_once_with(cli.trainer, cli.model, verbose=True, ckpt_path="foobar") @@ -1308,9 +1327,10 @@ def test_lightning_cli_config_before_subcommand(): "test": {"trainer": {"limit_test_batches": 1}, "verbose": True, "ckpt_path": "foobar"}, } - with mock.patch("sys.argv", ["any.py", f"--config={config}", "test"]), mock.patch( - "lightning.pytorch.Trainer.test", autospec=True - ) as test_mock: + with ( + mock.patch("sys.argv", ["any.py", f"--config={config}", "test"]), + mock.patch("lightning.pytorch.Trainer.test", autospec=True) as test_mock, + ): cli = LightningCLI(BoringModel) test_mock.assert_called_once_with(cli.trainer, model=cli.model, verbose=True, ckpt_path="foobar") @@ -1320,9 +1340,10 @@ def test_lightning_cli_config_before_subcommand(): assert save_config_callback.config.trainer.limit_test_batches == 1 assert save_config_callback.parser.subcommand == "test" - with mock.patch("sys.argv", ["any.py", f"--config={config}", "validate"]), mock.patch( - "lightning.pytorch.Trainer.validate", autospec=True - ) as validate_mock: + with ( + mock.patch("sys.argv", ["any.py", f"--config={config}", "validate"]), + mock.patch("lightning.pytorch.Trainer.validate", autospec=True) as validate_mock, + ): cli = LightningCLI(BoringModel) validate_mock.assert_called_once_with(cli.trainer, cli.model, verbose=False, ckpt_path="barfoo") @@ -1337,17 +1358,19 @@ def test_lightning_cli_config_before_subcommand_two_configs(): config1 = {"validate": {"trainer": {"limit_val_batches": 1}, "verbose": False, "ckpt_path": "barfoo"}} config2 = {"test": {"trainer": {"limit_test_batches": 1}, "verbose": True, "ckpt_path": "foobar"}} - with mock.patch("sys.argv", ["any.py", f"--config={config1}", f"--config={config2}", "test"]), mock.patch( - "lightning.pytorch.Trainer.test", autospec=True - ) as test_mock: + with ( + mock.patch("sys.argv", ["any.py", f"--config={config1}", f"--config={config2}", "test"]), + mock.patch("lightning.pytorch.Trainer.test", autospec=True) as test_mock, + ): cli = LightningCLI(BoringModel) test_mock.assert_called_once_with(cli.trainer, model=cli.model, verbose=True, ckpt_path="foobar") assert cli.trainer.limit_test_batches == 1 - with mock.patch("sys.argv", ["any.py", f"--config={config1}", f"--config={config2}", "validate"]), mock.patch( - "lightning.pytorch.Trainer.validate", autospec=True - ) as validate_mock: + with ( + mock.patch("sys.argv", ["any.py", f"--config={config1}", f"--config={config2}", "validate"]), + mock.patch("lightning.pytorch.Trainer.validate", autospec=True) as validate_mock, + ): cli = LightningCLI(BoringModel) validate_mock.assert_called_once_with(cli.trainer, cli.model, verbose=False, ckpt_path="barfoo") @@ -1356,9 +1379,10 @@ def test_lightning_cli_config_before_subcommand_two_configs(): def test_lightning_cli_config_after_subcommand(): config = {"trainer": {"limit_test_batches": 1}, "verbose": True, "ckpt_path": "foobar"} - with mock.patch("sys.argv", ["any.py", "test", f"--config={config}"]), mock.patch( - "lightning.pytorch.Trainer.test", autospec=True - ) as test_mock: + with ( + mock.patch("sys.argv", ["any.py", "test", f"--config={config}"]), + mock.patch("lightning.pytorch.Trainer.test", autospec=True) as test_mock, + ): cli = LightningCLI(BoringModel) test_mock.assert_called_once_with(cli.trainer, cli.model, verbose=True, ckpt_path="foobar") @@ -1368,9 +1392,10 @@ def test_lightning_cli_config_after_subcommand(): def test_lightning_cli_config_before_and_after_subcommand(): config1 = {"test": {"trainer": {"limit_test_batches": 1}, "verbose": True, "ckpt_path": "foobar"}} config2 = {"trainer": {"fast_dev_run": 1}, "verbose": False, "ckpt_path": "foobar"} - with mock.patch("sys.argv", ["any.py", f"--config={config1}", "test", f"--config={config2}"]), mock.patch( - "lightning.pytorch.Trainer.test", autospec=True - ) as test_mock: + with ( + mock.patch("sys.argv", ["any.py", f"--config={config1}", "test", f"--config={config2}"]), + mock.patch("lightning.pytorch.Trainer.test", autospec=True) as test_mock, + ): cli = LightningCLI(BoringModel) test_mock.assert_called_once_with(cli.trainer, model=cli.model, verbose=False, ckpt_path="foobar") @@ -1392,17 +1417,19 @@ def test_lightning_cli_parse_kwargs_with_subcommands(cleandir): "validate": {"default_config_files": [str(validate_config_path)]}, } - with mock.patch("sys.argv", ["any.py", "fit"]), mock.patch( - "lightning.pytorch.Trainer.fit", autospec=True - ) as fit_mock: + with ( + mock.patch("sys.argv", ["any.py", "fit"]), + mock.patch("lightning.pytorch.Trainer.fit", autospec=True) as fit_mock, + ): cli = LightningCLI(BoringModel, parser_kwargs=parser_kwargs) fit_mock.assert_called() assert cli.trainer.limit_train_batches == 2 assert cli.trainer.limit_val_batches == 1.0 - with mock.patch("sys.argv", ["any.py", "validate"]), mock.patch( - "lightning.pytorch.Trainer.validate", autospec=True - ) as validate_mock: + with ( + mock.patch("sys.argv", ["any.py", "validate"]), + mock.patch("lightning.pytorch.Trainer.validate", autospec=True) as validate_mock, + ): cli = LightningCLI(BoringModel, parser_kwargs=parser_kwargs) validate_mock.assert_called() assert cli.trainer.limit_train_batches == 1.0 @@ -1420,9 +1447,10 @@ def test_lightning_cli_subcommands_common_default_config_files(cleandir): config_path.write_text(str(config)) parser_kwargs = {"default_config_files": [str(config_path)]} - with mock.patch("sys.argv", ["any.py", "fit"]), mock.patch( - "lightning.pytorch.Trainer.fit", autospec=True - ) as fit_mock: + with ( + mock.patch("sys.argv", ["any.py", "fit"]), + mock.patch("lightning.pytorch.Trainer.fit", autospec=True) as fit_mock, + ): cli = LightningCLI(Model, parser_kwargs=parser_kwargs) fit_mock.assert_called() assert cli.model.foo == 123 diff --git a/tests/tests_pytorch/trainer/connectors/test_data_connector.py b/tests/tests_pytorch/trainer/connectors/test_data_connector.py index a820a3d6ee..ca5690ed20 100644 --- a/tests/tests_pytorch/trainer/connectors/test_data_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_data_connector.py @@ -11,8 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Sized from re import escape -from typing import Sized from unittest import mock from unittest.mock import Mock diff --git a/tests/tests_pytorch/trainer/optimization/test_optimizers.py b/tests/tests_pytorch/trainer/optimization/test_optimizers.py index 451557d084..ac660b6651 100644 --- a/tests/tests_pytorch/trainer/optimization/test_optimizers.py +++ b/tests/tests_pytorch/trainer/optimization/test_optimizers.py @@ -592,9 +592,10 @@ def test_lr_scheduler_step_hook(tmp_path): limit_train_batches=limit_train_batches, limit_val_batches=0, ) - with mock.patch.object(CustomEpochScheduler, "step") as mock_method_epoch, mock.patch.object( - torch.optim.lr_scheduler.StepLR, "step" - ) as mock_method_step: + with ( + mock.patch.object(CustomEpochScheduler, "step") as mock_method_epoch, + mock.patch.object(torch.optim.lr_scheduler.StepLR, "step") as mock_method_step, + ): trainer.fit(model) assert mock_method_epoch.mock_calls == [call(epoch=e) for e in range(max_epochs)] diff --git a/tests/tests_pytorch/trainer/test_trainer.py b/tests/tests_pytorch/trainer/test_trainer.py index 8946fb4ed9..d66f3aafee 100644 --- a/tests/tests_pytorch/trainer/test_trainer.py +++ b/tests/tests_pytorch/trainer/test_trainer.py @@ -1887,8 +1887,9 @@ def test_detect_anomaly_nan(tmp_path): model = NanModel() trainer = Trainer(default_root_dir=tmp_path, detect_anomaly=True) - with pytest.raises(RuntimeError, match=r"returned nan values in its 0th output."), pytest.warns( - UserWarning, match=r".*Error detected in.* Traceback of forward call that caused the error.*" + with ( + pytest.raises(RuntimeError, match=r"returned nan values in its 0th output."), + pytest.warns(UserWarning, match=r".*Error detected in.* Traceback of forward call that caused the error.*"), ): trainer.fit(model) @@ -2067,8 +2068,9 @@ def test_trainer_calls_strategy_on_exception(exception_type, tmp_path): raise exception trainer = Trainer(default_root_dir=tmp_path) - with mock.patch("lightning.pytorch.strategies.strategy.Strategy.on_exception") as on_exception_mock, suppress( - Exception, SystemExit + with ( + mock.patch("lightning.pytorch.strategies.strategy.Strategy.on_exception") as on_exception_mock, + suppress(Exception, SystemExit), ): trainer.fit(ExceptionModel()) on_exception_mock.assert_called_once_with(exception) diff --git a/tests/tests_pytorch/utilities/test_combined_loader.py b/tests/tests_pytorch/utilities/test_combined_loader.py index 74f5c1330a..43a146c6eb 100644 --- a/tests/tests_pytorch/utilities/test_combined_loader.py +++ b/tests/tests_pytorch/utilities/test_combined_loader.py @@ -13,7 +13,8 @@ # limitations under the License. import math import pickle -from typing import Any, NamedTuple, Sequence, get_args +from collections.abc import Sequence +from typing import Any, NamedTuple, get_args from unittest.mock import Mock import pytest diff --git a/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py b/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py index 1bdac616e7..2ef1ecd4fe 100644 --- a/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py +++ b/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py @@ -33,8 +33,9 @@ def test_upgrade_checkpoint_file_missing(tmp_path, caplog): # path to non-empty directory, but no checkpoints with matching extension file.touch() - with mock.patch("sys.argv", ["upgrade_checkpoint.py", str(tmp_path), "--extension", ".other"]), caplog.at_level( - logging.ERROR + with ( + mock.patch("sys.argv", ["upgrade_checkpoint.py", str(tmp_path), "--extension", ".other"]), + caplog.at_level(logging.ERROR), ): with pytest.raises(SystemExit): upgrade_main()