From 948fa7a312757df722f31b109a18b0ccea83f250 Mon Sep 17 00:00:00 2001 From: Ram Rachum Date: Thu, 25 Apr 2019 20:28:45 +0300 Subject: [PATCH] Argument to overwrite file #5 --- pysnooper/pysnooper.py | 31 ++++++++++----- pysnooper/tracer.py | 10 ++++- tests/test_pysnooper.py | 86 ++++++++++++++++++++++++++++++++++++++++- 3 files changed, 114 insertions(+), 13 deletions(-) diff --git a/pysnooper/pysnooper.py b/pysnooper/pysnooper.py index 6aba062..b70c932 100644 --- a/pysnooper/pysnooper.py +++ b/pysnooper/pysnooper.py @@ -16,25 +16,30 @@ from . import pycompat from .tracer import Tracer -def get_write_function(output): +def get_write_and_truncate_functions(output): if output is None: def write(s): stderr = sys.stderr stderr.write(s) + truncate = None elif isinstance(output, (pycompat.PathLike, str)): def write(s): with open(output, 'a') as output_file: output_file.write(s) + def truncate(): + with open(output, 'w') as output_file: + pass else: assert isinstance(output, utils.WritableStream) def write(s): output.write(s) + truncate = None - return write + return (write, truncate) -def snoop(output=None, variables=(), depth=1, prefix=''): +def snoop(output=None, variables=(), depth=1, prefix='', overwrite=False): ''' Snoop on the function, writing everything it's doing to stderr. @@ -62,14 +67,20 @@ def snoop(output=None, variables=(), depth=1, prefix=''): @pysnooper.snoop(prefix='ZZZ ') ''' - write = get_write_function(output) - @decorator.decorator - def decorate(function, *args, **kwargs): + write, truncate = get_write_and_truncate_functions(output) + if truncate is None and overwrite: + raise Exception("`overwrite=True` can only be used when writing " + "content to file.") + def decorate(function): target_code_object = function.__code__ - with Tracer(target_code_object=target_code_object, - write=write, variables=variables, - depth=depth, prefix=prefix): - return function(*args, **kwargs) + tracer = Tracer(target_code_object=target_code_object, write=write, + truncate=truncate, variables=variables, depth=depth, + prefix=prefix, overwrite=overwrite) + + def inner(function_, *args, **kwargs): + with tracer: + return function(*args, **kwargs) + return decorator.decorate(function, inner) return decorate diff --git a/pysnooper/tracer.py b/pysnooper/tracer.py index 14bfe72..cc75fd5 100644 --- a/pysnooper/tracer.py +++ b/pysnooper/tracer.py @@ -119,18 +119,24 @@ def get_source_from_frame(frame): return source class Tracer: - def __init__(self, target_code_object, write, variables=(), depth=1, - prefix=''): + def __init__(self, target_code_object, write, truncate, variables=(), + depth=1, prefix='', overwrite=False): self.target_code_object = target_code_object self._write = write + self.truncate = truncate self.variables = variables self.frame_to_old_local_reprs = collections.defaultdict(lambda: {}) self.frame_to_local_reprs = collections.defaultdict(lambda: {}) self.depth = depth self.prefix = prefix + self.overwrite = overwrite + self._did_overwrite = False assert self.depth >= 1 def write(self, s): + if self.overwrite and not self._did_overwrite: + self.truncate() + self._did_overwrite = True s = '{self.prefix}{s}\n'.format(**locals()) if isinstance(s, bytes): # Python 2 compatibility s = s.decode() diff --git a/tests/test_pysnooper.py b/tests/test_pysnooper.py index 9dac0e5..3cb7df5 100644 --- a/tests/test_pysnooper.py +++ b/tests/test_pysnooper.py @@ -9,6 +9,7 @@ from python_toolbox import caching from python_toolbox import sys_tools from python_toolbox import temp_file_tools from pysnooper.third_party import six +import pytest import pysnooper @@ -179,7 +180,6 @@ def test_method_and_prefix(): ) def test_file_output(): - with temp_file_tools.create_temp_folder(prefix='pysnooper') as folder: path = folder / 'foo.log' @pysnooper.snoop(str(path)) @@ -295,3 +295,87 @@ def test_unavailable_source(): ) ) + +def test_no_overwrite_by_default(): + with temp_file_tools.create_temp_folder(prefix='pysnooper') as folder: + path = folder / 'foo.log' + with path.open('w') as output_file: + output_file.write(u'lala') + @pysnooper.snoop(str(path)) + def my_function(foo): + x = 7 + y = 8 + return y + x + result = my_function('baba') + assert result == 15 + with path.open() as output_file: + output = output_file.read() + assert output.startswith('lala') + shortened_output = output[4:] + assert_output( + shortened_output, + ( + VariableEntry('foo', value_regex="u?'baba'"), + CallEntry('def my_function(foo):'), + LineEntry('x = 7'), + VariableEntry('x', '7'), + LineEntry('y = 8'), + VariableEntry('y', '8'), + LineEntry('return y + x'), + ReturnEntry('return y + x'), + ReturnValueEntry('15'), + ) + ) + + +def test_overwrite(): + with temp_file_tools.create_temp_folder(prefix='pysnooper') as folder: + path = folder / 'foo.log' + with path.open('w') as output_file: + output_file.write(u'lala') + @pysnooper.snoop(str(path), overwrite=True) + def my_function(foo): + x = 7 + y = 8 + return y + x + result = my_function('baba') + result = my_function('baba') + assert result == 15 + with path.open() as output_file: + output = output_file.read() + assert 'lala' not in output + assert_output( + output, + ( + VariableEntry('foo', value_regex="u?'baba'"), + CallEntry('def my_function(foo):'), + LineEntry('x = 7'), + VariableEntry('x', '7'), + LineEntry('y = 8'), + VariableEntry('y', '8'), + LineEntry('return y + x'), + ReturnEntry('return y + x'), + ReturnValueEntry('15'), + + VariableEntry('foo', value_regex="u?'baba'"), + CallEntry('def my_function(foo):'), + LineEntry('x = 7'), + VariableEntry('x', '7'), + LineEntry('y = 8'), + VariableEntry('y', '8'), + LineEntry('return y + x'), + ReturnEntry('return y + x'), + ReturnValueEntry('15'), + ) + ) + + +def test_error_in_overwrite_argument(): + with temp_file_tools.create_temp_folder(prefix='pysnooper') as folder: + with pytest.raises(Exception, match='can only be used when writing'): + @pysnooper.snoop(overwrite=True) + def my_function(foo): + x = 7 + y = 8 + return y + x +