Argument to overwrite file #5

This commit is contained in:
Ram Rachum 2019-04-25 20:28:45 +03:00
parent ae001ccd19
commit 948fa7a312
3 changed files with 114 additions and 13 deletions

View File

@ -16,25 +16,30 @@ from . import pycompat
from .tracer import Tracer
def get_write_function(output):
def get_write_and_truncate_functions(output):
if output is None:
def write(s):
stderr = sys.stderr
stderr.write(s)
truncate = None
elif isinstance(output, (pycompat.PathLike, str)):
def write(s):
with open(output, 'a') as output_file:
output_file.write(s)
def truncate():
with open(output, 'w') as output_file:
pass
else:
assert isinstance(output, utils.WritableStream)
def write(s):
output.write(s)
truncate = None
return write
return (write, truncate)
def snoop(output=None, variables=(), depth=1, prefix=''):
def snoop(output=None, variables=(), depth=1, prefix='', overwrite=False):
'''
Snoop on the function, writing everything it's doing to stderr.
@ -62,14 +67,20 @@ def snoop(output=None, variables=(), depth=1, prefix=''):
@pysnooper.snoop(prefix='ZZZ ')
'''
write = get_write_function(output)
@decorator.decorator
def decorate(function, *args, **kwargs):
write, truncate = get_write_and_truncate_functions(output)
if truncate is None and overwrite:
raise Exception("`overwrite=True` can only be used when writing "
"content to file.")
def decorate(function):
target_code_object = function.__code__
with Tracer(target_code_object=target_code_object,
write=write, variables=variables,
depth=depth, prefix=prefix):
return function(*args, **kwargs)
tracer = Tracer(target_code_object=target_code_object, write=write,
truncate=truncate, variables=variables, depth=depth,
prefix=prefix, overwrite=overwrite)
def inner(function_, *args, **kwargs):
with tracer:
return function(*args, **kwargs)
return decorator.decorate(function, inner)
return decorate

View File

@ -119,18 +119,24 @@ def get_source_from_frame(frame):
return source
class Tracer:
def __init__(self, target_code_object, write, variables=(), depth=1,
prefix=''):
def __init__(self, target_code_object, write, truncate, variables=(),
depth=1, prefix='', overwrite=False):
self.target_code_object = target_code_object
self._write = write
self.truncate = truncate
self.variables = variables
self.frame_to_old_local_reprs = collections.defaultdict(lambda: {})
self.frame_to_local_reprs = collections.defaultdict(lambda: {})
self.depth = depth
self.prefix = prefix
self.overwrite = overwrite
self._did_overwrite = False
assert self.depth >= 1
def write(self, s):
if self.overwrite and not self._did_overwrite:
self.truncate()
self._did_overwrite = True
s = '{self.prefix}{s}\n'.format(**locals())
if isinstance(s, bytes): # Python 2 compatibility
s = s.decode()

View File

@ -9,6 +9,7 @@ from python_toolbox import caching
from python_toolbox import sys_tools
from python_toolbox import temp_file_tools
from pysnooper.third_party import six
import pytest
import pysnooper
@ -179,7 +180,6 @@ def test_method_and_prefix():
)
def test_file_output():
with temp_file_tools.create_temp_folder(prefix='pysnooper') as folder:
path = folder / 'foo.log'
@pysnooper.snoop(str(path))
@ -295,3 +295,87 @@ def test_unavailable_source():
)
)
def test_no_overwrite_by_default():
with temp_file_tools.create_temp_folder(prefix='pysnooper') as folder:
path = folder / 'foo.log'
with path.open('w') as output_file:
output_file.write(u'lala')
@pysnooper.snoop(str(path))
def my_function(foo):
x = 7
y = 8
return y + x
result = my_function('baba')
assert result == 15
with path.open() as output_file:
output = output_file.read()
assert output.startswith('lala')
shortened_output = output[4:]
assert_output(
shortened_output,
(
VariableEntry('foo', value_regex="u?'baba'"),
CallEntry('def my_function(foo):'),
LineEntry('x = 7'),
VariableEntry('x', '7'),
LineEntry('y = 8'),
VariableEntry('y', '8'),
LineEntry('return y + x'),
ReturnEntry('return y + x'),
ReturnValueEntry('15'),
)
)
def test_overwrite():
with temp_file_tools.create_temp_folder(prefix='pysnooper') as folder:
path = folder / 'foo.log'
with path.open('w') as output_file:
output_file.write(u'lala')
@pysnooper.snoop(str(path), overwrite=True)
def my_function(foo):
x = 7
y = 8
return y + x
result = my_function('baba')
result = my_function('baba')
assert result == 15
with path.open() as output_file:
output = output_file.read()
assert 'lala' not in output
assert_output(
output,
(
VariableEntry('foo', value_regex="u?'baba'"),
CallEntry('def my_function(foo):'),
LineEntry('x = 7'),
VariableEntry('x', '7'),
LineEntry('y = 8'),
VariableEntry('y', '8'),
LineEntry('return y + x'),
ReturnEntry('return y + x'),
ReturnValueEntry('15'),
VariableEntry('foo', value_regex="u?'baba'"),
CallEntry('def my_function(foo):'),
LineEntry('x = 7'),
VariableEntry('x', '7'),
LineEntry('y = 8'),
VariableEntry('y', '8'),
LineEntry('return y + x'),
ReturnEntry('return y + x'),
ReturnValueEntry('15'),
)
)
def test_error_in_overwrite_argument():
with temp_file_tools.create_temp_folder(prefix='pysnooper') as folder:
with pytest.raises(Exception, match='can only be used when writing'):
@pysnooper.snoop(overwrite=True)
def my_function(foo):
x = 7
y = 8
return y + x