diff --git a/src/pipdeptree/__main__.py b/src/pipdeptree/__main__.py index 5bf3f65..ae9d63a 100644 --- a/src/pipdeptree/__main__.py +++ b/src/pipdeptree/__main__.py @@ -10,19 +10,26 @@ from pipdeptree._discovery import get_installed_distributions from pipdeptree._models import PackageDAG from pipdeptree._render import render from pipdeptree._validate import validate +from pipdeptree._warning import WarningPrinter, WarningType, get_warning_printer def main(args: Sequence[str] | None = None) -> None | int: """CLI - The main function called as entry point.""" options = get_options(args) + # Warnings are only enabled when using text output. + is_text_output = not any([options.json, options.json_tree, options.output_format]) + if not is_text_output: + options.warn = WarningType.SILENCE + warning_printer = get_warning_printer() + warning_printer.warning_type = options.warn + pkgs = get_installed_distributions( interpreter=options.python, local_only=options.local_only, user_only=options.user_only ) tree = PackageDAG.from_pkgs(pkgs) - is_text_output = not any([options.json, options.json_tree, options.output_format]) - return_code = validate(options, is_text_output, tree) + validate(tree) # Reverse the tree (if applicable) before filtering, thus ensuring, that the filter will be applied on ReverseTree if options.reverse: @@ -35,14 +42,17 @@ def main(args: Sequence[str] | None = None) -> None | int: try: tree = tree.filter_nodes(show_only, exclude) except ValueError as e: - if options.warn in {"suppress", "fail"}: - print(e, file=sys.stderr) # noqa: T201 - return_code |= 1 if options.warn == "fail" else 0 - return return_code + if warning_printer.should_warn(): + warning_printer.print_single_line(str(e)) + return _determine_return_code(warning_printer) render(options, tree) - return return_code + return _determine_return_code(warning_printer) + + +def _determine_return_code(warning_printer: WarningPrinter) -> int: + return 1 if warning_printer.has_warned_with_failure() else 0 if __name__ == "__main__": diff --git a/src/pipdeptree/_cli.py b/src/pipdeptree/_cli.py index 2cf43c1..9359b1a 100644 --- a/src/pipdeptree/_cli.py +++ b/src/pipdeptree/_cli.py @@ -1,14 +1,14 @@ from __future__ import annotations +import enum import sys -from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser, Namespace -from typing import TYPE_CHECKING, Sequence, cast +from argparse import Action, ArgumentDefaultsHelpFormatter, ArgumentParser, Namespace +from typing import Any, Sequence, cast + +from pipdeptree._warning import WarningType from .version import __version__ -if TYPE_CHECKING: - from typing import Literal - class Options(Namespace): freeze: bool @@ -16,7 +16,7 @@ class Options(Namespace): all: bool local_only: bool user_only: bool - warn: Literal["silence", "suppress", "fail"] + warn: WarningType reverse: bool packages: str exclude: str @@ -40,11 +40,11 @@ def build_parser() -> ArgumentParser: parser.add_argument( "-w", "--warn", - action="store", dest="warn", + type=WarningType, nargs="?", default="suppress", - choices=("silence", "suppress", "fail"), + action=EnumAction, help=( "warning control: suppress will show warnings but return 0 whether or not they are present; silence will " "not show warnings at all and always return 0; fail will show warnings and return 1 if any are present" @@ -154,6 +154,71 @@ def get_options(args: Sequence[str] | None) -> Options: return cast(Options, parsed_args) +class EnumAction(Action): + """ + Generic action that exists to convert a string into a Enum value that is then added into a `Namespace` object. + + This custom action exists because argparse doesn't have support for enums. + + References + ---------- + - https://github.com/python/cpython/issues/69247#issuecomment-1308082792 + - https://docs.python.org/3/library/argparse.html#action-classes + + """ + + def __init__( # noqa: PLR0913, PLR0917 + self, + option_strings: list[str], + dest: str, + nargs: str | None = None, + const: Any | None = None, + default: Any | None = None, + type: Any | None = None, # noqa: A002 + choices: Any | None = None, + required: bool = False, # noqa: FBT001, FBT002 + help: str | None = None, # noqa: A002 + metavar: str | None = None, + ) -> None: + if not type or not issubclass(type, enum.Enum): + msg = "type must be a subclass of Enum" + raise TypeError(msg) + if not isinstance(default, str): + msg = "default must be defined with a string value" + raise TypeError(msg) + + choices = tuple(e.name.lower() for e in type) + if default not in choices: + msg = "default value should be among the enum choices" + raise ValueError(msg) + + super().__init__( + option_strings=option_strings, + dest=dest, + nargs=nargs, + const=const, + default=default, + type=None, # We return None here so that we default to str. + choices=choices, + required=required, + help=help, + metavar=metavar, + ) + + self._enum = type + + def __call__( + self, + parser: ArgumentParser, # noqa: ARG002 + namespace: Namespace, + value: Any, + option_string: str | None = None, # noqa: ARG002 + ) -> None: + value = value or self.default + value = next(e for e in self._enum if e.name.lower() == value) + setattr(namespace, self.dest, value) + + __all__ = [ "Options", "get_options", diff --git a/src/pipdeptree/_discovery.py b/src/pipdeptree/_discovery.py index 56e24fd..ffc0b22 100644 --- a/src/pipdeptree/_discovery.py +++ b/src/pipdeptree/_discovery.py @@ -10,6 +10,8 @@ from typing import Iterable, Tuple from packaging.utils import canonicalize_name +from pipdeptree._warning import get_warning_printer + def get_installed_distributions( interpreter: str = str(sys.executable), @@ -42,6 +44,8 @@ def get_installed_distributions( else: original_dists = distributions() + warning_printer = get_warning_printer() + # Since importlib.metadata.distributions() can return duplicate packages, we need to handle this. pip's approach is # to keep track of each package metadata it finds, and if it encounters one again it will simply just ignore it. We # take it one step further and warn the user that there are duplicate packages in their environment. @@ -55,11 +59,17 @@ def get_installed_distributions( seen_dists[normalized_name] = dist dists.append(dist) continue - already_seen_dists = first_seen_to_already_seen_dists_dict.setdefault(seen_dists[normalized_name], []) - already_seen_dists.append(dist) + if warning_printer.should_warn(): + already_seen_dists = first_seen_to_already_seen_dists_dict.setdefault(seen_dists[normalized_name], []) + already_seen_dists.append(dist) - if first_seen_to_already_seen_dists_dict: - render_duplicated_dist_metadata_text(first_seen_to_already_seen_dists_dict) + should_print_warning = warning_printer.should_warn() and first_seen_to_already_seen_dists_dict + if should_print_warning: + warning_printer.print_multi_line( + "Duplicate package metadata found", + lambda: render_duplicated_dist_metadata_text(first_seen_to_already_seen_dists_dict), + ignore_fail=True, + ) return dists @@ -77,7 +87,6 @@ def render_duplicated_dist_metadata_text( dist_list = entries_to_pairs_dict.setdefault(entry, []) dist_list.append((first_seen, dist)) - print("Warning!!! Duplicate package metadata found:", file=sys.stderr) # noqa: T201 for entry, pairs in entries_to_pairs_dict.items(): print(f'"{entry}"', file=sys.stderr) # noqa: T201 for first_seen, dist in pairs: @@ -88,7 +97,6 @@ def render_duplicated_dist_metadata_text( ), file=sys.stderr, ) - print("-" * 72, file=sys.stderr) # noqa: T201 __all__ = [ diff --git a/src/pipdeptree/_models/dag.py b/src/pipdeptree/_models/dag.py index 1410aeb..2f5df60 100644 --- a/src/pipdeptree/_models/dag.py +++ b/src/pipdeptree/_models/dag.py @@ -12,20 +12,17 @@ if TYPE_CHECKING: from importlib.metadata import Distribution +from pipdeptree._warning import get_warning_printer + from .package import DistPackage, InvalidRequirementError, ReqPackage -def render_invalid_reqs_text_if_necessary(dist_name_to_invalid_reqs_dict: dict[str, list[str]]) -> None: - if not dist_name_to_invalid_reqs_dict: - return - - print("Warning!!! Invalid requirement strings found for the following distributions:", file=sys.stderr) # noqa: T201 +def render_invalid_reqs_text(dist_name_to_invalid_reqs_dict: dict[str, list[str]]) -> None: for dist_name, invalid_reqs in dist_name_to_invalid_reqs_dict.items(): print(dist_name, file=sys.stderr) # noqa: T201 for invalid_req in invalid_reqs: print(f' Skipping "{invalid_req}"', file=sys.stderr) # noqa: T201 - print("-" * 72, file=sys.stderr) # noqa: T201 class PackageDAG(Mapping[DistPackage, List[ReqPackage]]): @@ -53,6 +50,7 @@ class PackageDAG(Mapping[DistPackage, List[ReqPackage]]): @classmethod def from_pkgs(cls, pkgs: list[Distribution]) -> PackageDAG: + warning_printer = get_warning_printer() dist_pkgs = [DistPackage(p) for p in pkgs] idx = {p.key: p for p in dist_pkgs} m: dict[DistPackage, list[ReqPackage]] = {} @@ -65,7 +63,8 @@ class PackageDAG(Mapping[DistPackage, List[ReqPackage]]): req = next(requires_iterator) except InvalidRequirementError as err: # We can't work with invalid requirement strings. Let's warn the user about them. - dist_name_to_invalid_reqs_dict.setdefault(p.project_name, []).append(str(err)) + if warning_printer.should_warn(): + dist_name_to_invalid_reqs_dict.setdefault(p.project_name, []).append(str(err)) continue except StopIteration: break @@ -78,7 +77,12 @@ class PackageDAG(Mapping[DistPackage, List[ReqPackage]]): reqs.append(pkg) m[p] = reqs - render_invalid_reqs_text_if_necessary(dist_name_to_invalid_reqs_dict) + should_print_warning = warning_printer.should_warn() and dist_name_to_invalid_reqs_dict + if should_print_warning: + warning_printer.print_multi_line( + "Invalid requirement strings found for the following distributions", + lambda: render_invalid_reqs_text(dist_name_to_invalid_reqs_dict), + ) return cls(m) diff --git a/src/pipdeptree/_validate.py b/src/pipdeptree/_validate.py index 646afd3..15eae51 100644 --- a/src/pipdeptree/_validate.py +++ b/src/pipdeptree/_validate.py @@ -4,30 +4,28 @@ import sys from collections import defaultdict from typing import TYPE_CHECKING +from pipdeptree._warning import get_warning_printer + if TYPE_CHECKING: from pipdeptree._models.package import Package - from ._cli import Options from ._models import DistPackage, PackageDAG, ReqPackage -def validate(args: Options, is_text_output: bool, tree: PackageDAG) -> int: # noqa: FBT001 +def validate(tree: PackageDAG) -> None: # Before any reversing or filtering, show warnings to console, about possibly conflicting or cyclic deps if found # and warnings are enabled (i.e. only if output is to be printed to console) - if is_text_output and args.warn != "silence": + warning_printer = get_warning_printer() + if warning_printer.should_warn(): conflicts = conflicting_deps(tree) if conflicts: - render_conflicts_text(conflicts) - print("-" * 72, file=sys.stderr) # noqa: T201 + warning_printer.print_multi_line( + "Possibly conflicting dependencies found", lambda: render_conflicts_text(conflicts) + ) cycles = cyclic_deps(tree) if cycles: - render_cycles_text(cycles) - print("-" * 72, file=sys.stderr) # noqa: T201 - - if args.warn == "fail" and (conflicts or cycles): - return 1 - return 0 + warning_printer.print_multi_line("Cyclic dependencies found", lambda: render_cycles_text(cycles)) def conflicting_deps(tree: PackageDAG) -> dict[DistPackage, list[ReqPackage]]: @@ -50,16 +48,14 @@ def conflicting_deps(tree: PackageDAG) -> dict[DistPackage, list[ReqPackage]]: def render_conflicts_text(conflicts: dict[DistPackage, list[ReqPackage]]) -> None: - if conflicts: - print("Warning!!! Possibly conflicting dependencies found:", file=sys.stderr) # noqa: T201 - # Enforce alphabetical order when listing conflicts - pkgs = sorted(conflicts.keys()) - for p in pkgs: - pkg = p.render_as_root(frozen=False) - print(f"* {pkg}", file=sys.stderr) # noqa: T201 - for req in conflicts[p]: - req_str = req.render_as_branch(frozen=False) - print(f" - {req_str}", file=sys.stderr) # noqa: T201 + # Enforce alphabetical order when listing conflicts + pkgs = sorted(conflicts.keys()) + for p in pkgs: + pkg = p.render_as_root(frozen=False) + print(f"* {pkg}", file=sys.stderr) # noqa: T201 + for req in conflicts[p]: + req_str = req.render_as_branch(frozen=False) + print(f" - {req_str}", file=sys.stderr) # noqa: T201 def cyclic_deps(tree: PackageDAG) -> list[list[Package]]: @@ -104,20 +100,18 @@ def cyclic_deps(tree: PackageDAG) -> list[list[Package]]: def render_cycles_text(cycles: list[list[Package]]) -> None: - if cycles: - print("Warning!! Cyclic dependencies found:", file=sys.stderr) # noqa: T201 - # List in alphabetical order the dependency that caused the cycle (i.e. the second-to-last Package element) - cycles = sorted(cycles, key=lambda c: c[len(c) - 2].key) - for cycle in cycles: - print("*", end=" ", file=sys.stderr) # noqa: T201 + # List in alphabetical order the dependency that caused the cycle (i.e. the second-to-last Package element) + cycles = sorted(cycles, key=lambda c: c[len(c) - 2].key) + for cycle in cycles: + print("*", end=" ", file=sys.stderr) # noqa: T201 - size = len(cycle) - 1 - for idx, pkg in enumerate(cycle): - if idx == size: - print(f"{pkg.project_name}", end="", file=sys.stderr) # noqa: T201 - else: - print(f"{pkg.project_name} =>", end=" ", file=sys.stderr) # noqa: T201 - print(file=sys.stderr) # noqa: T201 + size = len(cycle) - 1 + for idx, pkg in enumerate(cycle): + if idx == size: + print(f"{pkg.project_name}", end="", file=sys.stderr) # noqa: T201 + else: + print(f"{pkg.project_name} =>", end=" ", file=sys.stderr) # noqa: T201 + print(file=sys.stderr) # noqa: T201 __all__ = [ diff --git a/src/pipdeptree/_warning.py b/src/pipdeptree/_warning.py new file mode 100644 index 0000000..c3b1a58 --- /dev/null +++ b/src/pipdeptree/_warning.py @@ -0,0 +1,60 @@ +from __future__ import annotations + +import sys +from enum import Enum +from typing import Callable + +WarningType = Enum("WarningType", ["SILENCE", "SUPPRESS", "FAIL"]) + + +class WarningPrinter: + """Non-thread safe class that handles printing warning logic.""" + + def __init__(self, warning_type: WarningType = WarningType.SUPPRESS) -> None: + self._warning_type = warning_type + self._has_warned = False + + @property + def warning_type(self) -> WarningType: + return self._warning_type + + @warning_type.setter + def warning_type(self, new_warning_type: WarningType) -> None: + self._warning_type = new_warning_type + + def should_warn(self) -> bool: + return self._warning_type != WarningType.SILENCE + + def has_warned_with_failure(self) -> bool: + return self._has_warned and self.warning_type == WarningType.FAIL + + def print_single_line(self, line: str) -> None: + self._has_warned = True + print(line, file=sys.stderr) # noqa: T201 + + def print_multi_line(self, summary: str, print_func: Callable[[], None], ignore_fail: bool = False) -> None: # noqa: FBT001, FBT002 + """ + Print a multi-line warning, delegating most of the printing logic to the caller. + + :param summary: a summary of the warning + :param print_func: a callback that the caller passes that performs most of the multi-line printing + :param ignore_fail: if True, this warning won't be a fail when `self.warning_type == WarningType.FAIL` + """ + print(f"Warning!!! {summary}:", file=sys.stderr) # noqa: T201 + print_func() + if ignore_fail: + print("NOTE: This warning isn't a failure warning.", file=sys.stderr) # noqa: T201 + else: + self._has_warned = True + print("-" * 72, file=sys.stderr) # noqa: T201 + + +_shared_warning_printer = WarningPrinter() + + +def get_warning_printer() -> WarningPrinter: + """Shared warning printer, representing a module-level singleton object.""" + return _shared_warning_printer + + +__all__ = ["WarningPrinter", "get_warning_printer"] diff --git a/tests/test_cli.py b/tests/test_cli.py index a55c90a..bf3604b 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,8 +1,12 @@ from __future__ import annotations +import argparse +from typing import Any + import pytest -from pipdeptree._cli import build_parser, get_options +from pipdeptree._cli import EnumAction, build_parser, get_options +from pipdeptree._warning import WarningType def test_parser_default() -> None: @@ -106,3 +110,35 @@ def test_parser_get_options_license_and_freeze_together_not_supported(capsys: py out, err = capsys.readouterr() assert not out assert "cannot use --license with --freeze" in err + + +@pytest.mark.parametrize(("bad_type"), [None, str]) +def test_enum_action_type_argument(bad_type: Any) -> None: + with pytest.raises(TypeError, match="type must be a subclass of Enum"): + EnumAction(["--test"], "test", type=bad_type) + + +def test_enum_action_default_argument_not_str() -> None: + with pytest.raises(TypeError, match="default must be defined with a string value"): + EnumAction(["--test"], "test", type=WarningType) + + +def test_enum_action_default_argument_not_a_valid_choice() -> None: + with pytest.raises(ValueError, match="default value should be among the enum choices"): + EnumAction(["--test"], "test", type=WarningType, default="bad-warning-type") + + +def test_enum_action_call_with_value() -> None: + action = EnumAction(["--test"], "test", type=WarningType, default="silence") + namespace = argparse.Namespace() + action(argparse.ArgumentParser(), namespace, "suppress") + assert getattr(namespace, "test", None) == WarningType.SUPPRESS + + +def test_enum_action_call_without_value() -> None: + # ensures that we end up using the default value in case no value is specified (currently we pass nargs='?' when + # creating the --warn option, which is why this test exists) + action = EnumAction(["--test"], "test", type=WarningType, default="silence") + namespace = argparse.Namespace() + action(argparse.ArgumentParser(), namespace, None) + assert getattr(namespace, "test", None) == WarningType.SILENCE diff --git a/tests/test_discovery.py b/tests/test_discovery.py index 140473d..351d417 100644 --- a/tests/test_discovery.py +++ b/tests/test_discovery.py @@ -79,6 +79,7 @@ def test_duplicate_metadata(mocker: MockerFixture, capfd: pytest.CaptureFixture[ _, err = capfd.readouterr() expected = ( 'Warning!!! Duplicate package metadata found:\n"/path/2"\n foo 5.9.0 ' - ' (using 1.2.5, "/path/1")\n------------------------------------------------------------------------\n' + ' (using 1.2.5, "/path/1")\nNOTE: This warning isn\'t a failure warning.\n---------------------------------' + "---------------------------------------\n" ) assert err == expected diff --git a/tests/test_validate.py b/tests/test_validate.py index 03755fb..efb164b 100644 --- a/tests/test_validate.py +++ b/tests/test_validate.py @@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Callable, Iterator import pytest from pipdeptree._models import PackageDAG -from pipdeptree._validate import conflicting_deps, cyclic_deps, render_conflicts_text, render_cycles_text +from pipdeptree._validate import conflicting_deps, cyclic_deps, render_conflicts_text, render_cycles_text, validate if TYPE_CHECKING: from unittest.mock import Mock @@ -24,7 +24,7 @@ if TYPE_CHECKING: ("d", "2.0"): [], }, [["a", "b", "a"], ["b", "a", "b"]], - ["Warning!! Cyclic dependencies found:", "* b => a => b", "* a => b => a"], + ["* b => a => b", "* a => b => a"], id="depth-of-2", ), pytest.param( @@ -42,7 +42,6 @@ if TYPE_CHECKING: ["a", "b", "c", "d", "a"], ], [ - "Warning!! Cyclic dependencies found:", "* b => c => d => a => b", "* c => d => a => b => c", "* d => a => b => c => d", @@ -90,7 +89,6 @@ def test_cyclic_deps( {("a", "1.0.1"): [("b", [(">=", "2.3.0")])], ("b", "1.9.1"): []}, {"a": ["b"]}, [ - "Warning!!! Possibly conflicting dependencies found:", "* a==1.0.1", " - b [required: >=2.3.0, installed: 1.9.1]", ], @@ -99,7 +97,6 @@ def test_cyclic_deps( {("a", "1.0.1"): [("c", [(">=", "9.4.1")])], ("b", "2.3.0"): [("c", [(">=", "7.0")])], ("c", "8.0.1"): []}, {"a": ["c"]}, [ - "Warning!!! Possibly conflicting dependencies found:", "* a==1.0.1", " - c [required: >=9.4.1, installed: 8.0.1]", ], @@ -108,7 +105,6 @@ def test_cyclic_deps( {("a", "1.0.1"): [("c", [(">=", "9.4.1")])], ("b", "2.3.0"): [("c", [(">=", "9.4.0")])]}, {"a": ["c"], "b": ["c"]}, [ - "Warning!!! Possibly conflicting dependencies found:", "* a==1.0.1", " - c [required: >=9.4.1, installed: ?]", "* b==2.3.0", @@ -136,3 +132,43 @@ def test_conflicting_deps( render_conflicts_text(result) captured = capsys.readouterr() assert "\n".join(expected_output).strip() == captured.err.strip() + + +@pytest.mark.parametrize( + ("mpkgs", "expected_output"), + [ + ( + {("a", "1.0.1"): [("b", [(">=", "2.3.0")])], ("b", "1.9.1"): []}, + [ + "Warning!!! Possibly conflicting dependencies found:", + "* a==1.0.1", + " - b [required: >=2.3.0, installed: 1.9.1]", + "------------------------------------------------------------------------", + ], + ), + ( + { + ("a", "1.0.1"): [("b", [(">=", "2.0.0")])], + ("b", "2.3.0"): [("a", [(">=", "1.0.1")])], + ("c", "4.5.0"): [], + }, + [ + "Warning!!! Cyclic dependencies found:", + "* b => a => b", + "* a => b => a", + "------------------------------------------------------------------------", + ], + ), + ], +) +def test_validate( + capsys: pytest.CaptureFixture[str], + mock_pkgs: Callable[[MockGraph], Iterator[Mock]], + mpkgs: dict[tuple[str, str], list[tuple[str, list[tuple[str, str]]]]], + expected_output: list[str], +) -> None: + tree = PackageDAG.from_pkgs(list(mock_pkgs(mpkgs))) + validate(tree) + out, err = capsys.readouterr() + assert len(out) == 0 + assert "\n".join(expected_output).strip() == err.strip() diff --git a/tests/test_warning.py b/tests/test_warning.py new file mode 100644 index 0000000..8811255 --- /dev/null +++ b/tests/test_warning.py @@ -0,0 +1,18 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from pipdeptree._warning import WarningPrinter, WarningType + +if TYPE_CHECKING: + import pytest + + +def test_warning_printer_print_single_line(capsys: pytest.CaptureFixture[str]) -> None: + # Use WarningType.FAIL so that we can be able to test to see if WarningPrinter remembers it has warned before. + warning_printer = WarningPrinter(WarningType.FAIL) + warning_printer.print_single_line("test") + assert warning_printer.has_warned_with_failure() + out, err = capsys.readouterr() + assert len(out) == 0 + assert err == "test\n"