Slightly safer multi node (#15538)
update Co-authored-by: Luca Antiga <luca.antiga@gmail.com>
This commit is contained in:
parent
dcfaa065ab
commit
d48aa03207
|
@ -58,27 +58,30 @@ class MultiNode(LightningFlow):
|
|||
self._cloud_compute = cloud_compute
|
||||
self._work_args = work_args
|
||||
self._work_kwargs = work_kwargs
|
||||
self.has_initialized = False
|
||||
self.has_started = False
|
||||
|
||||
def run(self) -> None:
|
||||
# 1. Create & start the works
|
||||
if not self.has_initialized:
|
||||
for node_rank in range(self.nodes):
|
||||
self.ws.append(
|
||||
self._work_cls(
|
||||
*self._work_args,
|
||||
cloud_compute=self._cloud_compute,
|
||||
**self._work_kwargs,
|
||||
parallel=True,
|
||||
)
|
||||
)
|
||||
# Starting node `node_rank`` ...
|
||||
self.ws[-1].start()
|
||||
self.has_initialized = True
|
||||
if not self.has_started:
|
||||
|
||||
# 2. Wait for all machines to be started !
|
||||
if all(w.status.stage == WorkStageStatus.STARTED for w in self.ws):
|
||||
return
|
||||
# 1. Create & start the works
|
||||
if not self.ws:
|
||||
for node_rank in range(self.nodes):
|
||||
self.ws.append(
|
||||
self._work_cls(
|
||||
*self._work_args,
|
||||
cloud_compute=self._cloud_compute,
|
||||
**self._work_kwargs,
|
||||
parallel=True,
|
||||
)
|
||||
)
|
||||
# Starting node `node_rank`` ...
|
||||
self.ws[-1].start()
|
||||
|
||||
# 2. Wait for all machines to be started !
|
||||
if not all(w.status.stage == WorkStageStatus.STARTED for w in self.ws):
|
||||
return
|
||||
|
||||
self.has_started = True
|
||||
|
||||
# Loop over all node machines
|
||||
for node_rank in range(self.nodes):
|
||||
|
|
Loading…
Reference in New Issue