[App] Fix `AutoScaler` trying to replicate multiple works in a single machine (#15991)
* dont try to replicate new works in the existing machine * update chglog * Update comment * Update src/lightning_app/components/auto_scaler.py * add test
This commit is contained in:
parent
9ed43c64b6
commit
c1d0156e1d
|
@ -43,6 +43,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
||||||
- Fixed Registration for CloudComputes of Works in `L.app.structures` ([#15964](https://github.com/Lightning-AI/lightning/pull/15964))
|
- Fixed Registration for CloudComputes of Works in `L.app.structures` ([#15964](https://github.com/Lightning-AI/lightning/pull/15964))
|
||||||
|
|
||||||
|
|
||||||
|
- Fixed `AutoScaler` raising an exception when non-default cloud compute is specified ([#15991](https://github.com/Lightning-AI/lightning/pull/15991))
|
||||||
|
|
||||||
|
|
||||||
## [1.8.4] - 2022-12-08
|
## [1.8.4] - 2022-12-08
|
||||||
|
|
||||||
### Added
|
### Added
|
||||||
|
|
|
@ -449,8 +449,15 @@ class AutoScaler(LightningFlow):
|
||||||
|
|
||||||
def create_work(self) -> LightningWork:
|
def create_work(self) -> LightningWork:
|
||||||
"""Replicates a LightningWork instance with args and kwargs provided via ``__init__``."""
|
"""Replicates a LightningWork instance with args and kwargs provided via ``__init__``."""
|
||||||
|
cloud_compute = self._work_kwargs.get("cloud_compute", None)
|
||||||
|
self._work_kwargs.update(
|
||||||
|
dict(
|
||||||
# TODO: Remove `start_with_flow=False` for faster initialization on the cloud
|
# TODO: Remove `start_with_flow=False` for faster initialization on the cloud
|
||||||
self._work_kwargs.update(dict(start_with_flow=False))
|
start_with_flow=False,
|
||||||
|
# don't try to create multiple works in a single machine
|
||||||
|
cloud_compute=cloud_compute.clone() if cloud_compute else None,
|
||||||
|
)
|
||||||
|
)
|
||||||
return self._work_cls(*self._work_args, **self._work_kwargs)
|
return self._work_cls(*self._work_args, **self._work_kwargs)
|
||||||
|
|
||||||
def add_work(self, work) -> str:
|
def add_work(self, work) -> str:
|
||||||
|
|
|
@ -3,7 +3,7 @@ from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from lightning_app import LightningWork
|
from lightning_app import CloudCompute, LightningWork
|
||||||
from lightning_app.components import AutoScaler
|
from lightning_app.components import AutoScaler
|
||||||
|
|
||||||
|
|
||||||
|
@ -90,3 +90,11 @@ def test_scale(replicas, metrics, expected_replicas):
|
||||||
)
|
)
|
||||||
|
|
||||||
assert auto_scaler.scale(replicas, metrics) == expected_replicas
|
assert auto_scaler.scale(replicas, metrics) == expected_replicas
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_work_cloud_compute_cloned():
|
||||||
|
"""Test CloudCompute is cloned to avoid creating multiple works in a single machine."""
|
||||||
|
cloud_compute = CloudCompute("gpu")
|
||||||
|
auto_scaler = AutoScaler(EmptyWork, cloud_compute=cloud_compute)
|
||||||
|
_ = auto_scaler.create_work()
|
||||||
|
assert auto_scaler._work_kwargs["cloud_compute"] is not cloud_compute
|
||||||
|
|
Loading…
Reference in New Issue