[App] Fix multi-node pytorch example CI (#15753)
This commit is contained in:
parent
1ffbe1bf1e
commit
bc797fd376
|
@ -246,7 +246,7 @@ subprojects:
|
|||
- ".github/workflows/ci-app-examples.yml"
|
||||
- "src/lightning_app/**"
|
||||
- "tests/tests_app_examples/**"
|
||||
- "examples/app_*"
|
||||
- "examples/app_*/**"
|
||||
- "requirements/app/**"
|
||||
- "setup.py"
|
||||
- ".actions/**"
|
||||
|
|
|
@ -11,7 +11,7 @@ on:
|
|||
- ".github/workflows/ci-app-examples.yml"
|
||||
- "src/lightning_app/**"
|
||||
- "tests/tests_app_examples/**"
|
||||
- "examples/app_*"
|
||||
- "examples/app_*/**"
|
||||
- "requirements/app/**"
|
||||
- "setup.py"
|
||||
- ".actions/**"
|
||||
|
|
|
@ -22,7 +22,7 @@ def distributed_train(local_rank: int, main_address: str, main_port: int, num_no
|
|||
# 2. PREPARE DISTRIBUTED MODEL
|
||||
model = torch.nn.Linear(32, 2)
|
||||
device = torch.device(f"cuda:{local_rank}") if torch.cuda.is_available() else torch.device("cpu")
|
||||
model = DistributedDataParallel(model, device_ids=[local_rank]).to(device)
|
||||
model = DistributedDataParallel(model, device_ids=[local_rank] if torch.cuda.is_available() else None).to(device)
|
||||
|
||||
# 3. SETUP LOSS AND OPTIMIZER
|
||||
criterion = torch.nn.MSELoss()
|
||||
|
|
|
@ -23,7 +23,7 @@ def distributed_train(local_rank: int, main_address: str, main_port: int, num_no
|
|||
# 2. PREPARE DISTRIBUTED MODEL
|
||||
model = torch.nn.Linear(32, 2)
|
||||
device = torch.device(f"cuda:{local_rank}") if torch.cuda.is_available() else torch.device("cpu")
|
||||
model = DistributedDataParallel(model, device_ids=[local_rank]).to(device)
|
||||
model = DistributedDataParallel(model, device_ids=[local_rank] if torch.cuda.is_available() else None).to(device)
|
||||
|
||||
# 3. SETUP LOSS AND OPTIMIZER
|
||||
criterion = torch.nn.MSELoss()
|
||||
|
@ -55,7 +55,7 @@ class PyTorchDistributed(L.LightningWork):
|
|||
)
|
||||
|
||||
|
||||
# 32 GPUs: (8 nodes x 4 v 100)
|
||||
# 8 GPUs: (2 nodes x 4 v 100)
|
||||
compute = L.CloudCompute("gpu-fast-multi") # 4xV100
|
||||
component = MultiNode(PyTorchDistributed, num_nodes=2, cloud_compute=compute)
|
||||
app = L.LightningApp(component)
|
||||
|
|
Loading…
Reference in New Issue