[App] Fixed Multi Node and add examples (#15557)

This commit is contained in:
thomas chaton 2022-11-07 09:36:41 +00:00 committed by GitHub
parent 96c574425d
commit 820233176b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 298 additions and 57 deletions

View File

@ -0,0 +1,41 @@
# Lightning & Multi Node Training
Lightning supports makes multi-node training simple by providing a simple interface to orchestrate compute and data.
## Multi Node with raw PyTorch
You can run the multi-node raw PyTorch by running the following commands.
```bash
lightning run app app_torch_work.py
```
## Multi Node with raw PyTorch + Lite
You can run the multi-node raw PyTorch and Lite by running the following commands.
```bash
lightning run app app_lite_work.py
```
## Multi Node with PyTorch Lightning
Lightning supports running PyTorch Lightning from a script or within a Lightning Work.
### Multi Node PyTorch Lightning Script
```bash
lightning run app app_pl_script.py
```
### Multi Node PyTorch Lightning Work
```bash
lightning run app app_pl_work.py
```
## Multi Node with any frameworks
```bash
lightning run app app_generic_work.py
```

View File

@ -1,4 +1,4 @@
import lightning.app as L
import lightning as L
from lightning.app.components import MultiNode
@ -7,16 +7,17 @@ class AnyDistributedComponent(L.LightningWork):
self,
main_address: str,
main_port: int,
num_nodes: int,
node_rank: int,
):
print(f"ADD YOUR DISTRIBUTED CODE: {main_address} {main_port} {node_rank}")
print(f"ADD YOUR DISTRIBUTED CODE: {main_address} {main_port} {num_nodes} {node_rank}.")
compute = L.CloudCompute("gpu")
app = L.LightningApp(
MultiNode(
AnyDistributedComponent,
nodes=2,
num_nodes=2,
cloud_compute=compute,
)
)

View File

@ -0,0 +1,59 @@
import os
import torch
import lightning as L
from lightning.app.components import MultiNode
from lightning.lite import LightningLite
def distributed_train(lite: LightningLite):
# 1. Prepare distributed model and optimizer
model = torch.nn.Linear(32, 2)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
model, optimizer = lite.setup(model, optimizer)
criterion = torch.nn.MSELoss()
# 2. Train the model for 50 steps.
for step in range(50):
model.zero_grad()
x = torch.randn(64, 32).to(lite.device)
output = model(x)
loss = criterion(output, torch.ones_like(output))
print(f"global_rank: {lite.global_rank} step: {step} loss: {loss}")
lite.backward(loss)
optimizer.step()
# 3. Verify all processes have the same weights at the end of training.
weight = model.module.weight.clone()
torch.distributed.all_reduce(weight)
assert torch.equal(model.module.weight, weight / lite.world_size)
print("Multi Node Distributed Training Done!")
class PyTorchDistributed(L.LightningWork):
def run(
self,
main_address: str,
main_port: int,
num_nodes: int,
node_rank: int,
):
os.environ["MASTER_ADDR"] = main_address
os.environ["MASTER_PORT"] = str(main_port)
os.environ["NODE_RANK"] = str(node_rank)
lite = LightningLite(accelerator="auto", devices="auto", strategy="ddp_spawn", num_nodes=num_nodes)
lite.launch(function=distributed_train)
compute = L.CloudCompute("gpu-fast-multi") # 4xV100
app = L.LightningApp(
MultiNode(
PyTorchDistributed,
num_nodes=2,
cloud_compute=compute,
)
)

View File

@ -0,0 +1,38 @@
import os
import lightning as L
from lightning.app.components import MultiNode
from lightning.pytorch.demos.boring_classes import BoringModel
class PyTorchLightningDistributed(L.LightningWork):
def run(
self,
main_address: str,
main_port: int,
num_nodes: int,
node_rank: int,
):
os.environ["MASTER_ADDR"] = main_address
os.environ["MASTER_PORT"] = str(main_port)
os.environ["NODE_RANK"] = str(node_rank)
model = BoringModel()
trainer = L.Trainer(
max_epochs=10,
devices="auto",
accelerator="auto",
num_nodes=num_nodes,
strategy="ddp_spawn", # Only spawn based strategies are supported for now.
)
trainer.fit(model)
compute = L.CloudCompute("gpu-fast-multi") # 4xV100
app = L.LightningApp(
MultiNode(
PyTorchLightningDistributed,
num_nodes=2,
cloud_compute=compute,
)
)

View File

@ -0,0 +1,70 @@
import torch
from torch.nn.parallel.distributed import DistributedDataParallel
import lightning as L
from lightning.app.components import MultiNode
def distributed_train(local_rank: int, main_address: str, main_port: int, num_nodes: int, node_rank: int, nprocs: int):
# 1. Setting distributed environment
global_rank = local_rank + node_rank * nprocs
world_size = num_nodes * nprocs
if torch.distributed.is_available() and not torch.distributed.is_initialized():
torch.distributed.init_process_group(
"nccl" if torch.cuda.is_available() else "gloo",
rank=global_rank,
world_size=world_size,
init_method=f"tcp://{main_address}:{main_port}",
)
# 2. Prepare distributed model
model = torch.nn.Linear(32, 2)
device = torch.device(f"cuda:{local_rank}") if torch.cuda.is_available() else torch.device("cpu")
device_ids = device if torch.cuda.is_available() else None
model = DistributedDataParallel(model, device_ids=device_ids).to(device)
# 3. Prepare loss and optimizer
criterion = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
# 4. Train the model for 50 steps.
for step in range(50):
model.zero_grad()
x = torch.randn(64, 32).to(device)
output = model(x)
loss = criterion(output, torch.ones_like(output))
print(f"global_rank: {global_rank} step: {step} loss: {loss}")
loss.backward()
optimizer.step()
# 5. Verify all processes have the same weights at the end of training.
weight = model.module.weight.clone()
torch.distributed.all_reduce(weight)
assert torch.equal(model.module.weight, weight / world_size)
print("Multi Node Distributed Training Done!")
class PyTorchDistributed(L.LightningWork):
def run(
self,
main_address: str,
main_port: int,
num_nodes: int,
node_rank: int,
):
nprocs = torch.cuda.device_count() if torch.cuda.is_available() else 1
torch.multiprocessing.spawn(
distributed_train, args=(main_address, main_port, num_nodes, node_rank, nprocs), nprocs=nprocs
)
compute = L.CloudCompute("gpu-fast-multi") # 4xV100
app = L.LightningApp(
MultiNode(
PyTorchDistributed,
num_nodes=2,
cloud_compute=compute,
)
)

View File

@ -1,2 +0,0 @@
.storage/
.shared/

View File

@ -1,36 +0,0 @@
import lightning as L
class Work(L.LightningWork):
def __init__(self, cloud_compute: L.CloudCompute = L.CloudCompute(), **kwargs):
super().__init__(parallel=True, **kwargs, cloud_compute=cloud_compute)
def run(self, main_address="localhost", main_port=1111, world_size=1, rank=0, init=False):
if init:
return
import torch.distributed
print(f"Initializing process group: {main_address=}, {main_port=}, {world_size=}, {rank=}")
torch.distributed.init_process_group(
backend="gloo", init_method=f"tcp://{main_address}:{main_port}", world_size=world_size, rank=rank
)
gathered = [torch.zeros(1) for _ in range(world_size)]
torch.distributed.all_gather(gathered, torch.tensor([rank]).float())
print(gathered)
class MultiNodeDemo(L.LightningFlow):
def __init__(self):
super().__init__()
self.work0 = Work()
self.work1 = Work()
def run(self):
self.work0.run(init=True)
if self.work0.internal_ip:
self.work0.run(main_address=self.work0.internal_ip, main_port=self.work0.port, world_size=2, rank=0)
self.work1.run(main_address=self.work0.internal_ip, main_port=self.work0.port, world_size=2, rank=1)
app = L.LightningApp(MultiNodeDemo())

View File

@ -1,7 +1,7 @@
from pytorch_lightning import Trainer
from pytorch_lightning.demos.boring_classes import BoringModel
import lightning as L
from lightning.pytorch.demos.boring_classes import BoringModel
if __name__ == "__main__":
model = BoringModel()
trainer = Trainer(max_epochs=1)
trainer = L.Trainer(max_epochs=1)
trainer.fit(model)

View File

@ -51,7 +51,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed missing root flow among the flows of the app ([#15531](https://github.com/Lightning-AI/lightning/pull/15531))
-
- Fixed bug with Multi Node Component and add some examples ([#15557](https://github.com/Lightning-AI/lightning/pull/15557))

View File

@ -11,7 +11,7 @@ class MultiNode(LightningFlow):
def __init__(
self,
work_cls: Type["LightningWork"],
nodes: int,
num_nodes: int,
cloud_compute: "CloudCompute",
*work_args: Any,
**work_kwargs: Any,
@ -39,14 +39,14 @@ class MultiNode(LightningFlow):
app = L.LightningApp(
MultiNode(
AnyDistributedComponent,
nodes=8,
num_nodes=8,
cloud_compute=compute,
)
)
Arguments:
work_cls: The work to be executed
nodes: Number of nodes.
num_nodes: Number of nodes.
cloud_compute: The cloud compute object used in the cloud.
work_args: Arguments to be provided to the work on instantiation.
work_kwargs: Keywords arguments to be provided to the work on instantiation.
@ -54,7 +54,7 @@ class MultiNode(LightningFlow):
super().__init__()
self.ws = structures.List()
self._work_cls = work_cls
self.nodes = nodes
self.num_nodes = num_nodes
self._cloud_compute = cloud_compute
self._work_args = work_args
self._work_kwargs = work_kwargs
@ -65,7 +65,7 @@ class MultiNode(LightningFlow):
# 1. Create & start the works
if not self.ws:
for node_rank in range(self.nodes):
for node_rank in range(self.num_nodes):
self.ws.append(
self._work_cls(
*self._work_args,
@ -84,12 +84,13 @@ class MultiNode(LightningFlow):
self.has_started = True
# Loop over all node machines
for node_rank in range(self.nodes):
for node_rank in range(self.num_nodes):
# 3. Run the user code in a distributed way !
self.ws[node_rank].run(
main_address=self.ws[0].internal_ip,
main_port=self.ws[0].port,
num_nodes=self.num_nodes,
node_rank=node_rank,
)

View File

@ -395,18 +395,18 @@ class WorkRunner:
# 6. Create the state observer thread.
self.state_observer = WorkStateObserver(self.work, delta_queue=self.delta_queue)
# 7. Deepcopy the work state and send the first `RUNNING` status delta to the flow.
reference_state = deepcopy(self.work.state)
# Set the internal IP address.
# Set this here after the state observer is initialized, since it needs to record it as a change and send
# it back to the flow
self.work._internal_ip = os.environ.get("LIGHTNING_NODE_IP", "127.0.0.1")
# 7. Patch the setattr method of the work. This needs to be done after step 4, so we don't
# 8. Patch the setattr method of the work. This needs to be done after step 4, so we don't
# send delta while calling `set_state`.
self._proxy_setattr()
# 8. Deepcopy the work state and send the first `RUNNING` status delta to the flow.
reference_state = deepcopy(self.work.state)
if self._is_starting(called, reference_state, call_hash):
return

View File

@ -0,0 +1,60 @@
import os
import shutil
import threading
import psutil
import pytest
from lightning_app.storage.path import _storage_root_dir
from lightning_app.utilities.component import _set_context
from lightning_app.utilities.packaging import cloud_compute
from lightning_app.utilities.packaging.app_config import _APP_CONFIG_FILENAME
from lightning_app.utilities.state import AppState
def pytest_sessionfinish(session, exitstatus):
"""Pytest hook that get called after whole test run finished, right before returning the exit status to the
system."""
# kill all the processes and threads created by parent
# TODO this isn't great. We should have each tests doing it's own cleanup
current_process = psutil.Process()
for child in current_process.children(recursive=True):
params = child.as_dict() or {}
cmd_lines = params.get("cmdline", [])
# we shouldn't kill the resource tracker from multiprocessing. If we do,
# `atexit` will throw as it uses resource tracker to try to clean up
if cmd_lines and "resource_tracker" in cmd_lines[-1]:
continue
child.kill()
main_thread = threading.current_thread()
for t in threading.enumerate():
if t is not main_thread:
t.join(0)
@pytest.fixture(scope="function", autouse=True)
def cleanup():
from lightning_app.utilities.app_helpers import _LightningAppRef
yield
_LightningAppRef._app_instance = None
shutil.rmtree("./storage", ignore_errors=True)
shutil.rmtree(_storage_root_dir(), ignore_errors=True)
shutil.rmtree("./.shared", ignore_errors=True)
if os.path.isfile(_APP_CONFIG_FILENAME):
os.remove(_APP_CONFIG_FILENAME)
_set_context(None)
@pytest.fixture(scope="function", autouse=True)
def clear_app_state_state_variables():
"""Resets global variables in order to prevent interference between tests."""
yield
import lightning_app.utilities.state
lightning_app.utilities.state._STATE = None
lightning_app.utilities.state._LAST_STATE = None
AppState._MY_AFFILIATION = ()
if hasattr(cloud_compute, "_CLOUD_COMPUTE_STORE"):
cloud_compute._CLOUD_COMPUTE_STORE.clear()

View File

@ -41,12 +41,21 @@ class LightningTestMultiNodeWorksApp(LightningTestApp):
return res
def test_multi_node_example_2():
@pytest.mark.parametrize(
"app_name",
[
"app_torch_work.py",
"app_generic_work.py",
# "app_lite_work.py",
# "app_pl_work.py": TODO Add once https://github.com/Lightning-AI/lightning/issues/15556 is resolved.
],
)
def test_multi_node_examples(app_name):
cwd = os.getcwd()
new_cwd = os.path.join(_PROJECT_ROOT, "examples/app_multi_node")
os.chdir(new_cwd)
command_line = [
"app_work.py",
app_name,
"--blocking",
"False",
"--open-ui",