103 lines
2.8 KiB
Python
103 lines
2.8 KiB
Python
import pytest
|
|
from lightning.app import LightningFlow, LightningWork
|
|
from lightning.app.testing.helpers import EmptyFlow, EmptyWork
|
|
from lightning.app.utilities.tree import breadth_first
|
|
|
|
|
|
class LeafFlow(EmptyFlow):
|
|
pass
|
|
|
|
|
|
class LeafWork(EmptyWork):
|
|
pass
|
|
|
|
|
|
class SimpleFlowTree(EmptyFlow):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.simple_flow_left = LeafFlow()
|
|
self.simple_flow_right = LeafFlow()
|
|
|
|
|
|
class SimpleWorkTree(EmptyFlow):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.simple_work_left = LeafWork()
|
|
self.simple_work_right = LeafWork()
|
|
|
|
|
|
class MixedTree(EmptyFlow):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.mixed_left = SimpleFlowTree()
|
|
self.work_tree = SimpleWorkTree()
|
|
self.mixed_right = SimpleFlowTree()
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
("input_tree", "types", "expected_sequence"),
|
|
[
|
|
(LeafFlow(), (LightningFlow,), ["root"]),
|
|
(LeafWork(), (LightningFlow,), []),
|
|
(
|
|
SimpleFlowTree(),
|
|
(LightningFlow,),
|
|
[
|
|
"root",
|
|
"root.simple_flow_left",
|
|
"root.simple_flow_right",
|
|
],
|
|
),
|
|
(SimpleWorkTree(), (LightningFlow,), ["root"]),
|
|
(
|
|
SimpleWorkTree(),
|
|
(LightningFlow, LightningWork),
|
|
[
|
|
"root",
|
|
"root.simple_work_left",
|
|
"root.simple_work_right",
|
|
],
|
|
),
|
|
(
|
|
MixedTree(),
|
|
(LightningFlow,),
|
|
[
|
|
"root",
|
|
"root.mixed_left",
|
|
"root.mixed_right",
|
|
"root.work_tree",
|
|
"root.mixed_left.simple_flow_left",
|
|
"root.mixed_left.simple_flow_right",
|
|
"root.mixed_right.simple_flow_left",
|
|
"root.mixed_right.simple_flow_right",
|
|
],
|
|
),
|
|
(
|
|
MixedTree(),
|
|
(LightningWork,),
|
|
[
|
|
"root.work_tree.simple_work_left",
|
|
"root.work_tree.simple_work_right",
|
|
],
|
|
),
|
|
(
|
|
MixedTree(),
|
|
(LightningFlow, LightningWork),
|
|
[
|
|
"root",
|
|
"root.mixed_left",
|
|
"root.mixed_right",
|
|
"root.work_tree",
|
|
"root.mixed_left.simple_flow_left",
|
|
"root.mixed_left.simple_flow_right",
|
|
"root.mixed_right.simple_flow_left",
|
|
"root.mixed_right.simple_flow_right",
|
|
"root.work_tree.simple_work_left",
|
|
"root.work_tree.simple_work_right",
|
|
],
|
|
),
|
|
],
|
|
)
|
|
def test_breadth_first(input_tree, types, expected_sequence):
|
|
assert [node.name for node in breadth_first(input_tree, types=types)] == expected_sequence
|