diff --git a/rich/tree.py b/rich/tree.py index 5fd46fd4..66203e69 100644 --- a/rich/tree.py +++ b/rich/tree.py @@ -28,6 +28,7 @@ class Tree(JupyterMixin): guide_style: StyleType = "tree.line", expanded: bool = True, highlight: bool = False, + hide_root: bool = False, ) -> None: self.label = label self.style = style @@ -35,6 +36,7 @@ class Tree(JupyterMixin): self.children: List[Tree] = [] self.expanded = expanded self.highlight = highlight + self.hide_root = hide_root def add( self, @@ -105,6 +107,8 @@ class Tree(JupyterMixin): style_stack = StyleStack(get_style(self.style)) remove_guide_styles = Style(bold=False, underline2=False) + depth = 0 + while stack: stack_node = pop() try: @@ -123,7 +127,7 @@ class Tree(JupyterMixin): guide_style = guide_style_stack.current + get_style(node.guide_style) style = style_stack.current + get_style(node.style) - prefix = levels[1:] + prefix = levels[(2 if self.hide_root else 1) :] renderable_lines = console.render_lines( Styled(node.label, style), options.update( @@ -133,19 +137,21 @@ class Tree(JupyterMixin): height=None, ), ) - for first, line in loop_first(renderable_lines): - if prefix: - yield from _Segment.apply_style( - prefix, - style.background_style, - post_style=remove_guide_styles, - ) - yield from line - yield new_line - if first and prefix: - prefix[-1] = make_guide( - SPACE if last else CONTINUE, prefix[-1].style or null_style - ) + + if not (depth == 0 and self.hide_root): + for first, line in loop_first(renderable_lines): + if prefix: + yield from _Segment.apply_style( + prefix, + style.background_style, + post_style=remove_guide_styles, + ) + yield from line + yield new_line + if first and prefix: + prefix[-1] = make_guide( + SPACE if last else CONTINUE, prefix[-1].style or null_style + ) if node.expanded and node.children: levels[-1] = make_guide( @@ -157,6 +163,7 @@ class Tree(JupyterMixin): style_stack.push(get_style(node.style)) guide_style_stack.push(get_style(node.guide_style)) push(iter(loop_last(node.children))) + depth += 1 def __rich_measure__( self, console: "Console", options: "ConsoleOptions" @@ -222,7 +229,7 @@ class Segment(NamedTuple): """ ) - root = Tree("🌲 [b green]Rich Tree", highlight=True) + root = Tree("🌲 [b green]Rich Tree", highlight=True, hide_root=True) node = root.add(":file_folder: Renderables", guide_style="red") simple_node = node.add(":file_folder: [bold yellow]Atomic", guide_style="uu green") diff --git a/tests/test_tree.py b/tests/test_tree.py index babcc3cf..8fc070d3 100644 --- a/tests/test_tree.py +++ b/tests/test_tree.py @@ -94,6 +94,42 @@ def test_render_tree_win32(): assert result == expected +@pytest.mark.skipif(sys.platform == "win32", reason="different on Windows") +def test_render_tree_hide_root_non_win32(): + tree = Tree("foo", hide_root=True) + tree.add("bar", style="italic") + baz_tree = tree.add("baz", guide_style="bold red", style="on blue") + baz_tree.add("1") + baz_tree.add("2") + tree.add("egg") + + console = Console(width=20, force_terminal=True, color_system="standard") + console.begin_capture() + console.print(tree) + result = console.end_capture() + print(repr(result)) + expected = "\x1b[3mbar\x1b[0m\x1b[3m \x1b[0m\n\x1b[44mbaz\x1b[0m\x1b[44m \x1b[0m\n\x1b[31;44m┣━━ \x1b[0m\x1b[44m1\x1b[0m\x1b[44m \x1b[0m\n\x1b[31;44m┗━━ \x1b[0m\x1b[44m2\x1b[0m\x1b[44m \x1b[0m\negg \n" + assert result == expected + + +@pytest.mark.skipif(sys.platform != "win32", reason="Windows specific") +def test_render_tree_hide_root_win32(): + tree = Tree("foo", hide_root=True) + tree.add("bar", style="italic") + baz_tree = tree.add("baz", guide_style="bold red", style="on blue") + baz_tree.add("1") + baz_tree.add("2") + tree.add("egg") + + console = Console(width=20, force_terminal=True, color_system="standard") + console.begin_capture() + console.print(tree) + result = console.end_capture() + print(repr(result)) + expected = "\x1b[3mbar\x1b[0m\x1b[3m \x1b[0m\n\x1b[44mbaz\x1b[0m\x1b[44m \x1b[0m\n\x1b[31;44m├── \x1b[0m\x1b[44m1\x1b[0m\x1b[44m \x1b[0m\n\x1b[31;44m└── \x1b[0m\x1b[44m2\x1b[0m\x1b[44m \x1b[0m\negg \n" + assert result == expected + + def test_tree_measure(): tree = Tree("foo") tree.add("bar")