[App] Rename to new convention (#15621)

* update

* update
This commit is contained in:
thomas chaton 2022-11-10 15:19:18 +00:00 committed by GitHub
parent 6ba00af1e0
commit 7ec15ae43a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 14 additions and 14 deletions

View File

@ -1,9 +1,9 @@
import lightning as L
from lightning.app.components import PyTorchLightningMultiNode
from lightning.app.components import LightningTrainerMultiNode
from lightning.pytorch.demos.boring_classes import BoringModel
class PyTorchLightningDistributed(L.LightningWork):
class LightningTrainerDistributed(L.LightningWork):
@staticmethod
def run():
model = BoringModel()
@ -16,8 +16,8 @@ class PyTorchLightningDistributed(L.LightningWork):
# Run over 2 nodes of 4 x V100
app = L.LightningApp(
PyTorchLightningMultiNode(
PyTorchLightningDistributed,
LightningTrainerMultiNode(
LightningTrainerDistributed,
num_nodes=2,
cloud_compute=L.CloudCompute("gpu-fast-multi"), # 4 x V100
)

View File

@ -62,7 +62,7 @@ module = [
"lightning_app.components.multi_node.lite",
"lightning_app.components.multi_node.base",
"lightning_app.components.multi_node.pytorch_spawn",
"lightning_app.components.multi_node.pl",
"lightning_app.components.multi_node.trainer",
"lightning_app.api.http_methods",
"lightning_app.api.request_types",
"lightning_app.cli.commands.app_commands",

View File

@ -1,9 +1,9 @@
from lightning_app.components.database.client import DatabaseClient
from lightning_app.components.database.server import Database
from lightning_app.components.multi_node import (
LightningTrainerMultiNode,
LiteMultiNode,
MultiNode,
PyTorchLightningMultiNode,
PyTorchSpawnMultiNode,
)
from lightning_app.components.python.popen import PopenPythonScript
@ -29,5 +29,5 @@ __all__ = [
"LightningTrainingComponent",
"PyTorchLightningScriptRunner",
"PyTorchSpawnMultiNode",
"PyTorchLightningMultiNode",
"LightningTrainerMultiNode",
]

View File

@ -1,6 +1,6 @@
from lightning_app.components.multi_node.base import MultiNode
from lightning_app.components.multi_node.lite import LiteMultiNode
from lightning_app.components.multi_node.pl import PyTorchLightningMultiNode
from lightning_app.components.multi_node.pytorch_spawn import PyTorchSpawnMultiNode
from lightning_app.components.multi_node.trainer import LightningTrainerMultiNode
__all__ = ["LiteMultiNode", "MultiNode", "PyTorchSpawnMultiNode", "PyTorchLightningMultiNode"]
__all__ = ["LiteMultiNode", "MultiNode", "PyTorchSpawnMultiNode", "LightningTrainerMultiNode"]

View File

@ -13,14 +13,14 @@ from lightning_app.utilities.tracer import Tracer
@runtime_checkable
class _PyTorchLightningWorkProtocol(Protocol):
class _LightningTrainerWorkProtocol(Protocol):
@staticmethod
def run() -> None:
...
@dataclass
class _PyTorchLightningRunExecutor(_PyTorchSpawnRunExecutor):
class _LightningTrainerRunExecutor(_PyTorchSpawnRunExecutor):
@staticmethod
def run(
local_rank: int,
@ -71,7 +71,7 @@ class _PyTorchLightningRunExecutor(_PyTorchSpawnRunExecutor):
tracer._restore()
class PyTorchLightningMultiNode(MultiNode):
class LightningTrainerMultiNode(MultiNode):
def __init__(
self,
work_cls: Type["LightningWork"],
@ -80,7 +80,7 @@ class PyTorchLightningMultiNode(MultiNode):
*work_args: Any,
**work_kwargs: Any,
) -> None:
assert issubclass(work_cls, _PyTorchLightningWorkProtocol)
assert issubclass(work_cls, _LightningTrainerWorkProtocol)
if not is_static_method(work_cls, "run"):
raise TypeError(
f"The provided {work_cls} run method needs to be static for now."
@ -89,7 +89,7 @@ class PyTorchLightningMultiNode(MultiNode):
# Note: Private way to modify the work run executor
# Probably exposed to the users in the future if needed.
work_cls._run_executor_cls = _PyTorchLightningRunExecutor
work_cls._run_executor_cls = _LightningTrainerRunExecutor
super().__init__(
work_cls,