2022-11-08 12:55:31 +00:00
|
|
|
import torch
|
|
|
|
from torch.nn.parallel.distributed import DistributedDataParallel
|
|
|
|
|
|
|
|
import lightning as L
|
|
|
|
from lightning.app.components import PyTorchSpawnMultiNode
|
|
|
|
|
|
|
|
|
|
|
|
class PyTorchDistributed(L.LightningWork):
|
|
|
|
def run(
|
2022-11-11 10:06:40 +00:00
|
|
|
self,
|
2022-11-08 12:55:31 +00:00
|
|
|
world_size: int,
|
|
|
|
node_rank: int,
|
|
|
|
global_rank: str,
|
|
|
|
local_rank: int,
|
|
|
|
):
|
2022-11-11 10:06:40 +00:00
|
|
|
# 1. Prepare the model
|
|
|
|
model = torch.nn.Sequential(
|
|
|
|
torch.nn.Linear(1, 1),
|
|
|
|
torch.nn.ReLU(),
|
|
|
|
torch.nn.Linear(1, 1),
|
|
|
|
)
|
2022-11-08 12:55:31 +00:00
|
|
|
|
2022-11-09 18:00:17 +00:00
|
|
|
# 2. Setup distributed training
|
2022-11-11 10:06:40 +00:00
|
|
|
device = torch.device(f"cuda:{local_rank}") if torch.cuda.is_available() else torch.device("cpu")
|
|
|
|
model = DistributedDataParallel(
|
|
|
|
model.to(device), device_ids=[local_rank] if torch.cuda.is_available() else None
|
|
|
|
)
|
2022-11-09 18:00:17 +00:00
|
|
|
|
|
|
|
# 3. Prepare loss and optimizer
|
2022-11-08 12:55:31 +00:00
|
|
|
criterion = torch.nn.MSELoss()
|
|
|
|
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
|
|
|
|
|
2022-11-11 10:06:40 +00:00
|
|
|
# 4. Train the model for 1000 steps.
|
|
|
|
for step in range(1000):
|
2022-11-08 12:55:31 +00:00
|
|
|
model.zero_grad()
|
2022-11-11 10:06:40 +00:00
|
|
|
x = torch.tensor([0.8]).to(device)
|
|
|
|
target = torch.tensor([1.0]).to(device)
|
2022-11-08 12:55:31 +00:00
|
|
|
output = model(x)
|
2022-11-11 10:06:40 +00:00
|
|
|
loss = criterion(output, target)
|
2022-11-08 12:55:31 +00:00
|
|
|
print(f"global_rank: {global_rank} step: {step} loss: {loss}")
|
|
|
|
loss.backward()
|
|
|
|
optimizer.step()
|
|
|
|
|
|
|
|
|
|
|
|
# Run over 2 nodes of 4 x V100
|
|
|
|
app = L.LightningApp(
|
|
|
|
PyTorchSpawnMultiNode(
|
|
|
|
PyTorchDistributed,
|
|
|
|
num_nodes=2,
|
|
|
|
cloud_compute=L.CloudCompute("gpu-fast-multi"), # 4 x V100
|
|
|
|
)
|
|
|
|
)
|