lightning/examples/app_multi_node/train_pytorch.py

62 lines
2.3 KiB
Python
Raw Normal View History

# app.py
# ! pip install torch
import torch
from torch.nn.parallel.distributed import DistributedDataParallel
import lightning as L
from lightning.app.components import MultiNode
def distributed_train(local_rank: int, main_address: str, main_port: int, num_nodes: int, node_rank: int, nprocs: int):
# 1. SET UP DISTRIBUTED ENVIRONMENT
global_rank = local_rank + node_rank * nprocs
world_size = num_nodes * nprocs
if torch.distributed.is_available() and not torch.distributed.is_initialized():
torch.distributed.init_process_group(
"nccl" if torch.cuda.is_available() else "gloo",
rank=global_rank,
world_size=world_size,
init_method=f"tcp://{main_address}:{main_port}",
)
# 2. PREPARE DISTRIBUTED MODEL
model = torch.nn.Linear(32, 2)
device = torch.device(f"cuda:{local_rank}") if torch.cuda.is_available() else torch.device("cpu")
model = DistributedDataParallel(model, device_ids=[local_rank] if torch.cuda.is_available() else None).to(device)
# 3. SETUP LOSS AND OPTIMIZER
criterion = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
# 4.TRAIN THE MODEL FOR 50 STEPS
for step in range(50):
model.zero_grad()
x = torch.randn(64, 32).to(device)
output = model(x)
loss = criterion(output, torch.ones_like(output))
print(f"global_rank: {global_rank} step: {step} loss: {loss}")
loss.backward()
optimizer.step()
# 5. VERIFY ALL COPIES OF THE MODEL HAVE THE SAME WEIGTHS AT END OF TRAINING
weight = model.module.weight.clone()
torch.distributed.all_reduce(weight)
assert torch.equal(model.module.weight, weight / world_size)
print("Multi Node Distributed Training Done!")
class PyTorchDistributed(L.LightningWork):
def run(self, main_address: str, main_port: int, num_nodes: int, node_rank: int):
nprocs = torch.cuda.device_count() if torch.cuda.is_available() else 1
torch.multiprocessing.spawn(
distributed_train, args=(main_address, main_port, num_nodes, node_rank, nprocs), nprocs=nprocs
)
# 8 GPUs: (2 nodes x 4 v 100)
compute = L.CloudCompute("gpu-fast-multi") # 4xV100
component = MultiNode(PyTorchDistributed, num_nodes=2, cloud_compute=compute)
app = L.LightningApp(component)