diff --git a/rich/markdown.py b/rich/markdown.py index e2402f72..0983c500 100644 --- a/rich/markdown.py +++ b/rich/markdown.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Dict, Iterable, List, Optional, Union +from typing import Any, Dict, Iterable, List, Optional, Union from commonmark.blocks import Parser @@ -13,22 +13,26 @@ from .console import ( StyledText, ) from .style import Style, StyleStack -from .text import Text +from .text import Lines, Text from ._stack import Stack class MarkdownElement: - def on_enter(self, context: MarkdownContext, node): + @classmethod + def create(cls, node: Any) -> MarkdownElement: + return cls() + + def on_enter(self, context: MarkdownContext): pass - def on_text(self, context: MarkdownContext, text: str,) -> RenderResult: + def on_text(self, context: MarkdownContext, text: str) -> None: pass - def on_leave(self, context: MarkdownContext): + def on_leave(self, context: MarkdownContext) -> RenderResult: return yield - def on_child_close(self, context: MarkdownContext, child: MarkdownElement): + def on_child_close(self, context: MarkdownContext, child: MarkdownElement) -> None: pass @@ -43,13 +47,13 @@ class TextElement(MarkdownElement): def __init__(self) -> None: self.text = Text() - def on_enter(self, context: MarkdownContext, node: Any) -> None: - context.enter_style(f"markdown.h{node.level}") + def on_enter(self, context: MarkdownContext) -> None: + context.enter_style(self.style_name) - def on_text(self, context: MarkdownContext, text: str): + def on_text(self, context: MarkdownContext, text: str) -> None: self.text.append(text, context.current_style) - def on_leave(self, context: MarkdownContext) -> Iterable[Text]: + def on_leave(self, context: MarkdownContext) -> Iterable[Lines]: context.leave_style() yield self.text @@ -60,20 +64,31 @@ class TextElement(MarkdownElement): class Paragraph(TextElement): style_name = "markdown.paragraph" - def on_leave(self, context: MarkdownContext) -> Iterable[Text]: + def on_leave(self, context: MarkdownContext) -> Iterable[Lines]: context.leave_style() lines = self.text.wrap(context.options.max_width) yield lines class Heading(TextElement): - def on_leave(self, context: MarkdownContext) -> Iterable[Text]: + @classmethod + def create(cls, node: Any) -> Heading: + heading = Heading(node.level) + return heading + + def __init__(self, level: int) -> None: + self.style_name = f"markdown.h{level}" + super().__init__() + + def on_leave(self, context: MarkdownContext) -> Iterable[Lines]: context.leave_style() lines = self.text.wrap(context.options.max_width, justify="center") yield lines class MarkdownContext: + """Manages the console render state.""" + def __init__(self, console: Console, options: ConsoleOptions) -> None: self.console = console self.options = options @@ -102,14 +117,14 @@ class Markdown: inlines = {"emph", "strong"} def __init__(self, markup: str) -> None: + """Parses the markup.""" self.markup = markup parser = Parser() self.parsed = parser.parse(markup) def __console__(self, console: Console, options: ConsoleOptions) -> RenderResult: - + """Render markdown to the console.""" context = MarkdownContext(console, options) - nodes = self.parsed.walker() for current, entering in nodes: @@ -129,9 +144,9 @@ class Markdown: element_class = self.elements.get(node_type) or UnknownElement if entering: - element = element_class() + element = element_class.create(current) context.stack.push(element) - element.on_enter(context, current) + element.on_enter(context) else: element = context.stack.pop() if context.stack: