[App] Fix `AutoScaler` trying to replicate multiple works in a single machine (#15991)

* dont try to replicate new works in the existing machine

* update chglog

* Update comment

* Update src/lightning_app/components/auto_scaler.py

* add test
This commit is contained in:
Akihiro Nitta 2022-12-11 09:56:46 +09:00 committed by GitHub
parent 9ed43c64b6
commit c1d0156e1d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 21 additions and 3 deletions

View File

@ -43,6 +43,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed Registration for CloudComputes of Works in `L.app.structures` ([#15964](https://github.com/Lightning-AI/lightning/pull/15964)) - Fixed Registration for CloudComputes of Works in `L.app.structures` ([#15964](https://github.com/Lightning-AI/lightning/pull/15964))
- Fixed `AutoScaler` raising an exception when non-default cloud compute is specified ([#15991](https://github.com/Lightning-AI/lightning/pull/15991))
## [1.8.4] - 2022-12-08 ## [1.8.4] - 2022-12-08
### Added ### Added

View File

@ -449,8 +449,15 @@ class AutoScaler(LightningFlow):
def create_work(self) -> LightningWork: def create_work(self) -> LightningWork:
"""Replicates a LightningWork instance with args and kwargs provided via ``__init__``.""" """Replicates a LightningWork instance with args and kwargs provided via ``__init__``."""
cloud_compute = self._work_kwargs.get("cloud_compute", None)
self._work_kwargs.update(
dict(
# TODO: Remove `start_with_flow=False` for faster initialization on the cloud # TODO: Remove `start_with_flow=False` for faster initialization on the cloud
self._work_kwargs.update(dict(start_with_flow=False)) start_with_flow=False,
# don't try to create multiple works in a single machine
cloud_compute=cloud_compute.clone() if cloud_compute else None,
)
)
return self._work_cls(*self._work_args, **self._work_kwargs) return self._work_cls(*self._work_args, **self._work_kwargs)
def add_work(self, work) -> str: def add_work(self, work) -> str:

View File

@ -3,7 +3,7 @@ from unittest.mock import patch
import pytest import pytest
from lightning_app import LightningWork from lightning_app import CloudCompute, LightningWork
from lightning_app.components import AutoScaler from lightning_app.components import AutoScaler
@ -90,3 +90,11 @@ def test_scale(replicas, metrics, expected_replicas):
) )
assert auto_scaler.scale(replicas, metrics) == expected_replicas assert auto_scaler.scale(replicas, metrics) == expected_replicas
def test_create_work_cloud_compute_cloned():
"""Test CloudCompute is cloned to avoid creating multiple works in a single machine."""
cloud_compute = CloudCompute("gpu")
auto_scaler = AutoScaler(EmptyWork, cloud_compute=cloud_compute)
_ = auto_scaler.create_work()
assert auto_scaler._work_kwargs["cloud_compute"] is not cloud_compute