diff --git a/src/lightning_app/components/multi_node.py b/src/lightning_app/components/multi_node.py index 3d308b83c3..66bebb76a4 100644 --- a/src/lightning_app/components/multi_node.py +++ b/src/lightning_app/components/multi_node.py @@ -58,27 +58,30 @@ class MultiNode(LightningFlow): self._cloud_compute = cloud_compute self._work_args = work_args self._work_kwargs = work_kwargs - self.has_initialized = False + self.has_started = False def run(self) -> None: - # 1. Create & start the works - if not self.has_initialized: - for node_rank in range(self.nodes): - self.ws.append( - self._work_cls( - *self._work_args, - cloud_compute=self._cloud_compute, - **self._work_kwargs, - parallel=True, - ) - ) - # Starting node `node_rank`` ... - self.ws[-1].start() - self.has_initialized = True + if not self.has_started: - # 2. Wait for all machines to be started ! - if all(w.status.stage == WorkStageStatus.STARTED for w in self.ws): - return + # 1. Create & start the works + if not self.ws: + for node_rank in range(self.nodes): + self.ws.append( + self._work_cls( + *self._work_args, + cloud_compute=self._cloud_compute, + **self._work_kwargs, + parallel=True, + ) + ) + # Starting node `node_rank`` ... + self.ws[-1].start() + + # 2. Wait for all machines to be started ! + if not all(w.status.stage == WorkStageStatus.STARTED for w in self.ws): + return + + self.has_started = True # Loop over all node machines for node_rank in range(self.nodes):