parent
6ba00af1e0
commit
7ec15ae43a
|
@ -1,9 +1,9 @@
|
|||
import lightning as L
|
||||
from lightning.app.components import PyTorchLightningMultiNode
|
||||
from lightning.app.components import LightningTrainerMultiNode
|
||||
from lightning.pytorch.demos.boring_classes import BoringModel
|
||||
|
||||
|
||||
class PyTorchLightningDistributed(L.LightningWork):
|
||||
class LightningTrainerDistributed(L.LightningWork):
|
||||
@staticmethod
|
||||
def run():
|
||||
model = BoringModel()
|
||||
|
@ -16,8 +16,8 @@ class PyTorchLightningDistributed(L.LightningWork):
|
|||
|
||||
# Run over 2 nodes of 4 x V100
|
||||
app = L.LightningApp(
|
||||
PyTorchLightningMultiNode(
|
||||
PyTorchLightningDistributed,
|
||||
LightningTrainerMultiNode(
|
||||
LightningTrainerDistributed,
|
||||
num_nodes=2,
|
||||
cloud_compute=L.CloudCompute("gpu-fast-multi"), # 4 x V100
|
||||
)
|
|
@ -62,7 +62,7 @@ module = [
|
|||
"lightning_app.components.multi_node.lite",
|
||||
"lightning_app.components.multi_node.base",
|
||||
"lightning_app.components.multi_node.pytorch_spawn",
|
||||
"lightning_app.components.multi_node.pl",
|
||||
"lightning_app.components.multi_node.trainer",
|
||||
"lightning_app.api.http_methods",
|
||||
"lightning_app.api.request_types",
|
||||
"lightning_app.cli.commands.app_commands",
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
from lightning_app.components.database.client import DatabaseClient
|
||||
from lightning_app.components.database.server import Database
|
||||
from lightning_app.components.multi_node import (
|
||||
LightningTrainerMultiNode,
|
||||
LiteMultiNode,
|
||||
MultiNode,
|
||||
PyTorchLightningMultiNode,
|
||||
PyTorchSpawnMultiNode,
|
||||
)
|
||||
from lightning_app.components.python.popen import PopenPythonScript
|
||||
|
@ -29,5 +29,5 @@ __all__ = [
|
|||
"LightningTrainingComponent",
|
||||
"PyTorchLightningScriptRunner",
|
||||
"PyTorchSpawnMultiNode",
|
||||
"PyTorchLightningMultiNode",
|
||||
"LightningTrainerMultiNode",
|
||||
]
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from lightning_app.components.multi_node.base import MultiNode
|
||||
from lightning_app.components.multi_node.lite import LiteMultiNode
|
||||
from lightning_app.components.multi_node.pl import PyTorchLightningMultiNode
|
||||
from lightning_app.components.multi_node.pytorch_spawn import PyTorchSpawnMultiNode
|
||||
from lightning_app.components.multi_node.trainer import LightningTrainerMultiNode
|
||||
|
||||
__all__ = ["LiteMultiNode", "MultiNode", "PyTorchSpawnMultiNode", "PyTorchLightningMultiNode"]
|
||||
__all__ = ["LiteMultiNode", "MultiNode", "PyTorchSpawnMultiNode", "LightningTrainerMultiNode"]
|
||||
|
|
|
@ -13,14 +13,14 @@ from lightning_app.utilities.tracer import Tracer
|
|||
|
||||
|
||||
@runtime_checkable
|
||||
class _PyTorchLightningWorkProtocol(Protocol):
|
||||
class _LightningTrainerWorkProtocol(Protocol):
|
||||
@staticmethod
|
||||
def run() -> None:
|
||||
...
|
||||
|
||||
|
||||
@dataclass
|
||||
class _PyTorchLightningRunExecutor(_PyTorchSpawnRunExecutor):
|
||||
class _LightningTrainerRunExecutor(_PyTorchSpawnRunExecutor):
|
||||
@staticmethod
|
||||
def run(
|
||||
local_rank: int,
|
||||
|
@ -71,7 +71,7 @@ class _PyTorchLightningRunExecutor(_PyTorchSpawnRunExecutor):
|
|||
tracer._restore()
|
||||
|
||||
|
||||
class PyTorchLightningMultiNode(MultiNode):
|
||||
class LightningTrainerMultiNode(MultiNode):
|
||||
def __init__(
|
||||
self,
|
||||
work_cls: Type["LightningWork"],
|
||||
|
@ -80,7 +80,7 @@ class PyTorchLightningMultiNode(MultiNode):
|
|||
*work_args: Any,
|
||||
**work_kwargs: Any,
|
||||
) -> None:
|
||||
assert issubclass(work_cls, _PyTorchLightningWorkProtocol)
|
||||
assert issubclass(work_cls, _LightningTrainerWorkProtocol)
|
||||
if not is_static_method(work_cls, "run"):
|
||||
raise TypeError(
|
||||
f"The provided {work_cls} run method needs to be static for now."
|
||||
|
@ -89,7 +89,7 @@ class PyTorchLightningMultiNode(MultiNode):
|
|||
|
||||
# Note: Private way to modify the work run executor
|
||||
# Probably exposed to the users in the future if needed.
|
||||
work_cls._run_executor_cls = _PyTorchLightningRunExecutor
|
||||
work_cls._run_executor_cls = _LightningTrainerRunExecutor
|
||||
|
||||
super().__init__(
|
||||
work_cls,
|
Loading…
Reference in New Issue