diff --git a/boltons/funcutils.py b/boltons/funcutils.py index c19348b..b678034 100644 --- a/boltons/funcutils.py +++ b/boltons/funcutils.py @@ -12,7 +12,6 @@ import inspect import functools import itertools from types import MethodType, FunctionType -from collections import OrderedDict try: xrange @@ -286,8 +285,6 @@ def wraps(func, injected=None, expected=None, **kw): :class:`FunctionBuilder` type, on which wraps was built. """ - # TODO: maybe automatically use normal wraps in the very rare case - # that the signatures actually match and no adapter is needed. if injected is None: injected = [] elif isinstance(injected, basestring): @@ -295,20 +292,7 @@ def wraps(func, injected=None, expected=None, **kw): else: injected = list(injected) - if expected is None: - expected = [] - elif isinstance(expected, basestring): - expected = [(expected, NO_DEFAULT)] - - try: - try: - expected = OrderedDict(expected) - except (ValueError, TypeError): - expected = OrderedDict.fromkeys(expected, NO_DEFAULT) - except Exception as e: - raise ValueError('"expected" takes string name, sequence of string names,' - ' iterable of (name, default) pairs, or a mapping of ' - '{name: default}, not %r (got: %r)' % (expected, e)) + expected_items = _parse_wraps_expected(expected) if isinstance(func, (classmethod, staticmethod)): raise TypeError('wraps does not support wrapping classmethods and' @@ -330,7 +314,7 @@ def wraps(func, injected=None, expected=None, **kw): continue # keyword arg will be caught by the varkw raise - for arg, default in expected.items(): + for arg, default in expected_items: fb.add_arg(arg, default) # may raise ExistingArgument if fb.is_async: @@ -348,6 +332,46 @@ def wraps(func, injected=None, expected=None, **kw): return wrapper_wrapper +def _parse_wraps_expected(expected): + # expected takes a pretty powerful argument, it's processed + # here. admittedly this would be less trouble if I relied on + # OrderedDict (there's an impl of that in the commit history if + # you look + if expected is None: + expected = [] + elif isinstance(expected, basestring): + expected = [(expected, NO_DEFAULT)] + + expected_items = [] + try: + expected_iter = iter(expected) + except TypeError as e: + raise ValueError('"expected" takes string name, sequence of string names,' + ' iterable of (name, default) pairs, or a mapping of ' + ' {name: default}, not %r (got: %r)' % (expected, e)) + for argname in expected_iter: + if isinstance(argname, basestring): + # dict keys and bare strings + try: + default = expected[argname] + except TypeError: + default = NO_DEFAULT + else: + # pairs + try: + argname, default = argname + except (TypeError, ValueError): + raise ValueError('"expected" takes string name, sequence of string names,' + ' iterable of (name, default) pairs, or a mapping of ' + ' {name: default}, not %r') + if not isinstance(argname, basestring): + raise ValueError('all "expected" argnames must be strings, not %r' % (argname,)) + + expected_items.append((argname, default)) + + return expected_items + + class FunctionBuilder(object): """The FunctionBuilder type provides an interface for programmatically creating new functions, either based on existing functions or from diff --git a/tests/test_funcutils_fb.py b/tests/test_funcutils_fb.py index fb8152f..001186c 100644 --- a/tests/test_funcutils_fb.py +++ b/tests/test_funcutils_fb.py @@ -177,58 +177,6 @@ def test_wraps_wrappers(): return -def test_wraps_expected(): - def expect_string(func): - @wraps(func, expected="c") - def wrapped(*args, **kwargs): - args, c = args[:2], args[-1] - return func(*args, **kwargs) + (c,) - return wrapped - - assert expect_string(wrappable_func)(1, 2, 3) == (1, 2, 3) - ''' - def inject_list(func): - @wraps(func, injected=["b"]) - def wrapped(a, *args, **kwargs): - return func(a, 2, *args, **kwargs) - return wrapped - - assert inject_list(wrappable_func)(1) == (1, 2) - - def inject_nonexistent_arg(func): - @wraps(func, injected=["X"]) - def wrapped(*args, **kwargs): - return func(*args, **kwargs) - return wrapped - - with pytest.raises(ValueError): - inject_nonexistent_arg(wrappable_func) - - def inject_missing_argument(func): - @wraps(func, injected="c") - def wrapped(*args, **kwargs): - return func(1, *args, **kwargs) - return wrapped - - def inject_misc_argument(func): - # inject_to_varkw is default True, just being explicit - @wraps(func, injected="c", inject_to_varkw=True) - def wrapped(*args, **kwargs): - return func(c=1, *args, **kwargs) - return wrapped - - assert inject_misc_argument(wrappable_varkw_func)(1, 2) == (1, 2) - - def inject_misc_argument_no_varkw(func): - @wraps(func, injected="c", inject_to_varkw=False) - def wrapped(*args, **kwargs): - return func(c=1, *args, **kwargs) - return wrapped - - with pytest.raises(ValueError): - inject_misc_argument_no_varkw(wrappable_varkw_func) - ''' - def test_FunctionBuilder_add_arg(): fb = FunctionBuilder('return_five', doc='returns the integer 5', body='return 5') @@ -257,3 +205,48 @@ def test_FunctionBuilder_add_arg(): assert better_func('positional') == 'positional' assert better_func(val='keyword') == 'keyword' + + +def test_wraps_expected(): + def expect_string(func): + @wraps(func, expected="c") + def wrapped(*args, **kwargs): + args, c = args[:2], args[-1] + return func(*args, **kwargs) + (c,) + return wrapped + + expected_string = expect_string(wrappable_func) + assert expected_string(1, 2, 3) == (1, 2, 3) + + with pytest.raises(TypeError) as excinfo: + expected_string(1, 2) + + # a rough way of making sure we got the kind of error we expected + assert 'argument' in repr(excinfo.value) + + def expect_list(func): + @wraps(func, expected=["c"]) + def wrapped(*args, **kwargs): + args, c = args[:2], args[-1] + return func(*args, **kwargs) + (c,) + return wrapped + + assert expect_list(wrappable_func)(1, 2, c=4) == (1, 2, 4) + + def expect_pair(func): + @wraps(func, expected=[('c', 5)]) + def wrapped(*args, **kwargs): + args, c = args[:2], args[-1] + return func(*args, **kwargs) + (c,) + return wrapped + + assert expect_pair(wrappable_func)(1, 2) == (1, 2, 5) + + def expect_dict(func): + @wraps(func, expected={'c': 6}) + def wrapped(*args, **kwargs): + args, c = args[:2], args[-1] + return func(*args, **kwargs) + (c,) + return wrapped + + assert expect_dict(wrappable_func)(1, 2) == (1, 2, 6)