[App] Introduce Multi Node Component (#15524)

This commit is contained in:
thomas chaton 2022-11-04 17:41:59 +00:00 committed by GitHub
parent 0c63534b7e
commit ecc8ac07c6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 155 additions and 6 deletions

View File

@ -1,8 +1,8 @@
from lightning import LightningApp
import lightning as L
from lightning.app.components import LightningTrainingComponent
from lightning.app.utilities.packaging.cloud_compute import CloudCompute
app = LightningApp(
app = L.LightningApp(
LightningTrainingComponent(
"train.py",
num_nodes=2,

View File

@ -0,0 +1,22 @@
import lightning.app as L
from lightning.app.components import MultiNode
class AnyDistributedComponent(L.LightningWork):
def run(
self,
main_address: str,
main_port: int,
node_rank: int,
):
print(f"ADD YOUR DISTRIBUTED CODE: {main_address} {main_port} {node_rank}")
compute = L.CloudCompute("gpu")
app = L.LightningApp(
MultiNode(
AnyDistributedComponent,
nodes=2,
cloud_compute=compute,
)
)

View File

@ -1,5 +1,5 @@
from lightning.pytorch import Trainer
from lightning.pytorch.demos.boring_classes import BoringModel
from pytorch_lightning import Trainer
from pytorch_lightning.demos.boring_classes import BoringModel
if __name__ == "__main__":
model = BoringModel()

View File

@ -36,6 +36,7 @@ exclude = [
"src/lightning_app/cli/react-ui-template",
"src/lightning_app/cli/app-template",
"src/lightning_app/components/database",
"src/lightning_app/components/multi_node",
"src/lightning_app/frontend/just_py/just_py",
]
install_types = "True"
@ -58,6 +59,7 @@ warn_no_return = "False"
# the list can be generated with:
# mypy --no-error-summary 2>&1 | tr ':' ' ' | awk '{print $1}' | sort | uniq | sed 's/\.py//g; s|src/||g; s|\/|\.|g' | xargs -I {} echo '"{}",'
module = [
"lightning_app.components.multi_node",
"lightning_app.api.http_methods",
"lightning_app.api.request_types",
"lightning_app.cli.commands.app_commands",

View File

@ -12,7 +12,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added the `start` method to the work ([#15523](https://github.com/Lightning-AI/lightning/pull/15523))
-
- Added a `MultiNode` Component to run with distributed computation with any frameworks ([#15524](https://github.com/Lightning-AI/lightning/pull/15524))
-

View File

@ -1,5 +1,6 @@
from lightning_app.components.database.client import DatabaseClient
from lightning_app.components.database.server import Database
from lightning_app.components.multi_node import MultiNode
from lightning_app.components.python.popen import PopenPythonScript
from lightning_app.components.python.tracer import Code, TracerPythonScript
from lightning_app.components.serve.gradio import ServeGradio
@ -16,6 +17,7 @@ __all__ = [
"ServeGradio",
"ServeStreamlit",
"ModelInferenceAPI",
"MultiNode",
"LightningTrainingComponent",
"PyTorchLightningScriptRunner",
]

View File

@ -0,0 +1,95 @@
from typing import Any, Type
from lightning_app import structures
from lightning_app.core.flow import LightningFlow
from lightning_app.core.work import LightningWork
from lightning_app.utilities.enum import WorkStageStatus
from lightning_app.utilities.packaging.cloud_compute import CloudCompute
class MultiNode(LightningFlow):
def __init__(
self,
work_cls: Type["LightningWork"],
nodes: int,
cloud_compute: "CloudCompute",
*work_args: Any,
**work_kwargs: Any,
) -> None:
"""This component enables performing distributed multi-node multi-device training.
Example::
import torch
import lightning as L
from lightning.components import MultiNode
class AnyDistributedComponent(L.LightningWork):
def run(
self,
main_address: str,
main_port: int,
node_rank: int,
):
print(f"ADD YOUR DISTRIBUTED CODE: {main_address} {main_port} {node_rank}")
compute = L.CloudCompute("gpu")
app = L.LightningApp(
MultiNode(
AnyDistributedComponent,
nodes=8,
cloud_compute=compute,
)
)
Arguments:
work_cls: The work to be executed
nodes: Number of nodes.
cloud_compute: The cloud compute object used in the cloud.
work_args: Arguments to be provided to the work on instantiation.
work_kwargs: Keywords arguments to be provided to the work on instantiation.
"""
super().__init__()
self.ws = structures.List()
self._work_cls = work_cls
self.nodes = nodes
self._cloud_compute = cloud_compute
self._work_args = work_args
self._work_kwargs = work_kwargs
self.has_initialized = False
def run(self) -> None:
# 1. Create & start the works
if not self.has_initialized:
for node_rank in range(self.nodes):
self.ws.append(
self._work_cls(
*self._work_args,
cloud_compute=self._cloud_compute,
**self._work_kwargs,
parallel=True,
)
)
# Starting node `node_rank`` ...
self.ws[-1].start()
self.has_initialized = True
# 2. Wait for all machines to be started !
if all(w.status.stage == WorkStageStatus.STARTED for w in self.ws):
return
# Loop over all node machines
for node_rank in range(self.nodes):
# 3. Run the user code in a distributed way !
self.ws[node_rank].run(
main_address=self.ws[0].internal_ip,
main_port=self.ws[0].port,
node_rank=node_rank,
)
# 4. Stop the machine when finished.
if self.ws[node_rank].has_succeeded:
self.ws[node_rank].stop()

View File

@ -1,5 +1,6 @@
import os
import pytest
from tests_app import _PROJECT_ROOT
from lightning_app.testing.testing import application_testing, LightningTestApp
@ -8,11 +9,13 @@ from lightning_app.testing.testing import application_testing, LightningTestApp
class LightningTestMultiNodeApp(LightningTestApp):
def on_before_run_once(self):
res = super().on_before_run_once()
if all(w.has_finished for w in self.works):
if self.works and all(w.has_stopped for w in self.works):
assert len([w for w in self.works]) == 2
return True
return res
@pytest.mark.skipif(True, reason="flaky")
def test_multi_node_example():
cwd = os.getcwd()
new_cwd = os.path.join(_PROJECT_ROOT, "examples/app_multi_node")
@ -27,3 +30,28 @@ def test_multi_node_example():
result = application_testing(LightningTestMultiNodeApp, command_line)
assert result.exit_code == 0
os.chdir(cwd)
class LightningTestMultiNodeWorksApp(LightningTestApp):
def on_before_run_once(self):
res = super().on_before_run_once()
if self.works and all(w.has_stopped for w in self.works):
assert len([w for w in self.works]) == 2
return True
return res
def test_multi_node_example_2():
cwd = os.getcwd()
new_cwd = os.path.join(_PROJECT_ROOT, "examples/app_multi_node")
os.chdir(new_cwd)
command_line = [
"app_work.py",
"--blocking",
"False",
"--open-ui",
"False",
]
result = application_testing(LightningTestMultiNodeWorksApp, command_line)
assert result.exit_code == 0
os.chdir(cwd)