[App] Introduce Multi Node Component (#15524)
This commit is contained in:
parent
0c63534b7e
commit
ecc8ac07c6
|
@ -1,8 +1,8 @@
|
|||
from lightning import LightningApp
|
||||
import lightning as L
|
||||
from lightning.app.components import LightningTrainingComponent
|
||||
from lightning.app.utilities.packaging.cloud_compute import CloudCompute
|
||||
|
||||
app = LightningApp(
|
||||
app = L.LightningApp(
|
||||
LightningTrainingComponent(
|
||||
"train.py",
|
||||
num_nodes=2,
|
||||
|
|
|
@ -0,0 +1,22 @@
|
|||
import lightning.app as L
|
||||
from lightning.app.components import MultiNode
|
||||
|
||||
|
||||
class AnyDistributedComponent(L.LightningWork):
|
||||
def run(
|
||||
self,
|
||||
main_address: str,
|
||||
main_port: int,
|
||||
node_rank: int,
|
||||
):
|
||||
print(f"ADD YOUR DISTRIBUTED CODE: {main_address} {main_port} {node_rank}")
|
||||
|
||||
|
||||
compute = L.CloudCompute("gpu")
|
||||
app = L.LightningApp(
|
||||
MultiNode(
|
||||
AnyDistributedComponent,
|
||||
nodes=2,
|
||||
cloud_compute=compute,
|
||||
)
|
||||
)
|
|
@ -1,5 +1,5 @@
|
|||
from lightning.pytorch import Trainer
|
||||
from lightning.pytorch.demos.boring_classes import BoringModel
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.demos.boring_classes import BoringModel
|
||||
|
||||
if __name__ == "__main__":
|
||||
model = BoringModel()
|
||||
|
|
|
@ -36,6 +36,7 @@ exclude = [
|
|||
"src/lightning_app/cli/react-ui-template",
|
||||
"src/lightning_app/cli/app-template",
|
||||
"src/lightning_app/components/database",
|
||||
"src/lightning_app/components/multi_node",
|
||||
"src/lightning_app/frontend/just_py/just_py",
|
||||
]
|
||||
install_types = "True"
|
||||
|
@ -58,6 +59,7 @@ warn_no_return = "False"
|
|||
# the list can be generated with:
|
||||
# mypy --no-error-summary 2>&1 | tr ':' ' ' | awk '{print $1}' | sort | uniq | sed 's/\.py//g; s|src/||g; s|\/|\.|g' | xargs -I {} echo '"{}",'
|
||||
module = [
|
||||
"lightning_app.components.multi_node",
|
||||
"lightning_app.api.http_methods",
|
||||
"lightning_app.api.request_types",
|
||||
"lightning_app.cli.commands.app_commands",
|
||||
|
|
|
@ -12,7 +12,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
- Added the `start` method to the work ([#15523](https://github.com/Lightning-AI/lightning/pull/15523))
|
||||
|
||||
-
|
||||
- Added a `MultiNode` Component to run with distributed computation with any frameworks ([#15524](https://github.com/Lightning-AI/lightning/pull/15524))
|
||||
|
||||
-
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from lightning_app.components.database.client import DatabaseClient
|
||||
from lightning_app.components.database.server import Database
|
||||
from lightning_app.components.multi_node import MultiNode
|
||||
from lightning_app.components.python.popen import PopenPythonScript
|
||||
from lightning_app.components.python.tracer import Code, TracerPythonScript
|
||||
from lightning_app.components.serve.gradio import ServeGradio
|
||||
|
@ -16,6 +17,7 @@ __all__ = [
|
|||
"ServeGradio",
|
||||
"ServeStreamlit",
|
||||
"ModelInferenceAPI",
|
||||
"MultiNode",
|
||||
"LightningTrainingComponent",
|
||||
"PyTorchLightningScriptRunner",
|
||||
]
|
||||
|
|
|
@ -0,0 +1,95 @@
|
|||
from typing import Any, Type
|
||||
|
||||
from lightning_app import structures
|
||||
from lightning_app.core.flow import LightningFlow
|
||||
from lightning_app.core.work import LightningWork
|
||||
from lightning_app.utilities.enum import WorkStageStatus
|
||||
from lightning_app.utilities.packaging.cloud_compute import CloudCompute
|
||||
|
||||
|
||||
class MultiNode(LightningFlow):
|
||||
def __init__(
|
||||
self,
|
||||
work_cls: Type["LightningWork"],
|
||||
nodes: int,
|
||||
cloud_compute: "CloudCompute",
|
||||
*work_args: Any,
|
||||
**work_kwargs: Any,
|
||||
) -> None:
|
||||
"""This component enables performing distributed multi-node multi-device training.
|
||||
|
||||
Example::
|
||||
|
||||
import torch
|
||||
|
||||
import lightning as L
|
||||
from lightning.components import MultiNode
|
||||
|
||||
class AnyDistributedComponent(L.LightningWork):
|
||||
def run(
|
||||
self,
|
||||
main_address: str,
|
||||
main_port: int,
|
||||
node_rank: int,
|
||||
):
|
||||
print(f"ADD YOUR DISTRIBUTED CODE: {main_address} {main_port} {node_rank}")
|
||||
|
||||
|
||||
compute = L.CloudCompute("gpu")
|
||||
app = L.LightningApp(
|
||||
MultiNode(
|
||||
AnyDistributedComponent,
|
||||
nodes=8,
|
||||
cloud_compute=compute,
|
||||
)
|
||||
)
|
||||
|
||||
Arguments:
|
||||
work_cls: The work to be executed
|
||||
nodes: Number of nodes.
|
||||
cloud_compute: The cloud compute object used in the cloud.
|
||||
work_args: Arguments to be provided to the work on instantiation.
|
||||
work_kwargs: Keywords arguments to be provided to the work on instantiation.
|
||||
"""
|
||||
super().__init__()
|
||||
self.ws = structures.List()
|
||||
self._work_cls = work_cls
|
||||
self.nodes = nodes
|
||||
self._cloud_compute = cloud_compute
|
||||
self._work_args = work_args
|
||||
self._work_kwargs = work_kwargs
|
||||
self.has_initialized = False
|
||||
|
||||
def run(self) -> None:
|
||||
# 1. Create & start the works
|
||||
if not self.has_initialized:
|
||||
for node_rank in range(self.nodes):
|
||||
self.ws.append(
|
||||
self._work_cls(
|
||||
*self._work_args,
|
||||
cloud_compute=self._cloud_compute,
|
||||
**self._work_kwargs,
|
||||
parallel=True,
|
||||
)
|
||||
)
|
||||
# Starting node `node_rank`` ...
|
||||
self.ws[-1].start()
|
||||
self.has_initialized = True
|
||||
|
||||
# 2. Wait for all machines to be started !
|
||||
if all(w.status.stage == WorkStageStatus.STARTED for w in self.ws):
|
||||
return
|
||||
|
||||
# Loop over all node machines
|
||||
for node_rank in range(self.nodes):
|
||||
|
||||
# 3. Run the user code in a distributed way !
|
||||
self.ws[node_rank].run(
|
||||
main_address=self.ws[0].internal_ip,
|
||||
main_port=self.ws[0].port,
|
||||
node_rank=node_rank,
|
||||
)
|
||||
|
||||
# 4. Stop the machine when finished.
|
||||
if self.ws[node_rank].has_succeeded:
|
||||
self.ws[node_rank].stop()
|
|
@ -1,5 +1,6 @@
|
|||
import os
|
||||
|
||||
import pytest
|
||||
from tests_app import _PROJECT_ROOT
|
||||
|
||||
from lightning_app.testing.testing import application_testing, LightningTestApp
|
||||
|
@ -8,11 +9,13 @@ from lightning_app.testing.testing import application_testing, LightningTestApp
|
|||
class LightningTestMultiNodeApp(LightningTestApp):
|
||||
def on_before_run_once(self):
|
||||
res = super().on_before_run_once()
|
||||
if all(w.has_finished for w in self.works):
|
||||
if self.works and all(w.has_stopped for w in self.works):
|
||||
assert len([w for w in self.works]) == 2
|
||||
return True
|
||||
return res
|
||||
|
||||
|
||||
@pytest.mark.skipif(True, reason="flaky")
|
||||
def test_multi_node_example():
|
||||
cwd = os.getcwd()
|
||||
new_cwd = os.path.join(_PROJECT_ROOT, "examples/app_multi_node")
|
||||
|
@ -27,3 +30,28 @@ def test_multi_node_example():
|
|||
result = application_testing(LightningTestMultiNodeApp, command_line)
|
||||
assert result.exit_code == 0
|
||||
os.chdir(cwd)
|
||||
|
||||
|
||||
class LightningTestMultiNodeWorksApp(LightningTestApp):
|
||||
def on_before_run_once(self):
|
||||
res = super().on_before_run_once()
|
||||
if self.works and all(w.has_stopped for w in self.works):
|
||||
assert len([w for w in self.works]) == 2
|
||||
return True
|
||||
return res
|
||||
|
||||
|
||||
def test_multi_node_example_2():
|
||||
cwd = os.getcwd()
|
||||
new_cwd = os.path.join(_PROJECT_ROOT, "examples/app_multi_node")
|
||||
os.chdir(new_cwd)
|
||||
command_line = [
|
||||
"app_work.py",
|
||||
"--blocking",
|
||||
"False",
|
||||
"--open-ui",
|
||||
"False",
|
||||
]
|
||||
result = application_testing(LightningTestMultiNodeWorksApp, command_line)
|
||||
assert result.exit_code == 0
|
||||
os.chdir(cwd)
|
||||
|
|
Loading…
Reference in New Issue