[App] Fixed Multi Node and add examples (#15557)
This commit is contained in:
parent
96c574425d
commit
820233176b
|
@ -0,0 +1,41 @@
|
|||
# Lightning & Multi Node Training
|
||||
|
||||
Lightning supports makes multi-node training simple by providing a simple interface to orchestrate compute and data.
|
||||
|
||||
## Multi Node with raw PyTorch
|
||||
|
||||
You can run the multi-node raw PyTorch by running the following commands.
|
||||
|
||||
```bash
|
||||
lightning run app app_torch_work.py
|
||||
```
|
||||
|
||||
## Multi Node with raw PyTorch + Lite
|
||||
|
||||
You can run the multi-node raw PyTorch and Lite by running the following commands.
|
||||
|
||||
```bash
|
||||
lightning run app app_lite_work.py
|
||||
```
|
||||
|
||||
## Multi Node with PyTorch Lightning
|
||||
|
||||
Lightning supports running PyTorch Lightning from a script or within a Lightning Work.
|
||||
|
||||
### Multi Node PyTorch Lightning Script
|
||||
|
||||
```bash
|
||||
lightning run app app_pl_script.py
|
||||
```
|
||||
|
||||
### Multi Node PyTorch Lightning Work
|
||||
|
||||
```bash
|
||||
lightning run app app_pl_work.py
|
||||
```
|
||||
|
||||
## Multi Node with any frameworks
|
||||
|
||||
```bash
|
||||
lightning run app app_generic_work.py
|
||||
```
|
|
@ -1,4 +1,4 @@
|
|||
import lightning.app as L
|
||||
import lightning as L
|
||||
from lightning.app.components import MultiNode
|
||||
|
||||
|
||||
|
@ -7,16 +7,17 @@ class AnyDistributedComponent(L.LightningWork):
|
|||
self,
|
||||
main_address: str,
|
||||
main_port: int,
|
||||
num_nodes: int,
|
||||
node_rank: int,
|
||||
):
|
||||
print(f"ADD YOUR DISTRIBUTED CODE: {main_address} {main_port} {node_rank}")
|
||||
print(f"ADD YOUR DISTRIBUTED CODE: {main_address} {main_port} {num_nodes} {node_rank}.")
|
||||
|
||||
|
||||
compute = L.CloudCompute("gpu")
|
||||
app = L.LightningApp(
|
||||
MultiNode(
|
||||
AnyDistributedComponent,
|
||||
nodes=2,
|
||||
num_nodes=2,
|
||||
cloud_compute=compute,
|
||||
)
|
||||
)
|
|
@ -0,0 +1,59 @@
|
|||
import os
|
||||
|
||||
import torch
|
||||
|
||||
import lightning as L
|
||||
from lightning.app.components import MultiNode
|
||||
from lightning.lite import LightningLite
|
||||
|
||||
|
||||
def distributed_train(lite: LightningLite):
|
||||
# 1. Prepare distributed model and optimizer
|
||||
model = torch.nn.Linear(32, 2)
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
|
||||
model, optimizer = lite.setup(model, optimizer)
|
||||
criterion = torch.nn.MSELoss()
|
||||
|
||||
# 2. Train the model for 50 steps.
|
||||
for step in range(50):
|
||||
model.zero_grad()
|
||||
x = torch.randn(64, 32).to(lite.device)
|
||||
output = model(x)
|
||||
loss = criterion(output, torch.ones_like(output))
|
||||
print(f"global_rank: {lite.global_rank} step: {step} loss: {loss}")
|
||||
lite.backward(loss)
|
||||
optimizer.step()
|
||||
|
||||
# 3. Verify all processes have the same weights at the end of training.
|
||||
weight = model.module.weight.clone()
|
||||
torch.distributed.all_reduce(weight)
|
||||
assert torch.equal(model.module.weight, weight / lite.world_size)
|
||||
|
||||
print("Multi Node Distributed Training Done!")
|
||||
|
||||
|
||||
class PyTorchDistributed(L.LightningWork):
|
||||
def run(
|
||||
self,
|
||||
main_address: str,
|
||||
main_port: int,
|
||||
num_nodes: int,
|
||||
node_rank: int,
|
||||
):
|
||||
|
||||
os.environ["MASTER_ADDR"] = main_address
|
||||
os.environ["MASTER_PORT"] = str(main_port)
|
||||
os.environ["NODE_RANK"] = str(node_rank)
|
||||
|
||||
lite = LightningLite(accelerator="auto", devices="auto", strategy="ddp_spawn", num_nodes=num_nodes)
|
||||
lite.launch(function=distributed_train)
|
||||
|
||||
|
||||
compute = L.CloudCompute("gpu-fast-multi") # 4xV100
|
||||
app = L.LightningApp(
|
||||
MultiNode(
|
||||
PyTorchDistributed,
|
||||
num_nodes=2,
|
||||
cloud_compute=compute,
|
||||
)
|
||||
)
|
|
@ -0,0 +1,38 @@
|
|||
import os
|
||||
|
||||
import lightning as L
|
||||
from lightning.app.components import MultiNode
|
||||
from lightning.pytorch.demos.boring_classes import BoringModel
|
||||
|
||||
|
||||
class PyTorchLightningDistributed(L.LightningWork):
|
||||
def run(
|
||||
self,
|
||||
main_address: str,
|
||||
main_port: int,
|
||||
num_nodes: int,
|
||||
node_rank: int,
|
||||
):
|
||||
os.environ["MASTER_ADDR"] = main_address
|
||||
os.environ["MASTER_PORT"] = str(main_port)
|
||||
os.environ["NODE_RANK"] = str(node_rank)
|
||||
|
||||
model = BoringModel()
|
||||
trainer = L.Trainer(
|
||||
max_epochs=10,
|
||||
devices="auto",
|
||||
accelerator="auto",
|
||||
num_nodes=num_nodes,
|
||||
strategy="ddp_spawn", # Only spawn based strategies are supported for now.
|
||||
)
|
||||
trainer.fit(model)
|
||||
|
||||
|
||||
compute = L.CloudCompute("gpu-fast-multi") # 4xV100
|
||||
app = L.LightningApp(
|
||||
MultiNode(
|
||||
PyTorchLightningDistributed,
|
||||
num_nodes=2,
|
||||
cloud_compute=compute,
|
||||
)
|
||||
)
|
|
@ -0,0 +1,70 @@
|
|||
import torch
|
||||
from torch.nn.parallel.distributed import DistributedDataParallel
|
||||
|
||||
import lightning as L
|
||||
from lightning.app.components import MultiNode
|
||||
|
||||
|
||||
def distributed_train(local_rank: int, main_address: str, main_port: int, num_nodes: int, node_rank: int, nprocs: int):
|
||||
# 1. Setting distributed environment
|
||||
global_rank = local_rank + node_rank * nprocs
|
||||
world_size = num_nodes * nprocs
|
||||
|
||||
if torch.distributed.is_available() and not torch.distributed.is_initialized():
|
||||
torch.distributed.init_process_group(
|
||||
"nccl" if torch.cuda.is_available() else "gloo",
|
||||
rank=global_rank,
|
||||
world_size=world_size,
|
||||
init_method=f"tcp://{main_address}:{main_port}",
|
||||
)
|
||||
|
||||
# 2. Prepare distributed model
|
||||
model = torch.nn.Linear(32, 2)
|
||||
device = torch.device(f"cuda:{local_rank}") if torch.cuda.is_available() else torch.device("cpu")
|
||||
device_ids = device if torch.cuda.is_available() else None
|
||||
model = DistributedDataParallel(model, device_ids=device_ids).to(device)
|
||||
|
||||
# 3. Prepare loss and optimizer
|
||||
criterion = torch.nn.MSELoss()
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
|
||||
|
||||
# 4. Train the model for 50 steps.
|
||||
for step in range(50):
|
||||
model.zero_grad()
|
||||
x = torch.randn(64, 32).to(device)
|
||||
output = model(x)
|
||||
loss = criterion(output, torch.ones_like(output))
|
||||
print(f"global_rank: {global_rank} step: {step} loss: {loss}")
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# 5. Verify all processes have the same weights at the end of training.
|
||||
weight = model.module.weight.clone()
|
||||
torch.distributed.all_reduce(weight)
|
||||
assert torch.equal(model.module.weight, weight / world_size)
|
||||
|
||||
print("Multi Node Distributed Training Done!")
|
||||
|
||||
|
||||
class PyTorchDistributed(L.LightningWork):
|
||||
def run(
|
||||
self,
|
||||
main_address: str,
|
||||
main_port: int,
|
||||
num_nodes: int,
|
||||
node_rank: int,
|
||||
):
|
||||
nprocs = torch.cuda.device_count() if torch.cuda.is_available() else 1
|
||||
torch.multiprocessing.spawn(
|
||||
distributed_train, args=(main_address, main_port, num_nodes, node_rank, nprocs), nprocs=nprocs
|
||||
)
|
||||
|
||||
|
||||
compute = L.CloudCompute("gpu-fast-multi") # 4xV100
|
||||
app = L.LightningApp(
|
||||
MultiNode(
|
||||
PyTorchDistributed,
|
||||
num_nodes=2,
|
||||
cloud_compute=compute,
|
||||
)
|
||||
)
|
|
@ -1,2 +0,0 @@
|
|||
.storage/
|
||||
.shared/
|
|
@ -1,36 +0,0 @@
|
|||
import lightning as L
|
||||
|
||||
|
||||
class Work(L.LightningWork):
|
||||
def __init__(self, cloud_compute: L.CloudCompute = L.CloudCompute(), **kwargs):
|
||||
super().__init__(parallel=True, **kwargs, cloud_compute=cloud_compute)
|
||||
|
||||
def run(self, main_address="localhost", main_port=1111, world_size=1, rank=0, init=False):
|
||||
if init:
|
||||
return
|
||||
|
||||
import torch.distributed
|
||||
|
||||
print(f"Initializing process group: {main_address=}, {main_port=}, {world_size=}, {rank=}")
|
||||
torch.distributed.init_process_group(
|
||||
backend="gloo", init_method=f"tcp://{main_address}:{main_port}", world_size=world_size, rank=rank
|
||||
)
|
||||
gathered = [torch.zeros(1) for _ in range(world_size)]
|
||||
torch.distributed.all_gather(gathered, torch.tensor([rank]).float())
|
||||
print(gathered)
|
||||
|
||||
|
||||
class MultiNodeDemo(L.LightningFlow):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.work0 = Work()
|
||||
self.work1 = Work()
|
||||
|
||||
def run(self):
|
||||
self.work0.run(init=True)
|
||||
if self.work0.internal_ip:
|
||||
self.work0.run(main_address=self.work0.internal_ip, main_port=self.work0.port, world_size=2, rank=0)
|
||||
self.work1.run(main_address=self.work0.internal_ip, main_port=self.work0.port, world_size=2, rank=1)
|
||||
|
||||
|
||||
app = L.LightningApp(MultiNodeDemo())
|
|
@ -1,7 +1,7 @@
|
|||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.demos.boring_classes import BoringModel
|
||||
import lightning as L
|
||||
from lightning.pytorch.demos.boring_classes import BoringModel
|
||||
|
||||
if __name__ == "__main__":
|
||||
model = BoringModel()
|
||||
trainer = Trainer(max_epochs=1)
|
||||
trainer = L.Trainer(max_epochs=1)
|
||||
trainer.fit(model)
|
||||
|
|
|
@ -51,7 +51,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
- Fixed missing root flow among the flows of the app ([#15531](https://github.com/Lightning-AI/lightning/pull/15531))
|
||||
|
||||
-
|
||||
- Fixed bug with Multi Node Component and add some examples ([#15557](https://github.com/Lightning-AI/lightning/pull/15557))
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -11,7 +11,7 @@ class MultiNode(LightningFlow):
|
|||
def __init__(
|
||||
self,
|
||||
work_cls: Type["LightningWork"],
|
||||
nodes: int,
|
||||
num_nodes: int,
|
||||
cloud_compute: "CloudCompute",
|
||||
*work_args: Any,
|
||||
**work_kwargs: Any,
|
||||
|
@ -39,14 +39,14 @@ class MultiNode(LightningFlow):
|
|||
app = L.LightningApp(
|
||||
MultiNode(
|
||||
AnyDistributedComponent,
|
||||
nodes=8,
|
||||
num_nodes=8,
|
||||
cloud_compute=compute,
|
||||
)
|
||||
)
|
||||
|
||||
Arguments:
|
||||
work_cls: The work to be executed
|
||||
nodes: Number of nodes.
|
||||
num_nodes: Number of nodes.
|
||||
cloud_compute: The cloud compute object used in the cloud.
|
||||
work_args: Arguments to be provided to the work on instantiation.
|
||||
work_kwargs: Keywords arguments to be provided to the work on instantiation.
|
||||
|
@ -54,7 +54,7 @@ class MultiNode(LightningFlow):
|
|||
super().__init__()
|
||||
self.ws = structures.List()
|
||||
self._work_cls = work_cls
|
||||
self.nodes = nodes
|
||||
self.num_nodes = num_nodes
|
||||
self._cloud_compute = cloud_compute
|
||||
self._work_args = work_args
|
||||
self._work_kwargs = work_kwargs
|
||||
|
@ -65,7 +65,7 @@ class MultiNode(LightningFlow):
|
|||
|
||||
# 1. Create & start the works
|
||||
if not self.ws:
|
||||
for node_rank in range(self.nodes):
|
||||
for node_rank in range(self.num_nodes):
|
||||
self.ws.append(
|
||||
self._work_cls(
|
||||
*self._work_args,
|
||||
|
@ -84,12 +84,13 @@ class MultiNode(LightningFlow):
|
|||
self.has_started = True
|
||||
|
||||
# Loop over all node machines
|
||||
for node_rank in range(self.nodes):
|
||||
for node_rank in range(self.num_nodes):
|
||||
|
||||
# 3. Run the user code in a distributed way !
|
||||
self.ws[node_rank].run(
|
||||
main_address=self.ws[0].internal_ip,
|
||||
main_port=self.ws[0].port,
|
||||
num_nodes=self.num_nodes,
|
||||
node_rank=node_rank,
|
||||
)
|
||||
|
||||
|
|
|
@ -395,18 +395,18 @@ class WorkRunner:
|
|||
# 6. Create the state observer thread.
|
||||
self.state_observer = WorkStateObserver(self.work, delta_queue=self.delta_queue)
|
||||
|
||||
# 7. Deepcopy the work state and send the first `RUNNING` status delta to the flow.
|
||||
reference_state = deepcopy(self.work.state)
|
||||
|
||||
# Set the internal IP address.
|
||||
# Set this here after the state observer is initialized, since it needs to record it as a change and send
|
||||
# it back to the flow
|
||||
self.work._internal_ip = os.environ.get("LIGHTNING_NODE_IP", "127.0.0.1")
|
||||
|
||||
# 7. Patch the setattr method of the work. This needs to be done after step 4, so we don't
|
||||
# 8. Patch the setattr method of the work. This needs to be done after step 4, so we don't
|
||||
# send delta while calling `set_state`.
|
||||
self._proxy_setattr()
|
||||
|
||||
# 8. Deepcopy the work state and send the first `RUNNING` status delta to the flow.
|
||||
reference_state = deepcopy(self.work.state)
|
||||
|
||||
if self._is_starting(called, reference_state, call_hash):
|
||||
return
|
||||
|
||||
|
|
|
@ -0,0 +1,60 @@
|
|||
import os
|
||||
import shutil
|
||||
import threading
|
||||
|
||||
import psutil
|
||||
import pytest
|
||||
|
||||
from lightning_app.storage.path import _storage_root_dir
|
||||
from lightning_app.utilities.component import _set_context
|
||||
from lightning_app.utilities.packaging import cloud_compute
|
||||
from lightning_app.utilities.packaging.app_config import _APP_CONFIG_FILENAME
|
||||
from lightning_app.utilities.state import AppState
|
||||
|
||||
|
||||
def pytest_sessionfinish(session, exitstatus):
|
||||
"""Pytest hook that get called after whole test run finished, right before returning the exit status to the
|
||||
system."""
|
||||
# kill all the processes and threads created by parent
|
||||
# TODO this isn't great. We should have each tests doing it's own cleanup
|
||||
current_process = psutil.Process()
|
||||
for child in current_process.children(recursive=True):
|
||||
params = child.as_dict() or {}
|
||||
cmd_lines = params.get("cmdline", [])
|
||||
# we shouldn't kill the resource tracker from multiprocessing. If we do,
|
||||
# `atexit` will throw as it uses resource tracker to try to clean up
|
||||
if cmd_lines and "resource_tracker" in cmd_lines[-1]:
|
||||
continue
|
||||
child.kill()
|
||||
|
||||
main_thread = threading.current_thread()
|
||||
for t in threading.enumerate():
|
||||
if t is not main_thread:
|
||||
t.join(0)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def cleanup():
|
||||
from lightning_app.utilities.app_helpers import _LightningAppRef
|
||||
|
||||
yield
|
||||
_LightningAppRef._app_instance = None
|
||||
shutil.rmtree("./storage", ignore_errors=True)
|
||||
shutil.rmtree(_storage_root_dir(), ignore_errors=True)
|
||||
shutil.rmtree("./.shared", ignore_errors=True)
|
||||
if os.path.isfile(_APP_CONFIG_FILENAME):
|
||||
os.remove(_APP_CONFIG_FILENAME)
|
||||
_set_context(None)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def clear_app_state_state_variables():
|
||||
"""Resets global variables in order to prevent interference between tests."""
|
||||
yield
|
||||
import lightning_app.utilities.state
|
||||
|
||||
lightning_app.utilities.state._STATE = None
|
||||
lightning_app.utilities.state._LAST_STATE = None
|
||||
AppState._MY_AFFILIATION = ()
|
||||
if hasattr(cloud_compute, "_CLOUD_COMPUTE_STORE"):
|
||||
cloud_compute._CLOUD_COMPUTE_STORE.clear()
|
|
@ -41,12 +41,21 @@ class LightningTestMultiNodeWorksApp(LightningTestApp):
|
|||
return res
|
||||
|
||||
|
||||
def test_multi_node_example_2():
|
||||
@pytest.mark.parametrize(
|
||||
"app_name",
|
||||
[
|
||||
"app_torch_work.py",
|
||||
"app_generic_work.py",
|
||||
# "app_lite_work.py",
|
||||
# "app_pl_work.py": TODO Add once https://github.com/Lightning-AI/lightning/issues/15556 is resolved.
|
||||
],
|
||||
)
|
||||
def test_multi_node_examples(app_name):
|
||||
cwd = os.getcwd()
|
||||
new_cwd = os.path.join(_PROJECT_ROOT, "examples/app_multi_node")
|
||||
os.chdir(new_cwd)
|
||||
command_line = [
|
||||
"app_work.py",
|
||||
app_name,
|
||||
"--blocking",
|
||||
"False",
|
||||
"--open-ui",
|
||||
|
|
Loading…
Reference in New Issue