diff --git a/CHANGELOG.md b/CHANGELOG.md index eb6b6262..efd84d66 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,4 @@ + # Changelog All notable changes to this project will be documented in this file. @@ -14,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added ProgressColumn `MofNCompleteColumn` to display raw `completed/total` column (similar to DownloadColumn, but displays values as ints, does not convert to floats or add bit/bytes units). https://github.com/Textualize/rich/pull/1941 +- Add support for namedtuples to `Pretty` https://github.com/Textualize/rich/pull/2031 ### Fixed diff --git a/rich/pretty.py b/rich/pretty.py index 216075af..57f3e62e 100644 --- a/rich/pretty.py +++ b/rich/pretty.py @@ -1,4 +1,5 @@ import builtins +import collections import dataclasses import inspect import os @@ -30,7 +31,6 @@ try: except ImportError: # pragma: no cover _attr_module = None # type: ignore - from . import get_console from ._loop import loop_last from ._pick import pick_bool @@ -79,6 +79,29 @@ def _is_dataclass_repr(obj: object) -> bool: return False +_dummy_namedtuple = collections.namedtuple("_dummy_namedtuple", []) + + +def _has_default_namedtuple_repr(obj: object) -> bool: + """Check if an instance of namedtuple contains the default repr + + Args: + obj (object): A namedtuple + + Returns: + bool: True if the default repr is used, False if there's a custom repr. + """ + obj_file = None + try: + obj_file = inspect.getfile(obj.__repr__) + except (OSError, TypeError): + # OSError handles case where object is defined in __main__ scope, e.g. REPL - no filename available. + # TypeError trapped defensively, in case of object without filename slips through. + pass + default_repr_file = inspect.getfile(_dummy_namedtuple.__repr__) + return obj_file == default_repr_file + + def _ipy_display_hook( value: Any, console: Optional["Console"] = None, @@ -383,6 +406,7 @@ class Node: empty: str = "" last: bool = False is_tuple: bool = False + is_namedtuple: bool = False children: Optional[List["Node"]] = None key_separator = ": " separator: str = ", " @@ -397,7 +421,7 @@ class Node: elif self.children is not None: if self.children: yield self.open_brace - if self.is_tuple and len(self.children) == 1: + if self.is_tuple and not self.is_namedtuple and len(self.children) == 1: yield from self.children[0].iter_tokens() yield "," else: @@ -524,6 +548,25 @@ class _Line: ) +def _is_namedtuple(obj: Any) -> bool: + """Checks if an object is most likely a namedtuple. It is possible + to craft an object that passes this check and isn't a namedtuple, but + there is only a minuscule chance of this happening unintentionally. + + Args: + obj (Any): The object to test + + Returns: + bool: True if the object is a namedtuple. False otherwise. + """ + try: + fields = getattr(obj, "_fields", None) + except Exception: + # Being very defensive - if we cannot get the attr then its not a namedtuple + return False + return isinstance(obj, tuple) and isinstance(fields, tuple) + + def traverse( _object: Any, max_length: Optional[int] = None, @@ -731,7 +774,25 @@ def traverse( append(child_node) pop_visited(obj_id) - + elif _is_namedtuple(obj) and _has_default_namedtuple_repr(obj): + if reached_max_depth: + node = Node(value_repr="...") + else: + children = [] + class_name = obj.__class__.__name__ + node = Node( + open_brace=f"{class_name}(", + close_brace=")", + children=children, + empty=f"{class_name}()", + ) + append = children.append + for last, (key, value) in loop_last(obj._asdict().items()): + child_node = _traverse(value, depth=depth + 1) + child_node.key_repr = key + child_node.last = last + child_node.key_separator = "=" + append(child_node) elif _safe_isinstance(obj, _CONTAINERS): for container_type in _CONTAINERS: if _safe_isinstance(obj, container_type): @@ -780,7 +841,7 @@ def traverse( child_node.last = index == last_item_index append(child_node) if max_length is not None and num_items > max_length: - append(Node(value_repr=f"... +{num_items-max_length}", last=True)) + append(Node(value_repr=f"... +{num_items - max_length}", last=True)) else: node = Node(empty=empty, children=[], last=root) @@ -788,6 +849,7 @@ def traverse( else: node = Node(value_repr=to_repr(obj), last=root) node.is_tuple = _safe_isinstance(obj, tuple) + node.is_namedtuple = _is_namedtuple(obj) return node node = _traverse(_object, root=True) @@ -878,6 +940,15 @@ if __name__ == "__main__": # pragma: no cover 1 / 0 return "this will fail" + from typing import NamedTuple + + class StockKeepingUnit(NamedTuple): + name: str + description: str + price: float + category: str + reviews: List[str] + d = defaultdict(int) d["foo"] = 5 data = { @@ -904,6 +975,13 @@ if __name__ == "__main__": # pragma: no cover ] ), "atomic": (False, True, None), + "namedtuple": StockKeepingUnit( + "Sparkling British Spring Water", + "Carbonated spring water", + 0.9, + "water", + ["its amazing!", "its terrible!"], + ), "Broken": BrokenRepr(), } data["foo"].append(data) # type: ignore diff --git a/tests/test_pretty.py b/tests/test_pretty.py index 02488f46..d45b0c22 100644 --- a/tests/test_pretty.py +++ b/tests/test_pretty.py @@ -1,9 +1,10 @@ +import collections import io import sys from array import array from collections import UserDict, defaultdict from dataclasses import dataclass, field -from typing import List +from typing import List, NamedTuple import attr import pytest @@ -169,6 +170,74 @@ def test_pretty_dataclass(): assert result == "ExampleDataclass(foo=1000, bar=..., baz=['foo', 'bar', 'baz'])" +class StockKeepingUnit(NamedTuple): + name: str + description: str + price: float + category: str + reviews: List[str] + + +def test_pretty_namedtuple(): + console = Console(color_system=None) + console.begin_capture() + + example_namedtuple = StockKeepingUnit( + "Sparkling British Spring Water", + "Carbonated spring water", + 0.9, + "water", + ["its amazing!", "its terrible!"], + ) + + result = pretty_repr(example_namedtuple) + + print(result) + assert ( + result + == """StockKeepingUnit( + name='Sparkling British Spring Water', + description='Carbonated spring water', + price=0.9, + category='water', + reviews=['its amazing!', 'its terrible!'] +)""" + ) + + +def test_pretty_namedtuple_length_one_no_trailing_comma(): + instance = collections.namedtuple("Thing", ["name"])(name="Bob") + assert pretty_repr(instance) == "Thing(name='Bob')" + + +def test_pretty_namedtuple_empty(): + instance = collections.namedtuple("Thing", [])() + assert pretty_repr(instance) == "Thing()" + + +def test_pretty_namedtuple_custom_repr(): + class Thing(NamedTuple): + def __repr__(self): + return "XX" + + assert pretty_repr(Thing()) == "XX" + + +def test_pretty_namedtuple_fields_invalid_type(): + class LooksLikeANamedTupleButIsnt(tuple): + _fields = "blah" + + instance = LooksLikeANamedTupleButIsnt() + result = pretty_repr(instance) + assert result == "()" # Treated as tuple + + +def test_pretty_namedtuple_max_depth(): + instance = {"unit": StockKeepingUnit("a", "b", 1.0, "c", ["d", "e"])} + result = pretty_repr(instance, max_depth=1) + assert result == "{'unit': ...}" + + def test_small_width(): test = ["Hello world! 12345"] result = pretty_repr(test, max_width=10)