Move src/pytorch_lightning/lite to src/lightning_lite (#14735)
This commit is contained in:
parent
f167d76508
commit
e3e71670e6
|
@ -38,10 +38,10 @@ import torchvision.transforms as T
|
|||
from torch.optim.lr_scheduler import StepLR
|
||||
from torchmetrics.classification import Accuracy
|
||||
|
||||
from lightning_lite.lite import LightningLite # import LightningLite
|
||||
from pytorch_lightning import seed_everything
|
||||
from pytorch_lightning.demos.boring_classes import Net
|
||||
from pytorch_lightning.demos.mnist_datamodule import MNIST
|
||||
from pytorch_lightning.lite import LightningLite # import LightningLite
|
||||
|
||||
DATASETS_PATH = path.join(path.dirname(__file__), "..", "..", "Datasets")
|
||||
|
||||
|
|
|
@ -34,10 +34,10 @@ import torchvision.transforms as T
|
|||
from torch.optim.lr_scheduler import StepLR
|
||||
from torchmetrics import Accuracy
|
||||
|
||||
from lightning_lite.lite import LightningLite
|
||||
from pytorch_lightning import seed_everything
|
||||
from pytorch_lightning.demos.boring_classes import Net
|
||||
from pytorch_lightning.demos.mnist_datamodule import MNIST
|
||||
from pytorch_lightning.lite import LightningLite
|
||||
|
||||
DATASETS_PATH = path.join(path.dirname(__file__), "..", "..", "Datasets")
|
||||
|
||||
|
|
|
@ -22,10 +22,10 @@ import torchvision.transforms as T
|
|||
from torch.optim.lr_scheduler import StepLR
|
||||
from torchmetrics import Accuracy
|
||||
|
||||
from lightning_lite.lite import LightningLite
|
||||
from pytorch_lightning import seed_everything
|
||||
from pytorch_lightning.demos.boring_classes import Net
|
||||
from pytorch_lightning.demos.mnist_datamodule import MNIST
|
||||
from pytorch_lightning.lite import LightningLite
|
||||
from pytorch_lightning.loops import Loop
|
||||
|
||||
DATASETS_PATH = path.join(path.dirname(__file__), "..", "..", "Datasets")
|
||||
|
|
|
@ -12,15 +12,11 @@ if not _root_logger.hasHandlers():
|
|||
_logger.addHandler(logging.StreamHandler())
|
||||
_logger.propagate = False
|
||||
|
||||
# TODO(lite): Re-enable this import
|
||||
# from lightning_lite.lite import LightningLite
|
||||
|
||||
from lightning_lite.lite import LightningLite # noqa: E402
|
||||
from lightning_lite.utilities.seed import seed_everything # noqa: E402
|
||||
|
||||
__all__ = [
|
||||
# TODO(lite): Re-enable this import
|
||||
# "LightningLite",
|
||||
"seed_everything",
|
||||
]
|
||||
__all__ = ["LightningLite", "seed_everything"]
|
||||
|
||||
# for compatibility with namespace packages
|
||||
__import__("pkg_resources").declare_namespace(__name__)
|
||||
|
|
|
@ -40,7 +40,7 @@ from lightning_lite.utilities.data import (
|
|||
)
|
||||
from lightning_lite.utilities.distributed import DistributedSamplerWrapper
|
||||
from lightning_lite.utilities.seed import seed_everything
|
||||
from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer
|
||||
from lightning_lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer
|
||||
|
||||
|
||||
class LightningLite(ABC):
|
|
@ -34,7 +34,7 @@ rank_zero_module.log = logging.getLogger(__name__)
|
|||
|
||||
|
||||
def _get_rank(
|
||||
strategy: Optional["lightning_lite.strategies.Strategy"] = None, # type: ignore[name-defined]
|
||||
strategy: Optional["lightning_lite.strategies.Strategy"] = None,
|
||||
) -> Optional[int]:
|
||||
if strategy is not None:
|
||||
return strategy.global_rank
|
||||
|
|
|
@ -261,7 +261,7 @@ class EarlyStopping(Callback):
|
|||
|
||||
@staticmethod
|
||||
def _log_info(trainer: Optional["pl.Trainer"], message: str, log_rank_zero_only: bool) -> None:
|
||||
rank = _get_rank(strategy=(trainer.strategy if trainer is not None else None))
|
||||
rank = _get_rank(strategy=(trainer.strategy if trainer is not None else None)) # type: ignore[arg-type]
|
||||
if trainer is not None and trainer.world_size <= 1:
|
||||
rank = None
|
||||
message = rank_prefixed_message(message, rank)
|
||||
|
|
|
@ -1,17 +0,0 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pytorch_lightning.lite.lite import LightningLite
|
||||
|
||||
__all__ = ["LightningLite"]
|
|
@ -23,7 +23,7 @@ from pkg_resources import get_distribution
|
|||
from lightning_lite.accelerators.mps import MPSAccelerator
|
||||
from lightning_lite.strategies.deepspeed import _DEEPSPEED_AVAILABLE
|
||||
from lightning_lite.strategies.fairscale import _FAIRSCALE_AVAILABLE
|
||||
from lightning_lite.utilities.imports import _TPU_AVAILABLE
|
||||
from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_10, _TPU_AVAILABLE
|
||||
|
||||
|
||||
class RunIf:
|
||||
|
@ -42,6 +42,7 @@ class RunIf:
|
|||
min_torch: Optional[str] = None,
|
||||
max_torch: Optional[str] = None,
|
||||
min_python: Optional[str] = None,
|
||||
bf16_cuda: bool = False,
|
||||
tpu: bool = False,
|
||||
mps: Optional[bool] = None,
|
||||
skip_windows: bool = False,
|
||||
|
@ -57,6 +58,7 @@ class RunIf:
|
|||
min_torch: Require that PyTorch is greater or equal than this version.
|
||||
max_torch: Require that PyTorch is less than this version.
|
||||
min_python: Require that Python is greater or equal than this version.
|
||||
bf16_cuda: Require that CUDA device supports bf16.
|
||||
tpu: Require that TPU is available.
|
||||
mps: If True: Require that MPS (Apple Silicon) is available,
|
||||
if False: Explicitly Require that MPS is not available
|
||||
|
@ -91,6 +93,20 @@ class RunIf:
|
|||
conditions.append(Version(py_version) < Version(min_python))
|
||||
reasons.append(f"python>={min_python}")
|
||||
|
||||
if bf16_cuda:
|
||||
try:
|
||||
cond = not (torch.cuda.is_available() and _TORCH_GREATER_EQUAL_1_10 and torch.cuda.is_bf16_supported())
|
||||
except (AssertionError, RuntimeError) as e:
|
||||
# AssertionError: Torch not compiled with CUDA enabled
|
||||
# RuntimeError: Found no NVIDIA driver on your system.
|
||||
is_unrelated = "Found no NVIDIA driver" not in str(e) or "Torch not compiled with CUDA" not in str(e)
|
||||
if is_unrelated:
|
||||
raise e
|
||||
cond = True
|
||||
|
||||
conditions.append(cond)
|
||||
reasons.append("CUDA device bf16")
|
||||
|
||||
if skip_windows:
|
||||
conditions.append(sys.platform == "win32")
|
||||
reasons.append("unimplemented on Windows")
|
||||
|
|
|
@ -20,17 +20,17 @@ import pytest
|
|||
import torch
|
||||
import torch.distributed
|
||||
import torch.nn.functional
|
||||
from tests_lite.helpers.runif import RunIf
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader, DistributedSampler, Sampler
|
||||
|
||||
from lightning_lite.lite import LightningLite
|
||||
from lightning_lite.plugins import Precision
|
||||
from lightning_lite.strategies import DeepSpeedStrategy, Strategy
|
||||
from lightning_lite.utilities import _StrategyType
|
||||
from lightning_lite.utilities.exceptions import MisconfigurationException
|
||||
from lightning_lite.utilities.seed import pl_worker_init_function
|
||||
from pytorch_lightning.lite import LightningLite
|
||||
from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer
|
||||
from tests_pytorch.helpers.runif import RunIf
|
||||
from lightning_lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer
|
||||
|
||||
|
||||
class EmptyLite(LightningLite):
|
||||
|
@ -165,7 +165,7 @@ def test_setup_dataloaders_return_type():
|
|||
assert lite_dataloader1.dataset is dataset1
|
||||
|
||||
|
||||
@mock.patch("pytorch_lightning.lite.lite._replace_dunder_methods")
|
||||
@mock.patch("lightning_lite.lite._replace_dunder_methods")
|
||||
def test_setup_dataloaders_captures_dataloader_arguments(ctx_manager):
|
||||
"""Test that Lite intercepts the DataLoader constructor arguments with a context manager in its run method."""
|
||||
|
||||
|
@ -210,7 +210,7 @@ def test_setup_dataloaders_twice_fails():
|
|||
|
||||
|
||||
@mock.patch(
|
||||
"pytorch_lightning.lite.lite.LightningLite.device",
|
||||
"lightning_lite.lite.LightningLite.device",
|
||||
new_callable=PropertyMock,
|
||||
return_value=torch.device("cuda", 1),
|
||||
)
|
|
@ -23,18 +23,18 @@ import torch.distributed
|
|||
import torch.multiprocessing as mp
|
||||
import torch.nn.functional
|
||||
from lightning_utilities.core.apply_func import apply_to_collection
|
||||
from tests_lite.helpers.runif import RunIf
|
||||
from torch import nn
|
||||
from torch.nn.parallel.distributed import DistributedDataParallel
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
from lightning_lite.lite import LightningLite
|
||||
from lightning_lite.plugins.environments.lightning_environment import find_free_network_port
|
||||
from lightning_lite.strategies.ddp_spawn import DDPSpawnStrategy
|
||||
from lightning_lite.utilities.apply_func import move_data_to_device
|
||||
from lightning_lite.utilities.cloud_io import atomic_save
|
||||
from pytorch_lightning.demos.boring_classes import RandomDataset
|
||||
from pytorch_lightning.lite import LightningLite
|
||||
from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy
|
||||
from tests_pytorch.helpers.runif import RunIf
|
||||
|
||||
|
||||
class BoringModel(nn.Module):
|
|
@ -15,12 +15,12 @@ from unittest.mock import Mock
|
|||
|
||||
import pytest
|
||||
import torch
|
||||
from tests_lite.helpers.runif import RunIf
|
||||
from torch.utils.data.dataloader import DataLoader
|
||||
|
||||
from lightning_lite.lite import LightningLite
|
||||
from lightning_lite.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
|
||||
from pytorch_lightning.lite import LightningLite
|
||||
from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer
|
||||
from tests_pytorch.helpers.runif import RunIf
|
||||
from lightning_lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer
|
||||
|
||||
|
||||
class EmptyLite(LightningLite):
|
Loading…
Reference in New Issue