From 7ec15ae43a5d3506f24036bed148ff37fe87c7a8 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Thu, 10 Nov 2022 15:19:18 +0000 Subject: [PATCH] [App] Rename to new convention (#15621) * update * update --- examples/app_multi_node/{train_pl.py => train_lt.py} | 8 ++++---- .../{train_pl_script.py => train_lt_script.py} | 0 pyproject.toml | 2 +- src/lightning_app/components/__init__.py | 4 ++-- src/lightning_app/components/multi_node/__init__.py | 4 ++-- .../components/multi_node/{pl.py => trainer.py} | 10 +++++----- 6 files changed, 14 insertions(+), 14 deletions(-) rename examples/app_multi_node/{train_pl.py => train_lt.py} (69%) rename examples/app_multi_node/{train_pl_script.py => train_lt_script.py} (100%) rename src/lightning_app/components/multi_node/{pl.py => trainer.py} (92%) diff --git a/examples/app_multi_node/train_pl.py b/examples/app_multi_node/train_lt.py similarity index 69% rename from examples/app_multi_node/train_pl.py rename to examples/app_multi_node/train_lt.py index e887eaef7c..5cbee32dd8 100644 --- a/examples/app_multi_node/train_pl.py +++ b/examples/app_multi_node/train_lt.py @@ -1,9 +1,9 @@ import lightning as L -from lightning.app.components import PyTorchLightningMultiNode +from lightning.app.components import LightningTrainerMultiNode from lightning.pytorch.demos.boring_classes import BoringModel -class PyTorchLightningDistributed(L.LightningWork): +class LightningTrainerDistributed(L.LightningWork): @staticmethod def run(): model = BoringModel() @@ -16,8 +16,8 @@ class PyTorchLightningDistributed(L.LightningWork): # Run over 2 nodes of 4 x V100 app = L.LightningApp( - PyTorchLightningMultiNode( - PyTorchLightningDistributed, + LightningTrainerMultiNode( + LightningTrainerDistributed, num_nodes=2, cloud_compute=L.CloudCompute("gpu-fast-multi"), # 4 x V100 ) diff --git a/examples/app_multi_node/train_pl_script.py b/examples/app_multi_node/train_lt_script.py similarity index 100% rename from examples/app_multi_node/train_pl_script.py rename to examples/app_multi_node/train_lt_script.py diff --git a/pyproject.toml b/pyproject.toml index 005eba2846..bc8d9c7658 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,7 +62,7 @@ module = [ "lightning_app.components.multi_node.lite", "lightning_app.components.multi_node.base", "lightning_app.components.multi_node.pytorch_spawn", - "lightning_app.components.multi_node.pl", + "lightning_app.components.multi_node.trainer", "lightning_app.api.http_methods", "lightning_app.api.request_types", "lightning_app.cli.commands.app_commands", diff --git a/src/lightning_app/components/__init__.py b/src/lightning_app/components/__init__.py index e72d1f443b..2426a9042b 100644 --- a/src/lightning_app/components/__init__.py +++ b/src/lightning_app/components/__init__.py @@ -1,9 +1,9 @@ from lightning_app.components.database.client import DatabaseClient from lightning_app.components.database.server import Database from lightning_app.components.multi_node import ( + LightningTrainerMultiNode, LiteMultiNode, MultiNode, - PyTorchLightningMultiNode, PyTorchSpawnMultiNode, ) from lightning_app.components.python.popen import PopenPythonScript @@ -29,5 +29,5 @@ __all__ = [ "LightningTrainingComponent", "PyTorchLightningScriptRunner", "PyTorchSpawnMultiNode", - "PyTorchLightningMultiNode", + "LightningTrainerMultiNode", ] diff --git a/src/lightning_app/components/multi_node/__init__.py b/src/lightning_app/components/multi_node/__init__.py index 2921f79dc7..b2d45a2610 100644 --- a/src/lightning_app/components/multi_node/__init__.py +++ b/src/lightning_app/components/multi_node/__init__.py @@ -1,6 +1,6 @@ from lightning_app.components.multi_node.base import MultiNode from lightning_app.components.multi_node.lite import LiteMultiNode -from lightning_app.components.multi_node.pl import PyTorchLightningMultiNode from lightning_app.components.multi_node.pytorch_spawn import PyTorchSpawnMultiNode +from lightning_app.components.multi_node.trainer import LightningTrainerMultiNode -__all__ = ["LiteMultiNode", "MultiNode", "PyTorchSpawnMultiNode", "PyTorchLightningMultiNode"] +__all__ = ["LiteMultiNode", "MultiNode", "PyTorchSpawnMultiNode", "LightningTrainerMultiNode"] diff --git a/src/lightning_app/components/multi_node/pl.py b/src/lightning_app/components/multi_node/trainer.py similarity index 92% rename from src/lightning_app/components/multi_node/pl.py rename to src/lightning_app/components/multi_node/trainer.py index c11b72b6ce..ea33106a7e 100644 --- a/src/lightning_app/components/multi_node/pl.py +++ b/src/lightning_app/components/multi_node/trainer.py @@ -13,14 +13,14 @@ from lightning_app.utilities.tracer import Tracer @runtime_checkable -class _PyTorchLightningWorkProtocol(Protocol): +class _LightningTrainerWorkProtocol(Protocol): @staticmethod def run() -> None: ... @dataclass -class _PyTorchLightningRunExecutor(_PyTorchSpawnRunExecutor): +class _LightningTrainerRunExecutor(_PyTorchSpawnRunExecutor): @staticmethod def run( local_rank: int, @@ -71,7 +71,7 @@ class _PyTorchLightningRunExecutor(_PyTorchSpawnRunExecutor): tracer._restore() -class PyTorchLightningMultiNode(MultiNode): +class LightningTrainerMultiNode(MultiNode): def __init__( self, work_cls: Type["LightningWork"], @@ -80,7 +80,7 @@ class PyTorchLightningMultiNode(MultiNode): *work_args: Any, **work_kwargs: Any, ) -> None: - assert issubclass(work_cls, _PyTorchLightningWorkProtocol) + assert issubclass(work_cls, _LightningTrainerWorkProtocol) if not is_static_method(work_cls, "run"): raise TypeError( f"The provided {work_cls} run method needs to be static for now." @@ -89,7 +89,7 @@ class PyTorchLightningMultiNode(MultiNode): # Note: Private way to modify the work run executor # Probably exposed to the users in the future if needed. - work_cls._run_executor_cls = _PyTorchLightningRunExecutor + work_cls._run_executor_cls = _LightningTrainerRunExecutor super().__init__( work_cls,