Move src/pytorch_lightning/lite to src/lightning_lite (#14735)

This commit is contained in:
awaelchli 2022-09-18 00:34:53 +02:00 committed by Adrian Wälchli
parent f167d76508
commit e3e71670e6
14 changed files with 37 additions and 42 deletions

View File

@ -38,10 +38,10 @@ import torchvision.transforms as T
from torch.optim.lr_scheduler import StepLR from torch.optim.lr_scheduler import StepLR
from torchmetrics.classification import Accuracy from torchmetrics.classification import Accuracy
from lightning_lite.lite import LightningLite # import LightningLite
from pytorch_lightning import seed_everything from pytorch_lightning import seed_everything
from pytorch_lightning.demos.boring_classes import Net from pytorch_lightning.demos.boring_classes import Net
from pytorch_lightning.demos.mnist_datamodule import MNIST from pytorch_lightning.demos.mnist_datamodule import MNIST
from pytorch_lightning.lite import LightningLite # import LightningLite
DATASETS_PATH = path.join(path.dirname(__file__), "..", "..", "Datasets") DATASETS_PATH = path.join(path.dirname(__file__), "..", "..", "Datasets")

View File

@ -34,10 +34,10 @@ import torchvision.transforms as T
from torch.optim.lr_scheduler import StepLR from torch.optim.lr_scheduler import StepLR
from torchmetrics import Accuracy from torchmetrics import Accuracy
from lightning_lite.lite import LightningLite
from pytorch_lightning import seed_everything from pytorch_lightning import seed_everything
from pytorch_lightning.demos.boring_classes import Net from pytorch_lightning.demos.boring_classes import Net
from pytorch_lightning.demos.mnist_datamodule import MNIST from pytorch_lightning.demos.mnist_datamodule import MNIST
from pytorch_lightning.lite import LightningLite
DATASETS_PATH = path.join(path.dirname(__file__), "..", "..", "Datasets") DATASETS_PATH = path.join(path.dirname(__file__), "..", "..", "Datasets")

View File

@ -22,10 +22,10 @@ import torchvision.transforms as T
from torch.optim.lr_scheduler import StepLR from torch.optim.lr_scheduler import StepLR
from torchmetrics import Accuracy from torchmetrics import Accuracy
from lightning_lite.lite import LightningLite
from pytorch_lightning import seed_everything from pytorch_lightning import seed_everything
from pytorch_lightning.demos.boring_classes import Net from pytorch_lightning.demos.boring_classes import Net
from pytorch_lightning.demos.mnist_datamodule import MNIST from pytorch_lightning.demos.mnist_datamodule import MNIST
from pytorch_lightning.lite import LightningLite
from pytorch_lightning.loops import Loop from pytorch_lightning.loops import Loop
DATASETS_PATH = path.join(path.dirname(__file__), "..", "..", "Datasets") DATASETS_PATH = path.join(path.dirname(__file__), "..", "..", "Datasets")

View File

@ -12,15 +12,11 @@ if not _root_logger.hasHandlers():
_logger.addHandler(logging.StreamHandler()) _logger.addHandler(logging.StreamHandler())
_logger.propagate = False _logger.propagate = False
# TODO(lite): Re-enable this import
# from lightning_lite.lite import LightningLite from lightning_lite.lite import LightningLite # noqa: E402
from lightning_lite.utilities.seed import seed_everything # noqa: E402 from lightning_lite.utilities.seed import seed_everything # noqa: E402
__all__ = [ __all__ = ["LightningLite", "seed_everything"]
# TODO(lite): Re-enable this import
# "LightningLite",
"seed_everything",
]
# for compatibility with namespace packages # for compatibility with namespace packages
__import__("pkg_resources").declare_namespace(__name__) __import__("pkg_resources").declare_namespace(__name__)

View File

@ -40,7 +40,7 @@ from lightning_lite.utilities.data import (
) )
from lightning_lite.utilities.distributed import DistributedSamplerWrapper from lightning_lite.utilities.distributed import DistributedSamplerWrapper
from lightning_lite.utilities.seed import seed_everything from lightning_lite.utilities.seed import seed_everything
from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer from lightning_lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer
class LightningLite(ABC): class LightningLite(ABC):

View File

@ -34,7 +34,7 @@ rank_zero_module.log = logging.getLogger(__name__)
def _get_rank( def _get_rank(
strategy: Optional["lightning_lite.strategies.Strategy"] = None, # type: ignore[name-defined] strategy: Optional["lightning_lite.strategies.Strategy"] = None,
) -> Optional[int]: ) -> Optional[int]:
if strategy is not None: if strategy is not None:
return strategy.global_rank return strategy.global_rank

View File

@ -261,7 +261,7 @@ class EarlyStopping(Callback):
@staticmethod @staticmethod
def _log_info(trainer: Optional["pl.Trainer"], message: str, log_rank_zero_only: bool) -> None: def _log_info(trainer: Optional["pl.Trainer"], message: str, log_rank_zero_only: bool) -> None:
rank = _get_rank(strategy=(trainer.strategy if trainer is not None else None)) rank = _get_rank(strategy=(trainer.strategy if trainer is not None else None)) # type: ignore[arg-type]
if trainer is not None and trainer.world_size <= 1: if trainer is not None and trainer.world_size <= 1:
rank = None rank = None
message = rank_prefixed_message(message, rank) message = rank_prefixed_message(message, rank)

View File

@ -1,17 +0,0 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.lite.lite import LightningLite
__all__ = ["LightningLite"]

View File

@ -23,7 +23,7 @@ from pkg_resources import get_distribution
from lightning_lite.accelerators.mps import MPSAccelerator from lightning_lite.accelerators.mps import MPSAccelerator
from lightning_lite.strategies.deepspeed import _DEEPSPEED_AVAILABLE from lightning_lite.strategies.deepspeed import _DEEPSPEED_AVAILABLE
from lightning_lite.strategies.fairscale import _FAIRSCALE_AVAILABLE from lightning_lite.strategies.fairscale import _FAIRSCALE_AVAILABLE
from lightning_lite.utilities.imports import _TPU_AVAILABLE from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_10, _TPU_AVAILABLE
class RunIf: class RunIf:
@ -42,6 +42,7 @@ class RunIf:
min_torch: Optional[str] = None, min_torch: Optional[str] = None,
max_torch: Optional[str] = None, max_torch: Optional[str] = None,
min_python: Optional[str] = None, min_python: Optional[str] = None,
bf16_cuda: bool = False,
tpu: bool = False, tpu: bool = False,
mps: Optional[bool] = None, mps: Optional[bool] = None,
skip_windows: bool = False, skip_windows: bool = False,
@ -57,6 +58,7 @@ class RunIf:
min_torch: Require that PyTorch is greater or equal than this version. min_torch: Require that PyTorch is greater or equal than this version.
max_torch: Require that PyTorch is less than this version. max_torch: Require that PyTorch is less than this version.
min_python: Require that Python is greater or equal than this version. min_python: Require that Python is greater or equal than this version.
bf16_cuda: Require that CUDA device supports bf16.
tpu: Require that TPU is available. tpu: Require that TPU is available.
mps: If True: Require that MPS (Apple Silicon) is available, mps: If True: Require that MPS (Apple Silicon) is available,
if False: Explicitly Require that MPS is not available if False: Explicitly Require that MPS is not available
@ -91,6 +93,20 @@ class RunIf:
conditions.append(Version(py_version) < Version(min_python)) conditions.append(Version(py_version) < Version(min_python))
reasons.append(f"python>={min_python}") reasons.append(f"python>={min_python}")
if bf16_cuda:
try:
cond = not (torch.cuda.is_available() and _TORCH_GREATER_EQUAL_1_10 and torch.cuda.is_bf16_supported())
except (AssertionError, RuntimeError) as e:
# AssertionError: Torch not compiled with CUDA enabled
# RuntimeError: Found no NVIDIA driver on your system.
is_unrelated = "Found no NVIDIA driver" not in str(e) or "Torch not compiled with CUDA" not in str(e)
if is_unrelated:
raise e
cond = True
conditions.append(cond)
reasons.append("CUDA device bf16")
if skip_windows: if skip_windows:
conditions.append(sys.platform == "win32") conditions.append(sys.platform == "win32")
reasons.append("unimplemented on Windows") reasons.append("unimplemented on Windows")

View File

@ -20,17 +20,17 @@ import pytest
import torch import torch
import torch.distributed import torch.distributed
import torch.nn.functional import torch.nn.functional
from tests_lite.helpers.runif import RunIf
from torch import nn from torch import nn
from torch.utils.data import DataLoader, DistributedSampler, Sampler from torch.utils.data import DataLoader, DistributedSampler, Sampler
from lightning_lite.lite import LightningLite
from lightning_lite.plugins import Precision from lightning_lite.plugins import Precision
from lightning_lite.strategies import DeepSpeedStrategy, Strategy from lightning_lite.strategies import DeepSpeedStrategy, Strategy
from lightning_lite.utilities import _StrategyType from lightning_lite.utilities import _StrategyType
from lightning_lite.utilities.exceptions import MisconfigurationException from lightning_lite.utilities.exceptions import MisconfigurationException
from lightning_lite.utilities.seed import pl_worker_init_function from lightning_lite.utilities.seed import pl_worker_init_function
from pytorch_lightning.lite import LightningLite from lightning_lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer
from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer
from tests_pytorch.helpers.runif import RunIf
class EmptyLite(LightningLite): class EmptyLite(LightningLite):
@ -165,7 +165,7 @@ def test_setup_dataloaders_return_type():
assert lite_dataloader1.dataset is dataset1 assert lite_dataloader1.dataset is dataset1
@mock.patch("pytorch_lightning.lite.lite._replace_dunder_methods") @mock.patch("lightning_lite.lite._replace_dunder_methods")
def test_setup_dataloaders_captures_dataloader_arguments(ctx_manager): def test_setup_dataloaders_captures_dataloader_arguments(ctx_manager):
"""Test that Lite intercepts the DataLoader constructor arguments with a context manager in its run method.""" """Test that Lite intercepts the DataLoader constructor arguments with a context manager in its run method."""
@ -210,7 +210,7 @@ def test_setup_dataloaders_twice_fails():
@mock.patch( @mock.patch(
"pytorch_lightning.lite.lite.LightningLite.device", "lightning_lite.lite.LightningLite.device",
new_callable=PropertyMock, new_callable=PropertyMock,
return_value=torch.device("cuda", 1), return_value=torch.device("cuda", 1),
) )

View File

@ -23,18 +23,18 @@ import torch.distributed
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn.functional import torch.nn.functional
from lightning_utilities.core.apply_func import apply_to_collection from lightning_utilities.core.apply_func import apply_to_collection
from tests_lite.helpers.runif import RunIf
from torch import nn from torch import nn
from torch.nn.parallel.distributed import DistributedDataParallel from torch.nn.parallel.distributed import DistributedDataParallel
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from lightning_lite.lite import LightningLite
from lightning_lite.plugins.environments.lightning_environment import find_free_network_port from lightning_lite.plugins.environments.lightning_environment import find_free_network_port
from lightning_lite.strategies.ddp_spawn import DDPSpawnStrategy
from lightning_lite.utilities.apply_func import move_data_to_device from lightning_lite.utilities.apply_func import move_data_to_device
from lightning_lite.utilities.cloud_io import atomic_save from lightning_lite.utilities.cloud_io import atomic_save
from pytorch_lightning.demos.boring_classes import RandomDataset from pytorch_lightning.demos.boring_classes import RandomDataset
from pytorch_lightning.lite import LightningLite
from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy
from tests_pytorch.helpers.runif import RunIf
class BoringModel(nn.Module): class BoringModel(nn.Module):

View File

@ -15,12 +15,12 @@ from unittest.mock import Mock
import pytest import pytest
import torch import torch
from tests_lite.helpers.runif import RunIf
from torch.utils.data.dataloader import DataLoader from torch.utils.data.dataloader import DataLoader
from lightning_lite.lite import LightningLite
from lightning_lite.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin from lightning_lite.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
from pytorch_lightning.lite import LightningLite from lightning_lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer
from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer
from tests_pytorch.helpers.runif import RunIf
class EmptyLite(LightningLite): class EmptyLite(LightningLite):