lightning/tests/tests_app/utilities/test_app_helpers.py

108 lines
2.8 KiB
Python
Raw Normal View History

from unittest import mock
from unittest.mock import Mock
import pytest
from lightning_app import LightningFlow, LightningWork
from lightning_app.utilities.app_helpers import (
AppStatePlugin,
BaseStatePlugin,
InMemoryStateStore,
is_overridden,
StateStore,
)
from lightning_app.utilities.exceptions import LightningAppStateException
class Work(LightningWork):
def run(self):
pass
class Flow(LightningFlow):
def run(self):
pass
def test_is_overridden():
flow = Flow()
work = Work()
# edge cases
assert not is_overridden("whatever", None)
with pytest.raises(ValueError, match="Expected a parent"):
is_overridden("whatever", object())
assert not is_overridden("whatever", flow)
assert not is_overridden("whatever", flow, parent=Flow)
class TestFlow(LightningFlow):
def run(self):
pass
def foo(self):
pass
def bar(self):
return 1
with pytest.raises(ValueError, match="The parent should define the method"):
is_overridden("foo", TestFlow())
# normal usage
assert is_overridden("run", flow)
assert is_overridden("run", work)
# `Mock` support
mock = Mock(spec=Flow, wraps=flow)
assert is_overridden("run", mock)
mock = Mock(spec=LightningWork, wraps=work)
assert is_overridden("run", mock)
def test_simple_app_store():
store = InMemoryStateStore()
user_id = "1234"
store.add(user_id)
state = {"data": user_id}
store.set_app_state(user_id, state)
store.set_served_state(user_id, state)
store.set_served_session_id(user_id, user_id)
assert store.get_app_state(user_id) == state
assert store.get_served_state(user_id) == state
assert store.get_served_session_id(user_id) == user_id
store.remove(user_id)
assert isinstance(store, StateStore)
@mock.patch("lightning_app.utilities.app_helpers.APP_STATE_MAX_SIZE_BYTES", 120)
def test_simple_app_store_warning():
store = InMemoryStateStore()
user_id = "1234"
store.add(user_id)
state = {"data": "I'm a state that's larger than 120 bytes"}
with pytest.raises(LightningAppStateException, match="is larger than the"):
store.set_app_state(user_id, state)
def test_base_state_plugin():
class DummyStatePlugin(BaseStatePlugin):
def should_update_app(self, deep_diff):
super().should_update_app(deep_diff)
def get_context(self):
super().get_context()
def render_non_authorized(self):
super().render_non_authorized()
plugin = DummyStatePlugin()
plugin.should_update_app(None)
plugin.get_context()
plugin.render_non_authorized()
plugin = AppStatePlugin()
plugin.should_update_app(None)
plugin.get_context()
plugin.render_non_authorized()