try edit import in overwrite

This commit is contained in:
Jirka 2022-10-24 17:55:40 +02:00
parent cbb16587f1
commit c02f766521
1 changed files with 12 additions and 12 deletions

View File

@ -16,7 +16,7 @@ from typing import Any, Callable, Dict, Generator, List, Mapping, Optional, Tupl
import websockets import websockets
from deepdiff import Delta from deepdiff import Delta
import lightning_app from lightning_app import _logger, LightningWork, structures
from lightning_app.core.constants import APP_SERVER_PORT, APP_STATE_MAX_SIZE_BYTES, SUPPORTED_PRIMITIVE_TYPES from lightning_app.core.constants import APP_SERVER_PORT, APP_STATE_MAX_SIZE_BYTES, SUPPORTED_PRIMITIVE_TYPES
from lightning_app.utilities.exceptions import LightningAppStateException from lightning_app.utilities.exceptions import LightningAppStateException
@ -233,10 +233,10 @@ def is_overridden(method_name: str, instance: Optional[object] = None, parent: O
if instance is None: if instance is None:
return False return False
if parent is None: if parent is None:
if isinstance(instance, lightning_app.LightningFlow): if isinstance(instance, LightningFlow):
parent = lightning_app.LightningFlow parent = LightningFlow
elif isinstance(instance, lightning_app.LightningWork): elif isinstance(instance, LightningWork):
parent = lightning_app.LightningWork parent = LightningWork
if parent is None: if parent is None:
raise ValueError("Expected a parent") raise ValueError("Expected a parent")
from lightning_utilities.core.overrides import is_overridden from lightning_utilities.core.overrides import is_overridden
@ -263,7 +263,7 @@ def _set_child_name(component: "Component", child: "Component", new_name: str) -
child._name = child_name child._name = child_name
# the name changed, so recursively update the names of the children of this child # the name changed, so recursively update the names of the children of this child
if isinstance(child, lightning_app.core.LightningFlow): if isinstance(child, LightningFlow):
for n, c in child.flows.items(): for n, c in child.flows.items():
_set_child_name(child, c, n) _set_child_name(child, c, n)
for n, w in child.named_works(recurse=False): for n, w in child.named_works(recurse=False):
@ -271,10 +271,10 @@ def _set_child_name(component: "Component", child: "Component", new_name: str) -
for n in child._structures: for n in child._structures:
s = getattr(child, n) s = getattr(child, n)
_set_child_name(child, s, n) _set_child_name(child, s, n)
if isinstance(child, lightning_app.structures.Dict): if isinstance(child, structures.Dict):
for n, c in child.items(): for n, c in child.items():
_set_child_name(child, c, n) _set_child_name(child, c, n)
if isinstance(child, lightning_app.structures.List): if isinstance(child, structures.List):
for c in child: for c in child:
_set_child_name(child, c, c.name.split(".")[-1]) _set_child_name(child, c, c.name.split(".")[-1])
@ -290,13 +290,13 @@ def _delta_to_app_state_delta(root: "LightningFlow", component: "Component", del
new_prefix = "root" new_prefix = "root"
for p, c in _walk_to_component(root, component): for p, c in _walk_to_component(root, component):
if isinstance(c, lightning_app.core.LightningWork): if isinstance(c, LightningWork):
new_prefix += "['works']" new_prefix += "['works']"
if isinstance(c, lightning_app.core.LightningFlow): if isinstance(c, LightningFlow):
new_prefix += "['flows']" new_prefix += "['flows']"
if isinstance(c, (lightning_app.structures.Dict, lightning_app.structures.List)): if isinstance(c, (structures.Dict, structures.List)):
new_prefix += "['structures']" new_prefix += "['structures']"
c_n = c.name.split(".")[-1] c_n = c.name.split(".")[-1]
@ -341,7 +341,7 @@ def _collect_child_process_pids(pid: int) -> List[int]:
def _print_to_logger_info(*args, **kwargs): def _print_to_logger_info(*args, **kwargs):
# TODO Find a better way to re-direct print to loggers. # TODO Find a better way to re-direct print to loggers.
lightning_app._logger.info(" ".join([str(v) for v in args])) _logger.info(" ".join([str(v) for v in args]))
def convert_print_to_logger_info(func: Callable) -> Callable: def convert_print_to_logger_info(func: Callable) -> Callable: