add a `convert` keyword to attr.ib() that allows specifying a function to convert the passed-in value.

This commit is contained in:
Christopher Armstrong 2015-09-16 10:45:51 -05:00 committed by Hynek Schlawack
parent 253491908b
commit 365cd89921
2 changed files with 70 additions and 10 deletions

View File

@ -41,7 +41,8 @@ Sentinel to indicate the lack of a value when ``None`` is ambiguous.
def attr(default=NOTHING, validator=None, def attr(default=NOTHING, validator=None,
repr=True, cmp=True, hash=True, init=True): repr=True, cmp=True, hash=True, init=True,
convert=None):
""" """
Create a new attribute on a class. Create a new attribute on a class.
@ -81,6 +82,13 @@ def attr(default=NOTHING, validator=None,
:param init: Include this attribute in the generated ``__init__`` method. :param init: Include this attribute in the generated ``__init__`` method.
:type init: bool :type init: bool
:param convert: :func:`callable` that is called by ``attrs``-generated
``__init__`` methods to convert attribute's value to the desired format.
It is given the passed-in value, and the returned value will be used as
the new value of the attribute.
:type convert: callable
""" """
return _CountingAttr( return _CountingAttr(
default=default, default=default,
@ -89,6 +97,7 @@ def attr(default=NOTHING, validator=None,
cmp=cmp, cmp=cmp,
hash=hash, hash=hash,
init=init, init=init,
convert=convert,
) )
@ -342,7 +351,8 @@ def _add_init(cl):
attr_dict = dict((a.name, a) for a in attrs) attr_dict = dict((a.name, a) for a in attrs)
exec_(bytecode, {"NOTHING": NOTHING, exec_(bytecode, {"NOTHING": NOTHING,
"attr_dict": attr_dict, "attr_dict": attr_dict,
"validate": validate}, locs) "validate": validate,
"_convert": _convert}, locs)
init = locs["__init__"] init = locs["__init__"]
# In order of debuggers like PDB being able to step through the code, # In order of debuggers like PDB being able to step through the code,
@ -395,6 +405,19 @@ def validate(inst):
a.validator(inst, a, getattr(inst, a.name)) a.validator(inst, a, getattr(inst, a.name))
def _convert(inst):
"""
Convert all attributes on *inst* that have a converter.
Leaves all exceptions through.
:param inst: Instance of a class with ``attrs`` attributes.
"""
for a in fields(inst.__class__):
if a.convert is not None:
setattr(inst, a.name, a.convert(getattr(inst, a.name)))
def _attrs_to_script(attrs): def _attrs_to_script(attrs):
""" """
Return a valid Python script of an initializer for *attrs*. Return a valid Python script of an initializer for *attrs*.
@ -402,9 +425,12 @@ def _attrs_to_script(attrs):
lines = [] lines = []
args = [] args = []
has_validator = False has_validator = False
has_convert = False
for a in attrs: for a in attrs:
if a.validator is not None: if a.validator is not None:
has_validator = True has_validator = True
if a.convert is not None:
has_convert = True
attr_name = a.name attr_name = a.name
arg_name = a.name.lstrip("_") arg_name = a.name.lstrip("_")
if a.default is not NOTHING and not isinstance(a.default, Factory): if a.default is not NOTHING and not isinstance(a.default, Factory):
@ -437,6 +463,8 @@ else:
if has_validator: if has_validator:
lines.append("validate(self)") lines.append("validate(self)")
if has_convert:
lines.append("_convert(self)")
return """\ return """\
def __init__(self, {args}): def __init__(self, {args}):
@ -456,17 +484,22 @@ class Attribute(object):
Plus *all* arguments of :func:`attr.ib`. Plus *all* arguments of :func:`attr.ib`.
""" """
_attributes = [ _attributes = [
"name", "default", "validator", "repr", "cmp", "hash", "init", "name", "default", "validator", "repr", "cmp", "hash", "init", "convert"
] # we can't use ``attrs`` so we have to cheat a little. ] # we can't use ``attrs`` so we have to cheat a little.
_optional = {"convert": None}
def __init__(self, **kw): def __init__(self, **kw):
if len(kw) > len(Attribute._attributes): if len(kw) > len(Attribute._attributes):
raise TypeError("Too many arguments.") raise TypeError("Too many arguments.")
try: for a in Attribute._attributes:
for a in Attribute._attributes: try:
setattr(self, a, kw[a]) setattr(self, a, kw[a])
except KeyError: except KeyError:
raise TypeError("Missing argument '{arg}'.".format(arg=a)) if a in Attribute._optional:
setattr(self, a, self._optional[a])
else:
raise TypeError("Missing argument '{arg}'.".format(arg=a))
@classmethod @classmethod
def from_counting_attr(cl, name, ca): def from_counting_attr(cl, name, ca):
@ -498,7 +531,7 @@ class _CountingAttr(object):
] ]
counter = 0 counter = 0
def __init__(self, default, validator, repr, cmp, hash, init): def __init__(self, default, validator, repr, cmp, hash, init, convert):
_CountingAttr.counter += 1 _CountingAttr.counter += 1
self.counter = _CountingAttr.counter self.counter = _CountingAttr.counter
self.default = default self.default = default
@ -507,6 +540,7 @@ class _CountingAttr(object):
self.cmp = cmp self.cmp = cmp
self.hash = hash self.hash = hash
self.init = init self.init = init
self.convert = convert
_CountingAttr = _add_cmp(_add_repr(_CountingAttr)) _CountingAttr = _add_cmp(_add_repr(_CountingAttr))

View File

@ -96,7 +96,7 @@ class TestTransformAttrs(object):
"No mandatory attributes allowed after an attribute with a " "No mandatory attributes allowed after an attribute with a "
"default value or factory. Attribute in question: Attribute" "default value or factory. Attribute in question: Attribute"
"(name='y', default=NOTHING, validator=None, repr=True, " "(name='y', default=NOTHING, validator=None, repr=True, "
"cmp=True, hash=True, init=True)", "cmp=True, hash=True, init=True, convert=None)",
) == e.value.args ) == e.value.args
def test_these(self): def test_these(self):
@ -297,7 +297,7 @@ class TestAttribute(object):
with pytest.raises(TypeError) as e: with pytest.raises(TypeError) as e:
Attribute(name="foo", default=NOTHING, Attribute(name="foo", default=NOTHING,
factory=NOTHING, validator=None, factory=NOTHING, validator=None,
repr=True, cmp=True, hash=True, init=True) repr=True, cmp=True, hash=True, init=True, convert=None)
assert ("Too many arguments.",) == e.value.args assert ("Too many arguments.",) == e.value.args
@ -392,6 +392,32 @@ class TestFields(object):
in zip(fields(C), C.__attrs_attrs__)) in zip(fields(C), C.__attrs_attrs__))
class TestConvert(object):
"""
Tests for attribute conversion.
"""
def test_convert(self):
"""
Return value of convert is used as the attribute's value.
"""
C = make_class("C", {"x": attr(convert=lambda v: v + 1),
"y": attr()})
c = C(1, 2)
assert c.x == 2
assert c.y == 2
def test_convert_after_validate(self):
"""
Validation happens before conversion.
"""
def validator(inst, attr, val):
raise RuntimeError("foo")
C = make_class("C", {"x": attr(validator=validator, convert=lambda v: 1 / 0),
"y": attr()})
with pytest.raises(RuntimeError):
C(1, 2)
class TestValidate(object): class TestValidate(object):
""" """
Tests for `validate`. Tests for `validate`.