Move src/pytorch_lightning/lite to src/lightning_lite (#14735)

This commit is contained in:
awaelchli 2022-09-18 00:34:53 +02:00 committed by Adrian Wälchli
parent f167d76508
commit e3e71670e6
14 changed files with 37 additions and 42 deletions

View File

@ -38,10 +38,10 @@ import torchvision.transforms as T
from torch.optim.lr_scheduler import StepLR
from torchmetrics.classification import Accuracy
from lightning_lite.lite import LightningLite # import LightningLite
from pytorch_lightning import seed_everything
from pytorch_lightning.demos.boring_classes import Net
from pytorch_lightning.demos.mnist_datamodule import MNIST
from pytorch_lightning.lite import LightningLite # import LightningLite
DATASETS_PATH = path.join(path.dirname(__file__), "..", "..", "Datasets")

View File

@ -34,10 +34,10 @@ import torchvision.transforms as T
from torch.optim.lr_scheduler import StepLR
from torchmetrics import Accuracy
from lightning_lite.lite import LightningLite
from pytorch_lightning import seed_everything
from pytorch_lightning.demos.boring_classes import Net
from pytorch_lightning.demos.mnist_datamodule import MNIST
from pytorch_lightning.lite import LightningLite
DATASETS_PATH = path.join(path.dirname(__file__), "..", "..", "Datasets")

View File

@ -22,10 +22,10 @@ import torchvision.transforms as T
from torch.optim.lr_scheduler import StepLR
from torchmetrics import Accuracy
from lightning_lite.lite import LightningLite
from pytorch_lightning import seed_everything
from pytorch_lightning.demos.boring_classes import Net
from pytorch_lightning.demos.mnist_datamodule import MNIST
from pytorch_lightning.lite import LightningLite
from pytorch_lightning.loops import Loop
DATASETS_PATH = path.join(path.dirname(__file__), "..", "..", "Datasets")

View File

@ -12,15 +12,11 @@ if not _root_logger.hasHandlers():
_logger.addHandler(logging.StreamHandler())
_logger.propagate = False
# TODO(lite): Re-enable this import
# from lightning_lite.lite import LightningLite
from lightning_lite.lite import LightningLite # noqa: E402
from lightning_lite.utilities.seed import seed_everything # noqa: E402
__all__ = [
# TODO(lite): Re-enable this import
# "LightningLite",
"seed_everything",
]
__all__ = ["LightningLite", "seed_everything"]
# for compatibility with namespace packages
__import__("pkg_resources").declare_namespace(__name__)

View File

@ -40,7 +40,7 @@ from lightning_lite.utilities.data import (
)
from lightning_lite.utilities.distributed import DistributedSamplerWrapper
from lightning_lite.utilities.seed import seed_everything
from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer
from lightning_lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer
class LightningLite(ABC):

View File

@ -34,7 +34,7 @@ rank_zero_module.log = logging.getLogger(__name__)
def _get_rank(
strategy: Optional["lightning_lite.strategies.Strategy"] = None, # type: ignore[name-defined]
strategy: Optional["lightning_lite.strategies.Strategy"] = None,
) -> Optional[int]:
if strategy is not None:
return strategy.global_rank

View File

@ -261,7 +261,7 @@ class EarlyStopping(Callback):
@staticmethod
def _log_info(trainer: Optional["pl.Trainer"], message: str, log_rank_zero_only: bool) -> None:
rank = _get_rank(strategy=(trainer.strategy if trainer is not None else None))
rank = _get_rank(strategy=(trainer.strategy if trainer is not None else None)) # type: ignore[arg-type]
if trainer is not None and trainer.world_size <= 1:
rank = None
message = rank_prefixed_message(message, rank)

View File

@ -1,17 +0,0 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.lite.lite import LightningLite
__all__ = ["LightningLite"]

View File

@ -23,7 +23,7 @@ from pkg_resources import get_distribution
from lightning_lite.accelerators.mps import MPSAccelerator
from lightning_lite.strategies.deepspeed import _DEEPSPEED_AVAILABLE
from lightning_lite.strategies.fairscale import _FAIRSCALE_AVAILABLE
from lightning_lite.utilities.imports import _TPU_AVAILABLE
from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_10, _TPU_AVAILABLE
class RunIf:
@ -42,6 +42,7 @@ class RunIf:
min_torch: Optional[str] = None,
max_torch: Optional[str] = None,
min_python: Optional[str] = None,
bf16_cuda: bool = False,
tpu: bool = False,
mps: Optional[bool] = None,
skip_windows: bool = False,
@ -57,6 +58,7 @@ class RunIf:
min_torch: Require that PyTorch is greater or equal than this version.
max_torch: Require that PyTorch is less than this version.
min_python: Require that Python is greater or equal than this version.
bf16_cuda: Require that CUDA device supports bf16.
tpu: Require that TPU is available.
mps: If True: Require that MPS (Apple Silicon) is available,
if False: Explicitly Require that MPS is not available
@ -91,6 +93,20 @@ class RunIf:
conditions.append(Version(py_version) < Version(min_python))
reasons.append(f"python>={min_python}")
if bf16_cuda:
try:
cond = not (torch.cuda.is_available() and _TORCH_GREATER_EQUAL_1_10 and torch.cuda.is_bf16_supported())
except (AssertionError, RuntimeError) as e:
# AssertionError: Torch not compiled with CUDA enabled
# RuntimeError: Found no NVIDIA driver on your system.
is_unrelated = "Found no NVIDIA driver" not in str(e) or "Torch not compiled with CUDA" not in str(e)
if is_unrelated:
raise e
cond = True
conditions.append(cond)
reasons.append("CUDA device bf16")
if skip_windows:
conditions.append(sys.platform == "win32")
reasons.append("unimplemented on Windows")

View File

@ -20,17 +20,17 @@ import pytest
import torch
import torch.distributed
import torch.nn.functional
from tests_lite.helpers.runif import RunIf
from torch import nn
from torch.utils.data import DataLoader, DistributedSampler, Sampler
from lightning_lite.lite import LightningLite
from lightning_lite.plugins import Precision
from lightning_lite.strategies import DeepSpeedStrategy, Strategy
from lightning_lite.utilities import _StrategyType
from lightning_lite.utilities.exceptions import MisconfigurationException
from lightning_lite.utilities.seed import pl_worker_init_function
from pytorch_lightning.lite import LightningLite
from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer
from tests_pytorch.helpers.runif import RunIf
from lightning_lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer
class EmptyLite(LightningLite):
@ -165,7 +165,7 @@ def test_setup_dataloaders_return_type():
assert lite_dataloader1.dataset is dataset1
@mock.patch("pytorch_lightning.lite.lite._replace_dunder_methods")
@mock.patch("lightning_lite.lite._replace_dunder_methods")
def test_setup_dataloaders_captures_dataloader_arguments(ctx_manager):
"""Test that Lite intercepts the DataLoader constructor arguments with a context manager in its run method."""
@ -210,7 +210,7 @@ def test_setup_dataloaders_twice_fails():
@mock.patch(
"pytorch_lightning.lite.lite.LightningLite.device",
"lightning_lite.lite.LightningLite.device",
new_callable=PropertyMock,
return_value=torch.device("cuda", 1),
)

View File

@ -23,18 +23,18 @@ import torch.distributed
import torch.multiprocessing as mp
import torch.nn.functional
from lightning_utilities.core.apply_func import apply_to_collection
from tests_lite.helpers.runif import RunIf
from torch import nn
from torch.nn.parallel.distributed import DistributedDataParallel
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from lightning_lite.lite import LightningLite
from lightning_lite.plugins.environments.lightning_environment import find_free_network_port
from lightning_lite.strategies.ddp_spawn import DDPSpawnStrategy
from lightning_lite.utilities.apply_func import move_data_to_device
from lightning_lite.utilities.cloud_io import atomic_save
from pytorch_lightning.demos.boring_classes import RandomDataset
from pytorch_lightning.lite import LightningLite
from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy
from tests_pytorch.helpers.runif import RunIf
class BoringModel(nn.Module):

View File

@ -15,12 +15,12 @@ from unittest.mock import Mock
import pytest
import torch
from tests_lite.helpers.runif import RunIf
from torch.utils.data.dataloader import DataLoader
from lightning_lite.lite import LightningLite
from lightning_lite.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
from pytorch_lightning.lite import LightningLite
from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer
from tests_pytorch.helpers.runif import RunIf
from lightning_lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer
class EmptyLite(LightningLite):