[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot] 2024-11-12 20:29:03 +00:00
parent 060d951605
commit c11c392acd
105 changed files with 408 additions and 259 deletions

View File

@ -18,10 +18,11 @@ import re
import shutil
import tempfile
import urllib.request
from collections.abc import Iterable, Iterator, Sequence
from itertools import chain
from os.path import dirname, isfile
from pathlib import Path
from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Tuple
from typing import Any, Dict, List, Optional, Tuple
from packaging.requirements import Requirement
from packaging.version import Version

View File

@ -1,7 +1,7 @@
import os
from collections.abc import Mapping
from collections.abc import Iterable, Mapping
from functools import partial
from typing import Any, Iterable, List, Literal, Optional, Tuple, Union, cast
from typing import Any, List, Literal, Optional, Tuple, Union, cast
import lightning as L
import torch

View File

@ -35,7 +35,8 @@ Second-Edition/blob/master/Chapter06/02_dqn_pong.py
import argparse
import random
from collections import OrderedDict, deque, namedtuple
from typing import Iterator, List, Tuple
from collections.abc import Iterator
from typing import List, Tuple
import gym
import torch

View File

@ -30,7 +30,8 @@ References
"""
import argparse
from typing import Callable, Iterator, List, Tuple
from collections.abc import Iterator
from typing import Callable, List, Tuple
import gym
import torch

View File

@ -45,9 +45,10 @@ import glob
import logging
import os
import tempfile
from collections.abc import Generator, Mapping
from importlib.util import module_from_spec, spec_from_file_location
from types import ModuleType
from typing import Generator, Mapping, Optional
from typing import Optional
import setuptools
import setuptools.command.egg_info

View File

@ -13,6 +13,7 @@
# limitations under the License.
import inspect
import os
from collections.abc import Generator, Mapping, Sequence
from contextlib import contextmanager, nullcontext
from functools import partial
from pathlib import Path
@ -21,11 +22,8 @@ from typing import (
Callable,
ContextManager,
Dict,
Generator,
List,
Mapping,
Optional,
Sequence,
Tuple,
Union,
cast,

View File

@ -14,7 +14,8 @@
import os
from argparse import Namespace
from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Union
from collections.abc import Mapping
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
from lightning_utilities.core.imports import RequirementCache
from torch import Tensor

View File

@ -16,10 +16,11 @@ import logging
import math
import os
import warnings
from collections import OrderedDict
from contextlib import ExitStack
from functools import partial
from types import ModuleType
from typing import Any, Callable, ContextManager, Literal, Optional, OrderedDict, Set, Tuple, Type, cast
from typing import Any, Callable, ContextManager, Literal, Optional, Set, Tuple, Type, cast
import torch
from lightning_utilities import apply_to_collection

View File

@ -12,8 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from collections.abc import Mapping
from contextlib import ExitStack
from typing import TYPE_CHECKING, Any, ContextManager, Literal, Mapping, Optional, Union
from typing import TYPE_CHECKING, Any, ContextManager, Literal, Optional, Union
import torch
from lightning_utilities import apply_to_collection

View File

@ -11,7 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Mapping, Type, Union
from collections.abc import Mapping
from typing import Any, Type, Union
import torch
from torch import Tensor

View File

@ -16,10 +16,11 @@ import json
import logging
import os
import platform
from collections.abc import Mapping
from contextlib import ExitStack
from itertools import chain
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Dict, List, Mapping, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Dict, List, Optional, Tuple, Union
import torch
from lightning_utilities.core.imports import RequirementCache

View File

@ -13,6 +13,7 @@
# limitations under the License.
import shutil
import warnings
from collections.abc import Generator
from contextlib import ExitStack, nullcontext
from datetime import timedelta
from functools import partial
@ -23,7 +24,6 @@ from typing import (
Callable,
ContextManager,
Dict,
Generator,
List,
Literal,
Optional,

View File

@ -18,7 +18,8 @@ import subprocess
import sys
import threading
import time
from typing import Any, Callable, List, Optional, Sequence, Tuple
from collections.abc import Sequence
from typing import Any, Callable, List, Optional, Tuple
from lightning_utilities.core.imports import RequirementCache
from typing_extensions import override

View File

@ -13,10 +13,11 @@
# limitations under the License.
import itertools
import shutil
from collections.abc import Generator
from contextlib import ExitStack
from datetime import timedelta
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Dict, Generator, Literal, Optional, TypeVar, Union
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Dict, Literal, Optional, TypeVar, Union
import torch
from lightning_utilities.core.rank_zero import rank_zero_only as utils_rank_zero_only

View File

@ -13,8 +13,9 @@
# limitations under the License.
import logging
from abc import ABC, abstractmethod
from collections.abc import Iterable
from contextlib import ExitStack
from typing import Any, Callable, ContextManager, Dict, Iterable, List, Optional, Tuple, TypeVar, Union
from typing import Any, Callable, ContextManager, Dict, List, Optional, Tuple, TypeVar, Union
import torch
from torch import Tensor

View File

@ -16,9 +16,10 @@ import functools
import inspect
import os
from collections import OrderedDict
from collections.abc import Generator, Iterable, Sized
from contextlib import contextmanager
from functools import partial
from typing import Any, Callable, Dict, Generator, Iterable, Optional, Sized, Tuple, Type, Union
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
from lightning_utilities.core.inheritance import get_all_subclasses
from torch.utils.data import BatchSampler, DataLoader, IterableDataset, Sampler

View File

@ -11,7 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, MutableSequence, Optional, Tuple, Union
from collections.abc import MutableSequence
from typing import List, Optional, Tuple, Union
import torch

View File

@ -4,10 +4,11 @@ import logging
import os
import signal
import time
from collections.abc import Iterable, Iterator, Sized
from contextlib import nullcontext
from datetime import timedelta
from pathlib import Path
from typing import TYPE_CHECKING, Any, Iterable, Iterator, List, Optional, Sized, Union
from typing import TYPE_CHECKING, Any, List, Optional, Union
import torch
import torch.nn.functional as F

View File

@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
from typing import Any, Callable, Dict, Optional, Sequence, Union
from collections.abc import Sequence
from typing import Any, Callable, Dict, Optional, Union
import torch
from torch.nn import Module, Parameter

View File

@ -13,10 +13,12 @@
import os
import pickle
import warnings
from collections import OrderedDict
from collections.abc import Sequence
from functools import partial
from io import BytesIO
from pathlib import Path
from typing import IO, TYPE_CHECKING, Any, Callable, Dict, Optional, OrderedDict, Sequence, Set, Union
from typing import IO, TYPE_CHECKING, Any, Callable, Dict, Optional, Set, Union
import torch
from lightning_utilities.core.apply_func import apply_to_collection

View File

@ -15,8 +15,9 @@
import inspect
import json
from argparse import Namespace
from collections.abc import Mapping, MutableMapping
from dataclasses import asdict, is_dataclass
from typing import Any, Dict, Mapping, MutableMapping, Optional, Union
from typing import Any, Dict, Optional, Union
from torch import Tensor

View File

@ -12,8 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import MutableMapping
from typing import Iterable
from collections.abc import Iterable, MutableMapping
from torch import Tensor
from torch.optim import Optimizer

View File

@ -11,13 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Iterator
from pathlib import Path
from typing import (
Any,
Callable,
DefaultDict,
Dict,
Iterator,
List,
Optional,
Protocol,

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from collections.abc import Generator, Iterator, Mapping
from copy import deepcopy
from functools import partial, wraps
from types import MethodType
@ -19,10 +20,7 @@ from typing import (
Any,
Callable,
Dict,
Generator,
Iterator,
List,
Mapping,
Optional,
Tuple,
TypeVar,

View File

@ -19,7 +19,8 @@ Freeze and unfreeze models for finetuning purposes.
"""
import logging
from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Union
from collections.abc import Generator, Iterable
from typing import Any, Callable, Dict, List, Optional, Union
import torch
from torch.nn import Module, ModuleDict

View File

@ -18,7 +18,8 @@ BasePredictionWriter
Aids in saving predictions
"""
from typing import Any, Literal, Optional, Sequence
from collections.abc import Sequence
from typing import Any, Literal, Optional
from typing_extensions import override

View File

@ -12,9 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from collections.abc import Generator
from dataclasses import dataclass
from datetime import timedelta
from typing import Any, Dict, Generator, Optional, Union, cast
from typing import Any, Dict, Optional, Union, cast
from lightning_utilities.core.imports import RequirementCache
from typing_extensions import override

View File

@ -18,9 +18,10 @@ ModelPruning
import inspect
import logging
from collections.abc import Sequence
from copy import deepcopy
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch.nn.utils.prune as pytorch_prune
from lightning_utilities.core.apply_func import apply_to_collection

View File

@ -1,5 +1,6 @@
import os
from typing import Any, Mapping, Union
from collections.abc import Mapping
from typing import Any, Union
import torch

View File

@ -14,9 +14,10 @@
import inspect
import os
import sys
from collections.abc import Iterable
from functools import partial, update_wrapper
from types import MethodType
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, TypeVar, Union
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union
import torch
import yaml

View File

@ -14,7 +14,8 @@
"""LightningDataModule for loading DataLoaders with ease."""
import inspect
from typing import IO, Any, Dict, Iterable, Optional, Union, cast
from collections.abc import Iterable
from typing import IO, Any, Dict, Optional, Union, cast
from lightning_utilities import apply_to_collection
from torch.utils.data import DataLoader, Dataset, IterableDataset

View File

@ -15,9 +15,10 @@ import copy
import inspect
import types
from argparse import Namespace
from collections.abc import Iterator, MutableMapping, Sequence
from contextlib import contextmanager
from contextvars import ContextVar
from typing import Any, Iterator, List, MutableMapping, Optional, Sequence, Union
from typing import Any, List, Optional, Union
from lightning.fabric.utilities.data import AttributeDict
from lightning.pytorch.utilities.parsing import save_hyperparameters

View File

@ -16,6 +16,7 @@
import logging
import numbers
import weakref
from collections.abc import Generator, Mapping, Sequence
from contextlib import contextmanager
from io import BytesIO
from pathlib import Path
@ -25,12 +26,9 @@ from typing import (
Any,
Callable,
Dict,
Generator,
List,
Literal,
Mapping,
Optional,
Sequence,
Tuple,
Union,
cast,

View File

@ -11,9 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Generator
from contextlib import contextmanager
from dataclasses import fields
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union, overload
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, overload
from weakref import proxy
import torch

View File

@ -11,7 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Iterator, List, Optional, Tuple
from collections.abc import Iterator
from typing import Any, Dict, List, Optional, Tuple
import torch
import torch.nn as nn

View File

@ -5,7 +5,8 @@ https://github.com/pytorch/examples/blob/main/word_language_model
"""
from typing import Iterator, List, Optional, Sized, Tuple
from collections.abc import Iterator, Sized
from typing import List, Optional, Tuple
import torch
import torch.nn as nn

View File

@ -16,7 +16,8 @@ import os
import random
import time
import urllib
from typing import Any, Callable, Optional, Sized, Tuple, Union
from collections.abc import Sized
from typing import Any, Callable, Optional, Tuple, Union
from urllib.error import HTTPError
from warnings import warn

View File

@ -19,7 +19,8 @@ Comet Logger
import logging
import os
from argparse import Namespace
from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Union
from collections.abc import Mapping
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
from lightning_utilities.core.imports import RequirementCache
from torch import Tensor

View File

@ -18,7 +18,8 @@ import operator
import statistics
from abc import ABC
from collections import defaultdict
from typing import Any, Callable, Dict, Mapping, Optional, Sequence
from collections.abc import Mapping, Sequence
from typing import Any, Callable, Dict, Optional
from typing_extensions import override

View File

@ -21,9 +21,10 @@ import os
import re
import tempfile
from argparse import Namespace
from collections.abc import Mapping
from pathlib import Path
from time import time
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Mapping, Optional, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Union
import yaml
from lightning_utilities.core.imports import RequirementCache

View File

@ -20,8 +20,9 @@ import contextlib
import logging
import os
from argparse import Namespace
from collections.abc import Generator
from functools import wraps
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Optional, Set, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Union
from lightning_utilities.core.imports import RequirementCache
from torch import Tensor

View File

@ -18,8 +18,9 @@ Weights and Biases Logger
import os
from argparse import Namespace
from collections.abc import Mapping
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union
import torch.nn as nn
from lightning_utilities.core.imports import RequirementCache

View File

@ -15,7 +15,8 @@ import os
import shutil
import sys
from collections import ChainMap, OrderedDict, defaultdict
from typing import Any, DefaultDict, Iterable, Iterator, List, Optional, Tuple, Union
from collections.abc import Iterable, Iterator
from typing import Any, DefaultDict, List, Optional, Tuple, Union
from lightning_utilities.core.apply_func import apply_to_collection
from torch import Tensor

View File

@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Iterator, List, Optional
from collections.abc import Iterator
from typing import Any, List, Optional
from typing_extensions import override

View File

@ -11,9 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import OrderedDict
from collections.abc import Mapping
from dataclasses import dataclass, field
from functools import partial
from typing import Any, Callable, Dict, Mapping, Optional, OrderedDict
from typing import Any, Callable, Dict, Optional
import torch
from torch import Tensor

View File

@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import OrderedDict
from typing import Any, Iterator, List, Optional, Union
from collections.abc import Iterator
from typing import Any, List, Optional, Union
import torch
from lightning_utilities import WarningCache

View File

@ -12,8 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from collections.abc import Generator
from contextlib import contextmanager
from typing import Any, Callable, ContextManager, Generator, Optional, Tuple, Type
from typing import Any, Callable, ContextManager, Optional, Tuple, Type
import torch
import torch.distributed as dist

View File

@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Sized, Union, cast
from collections.abc import Iterable, Iterator, Sized
from typing import Any, Callable, Dict, List, Optional, Union, cast
import torch
from torch import Tensor

View File

@ -9,8 +9,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Generator
from contextlib import contextmanager
from typing import Any, Callable, Dict, Generator, Literal, Optional, Union
from typing import Any, Callable, Dict, Literal, Optional, Union
import torch
from torch import Tensor

View File

@ -11,8 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Generator
from contextlib import contextmanager
from typing import Any, ContextManager, Generator, Literal
from typing import Any, ContextManager, Literal
import torch
import torch.nn as nn

View File

@ -11,8 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Generator
from contextlib import contextmanager
from typing import Any, ContextManager, Generator, Literal
from typing import Any, ContextManager, Literal
import torch
from lightning_utilities import apply_to_collection

View File

@ -12,8 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
from collections.abc import Generator
from functools import partial
from typing import Any, Callable, Generator, List, Optional, Tuple, Union
from typing import Any, Callable, List, Optional, Tuple, Union
import torch
from torch import Tensor

View File

@ -89,9 +89,10 @@ class AdvancedProfiler(Profiler):
dst_fs = get_filesystem(dst_filepath)
dst_fs.mkdirs(self.dirpath, exist_ok=True)
# temporarily save to local since pstats can only dump into a local file
with tempfile.TemporaryDirectory(
prefix="test", suffix=str(rank_zero_only.rank), dir=os.getcwd()
) as tmp_dir, dst_fs.open(dst_filepath, "wb") as dst_file:
with (
tempfile.TemporaryDirectory(prefix="test", suffix=str(rank_zero_only.rank), dir=os.getcwd()) as tmp_dir,
dst_fs.open(dst_filepath, "wb") as dst_file,
):
src_filepath = os.path.join(tmp_dir, "tmp.prof")
profile.dump_stats(src_filepath)
src_fs = get_filesystem(src_filepath)

View File

@ -16,9 +16,10 @@
import logging
import os
from abc import ABC, abstractmethod
from collections.abc import Generator
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Callable, Dict, Generator, Optional, TextIO, Union
from typing import Any, Callable, Dict, Optional, TextIO, Union
from lightning.fabric.utilities.cloud_io import get_filesystem

View File

@ -17,9 +17,10 @@ import logging
import os
import platform
from collections import OrderedDict
from collections.abc import Generator, Mapping
from contextlib import contextmanager
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Mapping, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import torch
from torch.nn import Module

View File

@ -13,6 +13,7 @@
# limitations under the License.
import logging
import shutil
from collections.abc import Generator, Mapping
from contextlib import contextmanager, nullcontext
from datetime import timedelta
from pathlib import Path
@ -21,10 +22,8 @@ from typing import (
Any,
Callable,
Dict,
Generator,
List,
Literal,
Mapping,
Optional,
Set,
Tuple,

View File

@ -12,10 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import shutil
from collections.abc import Generator, Mapping
from contextlib import contextmanager, nullcontext
from datetime import timedelta
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal, Mapping, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union
import torch
from lightning_utilities.core.rank_zero import rank_zero_only as utils_rank_zero_only

View File

@ -12,8 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from collections.abc import Generator
from contextlib import contextmanager
from typing import Any, Dict, Generator, List, Optional
from typing import Any, Dict, List, Optional
import torch
from torch import Tensor

View File

@ -13,8 +13,9 @@
# limitations under the License.
import logging
from abc import ABC, abstractmethod
from collections.abc import Generator, Mapping
from contextlib import contextmanager
from typing import Any, Callable, Dict, Generator, List, Mapping, Optional, Tuple, TypeVar, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union
import torch
from torch import Tensor

View File

@ -115,7 +115,11 @@ def _call_configure_model(trainer: "pl.Trainer") -> None:
# we don't normally check for this before calling the hook. it is done here to avoid instantiating the context
# managers
if is_overridden("configure_model", trainer.lightning_module):
with trainer.strategy.tensor_init_context(), trainer.strategy.model_sharded_context(), trainer.precision_plugin.module_init_context(): # noqa: E501
with (
trainer.strategy.tensor_init_context(),
trainer.strategy.model_sharded_context(),
trainer.precision_plugin.module_init_context(),
): # noqa: E501
_call_lightning_module_hook(trainer, "configure_model")

View File

@ -14,8 +14,9 @@
import logging
import os
from collections.abc import Sequence
from datetime import timedelta
from typing import Dict, List, Optional, Sequence, Union
from typing import Dict, List, Optional, Union
import lightning.pytorch as pl
from lightning.fabric.utilities.registry import _load_external_callbacks

View File

@ -12,8 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from collections.abc import Iterable
from dataclasses import dataclass, field
from typing import Any, Iterable, Optional, Tuple, Union
from typing import Any, Optional, Tuple, Union
import torch.multiprocessing as mp
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, Sampler, SequentialSampler

View File

@ -11,7 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Iterable, Optional, Union
from collections.abc import Iterable
from typing import Any, Optional, Union
from lightning_utilities.core.apply_func import apply_to_collection
from torch import Tensor

View File

@ -11,9 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Generator
from dataclasses import dataclass
from functools import partial, wraps
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union, cast
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast
import torch
from lightning_utilities.core.apply_func import apply_to_collection

View File

@ -23,9 +23,10 @@
import logging
import math
import os
from collections.abc import Generator, Iterable
from contextlib import contextmanager
from datetime import timedelta
from typing import Any, Dict, Generator, Iterable, List, Optional, Union
from typing import Any, Dict, List, Optional, Union
from weakref import proxy
import torch

View File

@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
from collections.abc import Iterable
from typing import Any, Callable, Dict, Iterator, List, Literal, Optional, Tuple, Type, Union
from collections.abc import Iterable, Iterator
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, Union
from torch.utils.data.dataloader import _BaseDataLoaderIter, _MultiProcessingDataLoaderIter
from typing_extensions import Self, TypedDict, override

View File

@ -12,8 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from collections.abc import Generator, Iterable, Mapping, Sized
from dataclasses import fields
from typing import Any, Dict, Generator, Iterable, Mapping, Optional, Sized, Tuple, Union
from typing import Any, Dict, Optional, Tuple, Union
import torch
from lightning_utilities.core.apply_func import is_dataclass_instance

View File

@ -17,8 +17,9 @@ import copy
import inspect
import pickle
import types
from collections.abc import MutableMapping, Sequence
from dataclasses import fields, is_dataclass
from typing import Any, Dict, List, Literal, MutableMapping, Optional, Sequence, Tuple, Type, Union
from typing import Any, Dict, List, Literal, Optional, Tuple, Type, Union
from torch import nn

View File

@ -13,8 +13,8 @@
# limitations under the License.
"""Utilities to help with reproducibility of models."""
from collections.abc import Generator
from contextlib import contextmanager
from typing import Generator
from lightning.fabric.utilities.seed import _collect_rng_states, _set_rng_states

View File

@ -17,17 +17,14 @@ Convention:
- Types used in public hooks (as those in the `LightningModule` and `Callback`) should be public (no leading `_`)
"""
from collections.abc import Generator, Iterator, Mapping, Sequence
from contextlib import contextmanager
from dataclasses import dataclass
from typing import (
Any,
Generator,
Iterator,
List,
Mapping,
Optional,
Protocol,
Sequence,
Tuple,
Type,
TypedDict,

View File

@ -121,27 +121,32 @@ def test_tf32_message(_, __, ___, caplog, monkeypatch):
def test_find_usable_cuda_devices_error_handling():
"""Test error handling for edge cases when using `find_usable_cuda_devices`."""
# Asking for GPUs if no GPUs visible
with mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=0), pytest.raises(
ValueError, match="You requested to find 2 devices but there are no visible CUDA"
with (
mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=0),
pytest.raises(ValueError, match="You requested to find 2 devices but there are no visible CUDA"),
):
find_usable_cuda_devices(2)
# Asking for more GPUs than are visible
with mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=1), pytest.raises(
ValueError, match="this machine only has 1 GPUs"
with (
mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=1),
pytest.raises(ValueError, match="this machine only has 1 GPUs"),
):
find_usable_cuda_devices(2)
# All GPUs are unusable
tensor_mock = Mock(side_effect=RuntimeError) # simulate device placement fails
with mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=2), mock.patch(
"lightning.fabric.accelerators.cuda.torch.tensor", tensor_mock
), pytest.raises(RuntimeError, match=escape("The devices [0, 1] are occupied by other processes")):
with (
mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=2),
mock.patch("lightning.fabric.accelerators.cuda.torch.tensor", tensor_mock),
pytest.raises(RuntimeError, match=escape("The devices [0, 1] are occupied by other processes")),
):
find_usable_cuda_devices(2)
# Request for as many GPUs as there are, no error should be raised
with mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=5), mock.patch(
"lightning.fabric.accelerators.cuda.torch.tensor"
with (
mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=5),
mock.patch("lightning.fabric.accelerators.cuda.torch.tensor"),
):
assert find_usable_cuda_devices(-1) == [0, 1, 2, 3, 4]

View File

@ -1,4 +1,4 @@
from typing import Iterator
from collections.abc import Iterator
import torch
from torch import Tensor

View File

@ -29,13 +29,16 @@ PASSED_OBJECT = mock.Mock()
@contextlib.contextmanager
def check_destroy_group():
with mock.patch(
"lightning.fabric.plugins.collectives.torch_collective.TorchCollective.new_group",
wraps=TorchCollective.new_group,
) as mock_new, mock.patch(
"lightning.fabric.plugins.collectives.torch_collective.TorchCollective.destroy_group",
wraps=TorchCollective.destroy_group,
) as mock_destroy:
with (
mock.patch(
"lightning.fabric.plugins.collectives.torch_collective.TorchCollective.new_group",
wraps=TorchCollective.new_group,
) as mock_new,
mock.patch(
"lightning.fabric.plugins.collectives.torch_collective.TorchCollective.destroy_group",
wraps=TorchCollective.destroy_group,
) as mock_destroy,
):
yield
# 0 to account for tests that mock distributed
# -1 to account for destroying the default process group
@ -155,9 +158,10 @@ def test_repeated_create_and_destroy():
with pytest.raises(RuntimeError, match="TorchCollective` already owns a group"):
collective.create_group()
with mock.patch.dict("torch.distributed.distributed_c10d._pg_map", {collective.group: ("", None)}), mock.patch(
"torch.distributed.destroy_process_group"
) as destroy_mock:
with (
mock.patch.dict("torch.distributed.distributed_c10d._pg_map", {collective.group: ("", None)}),
mock.patch("torch.distributed.destroy_process_group") as destroy_mock,
):
collective.teardown()
# this would be called twice if `init_process_group` wasn't patched. once for the group and once for the default
# group
@ -300,9 +304,11 @@ def test_collective_manages_default_group():
assert TorchCollective.manages_default_group
with mock.patch.object(collective, "_group") as mock_group, mock.patch.dict(
"torch.distributed.distributed_c10d._pg_map", {mock_group: ("", None)}
), mock.patch("torch.distributed.destroy_process_group") as destroy_mock:
with (
mock.patch.object(collective, "_group") as mock_group,
mock.patch.dict("torch.distributed.distributed_c10d._pg_map", {mock_group: ("", None)}),
mock.patch("torch.distributed.destroy_process_group") as destroy_mock,
):
collective.teardown()
destroy_mock.assert_called_once_with(mock_group)

View File

@ -41,8 +41,9 @@ def test_empty_lsb_djob_rankfile():
def test_missing_lsb_job_id(tmp_path):
"""Test an error when the job id cannot be found."""
with mock.patch.dict(os.environ, {"LSB_DJOB_RANKFILE": _make_rankfile(tmp_path)}), pytest.raises(
ValueError, match="Could not find job id in environment variable LSB_JOBID"
with (
mock.patch.dict(os.environ, {"LSB_DJOB_RANKFILE": _make_rankfile(tmp_path)}),
pytest.raises(ValueError, match="Could not find job id in environment variable LSB_JOBID"),
):
LSFEnvironment()

View File

@ -155,8 +155,9 @@ def test_srun_variable_validation():
"""Test that we raise useful errors when `srun` variables are misconfigured."""
with mock.patch.dict(os.environ, {"SLURM_NTASKS": "1"}):
SLURMEnvironment()
with mock.patch.dict(os.environ, {"SLURM_NTASKS": "2"}), pytest.raises(
RuntimeError, match="You set `--ntasks=2` in your SLURM"
with (
mock.patch.dict(os.environ, {"SLURM_NTASKS": "2"}),
pytest.raises(RuntimeError, match="You set `--ntasks=2` in your SLURM"),
):
SLURMEnvironment()

View File

@ -100,8 +100,11 @@ def test_check_for_bad_cuda_fork(mp_mock, _, start_method):
def test_check_for_missing_main_guard():
launcher = _MultiProcessingLauncher(strategy=Mock(), start_method="spawn")
with mock.patch(
"lightning.fabric.strategies.launchers.multiprocessing.mp.current_process",
return_value=Mock(_inheriting=True), # pretend that main is importing itself
), pytest.raises(RuntimeError, match="requires that your script guards the main"):
with (
mock.patch(
"lightning.fabric.strategies.launchers.multiprocessing.mp.current_process",
return_value=Mock(_inheriting=True), # pretend that main is importing itself
),
pytest.raises(RuntimeError, match="requires that your script guards the main"),
):
launcher.launch(function=Mock())

View File

@ -58,9 +58,12 @@ def test_ddp_no_backward_sync():
strategy = DDPStrategy()
assert isinstance(strategy._backward_sync_control, _DDPBackwardSyncControl)
with pytest.raises(
TypeError, match="is only possible if the module passed to .* is wrapped in `DistributedDataParallel`"
), strategy._backward_sync_control.no_backward_sync(Mock(), True):
with (
pytest.raises(
TypeError, match="is only possible if the module passed to .* is wrapped in `DistributedDataParallel`"
),
strategy._backward_sync_control.no_backward_sync(Mock(), True),
):
pass
module = MagicMock(spec=DistributedDataParallel)

View File

@ -404,9 +404,11 @@ def test_deepspeed_init_module_with_stages_1_2(stage, empty_init):
fabric = Fabric(accelerator="cuda", devices=2, strategy=strategy, precision="bf16-true")
fabric.launch()
with mock.patch("deepspeed.zero.Init") as zero_init_mock, mock.patch(
"torch.Tensor.uniform_"
) as init_mock, fabric.init_module(empty_init=empty_init):
with (
mock.patch("deepspeed.zero.Init") as zero_init_mock,
mock.patch("torch.Tensor.uniform_") as init_mock,
fabric.init_module(empty_init=empty_init),
):
model = BoringModel()
zero_init_mock.assert_called_with(enabled=False, remote_device=None, config_dict_or_path=ANY)

View File

@ -133,9 +133,12 @@ def test_no_backward_sync():
strategy = FSDPStrategy()
assert isinstance(strategy._backward_sync_control, _FSDPBackwardSyncControl)
with pytest.raises(
TypeError, match="is only possible if the module passed to .* is wrapped in `FullyShardedDataParallel`"
), strategy._backward_sync_control.no_backward_sync(Mock(), True):
with (
pytest.raises(
TypeError, match="is only possible if the module passed to .* is wrapped in `FullyShardedDataParallel`"
),
strategy._backward_sync_control.no_backward_sync(Mock(), True),
):
pass
module = MagicMock(spec=FullyShardedDataParallel)
@ -172,9 +175,12 @@ def test_activation_checkpointing():
assert isinstance(strategy._activation_checkpointing_kwargs["auto_wrap_policy"], ModuleWrapPolicy)
strategy._parallel_devices = [torch.device("cuda", 0)]
with mock.patch("torch.distributed.fsdp.FullyShardedDataParallel", new=MagicMock), mock.patch(
"torch.distributed.algorithms._checkpoint.checkpoint_wrapper.apply_activation_checkpointing"
) as apply_mock:
with (
mock.patch("torch.distributed.fsdp.FullyShardedDataParallel", new=MagicMock),
mock.patch(
"torch.distributed.algorithms._checkpoint.checkpoint_wrapper.apply_activation_checkpointing"
) as apply_mock,
):
wrapped = strategy.setup_module(Model())
apply_mock.assert_called_with(wrapped, checkpoint_wrapper_fn=ANY, **strategy._activation_checkpointing_kwargs)

View File

@ -48,9 +48,12 @@ def test_xla_fsdp_no_backward_sync():
strategy = XLAFSDPStrategy()
assert isinstance(strategy._backward_sync_control, _XLAFSDPBackwardSyncControl)
with pytest.raises(
TypeError, match="is only possible if the module passed to .* is wrapped in `XlaFullyShardedDataParallel`"
), strategy._backward_sync_control.no_backward_sync(object(), True):
with (
pytest.raises(
TypeError, match="is only possible if the module passed to .* is wrapped in `XlaFullyShardedDataParallel`"
),
strategy._backward_sync_control.no_backward_sync(object(), True),
):
pass
module = MagicMock(spec=XlaFullyShardedDataParallel)

View File

@ -960,28 +960,33 @@ def test_arguments_from_environment_collision():
with mock.patch.dict(os.environ, {"LT_ACCELERATOR": "cpu"}):
_Connector(accelerator="cuda")
with mock.patch.dict(os.environ, {"LT_ACCELERATOR": "cpu", "LT_CLI_USED": "1"}), pytest.raises(
ValueError, match="`Fabric\\(accelerator='cuda', ...\\)` but .* `--accelerator=cpu`"
with (
mock.patch.dict(os.environ, {"LT_ACCELERATOR": "cpu", "LT_CLI_USED": "1"}),
pytest.raises(ValueError, match="`Fabric\\(accelerator='cuda', ...\\)` but .* `--accelerator=cpu`"),
):
_Connector(accelerator="cuda")
with mock.patch.dict(os.environ, {"LT_STRATEGY": "ddp", "LT_CLI_USED": "1"}), pytest.raises(
ValueError, match="`Fabric\\(strategy='ddp_spawn', ...\\)` but .* `--strategy=ddp`"
with (
mock.patch.dict(os.environ, {"LT_STRATEGY": "ddp", "LT_CLI_USED": "1"}),
pytest.raises(ValueError, match="`Fabric\\(strategy='ddp_spawn', ...\\)` but .* `--strategy=ddp`"),
):
_Connector(strategy="ddp_spawn")
with mock.patch.dict(os.environ, {"LT_DEVICES": "2", "LT_CLI_USED": "1"}), pytest.raises(
ValueError, match="`Fabric\\(devices=3, ...\\)` but .* `--devices=2`"
with (
mock.patch.dict(os.environ, {"LT_DEVICES": "2", "LT_CLI_USED": "1"}),
pytest.raises(ValueError, match="`Fabric\\(devices=3, ...\\)` but .* `--devices=2`"),
):
_Connector(devices=3)
with mock.patch.dict(os.environ, {"LT_NUM_NODES": "3", "LT_CLI_USED": "1"}), pytest.raises(
ValueError, match="`Fabric\\(num_nodes=2, ...\\)` but .* `--num_nodes=3`"
with (
mock.patch.dict(os.environ, {"LT_NUM_NODES": "3", "LT_CLI_USED": "1"}),
pytest.raises(ValueError, match="`Fabric\\(num_nodes=2, ...\\)` but .* `--num_nodes=3`"),
):
_Connector(num_nodes=2)
with mock.patch.dict(os.environ, {"LT_PRECISION": "16-mixed", "LT_CLI_USED": "1"}), pytest.raises(
ValueError, match="`Fabric\\(precision='64-true', ...\\)` but .* `--precision=16-mixed`"
with (
mock.patch.dict(os.environ, {"LT_PRECISION": "16-mixed", "LT_CLI_USED": "1"}),
pytest.raises(ValueError, match="`Fabric\\(precision='64-true', ...\\)` but .* `--precision=16-mixed`"),
):
_Connector(precision="64-true")

View File

@ -746,9 +746,10 @@ def test_no_backward_sync():
# pretend that the strategy does not support skipping backward sync
fabric._strategy = Mock(spec=ParallelStrategy, _backward_sync_control=None)
with pytest.warns(
PossibleUserWarning, match="The `ParallelStrategy` does not support skipping the"
), fabric.no_backward_sync(model):
with (
pytest.warns(PossibleUserWarning, match="The `ParallelStrategy` does not support skipping the"),
fabric.no_backward_sync(model),
):
pass
# for single-device strategies, it becomes a no-op without warning

View File

@ -41,8 +41,9 @@ def test_parse_cli_args(args, expected):
def test_process_cli_args(tmp_path, caplog, monkeypatch):
# PyTorch version < 2.3
monkeypatch.setattr(lightning.fabric.utilities.consolidate_checkpoint, "_TORCH_GREATER_EQUAL_2_3", False)
with caplog.at_level(logging.ERROR, logger="lightning.fabric.utilities.consolidate_checkpoint"), pytest.raises(
SystemExit
with (
caplog.at_level(logging.ERROR, logger="lightning.fabric.utilities.consolidate_checkpoint"),
pytest.raises(SystemExit),
):
_process_cli_args(Namespace())
assert "requires PyTorch >= 2.3." in caplog.text
@ -51,8 +52,9 @@ def test_process_cli_args(tmp_path, caplog, monkeypatch):
# Checkpoint does not exist
checkpoint_folder = Path("does/not/exist")
with caplog.at_level(logging.ERROR, logger="lightning.fabric.utilities.consolidate_checkpoint"), pytest.raises(
SystemExit
with (
caplog.at_level(logging.ERROR, logger="lightning.fabric.utilities.consolidate_checkpoint"),
pytest.raises(SystemExit),
):
_process_cli_args(Namespace(checkpoint_folder=checkpoint_folder))
assert f"checkpoint folder does not exist: {checkpoint_folder}" in caplog.text
@ -61,8 +63,9 @@ def test_process_cli_args(tmp_path, caplog, monkeypatch):
# Checkpoint exists but is not a folder
file = tmp_path / "checkpoint_file"
file.touch()
with caplog.at_level(logging.ERROR, logger="lightning.fabric.utilities.consolidate_checkpoint"), pytest.raises(
SystemExit
with (
caplog.at_level(logging.ERROR, logger="lightning.fabric.utilities.consolidate_checkpoint"),
pytest.raises(SystemExit),
):
_process_cli_args(Namespace(checkpoint_folder=file))
assert "checkpoint path must be a folder" in caplog.text
@ -71,8 +74,9 @@ def test_process_cli_args(tmp_path, caplog, monkeypatch):
# Checkpoint exists but is not an FSDP checkpoint
folder = tmp_path / "checkpoint_folder"
folder.mkdir()
with caplog.at_level(logging.ERROR, logger="lightning.fabric.utilities.consolidate_checkpoint"), pytest.raises(
SystemExit
with (
caplog.at_level(logging.ERROR, logger="lightning.fabric.utilities.consolidate_checkpoint"),
pytest.raises(SystemExit),
):
_process_cli_args(Namespace(checkpoint_folder=folder))
assert "Only FSDP-sharded checkpoints saved with Lightning are supported" in caplog.text
@ -89,8 +93,9 @@ def test_process_cli_args(tmp_path, caplog, monkeypatch):
# Checkpoint is a FSDP folder, output file already exists
file = tmp_path / "ouput_file"
file.touch()
with caplog.at_level(logging.ERROR, logger="lightning.fabric.utilities.consolidate_checkpoint"), pytest.raises(
SystemExit
with (
caplog.at_level(logging.ERROR, logger="lightning.fabric.utilities.consolidate_checkpoint"),
pytest.raises(SystemExit),
):
_process_cli_args(Namespace(checkpoint_folder=folder, output_file=file))
assert "path for the converted checkpoint already exists" in caplog.text

View File

@ -215,9 +215,10 @@ def test_infinite_barrier():
# distributed available
barrier = _InfiniteBarrier()
with mock.patch(
"lightning.fabric.utilities.distributed._distributed_is_initialized", return_value=True
), mock.patch("lightning.fabric.utilities.distributed.torch.distributed") as dist_mock:
with (
mock.patch("lightning.fabric.utilities.distributed._distributed_is_initialized", return_value=True),
mock.patch("lightning.fabric.utilities.distributed.torch.distributed") as dist_mock,
):
barrier.__enter__()
dist_mock.new_group.assert_called_once()
assert barrier.barrier == barrier.group.monitored_barrier

View File

@ -39,8 +39,9 @@ def test_get_available_flops(xla_available):
with pytest.warns(match="not found for 'CocoNut"), mock.patch("torch.cuda.get_device_name", return_value="CocoNut"):
assert get_available_flops(torch.device("cuda"), torch.bfloat16) is None
with pytest.warns(match="t4' does not support torch.bfloat"), mock.patch(
"torch.cuda.get_device_name", return_value="t4"
with (
pytest.warns(match="t4' does not support torch.bfloat"),
mock.patch("torch.cuda.get_device_name", return_value="t4"),
):
assert get_available_flops(torch.device("cuda"), torch.bfloat16) is None

View File

@ -141,9 +141,12 @@ def test_rich_progress_bar_keyboard_interrupt(tmp_path):
model = TestModel()
with mock.patch(
"lightning.pytorch.callbacks.progress.rich_progress.Progress.stop", autospec=True
) as mock_progress_stop, pytest.raises(SystemExit):
with (
mock.patch(
"lightning.pytorch.callbacks.progress.rich_progress.Progress.stop", autospec=True
) as mock_progress_stop,
pytest.raises(SystemExit),
):
progress_bar = RichProgressBar()
trainer = Trainer(
default_root_dir=tmp_path,

View File

@ -43,8 +43,9 @@ def test_throughput_monitor_fit(tmp_path):
)
# these timing results are meant to precisely match the `test_throughput_monitor` test in fabric
timings = [0.0] + [0.5 + i for i in range(1, 6)]
with mock.patch("lightning.pytorch.callbacks.throughput_monitor.get_available_flops", return_value=100), mock.patch(
"time.perf_counter", side_effect=timings
with (
mock.patch("lightning.pytorch.callbacks.throughput_monitor.get_available_flops", return_value=100),
mock.patch("time.perf_counter", side_effect=timings),
):
trainer.fit(model)
@ -179,8 +180,9 @@ def test_throughput_monitor_fit_gradient_accumulation(log_every_n_steps, tmp_pat
enable_progress_bar=False,
)
timings = [0.0] + [0.5 + i for i in range(1, 11)]
with mock.patch("lightning.pytorch.callbacks.throughput_monitor.get_available_flops", return_value=100), mock.patch(
"time.perf_counter", side_effect=timings
with (
mock.patch("lightning.pytorch.callbacks.throughput_monitor.get_available_flops", return_value=100),
mock.patch("time.perf_counter", side_effect=timings),
):
trainer.fit(model)

View File

@ -92,9 +92,10 @@ def test_trainer_save_checkpoint_storage_options(tmp_path, xla_available):
io_mock.assert_called_with(ANY, instance_path, storage_options=None)
checkpoint_mock = Mock()
with mock.patch.object(trainer.strategy, "save_checkpoint") as save_mock, mock.patch.object(
trainer._checkpoint_connector, "dump_checkpoint", return_value=checkpoint_mock
) as dump_mock:
with (
mock.patch.object(trainer.strategy, "save_checkpoint") as save_mock,
mock.patch.object(trainer._checkpoint_connector, "dump_checkpoint", return_value=checkpoint_mock) as dump_mock,
):
trainer.save_checkpoint(instance_path, True)
dump_mock.assert_called_with(True)
save_mock.assert_called_with(checkpoint_mock, instance_path, storage_options=None)

View File

@ -110,9 +110,10 @@ def test_lightning_optimizer_manual_optimization_and_accumulated_gradients(tmp_p
default_root_dir=tmp_path, limit_train_batches=8, limit_val_batches=1, max_epochs=1, enable_model_summary=False
)
with patch.multiple(torch.optim.SGD, zero_grad=DEFAULT, step=DEFAULT) as sgd, patch.multiple(
torch.optim.Adam, zero_grad=DEFAULT, step=DEFAULT
) as adam:
with (
patch.multiple(torch.optim.SGD, zero_grad=DEFAULT, step=DEFAULT) as sgd,
patch.multiple(torch.optim.Adam, zero_grad=DEFAULT, step=DEFAULT) as adam,
):
trainer.fit(model)
assert sgd["step"].call_count == 4

View File

@ -625,8 +625,9 @@ def test_logger_sync_dist(distributed_env, log_val):
else nullcontext()
)
with warning_ctx(
PossibleUserWarning, match=r"recommended to use `self.log\('bar', ..., sync_dist=True\)`"
), patch_ctx:
with (
warning_ctx(PossibleUserWarning, match=r"recommended to use `self.log\('bar', ..., sync_dist=True\)`"),
patch_ctx,
):
value = _ResultCollection._get_cache(result_metric, on_step=False)
assert value == 0.5

View File

@ -16,7 +16,8 @@ import os
import random
import time
import urllib.request
from typing import Optional, Sequence, Tuple
from collections.abc import Sequence
from typing import Optional, Tuple
import torch
from torch import Tensor

View File

@ -11,8 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Iterator, Mapping
from contextlib import nullcontext
from typing import Dict, Generic, Iterator, Mapping, TypeVar
from typing import Dict, Generic, TypeVar
import pytest
import torch

View File

@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import Counter
from typing import Any, Iterator
from collections.abc import Iterator
from typing import Any
import pytest
import torch

View File

@ -12,9 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from collections.abc import Iterator
from copy import deepcopy
from dataclasses import dataclass
from typing import Dict, Iterator
from typing import Dict
from unittest.mock import ANY, Mock
import pytest

View File

@ -15,8 +15,9 @@ import glob
import logging as log
import os
import pickle
from collections.abc import Mapping
from copy import deepcopy
from typing import Generic, Mapping, TypeVar
from typing import Generic, TypeVar
import cloudpickle
import pytest

View File

@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Iterable
from collections.abc import Iterable
import pytest
import torch

View File

@ -209,10 +209,13 @@ def test_memory_sharing_disabled(tmp_path):
def test_check_for_missing_main_guard():
launcher = _MultiProcessingLauncher(strategy=Mock(), start_method="spawn")
with mock.patch(
"lightning.pytorch.strategies.launchers.multiprocessing.mp.current_process",
return_value=Mock(_inheriting=True), # pretend that main is importing itself
), pytest.raises(RuntimeError, match="requires that your script guards the main"):
with (
mock.patch(
"lightning.pytorch.strategies.launchers.multiprocessing.mp.current_process",
return_value=Mock(_inheriting=True), # pretend that main is importing itself
),
pytest.raises(RuntimeError, match="requires that your script guards the main"),
):
launcher.launch(function=Mock())

View File

@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Any, Mapping
from collections.abc import Mapping
from typing import Any
import pytest
import torch

View File

@ -444,9 +444,12 @@ def test_activation_checkpointing():
strategy._parallel_devices = [torch.device("cuda", 0)]
strategy._lightning_module = model
strategy._process_group = Mock()
with mock.patch("torch.distributed.fsdp.FullyShardedDataParallel", new=MagicMock), mock.patch(
"torch.distributed.algorithms._checkpoint.checkpoint_wrapper.apply_activation_checkpointing"
) as apply_mock:
with (
mock.patch("torch.distributed.fsdp.FullyShardedDataParallel", new=MagicMock),
mock.patch(
"torch.distributed.algorithms._checkpoint.checkpoint_wrapper.apply_activation_checkpointing"
) as apply_mock,
):
wrapped = strategy._setup_model(model)
apply_mock.assert_called_with(wrapped, checkpoint_wrapper_fn=ANY, **strategy._activation_checkpointing_kwargs)

View File

@ -409,8 +409,9 @@ def test_lightning_cli_config_and_subclass_mode(cleandir):
with open(config_path, "w") as f:
f.write(yaml.dump(input_config))
with mock.patch("sys.argv", ["any.py", "--config", config_path]), mock_subclasses(
LightningDataModule, DataDirDataModule
with (
mock.patch("sys.argv", ["any.py", "--config", config_path]),
mock_subclasses(LightningDataModule, DataDirDataModule),
):
cli = LightningCLI(
BoringModel,
@ -461,9 +462,12 @@ def test_lightning_cli_help():
cli_args = ["any.py", "fit", "--data.help=DataDirDataModule"]
out = StringIO()
with mock.patch("sys.argv", cli_args), redirect_stdout(out), mock_subclasses(
LightningDataModule, DataDirDataModule
), pytest.raises(SystemExit):
with (
mock.patch("sys.argv", cli_args),
redirect_stdout(out),
mock_subclasses(LightningDataModule, DataDirDataModule),
pytest.raises(SystemExit),
):
any_model_any_data_cli()
assert ("--data.data_dir" in out.getvalue()) or ("--data.init_args.data_dir" in out.getvalue())
@ -609,8 +613,9 @@ class EarlyExitTestModel(BoringModel):
def test_cli_distributed_save_config_callback(cleandir, logger, strategy):
from torch.multiprocessing import ProcessRaisedException
with mock.patch("sys.argv", ["any.py", "fit"]), pytest.raises(
(MisconfigurationException, ProcessRaisedException), match=r"Error on fit start"
with (
mock.patch("sys.argv", ["any.py", "fit"]),
pytest.raises((MisconfigurationException, ProcessRaisedException), match=r"Error on fit start"),
):
LightningCLI(
EarlyExitTestModel,
@ -710,12 +715,14 @@ def test_cli_no_need_configure_optimizers(cleandir):
from lightning.pytorch.trainer.configuration_validator import __verify_train_val_loop_configuration
with mock.patch("sys.argv", ["any.py", "fit", "--optimizer=Adam"]), mock.patch(
"lightning.pytorch.Trainer._run_stage"
) as run, mock.patch(
"lightning.pytorch.trainer.configuration_validator.__verify_train_val_loop_configuration",
wraps=__verify_train_val_loop_configuration,
) as verify:
with (
mock.patch("sys.argv", ["any.py", "fit", "--optimizer=Adam"]),
mock.patch("lightning.pytorch.Trainer._run_stage") as run,
mock.patch(
"lightning.pytorch.trainer.configuration_validator.__verify_train_val_loop_configuration",
wraps=__verify_train_val_loop_configuration,
) as verify,
):
cli = LightningCLI(BoringModel)
run.assert_called_once()
verify.assert_called_once_with(cli.trainer, cli.model)
@ -1074,15 +1081,18 @@ class TestModel(BoringModel):
@_xfail_python_ge_3_11_9
def test_lightning_cli_model_short_arguments():
with mock.patch("sys.argv", ["any.py", "fit", "--model=BoringModel"]), mock.patch(
"lightning.pytorch.Trainer._fit_impl"
) as run, mock_subclasses(LightningModule, BoringModel, TestModel):
with (
mock.patch("sys.argv", ["any.py", "fit", "--model=BoringModel"]),
mock.patch("lightning.pytorch.Trainer._fit_impl") as run,
mock_subclasses(LightningModule, BoringModel, TestModel),
):
cli = LightningCLI(trainer_defaults={"fast_dev_run": 1})
assert isinstance(cli.model, BoringModel)
run.assert_called_once_with(cli.model, ANY, ANY, ANY, ANY)
with mock.patch("sys.argv", ["any.py", "--model=TestModel", "--model.foo", "123"]), mock_subclasses(
LightningModule, BoringModel, TestModel
with (
mock.patch("sys.argv", ["any.py", "--model=TestModel", "--model.foo", "123"]),
mock_subclasses(LightningModule, BoringModel, TestModel),
):
cli = LightningCLI(run=False)
assert isinstance(cli.model, TestModel)
@ -1100,15 +1110,18 @@ class MyDataModule(BoringDataModule):
@_xfail_python_ge_3_11_9
def test_lightning_cli_datamodule_short_arguments():
# with set model
with mock.patch("sys.argv", ["any.py", "fit", "--data=BoringDataModule"]), mock.patch(
"lightning.pytorch.Trainer._fit_impl"
) as run, mock_subclasses(LightningDataModule, BoringDataModule):
with (
mock.patch("sys.argv", ["any.py", "fit", "--data=BoringDataModule"]),
mock.patch("lightning.pytorch.Trainer._fit_impl") as run,
mock_subclasses(LightningDataModule, BoringDataModule),
):
cli = LightningCLI(BoringModel, trainer_defaults={"fast_dev_run": 1})
assert isinstance(cli.datamodule, BoringDataModule)
run.assert_called_once_with(ANY, ANY, ANY, cli.datamodule, ANY)
with mock.patch("sys.argv", ["any.py", "--data=MyDataModule", "--data.foo", "123"]), mock_subclasses(
LightningDataModule, MyDataModule
with (
mock.patch("sys.argv", ["any.py", "--data=MyDataModule", "--data.foo", "123"]),
mock_subclasses(LightningDataModule, MyDataModule),
):
cli = LightningCLI(BoringModel, run=False)
assert isinstance(cli.datamodule, MyDataModule)
@ -1116,17 +1129,22 @@ def test_lightning_cli_datamodule_short_arguments():
assert cli.datamodule.bar == 5
# with configurable model
with mock.patch("sys.argv", ["any.py", "fit", "--model", "BoringModel", "--data=BoringDataModule"]), mock.patch(
"lightning.pytorch.Trainer._fit_impl"
) as run, mock_subclasses(LightningModule, BoringModel), mock_subclasses(LightningDataModule, BoringDataModule):
with (
mock.patch("sys.argv", ["any.py", "fit", "--model", "BoringModel", "--data=BoringDataModule"]),
mock.patch("lightning.pytorch.Trainer._fit_impl") as run,
mock_subclasses(LightningModule, BoringModel),
mock_subclasses(LightningDataModule, BoringDataModule),
):
cli = LightningCLI(trainer_defaults={"fast_dev_run": 1})
assert isinstance(cli.model, BoringModel)
assert isinstance(cli.datamodule, BoringDataModule)
run.assert_called_once_with(cli.model, ANY, ANY, cli.datamodule, ANY)
with mock.patch("sys.argv", ["any.py", "--model", "BoringModel", "--data=MyDataModule"]), mock_subclasses(
LightningModule, BoringModel
), mock_subclasses(LightningDataModule, MyDataModule):
with (
mock.patch("sys.argv", ["any.py", "--model", "BoringModel", "--data=MyDataModule"]),
mock_subclasses(LightningModule, BoringModel),
mock_subclasses(LightningDataModule, MyDataModule),
):
cli = LightningCLI(run=False)
assert isinstance(cli.model, BoringModel)
assert isinstance(cli.datamodule, MyDataModule)
@ -1293,9 +1311,10 @@ def test_optimizers_and_lr_schedulers_add_arguments_to_parser_implemented_reload
def test_lightning_cli_config_with_subcommand():
config = {"test": {"trainer": {"limit_test_batches": 1}, "verbose": True, "ckpt_path": "foobar"}}
with mock.patch("sys.argv", ["any.py", f"--config={config}"]), mock.patch(
"lightning.pytorch.Trainer.test", autospec=True
) as test_mock:
with (
mock.patch("sys.argv", ["any.py", f"--config={config}"]),
mock.patch("lightning.pytorch.Trainer.test", autospec=True) as test_mock,
):
cli = LightningCLI(BoringModel)
test_mock.assert_called_once_with(cli.trainer, cli.model, verbose=True, ckpt_path="foobar")
@ -1308,9 +1327,10 @@ def test_lightning_cli_config_before_subcommand():
"test": {"trainer": {"limit_test_batches": 1}, "verbose": True, "ckpt_path": "foobar"},
}
with mock.patch("sys.argv", ["any.py", f"--config={config}", "test"]), mock.patch(
"lightning.pytorch.Trainer.test", autospec=True
) as test_mock:
with (
mock.patch("sys.argv", ["any.py", f"--config={config}", "test"]),
mock.patch("lightning.pytorch.Trainer.test", autospec=True) as test_mock,
):
cli = LightningCLI(BoringModel)
test_mock.assert_called_once_with(cli.trainer, model=cli.model, verbose=True, ckpt_path="foobar")
@ -1320,9 +1340,10 @@ def test_lightning_cli_config_before_subcommand():
assert save_config_callback.config.trainer.limit_test_batches == 1
assert save_config_callback.parser.subcommand == "test"
with mock.patch("sys.argv", ["any.py", f"--config={config}", "validate"]), mock.patch(
"lightning.pytorch.Trainer.validate", autospec=True
) as validate_mock:
with (
mock.patch("sys.argv", ["any.py", f"--config={config}", "validate"]),
mock.patch("lightning.pytorch.Trainer.validate", autospec=True) as validate_mock,
):
cli = LightningCLI(BoringModel)
validate_mock.assert_called_once_with(cli.trainer, cli.model, verbose=False, ckpt_path="barfoo")
@ -1337,17 +1358,19 @@ def test_lightning_cli_config_before_subcommand_two_configs():
config1 = {"validate": {"trainer": {"limit_val_batches": 1}, "verbose": False, "ckpt_path": "barfoo"}}
config2 = {"test": {"trainer": {"limit_test_batches": 1}, "verbose": True, "ckpt_path": "foobar"}}
with mock.patch("sys.argv", ["any.py", f"--config={config1}", f"--config={config2}", "test"]), mock.patch(
"lightning.pytorch.Trainer.test", autospec=True
) as test_mock:
with (
mock.patch("sys.argv", ["any.py", f"--config={config1}", f"--config={config2}", "test"]),
mock.patch("lightning.pytorch.Trainer.test", autospec=True) as test_mock,
):
cli = LightningCLI(BoringModel)
test_mock.assert_called_once_with(cli.trainer, model=cli.model, verbose=True, ckpt_path="foobar")
assert cli.trainer.limit_test_batches == 1
with mock.patch("sys.argv", ["any.py", f"--config={config1}", f"--config={config2}", "validate"]), mock.patch(
"lightning.pytorch.Trainer.validate", autospec=True
) as validate_mock:
with (
mock.patch("sys.argv", ["any.py", f"--config={config1}", f"--config={config2}", "validate"]),
mock.patch("lightning.pytorch.Trainer.validate", autospec=True) as validate_mock,
):
cli = LightningCLI(BoringModel)
validate_mock.assert_called_once_with(cli.trainer, cli.model, verbose=False, ckpt_path="barfoo")
@ -1356,9 +1379,10 @@ def test_lightning_cli_config_before_subcommand_two_configs():
def test_lightning_cli_config_after_subcommand():
config = {"trainer": {"limit_test_batches": 1}, "verbose": True, "ckpt_path": "foobar"}
with mock.patch("sys.argv", ["any.py", "test", f"--config={config}"]), mock.patch(
"lightning.pytorch.Trainer.test", autospec=True
) as test_mock:
with (
mock.patch("sys.argv", ["any.py", "test", f"--config={config}"]),
mock.patch("lightning.pytorch.Trainer.test", autospec=True) as test_mock,
):
cli = LightningCLI(BoringModel)
test_mock.assert_called_once_with(cli.trainer, cli.model, verbose=True, ckpt_path="foobar")
@ -1368,9 +1392,10 @@ def test_lightning_cli_config_after_subcommand():
def test_lightning_cli_config_before_and_after_subcommand():
config1 = {"test": {"trainer": {"limit_test_batches": 1}, "verbose": True, "ckpt_path": "foobar"}}
config2 = {"trainer": {"fast_dev_run": 1}, "verbose": False, "ckpt_path": "foobar"}
with mock.patch("sys.argv", ["any.py", f"--config={config1}", "test", f"--config={config2}"]), mock.patch(
"lightning.pytorch.Trainer.test", autospec=True
) as test_mock:
with (
mock.patch("sys.argv", ["any.py", f"--config={config1}", "test", f"--config={config2}"]),
mock.patch("lightning.pytorch.Trainer.test", autospec=True) as test_mock,
):
cli = LightningCLI(BoringModel)
test_mock.assert_called_once_with(cli.trainer, model=cli.model, verbose=False, ckpt_path="foobar")
@ -1392,17 +1417,19 @@ def test_lightning_cli_parse_kwargs_with_subcommands(cleandir):
"validate": {"default_config_files": [str(validate_config_path)]},
}
with mock.patch("sys.argv", ["any.py", "fit"]), mock.patch(
"lightning.pytorch.Trainer.fit", autospec=True
) as fit_mock:
with (
mock.patch("sys.argv", ["any.py", "fit"]),
mock.patch("lightning.pytorch.Trainer.fit", autospec=True) as fit_mock,
):
cli = LightningCLI(BoringModel, parser_kwargs=parser_kwargs)
fit_mock.assert_called()
assert cli.trainer.limit_train_batches == 2
assert cli.trainer.limit_val_batches == 1.0
with mock.patch("sys.argv", ["any.py", "validate"]), mock.patch(
"lightning.pytorch.Trainer.validate", autospec=True
) as validate_mock:
with (
mock.patch("sys.argv", ["any.py", "validate"]),
mock.patch("lightning.pytorch.Trainer.validate", autospec=True) as validate_mock,
):
cli = LightningCLI(BoringModel, parser_kwargs=parser_kwargs)
validate_mock.assert_called()
assert cli.trainer.limit_train_batches == 1.0
@ -1420,9 +1447,10 @@ def test_lightning_cli_subcommands_common_default_config_files(cleandir):
config_path.write_text(str(config))
parser_kwargs = {"default_config_files": [str(config_path)]}
with mock.patch("sys.argv", ["any.py", "fit"]), mock.patch(
"lightning.pytorch.Trainer.fit", autospec=True
) as fit_mock:
with (
mock.patch("sys.argv", ["any.py", "fit"]),
mock.patch("lightning.pytorch.Trainer.fit", autospec=True) as fit_mock,
):
cli = LightningCLI(Model, parser_kwargs=parser_kwargs)
fit_mock.assert_called()
assert cli.model.foo == 123

Some files were not shown because too many files have changed in this diff Show More