From c1d0156e1db09581c23414a904eace5c23253199 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Sun, 11 Dec 2022 09:56:46 +0900 Subject: [PATCH] [App] Fix `AutoScaler` trying to replicate multiple works in a single machine (#15991) * dont try to replicate new works in the existing machine * update chglog * Update comment * Update src/lightning_app/components/auto_scaler.py * add test --- src/lightning_app/CHANGELOG.md | 3 +++ src/lightning_app/components/auto_scaler.py | 11 +++++++++-- tests/tests_app/components/test_auto_scaler.py | 10 +++++++++- 3 files changed, 21 insertions(+), 3 deletions(-) diff --git a/src/lightning_app/CHANGELOG.md b/src/lightning_app/CHANGELOG.md index 7439d6a4be..5dc5ca769c 100644 --- a/src/lightning_app/CHANGELOG.md +++ b/src/lightning_app/CHANGELOG.md @@ -43,6 +43,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed Registration for CloudComputes of Works in `L.app.structures` ([#15964](https://github.com/Lightning-AI/lightning/pull/15964)) +- Fixed `AutoScaler` raising an exception when non-default cloud compute is specified ([#15991](https://github.com/Lightning-AI/lightning/pull/15991)) + + ## [1.8.4] - 2022-12-08 ### Added diff --git a/src/lightning_app/components/auto_scaler.py b/src/lightning_app/components/auto_scaler.py index fc6a1a8737..13948ba50a 100644 --- a/src/lightning_app/components/auto_scaler.py +++ b/src/lightning_app/components/auto_scaler.py @@ -449,8 +449,15 @@ class AutoScaler(LightningFlow): def create_work(self) -> LightningWork: """Replicates a LightningWork instance with args and kwargs provided via ``__init__``.""" - # TODO: Remove `start_with_flow=False` for faster initialization on the cloud - self._work_kwargs.update(dict(start_with_flow=False)) + cloud_compute = self._work_kwargs.get("cloud_compute", None) + self._work_kwargs.update( + dict( + # TODO: Remove `start_with_flow=False` for faster initialization on the cloud + start_with_flow=False, + # don't try to create multiple works in a single machine + cloud_compute=cloud_compute.clone() if cloud_compute else None, + ) + ) return self._work_cls(*self._work_args, **self._work_kwargs) def add_work(self, work) -> str: diff --git a/tests/tests_app/components/test_auto_scaler.py b/tests/tests_app/components/test_auto_scaler.py index 436c3517d0..672b05bbc9 100644 --- a/tests/tests_app/components/test_auto_scaler.py +++ b/tests/tests_app/components/test_auto_scaler.py @@ -3,7 +3,7 @@ from unittest.mock import patch import pytest -from lightning_app import LightningWork +from lightning_app import CloudCompute, LightningWork from lightning_app.components import AutoScaler @@ -90,3 +90,11 @@ def test_scale(replicas, metrics, expected_replicas): ) assert auto_scaler.scale(replicas, metrics) == expected_replicas + + +def test_create_work_cloud_compute_cloned(): + """Test CloudCompute is cloned to avoid creating multiple works in a single machine.""" + cloud_compute = CloudCompute("gpu") + auto_scaler = AutoScaler(EmptyWork, cloud_compute=cloud_compute) + _ = auto_scaler.create_work() + assert auto_scaler._work_kwargs["cloud_compute"] is not cloud_compute