lightning/tests/tests_app/utilities/test_tree.py

103 lines
2.8 KiB
Python

import pytest
from lightning.app import LightningFlow, LightningWork
from lightning.app.testing.helpers import EmptyFlow, EmptyWork
from lightning.app.utilities.tree import breadth_first
class LeafFlow(EmptyFlow):
pass
class LeafWork(EmptyWork):
pass
class SimpleFlowTree(EmptyFlow):
def __init__(self):
super().__init__()
self.simple_flow_left = LeafFlow()
self.simple_flow_right = LeafFlow()
class SimpleWorkTree(EmptyFlow):
def __init__(self):
super().__init__()
self.simple_work_left = LeafWork()
self.simple_work_right = LeafWork()
class MixedTree(EmptyFlow):
def __init__(self):
super().__init__()
self.mixed_left = SimpleFlowTree()
self.work_tree = SimpleWorkTree()
self.mixed_right = SimpleFlowTree()
@pytest.mark.parametrize(
("input_tree", "types", "expected_sequence"),
[
(LeafFlow(), (LightningFlow,), ["root"]),
(LeafWork(), (LightningFlow,), []),
(
SimpleFlowTree(),
(LightningFlow,),
[
"root",
"root.simple_flow_left",
"root.simple_flow_right",
],
),
(SimpleWorkTree(), (LightningFlow,), ["root"]),
(
SimpleWorkTree(),
(LightningFlow, LightningWork),
[
"root",
"root.simple_work_left",
"root.simple_work_right",
],
),
(
MixedTree(),
(LightningFlow,),
[
"root",
"root.mixed_left",
"root.mixed_right",
"root.work_tree",
"root.mixed_left.simple_flow_left",
"root.mixed_left.simple_flow_right",
"root.mixed_right.simple_flow_left",
"root.mixed_right.simple_flow_right",
],
),
(
MixedTree(),
(LightningWork,),
[
"root.work_tree.simple_work_left",
"root.work_tree.simple_work_right",
],
),
(
MixedTree(),
(LightningFlow, LightningWork),
[
"root",
"root.mixed_left",
"root.mixed_right",
"root.work_tree",
"root.mixed_left.simple_flow_left",
"root.mixed_left.simple_flow_right",
"root.mixed_right.simple_flow_left",
"root.mixed_right.simple_flow_right",
"root.work_tree.simple_work_left",
"root.work_tree.simple_work_right",
],
),
],
)
def test_breadth_first(input_tree, types, expected_sequence):
assert [node.name for node in breadth_first(input_tree, types=types)] == expected_sequence