39 lines
973 B
Python
39 lines
973 B
Python
import os
|
|
|
|
import lightning as L
|
|
from lightning.app.components import MultiNode
|
|
from lightning.pytorch.demos.boring_classes import BoringModel
|
|
|
|
|
|
class PyTorchLightningDistributed(L.LightningWork):
|
|
def run(
|
|
self,
|
|
main_address: str,
|
|
main_port: int,
|
|
num_nodes: int,
|
|
node_rank: int,
|
|
):
|
|
os.environ["MASTER_ADDR"] = main_address
|
|
os.environ["MASTER_PORT"] = str(main_port)
|
|
os.environ["NODE_RANK"] = str(node_rank)
|
|
|
|
model = BoringModel()
|
|
trainer = L.Trainer(
|
|
max_epochs=10,
|
|
devices="auto",
|
|
accelerator="auto",
|
|
num_nodes=num_nodes,
|
|
strategy="ddp_spawn", # Only spawn based strategies are supported for now.
|
|
)
|
|
trainer.fit(model)
|
|
|
|
|
|
compute = L.CloudCompute("gpu-fast-multi") # 4xV100
|
|
app = L.LightningApp(
|
|
MultiNode(
|
|
PyTorchLightningDistributed,
|
|
num_nodes=2,
|
|
cloud_compute=compute,
|
|
)
|
|
)
|