[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
parent
060d951605
commit
c11c392acd
|
@ -18,10 +18,11 @@ import re
|
|||
import shutil
|
||||
import tempfile
|
||||
import urllib.request
|
||||
from collections.abc import Iterable, Iterator, Sequence
|
||||
from itertools import chain
|
||||
from os.path import dirname, isfile
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from packaging.requirements import Requirement
|
||||
from packaging.version import Version
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import os
|
||||
from collections.abc import Mapping
|
||||
from collections.abc import Iterable, Mapping
|
||||
from functools import partial
|
||||
from typing import Any, Iterable, List, Literal, Optional, Tuple, Union, cast
|
||||
from typing import Any, List, Literal, Optional, Tuple, Union, cast
|
||||
|
||||
import lightning as L
|
||||
import torch
|
||||
|
|
|
@ -35,7 +35,8 @@ Second-Edition/blob/master/Chapter06/02_dqn_pong.py
|
|||
import argparse
|
||||
import random
|
||||
from collections import OrderedDict, deque, namedtuple
|
||||
from typing import Iterator, List, Tuple
|
||||
from collections.abc import Iterator
|
||||
from typing import List, Tuple
|
||||
|
||||
import gym
|
||||
import torch
|
||||
|
|
|
@ -30,7 +30,8 @@ References
|
|||
"""
|
||||
|
||||
import argparse
|
||||
from typing import Callable, Iterator, List, Tuple
|
||||
from collections.abc import Iterator
|
||||
from typing import Callable, List, Tuple
|
||||
|
||||
import gym
|
||||
import torch
|
||||
|
|
3
setup.py
3
setup.py
|
@ -45,9 +45,10 @@ import glob
|
|||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from collections.abc import Generator, Mapping
|
||||
from importlib.util import module_from_spec, spec_from_file_location
|
||||
from types import ModuleType
|
||||
from typing import Generator, Mapping, Optional
|
||||
from typing import Optional
|
||||
|
||||
import setuptools
|
||||
import setuptools.command.egg_info
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
import inspect
|
||||
import os
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
|
@ -21,11 +22,8 @@ from typing import (
|
|||
Callable,
|
||||
ContextManager,
|
||||
Dict,
|
||||
Generator,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Union,
|
||||
cast,
|
||||
|
|
|
@ -14,7 +14,8 @@
|
|||
|
||||
import os
|
||||
from argparse import Namespace
|
||||
from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Union
|
||||
from collections.abc import Mapping
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
|
||||
|
||||
from lightning_utilities.core.imports import RequirementCache
|
||||
from torch import Tensor
|
||||
|
|
|
@ -16,10 +16,11 @@ import logging
|
|||
import math
|
||||
import os
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
from contextlib import ExitStack
|
||||
from functools import partial
|
||||
from types import ModuleType
|
||||
from typing import Any, Callable, ContextManager, Literal, Optional, OrderedDict, Set, Tuple, Type, cast
|
||||
from typing import Any, Callable, ContextManager, Literal, Optional, Set, Tuple, Type, cast
|
||||
|
||||
import torch
|
||||
from lightning_utilities import apply_to_collection
|
||||
|
|
|
@ -12,8 +12,9 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from contextlib import ExitStack
|
||||
from typing import TYPE_CHECKING, Any, ContextManager, Literal, Mapping, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, ContextManager, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
from lightning_utilities import apply_to_collection
|
||||
|
|
|
@ -11,7 +11,8 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Mapping, Type, Union
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Type, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
|
|
@ -16,10 +16,11 @@ import json
|
|||
import logging
|
||||
import os
|
||||
import platform
|
||||
from collections.abc import Mapping
|
||||
from contextlib import ExitStack
|
||||
from itertools import chain
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Dict, List, Mapping, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from lightning_utilities.core.imports import RequirementCache
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
import shutil
|
||||
import warnings
|
||||
from collections.abc import Generator
|
||||
from contextlib import ExitStack, nullcontext
|
||||
from datetime import timedelta
|
||||
from functools import partial
|
||||
|
@ -23,7 +24,6 @@ from typing import (
|
|||
Callable,
|
||||
ContextManager,
|
||||
Dict,
|
||||
Generator,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
|
|
|
@ -18,7 +18,8 @@ import subprocess
|
|||
import sys
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, Callable, List, Optional, Sequence, Tuple
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Callable, List, Optional, Tuple
|
||||
|
||||
from lightning_utilities.core.imports import RequirementCache
|
||||
from typing_extensions import override
|
||||
|
|
|
@ -13,10 +13,11 @@
|
|||
# limitations under the License.
|
||||
import itertools
|
||||
import shutil
|
||||
from collections.abc import Generator
|
||||
from contextlib import ExitStack
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Dict, Generator, Literal, Optional, TypeVar, Union
|
||||
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Dict, Literal, Optional, TypeVar, Union
|
||||
|
||||
import torch
|
||||
from lightning_utilities.core.rank_zero import rank_zero_only as utils_rank_zero_only
|
||||
|
|
|
@ -13,8 +13,9 @@
|
|||
# limitations under the License.
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Iterable
|
||||
from contextlib import ExitStack
|
||||
from typing import Any, Callable, ContextManager, Dict, Iterable, List, Optional, Tuple, TypeVar, Union
|
||||
from typing import Any, Callable, ContextManager, Dict, List, Optional, Tuple, TypeVar, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
|
|
@ -16,9 +16,10 @@ import functools
|
|||
import inspect
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Generator, Iterable, Sized
|
||||
from contextlib import contextmanager
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Dict, Generator, Iterable, Optional, Sized, Tuple, Type, Union
|
||||
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
|
||||
|
||||
from lightning_utilities.core.inheritance import get_all_subclasses
|
||||
from torch.utils.data import BatchSampler, DataLoader, IterableDataset, Sampler
|
||||
|
|
|
@ -11,7 +11,8 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import List, MutableSequence, Optional, Tuple, Union
|
||||
from collections.abc import MutableSequence
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
|
|
|
@ -4,10 +4,11 @@ import logging
|
|||
import os
|
||||
import signal
|
||||
import time
|
||||
from collections.abc import Iterable, Iterator, Sized
|
||||
from contextlib import nullcontext
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Iterable, Iterator, List, Optional, Sized, Union
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
|
|
@ -12,7 +12,8 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import itertools
|
||||
from typing import Any, Callable, Dict, Optional, Sequence, Union
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Callable, Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.nn import Module, Parameter
|
||||
|
|
|
@ -13,10 +13,12 @@
|
|||
import os
|
||||
import pickle
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Sequence
|
||||
from functools import partial
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import IO, TYPE_CHECKING, Any, Callable, Dict, Optional, OrderedDict, Sequence, Set, Union
|
||||
from typing import IO, TYPE_CHECKING, Any, Callable, Dict, Optional, Set, Union
|
||||
|
||||
import torch
|
||||
from lightning_utilities.core.apply_func import apply_to_collection
|
||||
|
|
|
@ -15,8 +15,9 @@
|
|||
import inspect
|
||||
import json
|
||||
from argparse import Namespace
|
||||
from collections.abc import Mapping, MutableMapping
|
||||
from dataclasses import asdict, is_dataclass
|
||||
from typing import Any, Dict, Mapping, MutableMapping, Optional, Union
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
|
|
|
@ -12,8 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections.abc import MutableMapping
|
||||
from typing import Iterable
|
||||
from collections.abc import Iterable, MutableMapping
|
||||
|
||||
from torch import Tensor
|
||||
from torch.optim import Optimizer
|
||||
|
|
|
@ -11,13 +11,13 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from collections.abc import Iterator
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
DefaultDict,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Protocol,
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import inspect
|
||||
from collections.abc import Generator, Iterator, Mapping
|
||||
from copy import deepcopy
|
||||
from functools import partial, wraps
|
||||
from types import MethodType
|
||||
|
@ -19,10 +20,7 @@ from typing import (
|
|||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Generator,
|
||||
Iterator,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
|
|
|
@ -19,7 +19,8 @@ Freeze and unfreeze models for finetuning purposes.
|
|||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Union
|
||||
from collections.abc import Generator, Iterable
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.nn import Module, ModuleDict
|
||||
|
|
|
@ -18,7 +18,8 @@ BasePredictionWriter
|
|||
Aids in saving predictions
|
||||
"""
|
||||
|
||||
from typing import Any, Literal, Optional, Sequence
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
|
|
|
@ -12,9 +12,10 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import math
|
||||
from collections.abc import Generator
|
||||
from dataclasses import dataclass
|
||||
from datetime import timedelta
|
||||
from typing import Any, Dict, Generator, Optional, Union, cast
|
||||
from typing import Any, Dict, Optional, Union, cast
|
||||
|
||||
from lightning_utilities.core.imports import RequirementCache
|
||||
from typing_extensions import override
|
||||
|
|
|
@ -18,9 +18,10 @@ ModelPruning
|
|||
|
||||
import inspect
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch.nn.utils.prune as pytorch_prune
|
||||
from lightning_utilities.core.apply_func import apply_to_collection
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import os
|
||||
from typing import Any, Mapping, Union
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Union
|
||||
|
||||
import torch
|
||||
|
||||
|
|
|
@ -14,9 +14,10 @@
|
|||
import inspect
|
||||
import os
|
||||
import sys
|
||||
from collections.abc import Iterable
|
||||
from functools import partial, update_wrapper
|
||||
from types import MethodType
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, TypeVar, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union
|
||||
|
||||
import torch
|
||||
import yaml
|
||||
|
|
|
@ -14,7 +14,8 @@
|
|||
"""LightningDataModule for loading DataLoaders with ease."""
|
||||
|
||||
import inspect
|
||||
from typing import IO, Any, Dict, Iterable, Optional, Union, cast
|
||||
from collections.abc import Iterable
|
||||
from typing import IO, Any, Dict, Optional, Union, cast
|
||||
|
||||
from lightning_utilities import apply_to_collection
|
||||
from torch.utils.data import DataLoader, Dataset, IterableDataset
|
||||
|
|
|
@ -15,9 +15,10 @@ import copy
|
|||
import inspect
|
||||
import types
|
||||
from argparse import Namespace
|
||||
from collections.abc import Iterator, MutableMapping, Sequence
|
||||
from contextlib import contextmanager
|
||||
from contextvars import ContextVar
|
||||
from typing import Any, Iterator, List, MutableMapping, Optional, Sequence, Union
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
from lightning.fabric.utilities.data import AttributeDict
|
||||
from lightning.pytorch.utilities.parsing import save_hyperparameters
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
import logging
|
||||
import numbers
|
||||
import weakref
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from contextlib import contextmanager
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
|
@ -25,12 +26,9 @@ from typing import (
|
|||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Generator,
|
||||
List,
|
||||
Literal,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Union,
|
||||
cast,
|
||||
|
|
|
@ -11,9 +11,10 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import fields
|
||||
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union, overload
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, overload
|
||||
from weakref import proxy
|
||||
|
||||
import torch
|
||||
|
|
|
@ -11,7 +11,8 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Dict, Iterator, List, Optional, Tuple
|
||||
from collections.abc import Iterator
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
|
|
@ -5,7 +5,8 @@ https://github.com/pytorch/examples/blob/main/word_language_model
|
|||
|
||||
"""
|
||||
|
||||
from typing import Iterator, List, Optional, Sized, Tuple
|
||||
from collections.abc import Iterator, Sized
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
|
|
@ -16,7 +16,8 @@ import os
|
|||
import random
|
||||
import time
|
||||
import urllib
|
||||
from typing import Any, Callable, Optional, Sized, Tuple, Union
|
||||
from collections.abc import Sized
|
||||
from typing import Any, Callable, Optional, Tuple, Union
|
||||
from urllib.error import HTTPError
|
||||
from warnings import warn
|
||||
|
||||
|
|
|
@ -19,7 +19,8 @@ Comet Logger
|
|||
import logging
|
||||
import os
|
||||
from argparse import Namespace
|
||||
from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Union
|
||||
from collections.abc import Mapping
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
|
||||
|
||||
from lightning_utilities.core.imports import RequirementCache
|
||||
from torch import Tensor
|
||||
|
|
|
@ -18,7 +18,8 @@ import operator
|
|||
import statistics
|
||||
from abc import ABC
|
||||
from collections import defaultdict
|
||||
from typing import Any, Callable, Dict, Mapping, Optional, Sequence
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
|
|
|
@ -21,9 +21,10 @@ import os
|
|||
import re
|
||||
import tempfile
|
||||
from argparse import Namespace
|
||||
from collections.abc import Mapping
|
||||
from pathlib import Path
|
||||
from time import time
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Mapping, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Union
|
||||
|
||||
import yaml
|
||||
from lightning_utilities.core.imports import RequirementCache
|
||||
|
|
|
@ -20,8 +20,9 @@ import contextlib
|
|||
import logging
|
||||
import os
|
||||
from argparse import Namespace
|
||||
from collections.abc import Generator
|
||||
from functools import wraps
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Optional, Set, Union
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Union
|
||||
|
||||
from lightning_utilities.core.imports import RequirementCache
|
||||
from torch import Tensor
|
||||
|
|
|
@ -18,8 +18,9 @@ Weights and Biases Logger
|
|||
|
||||
import os
|
||||
from argparse import Namespace
|
||||
from collections.abc import Mapping
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
import torch.nn as nn
|
||||
from lightning_utilities.core.imports import RequirementCache
|
||||
|
|
|
@ -15,7 +15,8 @@ import os
|
|||
import shutil
|
||||
import sys
|
||||
from collections import ChainMap, OrderedDict, defaultdict
|
||||
from typing import Any, DefaultDict, Iterable, Iterator, List, Optional, Tuple, Union
|
||||
from collections.abc import Iterable, Iterator
|
||||
from typing import Any, DefaultDict, List, Optional, Tuple, Union
|
||||
|
||||
from lightning_utilities.core.apply_func import apply_to_collection
|
||||
from torch import Tensor
|
||||
|
|
|
@ -12,7 +12,8 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Iterator, List, Optional
|
||||
from collections.abc import Iterator
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
|
|
|
@ -11,9 +11,11 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Dict, Mapping, Optional, OrderedDict
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
|
|
@ -12,7 +12,8 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Iterator, List, Optional, Union
|
||||
from collections.abc import Iterator
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from lightning_utilities import WarningCache
|
||||
|
|
|
@ -12,8 +12,9 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import inspect
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Callable, ContextManager, Generator, Optional, Tuple, Type
|
||||
from typing import Any, Callable, ContextManager, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
|
|
@ -12,7 +12,8 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import itertools
|
||||
from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Sized, Union, cast
|
||||
from collections.abc import Iterable, Iterator, Sized
|
||||
from typing import Any, Callable, Dict, List, Optional, Union, cast
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
|
|
@ -9,8 +9,9 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Callable, Dict, Generator, Literal, Optional, Union
|
||||
from typing import Any, Callable, Dict, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
|
|
@ -11,8 +11,9 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, ContextManager, Generator, Literal
|
||||
from typing import Any, ContextManager, Literal
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
|
|
@ -11,8 +11,9 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, ContextManager, Generator, Literal
|
||||
from typing import Any, ContextManager, Literal
|
||||
|
||||
import torch
|
||||
from lightning_utilities import apply_to_collection
|
||||
|
|
|
@ -12,8 +12,9 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import contextlib
|
||||
from collections.abc import Generator
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Generator, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
|
|
@ -89,9 +89,10 @@ class AdvancedProfiler(Profiler):
|
|||
dst_fs = get_filesystem(dst_filepath)
|
||||
dst_fs.mkdirs(self.dirpath, exist_ok=True)
|
||||
# temporarily save to local since pstats can only dump into a local file
|
||||
with tempfile.TemporaryDirectory(
|
||||
prefix="test", suffix=str(rank_zero_only.rank), dir=os.getcwd()
|
||||
) as tmp_dir, dst_fs.open(dst_filepath, "wb") as dst_file:
|
||||
with (
|
||||
tempfile.TemporaryDirectory(prefix="test", suffix=str(rank_zero_only.rank), dir=os.getcwd()) as tmp_dir,
|
||||
dst_fs.open(dst_filepath, "wb") as dst_file,
|
||||
):
|
||||
src_filepath = os.path.join(tmp_dir, "tmp.prof")
|
||||
profile.dump_stats(src_filepath)
|
||||
src_fs = get_filesystem(src_filepath)
|
||||
|
|
|
@ -16,9 +16,10 @@
|
|||
import logging
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Generator, Optional, TextIO, Union
|
||||
from typing import Any, Callable, Dict, Optional, TextIO, Union
|
||||
|
||||
from lightning.fabric.utilities.cloud_io import get_filesystem
|
||||
|
||||
|
|
|
@ -17,9 +17,10 @@ import logging
|
|||
import os
|
||||
import platform
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Generator, Mapping
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Mapping, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch.nn import Module
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
import logging
|
||||
import shutil
|
||||
from collections.abc import Generator, Mapping
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
|
@ -21,10 +22,8 @@ from typing import (
|
|||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Generator,
|
||||
List,
|
||||
Literal,
|
||||
Mapping,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
|
|
|
@ -12,10 +12,11 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import shutil
|
||||
from collections.abc import Generator, Mapping
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal, Mapping, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
from lightning_utilities.core.rank_zero import rank_zero_only as utils_rank_zero_only
|
||||
|
|
|
@ -12,8 +12,9 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, Generator, List, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
|
|
@ -13,8 +13,9 @@
|
|||
# limitations under the License.
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Generator, Mapping
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Callable, Dict, Generator, List, Mapping, Optional, Tuple, TypeVar, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
|
|
@ -115,7 +115,11 @@ def _call_configure_model(trainer: "pl.Trainer") -> None:
|
|||
# we don't normally check for this before calling the hook. it is done here to avoid instantiating the context
|
||||
# managers
|
||||
if is_overridden("configure_model", trainer.lightning_module):
|
||||
with trainer.strategy.tensor_init_context(), trainer.strategy.model_sharded_context(), trainer.precision_plugin.module_init_context(): # noqa: E501
|
||||
with (
|
||||
trainer.strategy.tensor_init_context(),
|
||||
trainer.strategy.model_sharded_context(),
|
||||
trainer.precision_plugin.module_init_context(),
|
||||
): # noqa: E501
|
||||
_call_lightning_module_hook(trainer, "configure_model")
|
||||
|
||||
|
||||
|
|
|
@ -14,8 +14,9 @@
|
|||
|
||||
import logging
|
||||
import os
|
||||
from collections.abc import Sequence
|
||||
from datetime import timedelta
|
||||
from typing import Dict, List, Optional, Sequence, Union
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import lightning.pytorch as pl
|
||||
from lightning.fabric.utilities.registry import _load_external_callbacks
|
||||
|
|
|
@ -12,8 +12,9 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Iterable, Optional, Tuple, Union
|
||||
from typing import Any, Optional, Tuple, Union
|
||||
|
||||
import torch.multiprocessing as mp
|
||||
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, Sampler, SequentialSampler
|
||||
|
|
|
@ -11,7 +11,8 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Iterable, Optional, Union
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from lightning_utilities.core.apply_func import apply_to_collection
|
||||
from torch import Tensor
|
||||
|
|
|
@ -11,9 +11,10 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from collections.abc import Generator
|
||||
from dataclasses import dataclass
|
||||
from functools import partial, wraps
|
||||
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union, cast
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast
|
||||
|
||||
import torch
|
||||
from lightning_utilities.core.apply_func import apply_to_collection
|
||||
|
|
|
@ -23,9 +23,10 @@
|
|||
import logging
|
||||
import math
|
||||
import os
|
||||
from collections.abc import Generator, Iterable
|
||||
from contextlib import contextmanager
|
||||
from datetime import timedelta
|
||||
from typing import Any, Dict, Generator, Iterable, List, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from weakref import proxy
|
||||
|
||||
import torch
|
||||
|
|
|
@ -12,8 +12,8 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import contextlib
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, Callable, Dict, Iterator, List, Literal, Optional, Tuple, Type, Union
|
||||
from collections.abc import Iterable, Iterator
|
||||
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, Union
|
||||
|
||||
from torch.utils.data.dataloader import _BaseDataLoaderIter, _MultiProcessingDataLoaderIter
|
||||
from typing_extensions import Self, TypedDict, override
|
||||
|
|
|
@ -12,8 +12,9 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import inspect
|
||||
from collections.abc import Generator, Iterable, Mapping, Sized
|
||||
from dataclasses import fields
|
||||
from typing import Any, Dict, Generator, Iterable, Mapping, Optional, Sized, Tuple, Union
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from lightning_utilities.core.apply_func import is_dataclass_instance
|
||||
|
|
|
@ -17,8 +17,9 @@ import copy
|
|||
import inspect
|
||||
import pickle
|
||||
import types
|
||||
from collections.abc import MutableMapping, Sequence
|
||||
from dataclasses import fields, is_dataclass
|
||||
from typing import Any, Dict, List, Literal, MutableMapping, Optional, Sequence, Tuple, Type, Union
|
||||
from typing import Any, Dict, List, Literal, Optional, Tuple, Type, Union
|
||||
|
||||
from torch import nn
|
||||
|
||||
|
|
|
@ -13,8 +13,8 @@
|
|||
# limitations under the License.
|
||||
"""Utilities to help with reproducibility of models."""
|
||||
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from typing import Generator
|
||||
|
||||
from lightning.fabric.utilities.seed import _collect_rng_states, _set_rng_states
|
||||
|
||||
|
|
|
@ -17,17 +17,14 @@ Convention:
|
|||
- Types used in public hooks (as those in the `LightningModule` and `Callback`) should be public (no leading `_`)
|
||||
"""
|
||||
|
||||
from collections.abc import Generator, Iterator, Mapping, Sequence
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import (
|
||||
Any,
|
||||
Generator,
|
||||
Iterator,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Protocol,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
TypedDict,
|
||||
|
|
|
@ -121,27 +121,32 @@ def test_tf32_message(_, __, ___, caplog, monkeypatch):
|
|||
def test_find_usable_cuda_devices_error_handling():
|
||||
"""Test error handling for edge cases when using `find_usable_cuda_devices`."""
|
||||
# Asking for GPUs if no GPUs visible
|
||||
with mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=0), pytest.raises(
|
||||
ValueError, match="You requested to find 2 devices but there are no visible CUDA"
|
||||
with (
|
||||
mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=0),
|
||||
pytest.raises(ValueError, match="You requested to find 2 devices but there are no visible CUDA"),
|
||||
):
|
||||
find_usable_cuda_devices(2)
|
||||
|
||||
# Asking for more GPUs than are visible
|
||||
with mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=1), pytest.raises(
|
||||
ValueError, match="this machine only has 1 GPUs"
|
||||
with (
|
||||
mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=1),
|
||||
pytest.raises(ValueError, match="this machine only has 1 GPUs"),
|
||||
):
|
||||
find_usable_cuda_devices(2)
|
||||
|
||||
# All GPUs are unusable
|
||||
tensor_mock = Mock(side_effect=RuntimeError) # simulate device placement fails
|
||||
with mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=2), mock.patch(
|
||||
"lightning.fabric.accelerators.cuda.torch.tensor", tensor_mock
|
||||
), pytest.raises(RuntimeError, match=escape("The devices [0, 1] are occupied by other processes")):
|
||||
with (
|
||||
mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=2),
|
||||
mock.patch("lightning.fabric.accelerators.cuda.torch.tensor", tensor_mock),
|
||||
pytest.raises(RuntimeError, match=escape("The devices [0, 1] are occupied by other processes")),
|
||||
):
|
||||
find_usable_cuda_devices(2)
|
||||
|
||||
# Request for as many GPUs as there are, no error should be raised
|
||||
with mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=5), mock.patch(
|
||||
"lightning.fabric.accelerators.cuda.torch.tensor"
|
||||
with (
|
||||
mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=5),
|
||||
mock.patch("lightning.fabric.accelerators.cuda.torch.tensor"),
|
||||
):
|
||||
assert find_usable_cuda_devices(-1) == [0, 1, 2, 3, 4]
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Iterator
|
||||
from collections.abc import Iterator
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
|
|
@ -29,13 +29,16 @@ PASSED_OBJECT = mock.Mock()
|
|||
|
||||
@contextlib.contextmanager
|
||||
def check_destroy_group():
|
||||
with mock.patch(
|
||||
"lightning.fabric.plugins.collectives.torch_collective.TorchCollective.new_group",
|
||||
wraps=TorchCollective.new_group,
|
||||
) as mock_new, mock.patch(
|
||||
"lightning.fabric.plugins.collectives.torch_collective.TorchCollective.destroy_group",
|
||||
wraps=TorchCollective.destroy_group,
|
||||
) as mock_destroy:
|
||||
with (
|
||||
mock.patch(
|
||||
"lightning.fabric.plugins.collectives.torch_collective.TorchCollective.new_group",
|
||||
wraps=TorchCollective.new_group,
|
||||
) as mock_new,
|
||||
mock.patch(
|
||||
"lightning.fabric.plugins.collectives.torch_collective.TorchCollective.destroy_group",
|
||||
wraps=TorchCollective.destroy_group,
|
||||
) as mock_destroy,
|
||||
):
|
||||
yield
|
||||
# 0 to account for tests that mock distributed
|
||||
# -1 to account for destroying the default process group
|
||||
|
@ -155,9 +158,10 @@ def test_repeated_create_and_destroy():
|
|||
with pytest.raises(RuntimeError, match="TorchCollective` already owns a group"):
|
||||
collective.create_group()
|
||||
|
||||
with mock.patch.dict("torch.distributed.distributed_c10d._pg_map", {collective.group: ("", None)}), mock.patch(
|
||||
"torch.distributed.destroy_process_group"
|
||||
) as destroy_mock:
|
||||
with (
|
||||
mock.patch.dict("torch.distributed.distributed_c10d._pg_map", {collective.group: ("", None)}),
|
||||
mock.patch("torch.distributed.destroy_process_group") as destroy_mock,
|
||||
):
|
||||
collective.teardown()
|
||||
# this would be called twice if `init_process_group` wasn't patched. once for the group and once for the default
|
||||
# group
|
||||
|
@ -300,9 +304,11 @@ def test_collective_manages_default_group():
|
|||
|
||||
assert TorchCollective.manages_default_group
|
||||
|
||||
with mock.patch.object(collective, "_group") as mock_group, mock.patch.dict(
|
||||
"torch.distributed.distributed_c10d._pg_map", {mock_group: ("", None)}
|
||||
), mock.patch("torch.distributed.destroy_process_group") as destroy_mock:
|
||||
with (
|
||||
mock.patch.object(collective, "_group") as mock_group,
|
||||
mock.patch.dict("torch.distributed.distributed_c10d._pg_map", {mock_group: ("", None)}),
|
||||
mock.patch("torch.distributed.destroy_process_group") as destroy_mock,
|
||||
):
|
||||
collective.teardown()
|
||||
destroy_mock.assert_called_once_with(mock_group)
|
||||
|
||||
|
|
|
@ -41,8 +41,9 @@ def test_empty_lsb_djob_rankfile():
|
|||
|
||||
def test_missing_lsb_job_id(tmp_path):
|
||||
"""Test an error when the job id cannot be found."""
|
||||
with mock.patch.dict(os.environ, {"LSB_DJOB_RANKFILE": _make_rankfile(tmp_path)}), pytest.raises(
|
||||
ValueError, match="Could not find job id in environment variable LSB_JOBID"
|
||||
with (
|
||||
mock.patch.dict(os.environ, {"LSB_DJOB_RANKFILE": _make_rankfile(tmp_path)}),
|
||||
pytest.raises(ValueError, match="Could not find job id in environment variable LSB_JOBID"),
|
||||
):
|
||||
LSFEnvironment()
|
||||
|
||||
|
|
|
@ -155,8 +155,9 @@ def test_srun_variable_validation():
|
|||
"""Test that we raise useful errors when `srun` variables are misconfigured."""
|
||||
with mock.patch.dict(os.environ, {"SLURM_NTASKS": "1"}):
|
||||
SLURMEnvironment()
|
||||
with mock.patch.dict(os.environ, {"SLURM_NTASKS": "2"}), pytest.raises(
|
||||
RuntimeError, match="You set `--ntasks=2` in your SLURM"
|
||||
with (
|
||||
mock.patch.dict(os.environ, {"SLURM_NTASKS": "2"}),
|
||||
pytest.raises(RuntimeError, match="You set `--ntasks=2` in your SLURM"),
|
||||
):
|
||||
SLURMEnvironment()
|
||||
|
||||
|
|
|
@ -100,8 +100,11 @@ def test_check_for_bad_cuda_fork(mp_mock, _, start_method):
|
|||
|
||||
def test_check_for_missing_main_guard():
|
||||
launcher = _MultiProcessingLauncher(strategy=Mock(), start_method="spawn")
|
||||
with mock.patch(
|
||||
"lightning.fabric.strategies.launchers.multiprocessing.mp.current_process",
|
||||
return_value=Mock(_inheriting=True), # pretend that main is importing itself
|
||||
), pytest.raises(RuntimeError, match="requires that your script guards the main"):
|
||||
with (
|
||||
mock.patch(
|
||||
"lightning.fabric.strategies.launchers.multiprocessing.mp.current_process",
|
||||
return_value=Mock(_inheriting=True), # pretend that main is importing itself
|
||||
),
|
||||
pytest.raises(RuntimeError, match="requires that your script guards the main"),
|
||||
):
|
||||
launcher.launch(function=Mock())
|
||||
|
|
|
@ -58,9 +58,12 @@ def test_ddp_no_backward_sync():
|
|||
strategy = DDPStrategy()
|
||||
assert isinstance(strategy._backward_sync_control, _DDPBackwardSyncControl)
|
||||
|
||||
with pytest.raises(
|
||||
TypeError, match="is only possible if the module passed to .* is wrapped in `DistributedDataParallel`"
|
||||
), strategy._backward_sync_control.no_backward_sync(Mock(), True):
|
||||
with (
|
||||
pytest.raises(
|
||||
TypeError, match="is only possible if the module passed to .* is wrapped in `DistributedDataParallel`"
|
||||
),
|
||||
strategy._backward_sync_control.no_backward_sync(Mock(), True),
|
||||
):
|
||||
pass
|
||||
|
||||
module = MagicMock(spec=DistributedDataParallel)
|
||||
|
|
|
@ -404,9 +404,11 @@ def test_deepspeed_init_module_with_stages_1_2(stage, empty_init):
|
|||
fabric = Fabric(accelerator="cuda", devices=2, strategy=strategy, precision="bf16-true")
|
||||
fabric.launch()
|
||||
|
||||
with mock.patch("deepspeed.zero.Init") as zero_init_mock, mock.patch(
|
||||
"torch.Tensor.uniform_"
|
||||
) as init_mock, fabric.init_module(empty_init=empty_init):
|
||||
with (
|
||||
mock.patch("deepspeed.zero.Init") as zero_init_mock,
|
||||
mock.patch("torch.Tensor.uniform_") as init_mock,
|
||||
fabric.init_module(empty_init=empty_init),
|
||||
):
|
||||
model = BoringModel()
|
||||
|
||||
zero_init_mock.assert_called_with(enabled=False, remote_device=None, config_dict_or_path=ANY)
|
||||
|
|
|
@ -133,9 +133,12 @@ def test_no_backward_sync():
|
|||
strategy = FSDPStrategy()
|
||||
assert isinstance(strategy._backward_sync_control, _FSDPBackwardSyncControl)
|
||||
|
||||
with pytest.raises(
|
||||
TypeError, match="is only possible if the module passed to .* is wrapped in `FullyShardedDataParallel`"
|
||||
), strategy._backward_sync_control.no_backward_sync(Mock(), True):
|
||||
with (
|
||||
pytest.raises(
|
||||
TypeError, match="is only possible if the module passed to .* is wrapped in `FullyShardedDataParallel`"
|
||||
),
|
||||
strategy._backward_sync_control.no_backward_sync(Mock(), True),
|
||||
):
|
||||
pass
|
||||
|
||||
module = MagicMock(spec=FullyShardedDataParallel)
|
||||
|
@ -172,9 +175,12 @@ def test_activation_checkpointing():
|
|||
assert isinstance(strategy._activation_checkpointing_kwargs["auto_wrap_policy"], ModuleWrapPolicy)
|
||||
|
||||
strategy._parallel_devices = [torch.device("cuda", 0)]
|
||||
with mock.patch("torch.distributed.fsdp.FullyShardedDataParallel", new=MagicMock), mock.patch(
|
||||
"torch.distributed.algorithms._checkpoint.checkpoint_wrapper.apply_activation_checkpointing"
|
||||
) as apply_mock:
|
||||
with (
|
||||
mock.patch("torch.distributed.fsdp.FullyShardedDataParallel", new=MagicMock),
|
||||
mock.patch(
|
||||
"torch.distributed.algorithms._checkpoint.checkpoint_wrapper.apply_activation_checkpointing"
|
||||
) as apply_mock,
|
||||
):
|
||||
wrapped = strategy.setup_module(Model())
|
||||
apply_mock.assert_called_with(wrapped, checkpoint_wrapper_fn=ANY, **strategy._activation_checkpointing_kwargs)
|
||||
|
||||
|
|
|
@ -48,9 +48,12 @@ def test_xla_fsdp_no_backward_sync():
|
|||
strategy = XLAFSDPStrategy()
|
||||
assert isinstance(strategy._backward_sync_control, _XLAFSDPBackwardSyncControl)
|
||||
|
||||
with pytest.raises(
|
||||
TypeError, match="is only possible if the module passed to .* is wrapped in `XlaFullyShardedDataParallel`"
|
||||
), strategy._backward_sync_control.no_backward_sync(object(), True):
|
||||
with (
|
||||
pytest.raises(
|
||||
TypeError, match="is only possible if the module passed to .* is wrapped in `XlaFullyShardedDataParallel`"
|
||||
),
|
||||
strategy._backward_sync_control.no_backward_sync(object(), True),
|
||||
):
|
||||
pass
|
||||
|
||||
module = MagicMock(spec=XlaFullyShardedDataParallel)
|
||||
|
|
|
@ -960,28 +960,33 @@ def test_arguments_from_environment_collision():
|
|||
with mock.patch.dict(os.environ, {"LT_ACCELERATOR": "cpu"}):
|
||||
_Connector(accelerator="cuda")
|
||||
|
||||
with mock.patch.dict(os.environ, {"LT_ACCELERATOR": "cpu", "LT_CLI_USED": "1"}), pytest.raises(
|
||||
ValueError, match="`Fabric\\(accelerator='cuda', ...\\)` but .* `--accelerator=cpu`"
|
||||
with (
|
||||
mock.patch.dict(os.environ, {"LT_ACCELERATOR": "cpu", "LT_CLI_USED": "1"}),
|
||||
pytest.raises(ValueError, match="`Fabric\\(accelerator='cuda', ...\\)` but .* `--accelerator=cpu`"),
|
||||
):
|
||||
_Connector(accelerator="cuda")
|
||||
|
||||
with mock.patch.dict(os.environ, {"LT_STRATEGY": "ddp", "LT_CLI_USED": "1"}), pytest.raises(
|
||||
ValueError, match="`Fabric\\(strategy='ddp_spawn', ...\\)` but .* `--strategy=ddp`"
|
||||
with (
|
||||
mock.patch.dict(os.environ, {"LT_STRATEGY": "ddp", "LT_CLI_USED": "1"}),
|
||||
pytest.raises(ValueError, match="`Fabric\\(strategy='ddp_spawn', ...\\)` but .* `--strategy=ddp`"),
|
||||
):
|
||||
_Connector(strategy="ddp_spawn")
|
||||
|
||||
with mock.patch.dict(os.environ, {"LT_DEVICES": "2", "LT_CLI_USED": "1"}), pytest.raises(
|
||||
ValueError, match="`Fabric\\(devices=3, ...\\)` but .* `--devices=2`"
|
||||
with (
|
||||
mock.patch.dict(os.environ, {"LT_DEVICES": "2", "LT_CLI_USED": "1"}),
|
||||
pytest.raises(ValueError, match="`Fabric\\(devices=3, ...\\)` but .* `--devices=2`"),
|
||||
):
|
||||
_Connector(devices=3)
|
||||
|
||||
with mock.patch.dict(os.environ, {"LT_NUM_NODES": "3", "LT_CLI_USED": "1"}), pytest.raises(
|
||||
ValueError, match="`Fabric\\(num_nodes=2, ...\\)` but .* `--num_nodes=3`"
|
||||
with (
|
||||
mock.patch.dict(os.environ, {"LT_NUM_NODES": "3", "LT_CLI_USED": "1"}),
|
||||
pytest.raises(ValueError, match="`Fabric\\(num_nodes=2, ...\\)` but .* `--num_nodes=3`"),
|
||||
):
|
||||
_Connector(num_nodes=2)
|
||||
|
||||
with mock.patch.dict(os.environ, {"LT_PRECISION": "16-mixed", "LT_CLI_USED": "1"}), pytest.raises(
|
||||
ValueError, match="`Fabric\\(precision='64-true', ...\\)` but .* `--precision=16-mixed`"
|
||||
with (
|
||||
mock.patch.dict(os.environ, {"LT_PRECISION": "16-mixed", "LT_CLI_USED": "1"}),
|
||||
pytest.raises(ValueError, match="`Fabric\\(precision='64-true', ...\\)` but .* `--precision=16-mixed`"),
|
||||
):
|
||||
_Connector(precision="64-true")
|
||||
|
||||
|
|
|
@ -746,9 +746,10 @@ def test_no_backward_sync():
|
|||
|
||||
# pretend that the strategy does not support skipping backward sync
|
||||
fabric._strategy = Mock(spec=ParallelStrategy, _backward_sync_control=None)
|
||||
with pytest.warns(
|
||||
PossibleUserWarning, match="The `ParallelStrategy` does not support skipping the"
|
||||
), fabric.no_backward_sync(model):
|
||||
with (
|
||||
pytest.warns(PossibleUserWarning, match="The `ParallelStrategy` does not support skipping the"),
|
||||
fabric.no_backward_sync(model),
|
||||
):
|
||||
pass
|
||||
|
||||
# for single-device strategies, it becomes a no-op without warning
|
||||
|
|
|
@ -41,8 +41,9 @@ def test_parse_cli_args(args, expected):
|
|||
def test_process_cli_args(tmp_path, caplog, monkeypatch):
|
||||
# PyTorch version < 2.3
|
||||
monkeypatch.setattr(lightning.fabric.utilities.consolidate_checkpoint, "_TORCH_GREATER_EQUAL_2_3", False)
|
||||
with caplog.at_level(logging.ERROR, logger="lightning.fabric.utilities.consolidate_checkpoint"), pytest.raises(
|
||||
SystemExit
|
||||
with (
|
||||
caplog.at_level(logging.ERROR, logger="lightning.fabric.utilities.consolidate_checkpoint"),
|
||||
pytest.raises(SystemExit),
|
||||
):
|
||||
_process_cli_args(Namespace())
|
||||
assert "requires PyTorch >= 2.3." in caplog.text
|
||||
|
@ -51,8 +52,9 @@ def test_process_cli_args(tmp_path, caplog, monkeypatch):
|
|||
|
||||
# Checkpoint does not exist
|
||||
checkpoint_folder = Path("does/not/exist")
|
||||
with caplog.at_level(logging.ERROR, logger="lightning.fabric.utilities.consolidate_checkpoint"), pytest.raises(
|
||||
SystemExit
|
||||
with (
|
||||
caplog.at_level(logging.ERROR, logger="lightning.fabric.utilities.consolidate_checkpoint"),
|
||||
pytest.raises(SystemExit),
|
||||
):
|
||||
_process_cli_args(Namespace(checkpoint_folder=checkpoint_folder))
|
||||
assert f"checkpoint folder does not exist: {checkpoint_folder}" in caplog.text
|
||||
|
@ -61,8 +63,9 @@ def test_process_cli_args(tmp_path, caplog, monkeypatch):
|
|||
# Checkpoint exists but is not a folder
|
||||
file = tmp_path / "checkpoint_file"
|
||||
file.touch()
|
||||
with caplog.at_level(logging.ERROR, logger="lightning.fabric.utilities.consolidate_checkpoint"), pytest.raises(
|
||||
SystemExit
|
||||
with (
|
||||
caplog.at_level(logging.ERROR, logger="lightning.fabric.utilities.consolidate_checkpoint"),
|
||||
pytest.raises(SystemExit),
|
||||
):
|
||||
_process_cli_args(Namespace(checkpoint_folder=file))
|
||||
assert "checkpoint path must be a folder" in caplog.text
|
||||
|
@ -71,8 +74,9 @@ def test_process_cli_args(tmp_path, caplog, monkeypatch):
|
|||
# Checkpoint exists but is not an FSDP checkpoint
|
||||
folder = tmp_path / "checkpoint_folder"
|
||||
folder.mkdir()
|
||||
with caplog.at_level(logging.ERROR, logger="lightning.fabric.utilities.consolidate_checkpoint"), pytest.raises(
|
||||
SystemExit
|
||||
with (
|
||||
caplog.at_level(logging.ERROR, logger="lightning.fabric.utilities.consolidate_checkpoint"),
|
||||
pytest.raises(SystemExit),
|
||||
):
|
||||
_process_cli_args(Namespace(checkpoint_folder=folder))
|
||||
assert "Only FSDP-sharded checkpoints saved with Lightning are supported" in caplog.text
|
||||
|
@ -89,8 +93,9 @@ def test_process_cli_args(tmp_path, caplog, monkeypatch):
|
|||
# Checkpoint is a FSDP folder, output file already exists
|
||||
file = tmp_path / "ouput_file"
|
||||
file.touch()
|
||||
with caplog.at_level(logging.ERROR, logger="lightning.fabric.utilities.consolidate_checkpoint"), pytest.raises(
|
||||
SystemExit
|
||||
with (
|
||||
caplog.at_level(logging.ERROR, logger="lightning.fabric.utilities.consolidate_checkpoint"),
|
||||
pytest.raises(SystemExit),
|
||||
):
|
||||
_process_cli_args(Namespace(checkpoint_folder=folder, output_file=file))
|
||||
assert "path for the converted checkpoint already exists" in caplog.text
|
||||
|
|
|
@ -215,9 +215,10 @@ def test_infinite_barrier():
|
|||
|
||||
# distributed available
|
||||
barrier = _InfiniteBarrier()
|
||||
with mock.patch(
|
||||
"lightning.fabric.utilities.distributed._distributed_is_initialized", return_value=True
|
||||
), mock.patch("lightning.fabric.utilities.distributed.torch.distributed") as dist_mock:
|
||||
with (
|
||||
mock.patch("lightning.fabric.utilities.distributed._distributed_is_initialized", return_value=True),
|
||||
mock.patch("lightning.fabric.utilities.distributed.torch.distributed") as dist_mock,
|
||||
):
|
||||
barrier.__enter__()
|
||||
dist_mock.new_group.assert_called_once()
|
||||
assert barrier.barrier == barrier.group.monitored_barrier
|
||||
|
|
|
@ -39,8 +39,9 @@ def test_get_available_flops(xla_available):
|
|||
with pytest.warns(match="not found for 'CocoNut"), mock.patch("torch.cuda.get_device_name", return_value="CocoNut"):
|
||||
assert get_available_flops(torch.device("cuda"), torch.bfloat16) is None
|
||||
|
||||
with pytest.warns(match="t4' does not support torch.bfloat"), mock.patch(
|
||||
"torch.cuda.get_device_name", return_value="t4"
|
||||
with (
|
||||
pytest.warns(match="t4' does not support torch.bfloat"),
|
||||
mock.patch("torch.cuda.get_device_name", return_value="t4"),
|
||||
):
|
||||
assert get_available_flops(torch.device("cuda"), torch.bfloat16) is None
|
||||
|
||||
|
|
|
@ -141,9 +141,12 @@ def test_rich_progress_bar_keyboard_interrupt(tmp_path):
|
|||
|
||||
model = TestModel()
|
||||
|
||||
with mock.patch(
|
||||
"lightning.pytorch.callbacks.progress.rich_progress.Progress.stop", autospec=True
|
||||
) as mock_progress_stop, pytest.raises(SystemExit):
|
||||
with (
|
||||
mock.patch(
|
||||
"lightning.pytorch.callbacks.progress.rich_progress.Progress.stop", autospec=True
|
||||
) as mock_progress_stop,
|
||||
pytest.raises(SystemExit),
|
||||
):
|
||||
progress_bar = RichProgressBar()
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmp_path,
|
||||
|
|
|
@ -43,8 +43,9 @@ def test_throughput_monitor_fit(tmp_path):
|
|||
)
|
||||
# these timing results are meant to precisely match the `test_throughput_monitor` test in fabric
|
||||
timings = [0.0] + [0.5 + i for i in range(1, 6)]
|
||||
with mock.patch("lightning.pytorch.callbacks.throughput_monitor.get_available_flops", return_value=100), mock.patch(
|
||||
"time.perf_counter", side_effect=timings
|
||||
with (
|
||||
mock.patch("lightning.pytorch.callbacks.throughput_monitor.get_available_flops", return_value=100),
|
||||
mock.patch("time.perf_counter", side_effect=timings),
|
||||
):
|
||||
trainer.fit(model)
|
||||
|
||||
|
@ -179,8 +180,9 @@ def test_throughput_monitor_fit_gradient_accumulation(log_every_n_steps, tmp_pat
|
|||
enable_progress_bar=False,
|
||||
)
|
||||
timings = [0.0] + [0.5 + i for i in range(1, 11)]
|
||||
with mock.patch("lightning.pytorch.callbacks.throughput_monitor.get_available_flops", return_value=100), mock.patch(
|
||||
"time.perf_counter", side_effect=timings
|
||||
with (
|
||||
mock.patch("lightning.pytorch.callbacks.throughput_monitor.get_available_flops", return_value=100),
|
||||
mock.patch("time.perf_counter", side_effect=timings),
|
||||
):
|
||||
trainer.fit(model)
|
||||
|
||||
|
|
|
@ -92,9 +92,10 @@ def test_trainer_save_checkpoint_storage_options(tmp_path, xla_available):
|
|||
io_mock.assert_called_with(ANY, instance_path, storage_options=None)
|
||||
|
||||
checkpoint_mock = Mock()
|
||||
with mock.patch.object(trainer.strategy, "save_checkpoint") as save_mock, mock.patch.object(
|
||||
trainer._checkpoint_connector, "dump_checkpoint", return_value=checkpoint_mock
|
||||
) as dump_mock:
|
||||
with (
|
||||
mock.patch.object(trainer.strategy, "save_checkpoint") as save_mock,
|
||||
mock.patch.object(trainer._checkpoint_connector, "dump_checkpoint", return_value=checkpoint_mock) as dump_mock,
|
||||
):
|
||||
trainer.save_checkpoint(instance_path, True)
|
||||
dump_mock.assert_called_with(True)
|
||||
save_mock.assert_called_with(checkpoint_mock, instance_path, storage_options=None)
|
||||
|
|
|
@ -110,9 +110,10 @@ def test_lightning_optimizer_manual_optimization_and_accumulated_gradients(tmp_p
|
|||
default_root_dir=tmp_path, limit_train_batches=8, limit_val_batches=1, max_epochs=1, enable_model_summary=False
|
||||
)
|
||||
|
||||
with patch.multiple(torch.optim.SGD, zero_grad=DEFAULT, step=DEFAULT) as sgd, patch.multiple(
|
||||
torch.optim.Adam, zero_grad=DEFAULT, step=DEFAULT
|
||||
) as adam:
|
||||
with (
|
||||
patch.multiple(torch.optim.SGD, zero_grad=DEFAULT, step=DEFAULT) as sgd,
|
||||
patch.multiple(torch.optim.Adam, zero_grad=DEFAULT, step=DEFAULT) as adam,
|
||||
):
|
||||
trainer.fit(model)
|
||||
|
||||
assert sgd["step"].call_count == 4
|
||||
|
|
|
@ -625,8 +625,9 @@ def test_logger_sync_dist(distributed_env, log_val):
|
|||
else nullcontext()
|
||||
)
|
||||
|
||||
with warning_ctx(
|
||||
PossibleUserWarning, match=r"recommended to use `self.log\('bar', ..., sync_dist=True\)`"
|
||||
), patch_ctx:
|
||||
with (
|
||||
warning_ctx(PossibleUserWarning, match=r"recommended to use `self.log\('bar', ..., sync_dist=True\)`"),
|
||||
patch_ctx,
|
||||
):
|
||||
value = _ResultCollection._get_cache(result_metric, on_step=False)
|
||||
assert value == 0.5
|
||||
|
|
|
@ -16,7 +16,8 @@ import os
|
|||
import random
|
||||
import time
|
||||
import urllib.request
|
||||
from typing import Optional, Sequence, Tuple
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
|
|
@ -11,8 +11,9 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from collections.abc import Iterator, Mapping
|
||||
from contextlib import nullcontext
|
||||
from typing import Dict, Generic, Iterator, Mapping, TypeVar
|
||||
from typing import Dict, Generic, TypeVar
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
|
|
@ -12,7 +12,8 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from collections import Counter
|
||||
from typing import Any, Iterator
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
|
|
@ -12,9 +12,10 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
from collections.abc import Iterator
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Iterator
|
||||
from typing import Dict
|
||||
from unittest.mock import ANY, Mock
|
||||
|
||||
import pytest
|
||||
|
|
|
@ -15,8 +15,9 @@ import glob
|
|||
import logging as log
|
||||
import os
|
||||
import pickle
|
||||
from collections.abc import Mapping
|
||||
from copy import deepcopy
|
||||
from typing import Generic, Mapping, TypeVar
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
import cloudpickle
|
||||
import pytest
|
||||
|
|
|
@ -11,7 +11,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Iterable
|
||||
from collections.abc import Iterable
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
|
|
@ -209,10 +209,13 @@ def test_memory_sharing_disabled(tmp_path):
|
|||
|
||||
def test_check_for_missing_main_guard():
|
||||
launcher = _MultiProcessingLauncher(strategy=Mock(), start_method="spawn")
|
||||
with mock.patch(
|
||||
"lightning.pytorch.strategies.launchers.multiprocessing.mp.current_process",
|
||||
return_value=Mock(_inheriting=True), # pretend that main is importing itself
|
||||
), pytest.raises(RuntimeError, match="requires that your script guards the main"):
|
||||
with (
|
||||
mock.patch(
|
||||
"lightning.pytorch.strategies.launchers.multiprocessing.mp.current_process",
|
||||
return_value=Mock(_inheriting=True), # pretend that main is importing itself
|
||||
),
|
||||
pytest.raises(RuntimeError, match="requires that your script guards the main"),
|
||||
):
|
||||
launcher.launch(function=Mock())
|
||||
|
||||
|
||||
|
|
|
@ -12,7 +12,8 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
from typing import Any, Mapping
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
|
|
@ -444,9 +444,12 @@ def test_activation_checkpointing():
|
|||
strategy._parallel_devices = [torch.device("cuda", 0)]
|
||||
strategy._lightning_module = model
|
||||
strategy._process_group = Mock()
|
||||
with mock.patch("torch.distributed.fsdp.FullyShardedDataParallel", new=MagicMock), mock.patch(
|
||||
"torch.distributed.algorithms._checkpoint.checkpoint_wrapper.apply_activation_checkpointing"
|
||||
) as apply_mock:
|
||||
with (
|
||||
mock.patch("torch.distributed.fsdp.FullyShardedDataParallel", new=MagicMock),
|
||||
mock.patch(
|
||||
"torch.distributed.algorithms._checkpoint.checkpoint_wrapper.apply_activation_checkpointing"
|
||||
) as apply_mock,
|
||||
):
|
||||
wrapped = strategy._setup_model(model)
|
||||
apply_mock.assert_called_with(wrapped, checkpoint_wrapper_fn=ANY, **strategy._activation_checkpointing_kwargs)
|
||||
|
||||
|
|
|
@ -409,8 +409,9 @@ def test_lightning_cli_config_and_subclass_mode(cleandir):
|
|||
with open(config_path, "w") as f:
|
||||
f.write(yaml.dump(input_config))
|
||||
|
||||
with mock.patch("sys.argv", ["any.py", "--config", config_path]), mock_subclasses(
|
||||
LightningDataModule, DataDirDataModule
|
||||
with (
|
||||
mock.patch("sys.argv", ["any.py", "--config", config_path]),
|
||||
mock_subclasses(LightningDataModule, DataDirDataModule),
|
||||
):
|
||||
cli = LightningCLI(
|
||||
BoringModel,
|
||||
|
@ -461,9 +462,12 @@ def test_lightning_cli_help():
|
|||
|
||||
cli_args = ["any.py", "fit", "--data.help=DataDirDataModule"]
|
||||
out = StringIO()
|
||||
with mock.patch("sys.argv", cli_args), redirect_stdout(out), mock_subclasses(
|
||||
LightningDataModule, DataDirDataModule
|
||||
), pytest.raises(SystemExit):
|
||||
with (
|
||||
mock.patch("sys.argv", cli_args),
|
||||
redirect_stdout(out),
|
||||
mock_subclasses(LightningDataModule, DataDirDataModule),
|
||||
pytest.raises(SystemExit),
|
||||
):
|
||||
any_model_any_data_cli()
|
||||
|
||||
assert ("--data.data_dir" in out.getvalue()) or ("--data.init_args.data_dir" in out.getvalue())
|
||||
|
@ -609,8 +613,9 @@ class EarlyExitTestModel(BoringModel):
|
|||
def test_cli_distributed_save_config_callback(cleandir, logger, strategy):
|
||||
from torch.multiprocessing import ProcessRaisedException
|
||||
|
||||
with mock.patch("sys.argv", ["any.py", "fit"]), pytest.raises(
|
||||
(MisconfigurationException, ProcessRaisedException), match=r"Error on fit start"
|
||||
with (
|
||||
mock.patch("sys.argv", ["any.py", "fit"]),
|
||||
pytest.raises((MisconfigurationException, ProcessRaisedException), match=r"Error on fit start"),
|
||||
):
|
||||
LightningCLI(
|
||||
EarlyExitTestModel,
|
||||
|
@ -710,12 +715,14 @@ def test_cli_no_need_configure_optimizers(cleandir):
|
|||
|
||||
from lightning.pytorch.trainer.configuration_validator import __verify_train_val_loop_configuration
|
||||
|
||||
with mock.patch("sys.argv", ["any.py", "fit", "--optimizer=Adam"]), mock.patch(
|
||||
"lightning.pytorch.Trainer._run_stage"
|
||||
) as run, mock.patch(
|
||||
"lightning.pytorch.trainer.configuration_validator.__verify_train_val_loop_configuration",
|
||||
wraps=__verify_train_val_loop_configuration,
|
||||
) as verify:
|
||||
with (
|
||||
mock.patch("sys.argv", ["any.py", "fit", "--optimizer=Adam"]),
|
||||
mock.patch("lightning.pytorch.Trainer._run_stage") as run,
|
||||
mock.patch(
|
||||
"lightning.pytorch.trainer.configuration_validator.__verify_train_val_loop_configuration",
|
||||
wraps=__verify_train_val_loop_configuration,
|
||||
) as verify,
|
||||
):
|
||||
cli = LightningCLI(BoringModel)
|
||||
run.assert_called_once()
|
||||
verify.assert_called_once_with(cli.trainer, cli.model)
|
||||
|
@ -1074,15 +1081,18 @@ class TestModel(BoringModel):
|
|||
|
||||
@_xfail_python_ge_3_11_9
|
||||
def test_lightning_cli_model_short_arguments():
|
||||
with mock.patch("sys.argv", ["any.py", "fit", "--model=BoringModel"]), mock.patch(
|
||||
"lightning.pytorch.Trainer._fit_impl"
|
||||
) as run, mock_subclasses(LightningModule, BoringModel, TestModel):
|
||||
with (
|
||||
mock.patch("sys.argv", ["any.py", "fit", "--model=BoringModel"]),
|
||||
mock.patch("lightning.pytorch.Trainer._fit_impl") as run,
|
||||
mock_subclasses(LightningModule, BoringModel, TestModel),
|
||||
):
|
||||
cli = LightningCLI(trainer_defaults={"fast_dev_run": 1})
|
||||
assert isinstance(cli.model, BoringModel)
|
||||
run.assert_called_once_with(cli.model, ANY, ANY, ANY, ANY)
|
||||
|
||||
with mock.patch("sys.argv", ["any.py", "--model=TestModel", "--model.foo", "123"]), mock_subclasses(
|
||||
LightningModule, BoringModel, TestModel
|
||||
with (
|
||||
mock.patch("sys.argv", ["any.py", "--model=TestModel", "--model.foo", "123"]),
|
||||
mock_subclasses(LightningModule, BoringModel, TestModel),
|
||||
):
|
||||
cli = LightningCLI(run=False)
|
||||
assert isinstance(cli.model, TestModel)
|
||||
|
@ -1100,15 +1110,18 @@ class MyDataModule(BoringDataModule):
|
|||
@_xfail_python_ge_3_11_9
|
||||
def test_lightning_cli_datamodule_short_arguments():
|
||||
# with set model
|
||||
with mock.patch("sys.argv", ["any.py", "fit", "--data=BoringDataModule"]), mock.patch(
|
||||
"lightning.pytorch.Trainer._fit_impl"
|
||||
) as run, mock_subclasses(LightningDataModule, BoringDataModule):
|
||||
with (
|
||||
mock.patch("sys.argv", ["any.py", "fit", "--data=BoringDataModule"]),
|
||||
mock.patch("lightning.pytorch.Trainer._fit_impl") as run,
|
||||
mock_subclasses(LightningDataModule, BoringDataModule),
|
||||
):
|
||||
cli = LightningCLI(BoringModel, trainer_defaults={"fast_dev_run": 1})
|
||||
assert isinstance(cli.datamodule, BoringDataModule)
|
||||
run.assert_called_once_with(ANY, ANY, ANY, cli.datamodule, ANY)
|
||||
|
||||
with mock.patch("sys.argv", ["any.py", "--data=MyDataModule", "--data.foo", "123"]), mock_subclasses(
|
||||
LightningDataModule, MyDataModule
|
||||
with (
|
||||
mock.patch("sys.argv", ["any.py", "--data=MyDataModule", "--data.foo", "123"]),
|
||||
mock_subclasses(LightningDataModule, MyDataModule),
|
||||
):
|
||||
cli = LightningCLI(BoringModel, run=False)
|
||||
assert isinstance(cli.datamodule, MyDataModule)
|
||||
|
@ -1116,17 +1129,22 @@ def test_lightning_cli_datamodule_short_arguments():
|
|||
assert cli.datamodule.bar == 5
|
||||
|
||||
# with configurable model
|
||||
with mock.patch("sys.argv", ["any.py", "fit", "--model", "BoringModel", "--data=BoringDataModule"]), mock.patch(
|
||||
"lightning.pytorch.Trainer._fit_impl"
|
||||
) as run, mock_subclasses(LightningModule, BoringModel), mock_subclasses(LightningDataModule, BoringDataModule):
|
||||
with (
|
||||
mock.patch("sys.argv", ["any.py", "fit", "--model", "BoringModel", "--data=BoringDataModule"]),
|
||||
mock.patch("lightning.pytorch.Trainer._fit_impl") as run,
|
||||
mock_subclasses(LightningModule, BoringModel),
|
||||
mock_subclasses(LightningDataModule, BoringDataModule),
|
||||
):
|
||||
cli = LightningCLI(trainer_defaults={"fast_dev_run": 1})
|
||||
assert isinstance(cli.model, BoringModel)
|
||||
assert isinstance(cli.datamodule, BoringDataModule)
|
||||
run.assert_called_once_with(cli.model, ANY, ANY, cli.datamodule, ANY)
|
||||
|
||||
with mock.patch("sys.argv", ["any.py", "--model", "BoringModel", "--data=MyDataModule"]), mock_subclasses(
|
||||
LightningModule, BoringModel
|
||||
), mock_subclasses(LightningDataModule, MyDataModule):
|
||||
with (
|
||||
mock.patch("sys.argv", ["any.py", "--model", "BoringModel", "--data=MyDataModule"]),
|
||||
mock_subclasses(LightningModule, BoringModel),
|
||||
mock_subclasses(LightningDataModule, MyDataModule),
|
||||
):
|
||||
cli = LightningCLI(run=False)
|
||||
assert isinstance(cli.model, BoringModel)
|
||||
assert isinstance(cli.datamodule, MyDataModule)
|
||||
|
@ -1293,9 +1311,10 @@ def test_optimizers_and_lr_schedulers_add_arguments_to_parser_implemented_reload
|
|||
|
||||
def test_lightning_cli_config_with_subcommand():
|
||||
config = {"test": {"trainer": {"limit_test_batches": 1}, "verbose": True, "ckpt_path": "foobar"}}
|
||||
with mock.patch("sys.argv", ["any.py", f"--config={config}"]), mock.patch(
|
||||
"lightning.pytorch.Trainer.test", autospec=True
|
||||
) as test_mock:
|
||||
with (
|
||||
mock.patch("sys.argv", ["any.py", f"--config={config}"]),
|
||||
mock.patch("lightning.pytorch.Trainer.test", autospec=True) as test_mock,
|
||||
):
|
||||
cli = LightningCLI(BoringModel)
|
||||
|
||||
test_mock.assert_called_once_with(cli.trainer, cli.model, verbose=True, ckpt_path="foobar")
|
||||
|
@ -1308,9 +1327,10 @@ def test_lightning_cli_config_before_subcommand():
|
|||
"test": {"trainer": {"limit_test_batches": 1}, "verbose": True, "ckpt_path": "foobar"},
|
||||
}
|
||||
|
||||
with mock.patch("sys.argv", ["any.py", f"--config={config}", "test"]), mock.patch(
|
||||
"lightning.pytorch.Trainer.test", autospec=True
|
||||
) as test_mock:
|
||||
with (
|
||||
mock.patch("sys.argv", ["any.py", f"--config={config}", "test"]),
|
||||
mock.patch("lightning.pytorch.Trainer.test", autospec=True) as test_mock,
|
||||
):
|
||||
cli = LightningCLI(BoringModel)
|
||||
|
||||
test_mock.assert_called_once_with(cli.trainer, model=cli.model, verbose=True, ckpt_path="foobar")
|
||||
|
@ -1320,9 +1340,10 @@ def test_lightning_cli_config_before_subcommand():
|
|||
assert save_config_callback.config.trainer.limit_test_batches == 1
|
||||
assert save_config_callback.parser.subcommand == "test"
|
||||
|
||||
with mock.patch("sys.argv", ["any.py", f"--config={config}", "validate"]), mock.patch(
|
||||
"lightning.pytorch.Trainer.validate", autospec=True
|
||||
) as validate_mock:
|
||||
with (
|
||||
mock.patch("sys.argv", ["any.py", f"--config={config}", "validate"]),
|
||||
mock.patch("lightning.pytorch.Trainer.validate", autospec=True) as validate_mock,
|
||||
):
|
||||
cli = LightningCLI(BoringModel)
|
||||
|
||||
validate_mock.assert_called_once_with(cli.trainer, cli.model, verbose=False, ckpt_path="barfoo")
|
||||
|
@ -1337,17 +1358,19 @@ def test_lightning_cli_config_before_subcommand_two_configs():
|
|||
config1 = {"validate": {"trainer": {"limit_val_batches": 1}, "verbose": False, "ckpt_path": "barfoo"}}
|
||||
config2 = {"test": {"trainer": {"limit_test_batches": 1}, "verbose": True, "ckpt_path": "foobar"}}
|
||||
|
||||
with mock.patch("sys.argv", ["any.py", f"--config={config1}", f"--config={config2}", "test"]), mock.patch(
|
||||
"lightning.pytorch.Trainer.test", autospec=True
|
||||
) as test_mock:
|
||||
with (
|
||||
mock.patch("sys.argv", ["any.py", f"--config={config1}", f"--config={config2}", "test"]),
|
||||
mock.patch("lightning.pytorch.Trainer.test", autospec=True) as test_mock,
|
||||
):
|
||||
cli = LightningCLI(BoringModel)
|
||||
|
||||
test_mock.assert_called_once_with(cli.trainer, model=cli.model, verbose=True, ckpt_path="foobar")
|
||||
assert cli.trainer.limit_test_batches == 1
|
||||
|
||||
with mock.patch("sys.argv", ["any.py", f"--config={config1}", f"--config={config2}", "validate"]), mock.patch(
|
||||
"lightning.pytorch.Trainer.validate", autospec=True
|
||||
) as validate_mock:
|
||||
with (
|
||||
mock.patch("sys.argv", ["any.py", f"--config={config1}", f"--config={config2}", "validate"]),
|
||||
mock.patch("lightning.pytorch.Trainer.validate", autospec=True) as validate_mock,
|
||||
):
|
||||
cli = LightningCLI(BoringModel)
|
||||
|
||||
validate_mock.assert_called_once_with(cli.trainer, cli.model, verbose=False, ckpt_path="barfoo")
|
||||
|
@ -1356,9 +1379,10 @@ def test_lightning_cli_config_before_subcommand_two_configs():
|
|||
|
||||
def test_lightning_cli_config_after_subcommand():
|
||||
config = {"trainer": {"limit_test_batches": 1}, "verbose": True, "ckpt_path": "foobar"}
|
||||
with mock.patch("sys.argv", ["any.py", "test", f"--config={config}"]), mock.patch(
|
||||
"lightning.pytorch.Trainer.test", autospec=True
|
||||
) as test_mock:
|
||||
with (
|
||||
mock.patch("sys.argv", ["any.py", "test", f"--config={config}"]),
|
||||
mock.patch("lightning.pytorch.Trainer.test", autospec=True) as test_mock,
|
||||
):
|
||||
cli = LightningCLI(BoringModel)
|
||||
|
||||
test_mock.assert_called_once_with(cli.trainer, cli.model, verbose=True, ckpt_path="foobar")
|
||||
|
@ -1368,9 +1392,10 @@ def test_lightning_cli_config_after_subcommand():
|
|||
def test_lightning_cli_config_before_and_after_subcommand():
|
||||
config1 = {"test": {"trainer": {"limit_test_batches": 1}, "verbose": True, "ckpt_path": "foobar"}}
|
||||
config2 = {"trainer": {"fast_dev_run": 1}, "verbose": False, "ckpt_path": "foobar"}
|
||||
with mock.patch("sys.argv", ["any.py", f"--config={config1}", "test", f"--config={config2}"]), mock.patch(
|
||||
"lightning.pytorch.Trainer.test", autospec=True
|
||||
) as test_mock:
|
||||
with (
|
||||
mock.patch("sys.argv", ["any.py", f"--config={config1}", "test", f"--config={config2}"]),
|
||||
mock.patch("lightning.pytorch.Trainer.test", autospec=True) as test_mock,
|
||||
):
|
||||
cli = LightningCLI(BoringModel)
|
||||
|
||||
test_mock.assert_called_once_with(cli.trainer, model=cli.model, verbose=False, ckpt_path="foobar")
|
||||
|
@ -1392,17 +1417,19 @@ def test_lightning_cli_parse_kwargs_with_subcommands(cleandir):
|
|||
"validate": {"default_config_files": [str(validate_config_path)]},
|
||||
}
|
||||
|
||||
with mock.patch("sys.argv", ["any.py", "fit"]), mock.patch(
|
||||
"lightning.pytorch.Trainer.fit", autospec=True
|
||||
) as fit_mock:
|
||||
with (
|
||||
mock.patch("sys.argv", ["any.py", "fit"]),
|
||||
mock.patch("lightning.pytorch.Trainer.fit", autospec=True) as fit_mock,
|
||||
):
|
||||
cli = LightningCLI(BoringModel, parser_kwargs=parser_kwargs)
|
||||
fit_mock.assert_called()
|
||||
assert cli.trainer.limit_train_batches == 2
|
||||
assert cli.trainer.limit_val_batches == 1.0
|
||||
|
||||
with mock.patch("sys.argv", ["any.py", "validate"]), mock.patch(
|
||||
"lightning.pytorch.Trainer.validate", autospec=True
|
||||
) as validate_mock:
|
||||
with (
|
||||
mock.patch("sys.argv", ["any.py", "validate"]),
|
||||
mock.patch("lightning.pytorch.Trainer.validate", autospec=True) as validate_mock,
|
||||
):
|
||||
cli = LightningCLI(BoringModel, parser_kwargs=parser_kwargs)
|
||||
validate_mock.assert_called()
|
||||
assert cli.trainer.limit_train_batches == 1.0
|
||||
|
@ -1420,9 +1447,10 @@ def test_lightning_cli_subcommands_common_default_config_files(cleandir):
|
|||
config_path.write_text(str(config))
|
||||
parser_kwargs = {"default_config_files": [str(config_path)]}
|
||||
|
||||
with mock.patch("sys.argv", ["any.py", "fit"]), mock.patch(
|
||||
"lightning.pytorch.Trainer.fit", autospec=True
|
||||
) as fit_mock:
|
||||
with (
|
||||
mock.patch("sys.argv", ["any.py", "fit"]),
|
||||
mock.patch("lightning.pytorch.Trainer.fit", autospec=True) as fit_mock,
|
||||
):
|
||||
cli = LightningCLI(Model, parser_kwargs=parser_kwargs)
|
||||
fit_mock.assert_called()
|
||||
assert cli.model.foo == 123
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue