[App] Fix `AutoScaler` trying to replicate multiple works in a single machine (#15991)
* dont try to replicate new works in the existing machine * update chglog * Update comment * Update src/lightning_app/components/auto_scaler.py * add test
This commit is contained in:
parent
9ed43c64b6
commit
c1d0156e1d
|
@ -43,6 +43,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Fixed Registration for CloudComputes of Works in `L.app.structures` ([#15964](https://github.com/Lightning-AI/lightning/pull/15964))
|
||||
|
||||
|
||||
- Fixed `AutoScaler` raising an exception when non-default cloud compute is specified ([#15991](https://github.com/Lightning-AI/lightning/pull/15991))
|
||||
|
||||
|
||||
## [1.8.4] - 2022-12-08
|
||||
|
||||
### Added
|
||||
|
|
|
@ -449,8 +449,15 @@ class AutoScaler(LightningFlow):
|
|||
|
||||
def create_work(self) -> LightningWork:
|
||||
"""Replicates a LightningWork instance with args and kwargs provided via ``__init__``."""
|
||||
# TODO: Remove `start_with_flow=False` for faster initialization on the cloud
|
||||
self._work_kwargs.update(dict(start_with_flow=False))
|
||||
cloud_compute = self._work_kwargs.get("cloud_compute", None)
|
||||
self._work_kwargs.update(
|
||||
dict(
|
||||
# TODO: Remove `start_with_flow=False` for faster initialization on the cloud
|
||||
start_with_flow=False,
|
||||
# don't try to create multiple works in a single machine
|
||||
cloud_compute=cloud_compute.clone() if cloud_compute else None,
|
||||
)
|
||||
)
|
||||
return self._work_cls(*self._work_args, **self._work_kwargs)
|
||||
|
||||
def add_work(self, work) -> str:
|
||||
|
|
|
@ -3,7 +3,7 @@ from unittest.mock import patch
|
|||
|
||||
import pytest
|
||||
|
||||
from lightning_app import LightningWork
|
||||
from lightning_app import CloudCompute, LightningWork
|
||||
from lightning_app.components import AutoScaler
|
||||
|
||||
|
||||
|
@ -90,3 +90,11 @@ def test_scale(replicas, metrics, expected_replicas):
|
|||
)
|
||||
|
||||
assert auto_scaler.scale(replicas, metrics) == expected_replicas
|
||||
|
||||
|
||||
def test_create_work_cloud_compute_cloned():
|
||||
"""Test CloudCompute is cloned to avoid creating multiple works in a single machine."""
|
||||
cloud_compute = CloudCompute("gpu")
|
||||
auto_scaler = AutoScaler(EmptyWork, cloud_compute=cloud_compute)
|
||||
_ = auto_scaler.create_work()
|
||||
assert auto_scaler._work_kwargs["cloud_compute"] is not cloud_compute
|
||||
|
|
Loading…
Reference in New Issue