Argument to overwrite file #5

This commit is contained in:
Ram Rachum 2019-04-25 20:28:45 +03:00
parent ae001ccd19
commit 948fa7a312
3 changed files with 114 additions and 13 deletions

View File

@ -16,25 +16,30 @@ from . import pycompat
from .tracer import Tracer from .tracer import Tracer
def get_write_function(output): def get_write_and_truncate_functions(output):
if output is None: if output is None:
def write(s): def write(s):
stderr = sys.stderr stderr = sys.stderr
stderr.write(s) stderr.write(s)
truncate = None
elif isinstance(output, (pycompat.PathLike, str)): elif isinstance(output, (pycompat.PathLike, str)):
def write(s): def write(s):
with open(output, 'a') as output_file: with open(output, 'a') as output_file:
output_file.write(s) output_file.write(s)
def truncate():
with open(output, 'w') as output_file:
pass
else: else:
assert isinstance(output, utils.WritableStream) assert isinstance(output, utils.WritableStream)
def write(s): def write(s):
output.write(s) output.write(s)
truncate = None
return write return (write, truncate)
def snoop(output=None, variables=(), depth=1, prefix=''): def snoop(output=None, variables=(), depth=1, prefix='', overwrite=False):
''' '''
Snoop on the function, writing everything it's doing to stderr. Snoop on the function, writing everything it's doing to stderr.
@ -62,14 +67,20 @@ def snoop(output=None, variables=(), depth=1, prefix=''):
@pysnooper.snoop(prefix='ZZZ ') @pysnooper.snoop(prefix='ZZZ ')
''' '''
write = get_write_function(output) write, truncate = get_write_and_truncate_functions(output)
@decorator.decorator if truncate is None and overwrite:
def decorate(function, *args, **kwargs): raise Exception("`overwrite=True` can only be used when writing "
"content to file.")
def decorate(function):
target_code_object = function.__code__ target_code_object = function.__code__
with Tracer(target_code_object=target_code_object, tracer = Tracer(target_code_object=target_code_object, write=write,
write=write, variables=variables, truncate=truncate, variables=variables, depth=depth,
depth=depth, prefix=prefix): prefix=prefix, overwrite=overwrite)
return function(*args, **kwargs)
def inner(function_, *args, **kwargs):
with tracer:
return function(*args, **kwargs)
return decorator.decorate(function, inner)
return decorate return decorate

View File

@ -119,18 +119,24 @@ def get_source_from_frame(frame):
return source return source
class Tracer: class Tracer:
def __init__(self, target_code_object, write, variables=(), depth=1, def __init__(self, target_code_object, write, truncate, variables=(),
prefix=''): depth=1, prefix='', overwrite=False):
self.target_code_object = target_code_object self.target_code_object = target_code_object
self._write = write self._write = write
self.truncate = truncate
self.variables = variables self.variables = variables
self.frame_to_old_local_reprs = collections.defaultdict(lambda: {}) self.frame_to_old_local_reprs = collections.defaultdict(lambda: {})
self.frame_to_local_reprs = collections.defaultdict(lambda: {}) self.frame_to_local_reprs = collections.defaultdict(lambda: {})
self.depth = depth self.depth = depth
self.prefix = prefix self.prefix = prefix
self.overwrite = overwrite
self._did_overwrite = False
assert self.depth >= 1 assert self.depth >= 1
def write(self, s): def write(self, s):
if self.overwrite and not self._did_overwrite:
self.truncate()
self._did_overwrite = True
s = '{self.prefix}{s}\n'.format(**locals()) s = '{self.prefix}{s}\n'.format(**locals())
if isinstance(s, bytes): # Python 2 compatibility if isinstance(s, bytes): # Python 2 compatibility
s = s.decode() s = s.decode()

View File

@ -9,6 +9,7 @@ from python_toolbox import caching
from python_toolbox import sys_tools from python_toolbox import sys_tools
from python_toolbox import temp_file_tools from python_toolbox import temp_file_tools
from pysnooper.third_party import six from pysnooper.third_party import six
import pytest
import pysnooper import pysnooper
@ -179,7 +180,6 @@ def test_method_and_prefix():
) )
def test_file_output(): def test_file_output():
with temp_file_tools.create_temp_folder(prefix='pysnooper') as folder: with temp_file_tools.create_temp_folder(prefix='pysnooper') as folder:
path = folder / 'foo.log' path = folder / 'foo.log'
@pysnooper.snoop(str(path)) @pysnooper.snoop(str(path))
@ -295,3 +295,87 @@ def test_unavailable_source():
) )
) )
def test_no_overwrite_by_default():
with temp_file_tools.create_temp_folder(prefix='pysnooper') as folder:
path = folder / 'foo.log'
with path.open('w') as output_file:
output_file.write(u'lala')
@pysnooper.snoop(str(path))
def my_function(foo):
x = 7
y = 8
return y + x
result = my_function('baba')
assert result == 15
with path.open() as output_file:
output = output_file.read()
assert output.startswith('lala')
shortened_output = output[4:]
assert_output(
shortened_output,
(
VariableEntry('foo', value_regex="u?'baba'"),
CallEntry('def my_function(foo):'),
LineEntry('x = 7'),
VariableEntry('x', '7'),
LineEntry('y = 8'),
VariableEntry('y', '8'),
LineEntry('return y + x'),
ReturnEntry('return y + x'),
ReturnValueEntry('15'),
)
)
def test_overwrite():
with temp_file_tools.create_temp_folder(prefix='pysnooper') as folder:
path = folder / 'foo.log'
with path.open('w') as output_file:
output_file.write(u'lala')
@pysnooper.snoop(str(path), overwrite=True)
def my_function(foo):
x = 7
y = 8
return y + x
result = my_function('baba')
result = my_function('baba')
assert result == 15
with path.open() as output_file:
output = output_file.read()
assert 'lala' not in output
assert_output(
output,
(
VariableEntry('foo', value_regex="u?'baba'"),
CallEntry('def my_function(foo):'),
LineEntry('x = 7'),
VariableEntry('x', '7'),
LineEntry('y = 8'),
VariableEntry('y', '8'),
LineEntry('return y + x'),
ReturnEntry('return y + x'),
ReturnValueEntry('15'),
VariableEntry('foo', value_regex="u?'baba'"),
CallEntry('def my_function(foo):'),
LineEntry('x = 7'),
VariableEntry('x', '7'),
LineEntry('y = 8'),
VariableEntry('y', '8'),
LineEntry('return y + x'),
ReturnEntry('return y + x'),
ReturnValueEntry('15'),
)
)
def test_error_in_overwrite_argument():
with temp_file_tools.create_temp_folder(prefix='pysnooper') as folder:
with pytest.raises(Exception, match='can only be used when writing'):
@pysnooper.snoop(overwrite=True)
def my_function(foo):
x = 7
y = 8
return y + x