Slightly safer multi node (#15538)

update

Co-authored-by: Luca Antiga <luca.antiga@gmail.com>
This commit is contained in:
thomas chaton 2022-11-05 02:05:11 +00:00 committed by GitHub
parent dcfaa065ab
commit d48aa03207
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 21 additions and 18 deletions

View File

@ -58,27 +58,30 @@ class MultiNode(LightningFlow):
self._cloud_compute = cloud_compute
self._work_args = work_args
self._work_kwargs = work_kwargs
self.has_initialized = False
self.has_started = False
def run(self) -> None:
# 1. Create & start the works
if not self.has_initialized:
for node_rank in range(self.nodes):
self.ws.append(
self._work_cls(
*self._work_args,
cloud_compute=self._cloud_compute,
**self._work_kwargs,
parallel=True,
)
)
# Starting node `node_rank`` ...
self.ws[-1].start()
self.has_initialized = True
if not self.has_started:
# 2. Wait for all machines to be started !
if all(w.status.stage == WorkStageStatus.STARTED for w in self.ws):
return
# 1. Create & start the works
if not self.ws:
for node_rank in range(self.nodes):
self.ws.append(
self._work_cls(
*self._work_args,
cloud_compute=self._cloud_compute,
**self._work_kwargs,
parallel=True,
)
)
# Starting node `node_rank`` ...
self.ws[-1].start()
# 2. Wait for all machines to be started !
if not all(w.status.stage == WorkStageStatus.STARTED for w in self.ws):
return
self.has_started = True
# Loop over all node machines
for node_rank in range(self.nodes):