diff --git a/README.md b/README.md index 331488c..9b43596 100644 --- a/README.md +++ b/README.md @@ -149,6 +149,9 @@ On multi-threaded apps identify which thread are snooped in output:: @pysnooper.snoop(thread_info=True) ``` +PySnooper supports decorating generators. + + # Installation # You can install **PySnooper** by: diff --git a/pysnooper/pycompat.py b/pysnooper/pycompat.py index 4f2a9a6..b63df9d 100644 --- a/pysnooper/pycompat.py +++ b/pysnooper/pycompat.py @@ -4,6 +4,7 @@ import abc import os +import inspect if hasattr(abc, 'ABC'): ABC = abc.ABC @@ -35,3 +36,9 @@ else: (hasattr(subclass, 'open') and 'path' in subclass.__name__.lower()) ) + + +try: + iscoroutinefunction = inspect.iscoroutinefunction +except AttributeError: + iscoroutinefunction = lambda whatever: False # Lolz diff --git a/pysnooper/tracer.py b/pysnooper/tracer.py index 844809d..b05b0f9 100644 --- a/pysnooper/tracer.py +++ b/pysnooper/tracer.py @@ -209,11 +209,32 @@ class Tracer: self.target_codes.add(function.__code__) @functools.wraps(function) - def inner(*args, **kwargs): + def simple_wrapper(*args, **kwargs): with self: return function(*args, **kwargs) - return inner + @functools.wraps(function) + def generator_wrapper(*args, **kwargs): + gen = function(*args, **kwargs) + method, incoming = gen.send, None + while True: + with self: + try: + outgoing = method(incoming) + except StopIteration: + return + try: + method, incoming = gen.send, (yield outgoing) + except Exception as e: + method, incoming = gen.throw, e + + if pycompat.iscoroutinefunction(function): + # return decorate(function, coroutine_wrapper) + raise NotImplementedError + elif inspect.isgeneratorfunction(function): + return generator_wrapper + else: + return simple_wrapper def write(self, s): if self.overwrite and not self._did_overwrite: diff --git a/tests/test_pysnooper.py b/tests/test_pysnooper.py index d3882e4..4691cfb 100644 --- a/tests/test_pysnooper.py +++ b/tests/test_pysnooper.py @@ -5,6 +5,7 @@ import io import textwrap import threading import types +import sys from pysnooper.utils import truncate from python_toolbox import sys_tools, temp_file_tools @@ -1047,3 +1048,86 @@ def test_indentation(): def test_exception(): from .samples import exception assert_sample_output(exception) + + +def test_generator(): + string_io = io.StringIO() + original_tracer = sys.gettrace() + original_tracer_active = lambda: (sys.gettrace() is original_tracer) + + + @pysnooper.snoop(string_io) + def f(x1): + assert not original_tracer_active() + x2 = (yield x1) + assert not original_tracer_active() + x3 = 'foo' + assert not original_tracer_active() + x4 = (yield 2) + assert not original_tracer_active() + return + + + assert original_tracer_active() + generator = f(0) + assert original_tracer_active() + first_item = next(generator) + assert original_tracer_active() + assert first_item == 0 + second_item = generator.send('blabla') + assert original_tracer_active() + assert second_item == 2 + with pytest.raises(StopIteration) as exc_info: + generator.send('looloo') + assert original_tracer_active() + + output = string_io.getvalue() + assert_output( + output, + ( + VariableEntry('x1', '0'), + VariableEntry(), + CallEntry(), + LineEntry(), + VariableEntry(), + VariableEntry(), + LineEntry(), + ReturnEntry(), + ReturnValueEntry('0'), + + # Pause and resume: + + VariableEntry('x1', '0'), + VariableEntry(), + VariableEntry(), + VariableEntry(), + CallEntry(), + VariableEntry('x2', "'blabla'"), + LineEntry(), + LineEntry(), + VariableEntry('x3', "'foo'"), + LineEntry(), + LineEntry(), + ReturnEntry(), + ReturnValueEntry('2'), + + # Pause and resume: + + VariableEntry('x1', '0'), + VariableEntry(), + VariableEntry(), + VariableEntry(), + VariableEntry(), + VariableEntry(), + CallEntry(), + VariableEntry('x4', "'looloo'"), + LineEntry(), + LineEntry(), + ReturnEntry(), + ReturnValueEntry(None), + + ) + ) + + +