From e3e71670e6dc6a3fd8f1d9a993552f754725ad0f Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sun, 18 Sep 2022 00:34:53 +0200 Subject: [PATCH] Move src/pytorch_lightning/lite to src/lightning_lite (#14735) --- .../image_classifier_2_lite.py | 2 +- ...ge_classifier_3_lite_to_lightning_module.py | 2 +- examples/pl_loops/mnist_lite.py | 2 +- src/lightning_lite/__init__.py | 10 +++------- .../lite => lightning_lite}/lite.py | 2 +- src/lightning_lite/utilities/rank_zero.py | 2 +- .../lite => lightning_lite}/wrappers.py | 0 .../callbacks/early_stopping.py | 2 +- src/pytorch_lightning/lite/__init__.py | 17 ----------------- tests/tests_lite/helpers/runif.py | 18 +++++++++++++++++- .../lite => tests_lite}/test_lite.py | 10 +++++----- .../lite => tests_lite}/test_parity.py | 6 +++--- .../lite => tests_lite}/test_wrappers.py | 6 +++--- tests/tests_pytorch/lite/__init__.py | 0 14 files changed, 37 insertions(+), 42 deletions(-) rename src/{pytorch_lightning/lite => lightning_lite}/lite.py (99%) rename src/{pytorch_lightning/lite => lightning_lite}/wrappers.py (100%) delete mode 100644 src/pytorch_lightning/lite/__init__.py rename tests/{tests_pytorch/lite => tests_lite}/test_lite.py (98%) rename tests/{tests_pytorch/lite => tests_lite}/test_parity.py (98%) rename tests/{tests_pytorch/lite => tests_lite}/test_wrappers.py (98%) delete mode 100644 tests/tests_pytorch/lite/__init__.py diff --git a/examples/convert_from_pt_to_pl/image_classifier_2_lite.py b/examples/convert_from_pt_to_pl/image_classifier_2_lite.py index da82db0328..f5f8e54c14 100644 --- a/examples/convert_from_pt_to_pl/image_classifier_2_lite.py +++ b/examples/convert_from_pt_to_pl/image_classifier_2_lite.py @@ -38,10 +38,10 @@ import torchvision.transforms as T from torch.optim.lr_scheduler import StepLR from torchmetrics.classification import Accuracy +from lightning_lite.lite import LightningLite # import LightningLite from pytorch_lightning import seed_everything from pytorch_lightning.demos.boring_classes import Net from pytorch_lightning.demos.mnist_datamodule import MNIST -from pytorch_lightning.lite import LightningLite # import LightningLite DATASETS_PATH = path.join(path.dirname(__file__), "..", "..", "Datasets") diff --git a/examples/convert_from_pt_to_pl/image_classifier_3_lite_to_lightning_module.py b/examples/convert_from_pt_to_pl/image_classifier_3_lite_to_lightning_module.py index d2dc9a581b..e398e42fd8 100644 --- a/examples/convert_from_pt_to_pl/image_classifier_3_lite_to_lightning_module.py +++ b/examples/convert_from_pt_to_pl/image_classifier_3_lite_to_lightning_module.py @@ -34,10 +34,10 @@ import torchvision.transforms as T from torch.optim.lr_scheduler import StepLR from torchmetrics import Accuracy +from lightning_lite.lite import LightningLite from pytorch_lightning import seed_everything from pytorch_lightning.demos.boring_classes import Net from pytorch_lightning.demos.mnist_datamodule import MNIST -from pytorch_lightning.lite import LightningLite DATASETS_PATH = path.join(path.dirname(__file__), "..", "..", "Datasets") diff --git a/examples/pl_loops/mnist_lite.py b/examples/pl_loops/mnist_lite.py index 9310705508..0ae27ad921 100644 --- a/examples/pl_loops/mnist_lite.py +++ b/examples/pl_loops/mnist_lite.py @@ -22,10 +22,10 @@ import torchvision.transforms as T from torch.optim.lr_scheduler import StepLR from torchmetrics import Accuracy +from lightning_lite.lite import LightningLite from pytorch_lightning import seed_everything from pytorch_lightning.demos.boring_classes import Net from pytorch_lightning.demos.mnist_datamodule import MNIST -from pytorch_lightning.lite import LightningLite from pytorch_lightning.loops import Loop DATASETS_PATH = path.join(path.dirname(__file__), "..", "..", "Datasets") diff --git a/src/lightning_lite/__init__.py b/src/lightning_lite/__init__.py index dccaeae932..09b10e1723 100644 --- a/src/lightning_lite/__init__.py +++ b/src/lightning_lite/__init__.py @@ -12,15 +12,11 @@ if not _root_logger.hasHandlers(): _logger.addHandler(logging.StreamHandler()) _logger.propagate = False -# TODO(lite): Re-enable this import -# from lightning_lite.lite import LightningLite + +from lightning_lite.lite import LightningLite # noqa: E402 from lightning_lite.utilities.seed import seed_everything # noqa: E402 -__all__ = [ - # TODO(lite): Re-enable this import - # "LightningLite", - "seed_everything", -] +__all__ = ["LightningLite", "seed_everything"] # for compatibility with namespace packages __import__("pkg_resources").declare_namespace(__name__) diff --git a/src/pytorch_lightning/lite/lite.py b/src/lightning_lite/lite.py similarity index 99% rename from src/pytorch_lightning/lite/lite.py rename to src/lightning_lite/lite.py index 7a36123135..6628411945 100644 --- a/src/pytorch_lightning/lite/lite.py +++ b/src/lightning_lite/lite.py @@ -40,7 +40,7 @@ from lightning_lite.utilities.data import ( ) from lightning_lite.utilities.distributed import DistributedSamplerWrapper from lightning_lite.utilities.seed import seed_everything -from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer +from lightning_lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer class LightningLite(ABC): diff --git a/src/lightning_lite/utilities/rank_zero.py b/src/lightning_lite/utilities/rank_zero.py index db364dfd8f..6424b38854 100644 --- a/src/lightning_lite/utilities/rank_zero.py +++ b/src/lightning_lite/utilities/rank_zero.py @@ -34,7 +34,7 @@ rank_zero_module.log = logging.getLogger(__name__) def _get_rank( - strategy: Optional["lightning_lite.strategies.Strategy"] = None, # type: ignore[name-defined] + strategy: Optional["lightning_lite.strategies.Strategy"] = None, ) -> Optional[int]: if strategy is not None: return strategy.global_rank diff --git a/src/pytorch_lightning/lite/wrappers.py b/src/lightning_lite/wrappers.py similarity index 100% rename from src/pytorch_lightning/lite/wrappers.py rename to src/lightning_lite/wrappers.py diff --git a/src/pytorch_lightning/callbacks/early_stopping.py b/src/pytorch_lightning/callbacks/early_stopping.py index 4acff87dfa..be3acfc6a0 100644 --- a/src/pytorch_lightning/callbacks/early_stopping.py +++ b/src/pytorch_lightning/callbacks/early_stopping.py @@ -261,7 +261,7 @@ class EarlyStopping(Callback): @staticmethod def _log_info(trainer: Optional["pl.Trainer"], message: str, log_rank_zero_only: bool) -> None: - rank = _get_rank(strategy=(trainer.strategy if trainer is not None else None)) + rank = _get_rank(strategy=(trainer.strategy if trainer is not None else None)) # type: ignore[arg-type] if trainer is not None and trainer.world_size <= 1: rank = None message = rank_prefixed_message(message, rank) diff --git a/src/pytorch_lightning/lite/__init__.py b/src/pytorch_lightning/lite/__init__.py deleted file mode 100644 index f4634fe54e..0000000000 --- a/src/pytorch_lightning/lite/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from pytorch_lightning.lite.lite import LightningLite - -__all__ = ["LightningLite"] diff --git a/tests/tests_lite/helpers/runif.py b/tests/tests_lite/helpers/runif.py index e996aec4d3..a3f484255c 100644 --- a/tests/tests_lite/helpers/runif.py +++ b/tests/tests_lite/helpers/runif.py @@ -23,7 +23,7 @@ from pkg_resources import get_distribution from lightning_lite.accelerators.mps import MPSAccelerator from lightning_lite.strategies.deepspeed import _DEEPSPEED_AVAILABLE from lightning_lite.strategies.fairscale import _FAIRSCALE_AVAILABLE -from lightning_lite.utilities.imports import _TPU_AVAILABLE +from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_10, _TPU_AVAILABLE class RunIf: @@ -42,6 +42,7 @@ class RunIf: min_torch: Optional[str] = None, max_torch: Optional[str] = None, min_python: Optional[str] = None, + bf16_cuda: bool = False, tpu: bool = False, mps: Optional[bool] = None, skip_windows: bool = False, @@ -57,6 +58,7 @@ class RunIf: min_torch: Require that PyTorch is greater or equal than this version. max_torch: Require that PyTorch is less than this version. min_python: Require that Python is greater or equal than this version. + bf16_cuda: Require that CUDA device supports bf16. tpu: Require that TPU is available. mps: If True: Require that MPS (Apple Silicon) is available, if False: Explicitly Require that MPS is not available @@ -91,6 +93,20 @@ class RunIf: conditions.append(Version(py_version) < Version(min_python)) reasons.append(f"python>={min_python}") + if bf16_cuda: + try: + cond = not (torch.cuda.is_available() and _TORCH_GREATER_EQUAL_1_10 and torch.cuda.is_bf16_supported()) + except (AssertionError, RuntimeError) as e: + # AssertionError: Torch not compiled with CUDA enabled + # RuntimeError: Found no NVIDIA driver on your system. + is_unrelated = "Found no NVIDIA driver" not in str(e) or "Torch not compiled with CUDA" not in str(e) + if is_unrelated: + raise e + cond = True + + conditions.append(cond) + reasons.append("CUDA device bf16") + if skip_windows: conditions.append(sys.platform == "win32") reasons.append("unimplemented on Windows") diff --git a/tests/tests_pytorch/lite/test_lite.py b/tests/tests_lite/test_lite.py similarity index 98% rename from tests/tests_pytorch/lite/test_lite.py rename to tests/tests_lite/test_lite.py index 8b8c999580..74fd4f8329 100644 --- a/tests/tests_pytorch/lite/test_lite.py +++ b/tests/tests_lite/test_lite.py @@ -20,17 +20,17 @@ import pytest import torch import torch.distributed import torch.nn.functional +from tests_lite.helpers.runif import RunIf from torch import nn from torch.utils.data import DataLoader, DistributedSampler, Sampler +from lightning_lite.lite import LightningLite from lightning_lite.plugins import Precision from lightning_lite.strategies import DeepSpeedStrategy, Strategy from lightning_lite.utilities import _StrategyType from lightning_lite.utilities.exceptions import MisconfigurationException from lightning_lite.utilities.seed import pl_worker_init_function -from pytorch_lightning.lite import LightningLite -from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer -from tests_pytorch.helpers.runif import RunIf +from lightning_lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer class EmptyLite(LightningLite): @@ -165,7 +165,7 @@ def test_setup_dataloaders_return_type(): assert lite_dataloader1.dataset is dataset1 -@mock.patch("pytorch_lightning.lite.lite._replace_dunder_methods") +@mock.patch("lightning_lite.lite._replace_dunder_methods") def test_setup_dataloaders_captures_dataloader_arguments(ctx_manager): """Test that Lite intercepts the DataLoader constructor arguments with a context manager in its run method.""" @@ -210,7 +210,7 @@ def test_setup_dataloaders_twice_fails(): @mock.patch( - "pytorch_lightning.lite.lite.LightningLite.device", + "lightning_lite.lite.LightningLite.device", new_callable=PropertyMock, return_value=torch.device("cuda", 1), ) diff --git a/tests/tests_pytorch/lite/test_parity.py b/tests/tests_lite/test_parity.py similarity index 98% rename from tests/tests_pytorch/lite/test_parity.py rename to tests/tests_lite/test_parity.py index ffb9585515..3f7808d6a8 100644 --- a/tests/tests_pytorch/lite/test_parity.py +++ b/tests/tests_lite/test_parity.py @@ -23,18 +23,18 @@ import torch.distributed import torch.multiprocessing as mp import torch.nn.functional from lightning_utilities.core.apply_func import apply_to_collection +from tests_lite.helpers.runif import RunIf from torch import nn from torch.nn.parallel.distributed import DistributedDataParallel from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler +from lightning_lite.lite import LightningLite from lightning_lite.plugins.environments.lightning_environment import find_free_network_port +from lightning_lite.strategies.ddp_spawn import DDPSpawnStrategy from lightning_lite.utilities.apply_func import move_data_to_device from lightning_lite.utilities.cloud_io import atomic_save from pytorch_lightning.demos.boring_classes import RandomDataset -from pytorch_lightning.lite import LightningLite -from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy -from tests_pytorch.helpers.runif import RunIf class BoringModel(nn.Module): diff --git a/tests/tests_pytorch/lite/test_wrappers.py b/tests/tests_lite/test_wrappers.py similarity index 98% rename from tests/tests_pytorch/lite/test_wrappers.py rename to tests/tests_lite/test_wrappers.py index 0589ac64d9..91a756f43c 100644 --- a/tests/tests_pytorch/lite/test_wrappers.py +++ b/tests/tests_lite/test_wrappers.py @@ -15,12 +15,12 @@ from unittest.mock import Mock import pytest import torch +from tests_lite.helpers.runif import RunIf from torch.utils.data.dataloader import DataLoader +from lightning_lite.lite import LightningLite from lightning_lite.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin -from pytorch_lightning.lite import LightningLite -from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer -from tests_pytorch.helpers.runif import RunIf +from lightning_lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer class EmptyLite(LightningLite): diff --git a/tests/tests_pytorch/lite/__init__.py b/tests/tests_pytorch/lite/__init__.py deleted file mode 100644 index e69de29bb2..0000000000