add a `convert` keyword to attr.ib() that allows specifying a function to convert the passed-in value.
This commit is contained in:
parent
253491908b
commit
365cd89921
|
@ -41,7 +41,8 @@ Sentinel to indicate the lack of a value when ``None`` is ambiguous.
|
||||||
|
|
||||||
|
|
||||||
def attr(default=NOTHING, validator=None,
|
def attr(default=NOTHING, validator=None,
|
||||||
repr=True, cmp=True, hash=True, init=True):
|
repr=True, cmp=True, hash=True, init=True,
|
||||||
|
convert=None):
|
||||||
"""
|
"""
|
||||||
Create a new attribute on a class.
|
Create a new attribute on a class.
|
||||||
|
|
||||||
|
@ -81,6 +82,13 @@ def attr(default=NOTHING, validator=None,
|
||||||
|
|
||||||
:param init: Include this attribute in the generated ``__init__`` method.
|
:param init: Include this attribute in the generated ``__init__`` method.
|
||||||
:type init: bool
|
:type init: bool
|
||||||
|
|
||||||
|
:param convert: :func:`callable` that is called by ``attrs``-generated
|
||||||
|
``__init__`` methods to convert attribute's value to the desired format.
|
||||||
|
It is given the passed-in value, and the returned value will be used as
|
||||||
|
the new value of the attribute.
|
||||||
|
:type convert: callable
|
||||||
|
|
||||||
"""
|
"""
|
||||||
return _CountingAttr(
|
return _CountingAttr(
|
||||||
default=default,
|
default=default,
|
||||||
|
@ -89,6 +97,7 @@ def attr(default=NOTHING, validator=None,
|
||||||
cmp=cmp,
|
cmp=cmp,
|
||||||
hash=hash,
|
hash=hash,
|
||||||
init=init,
|
init=init,
|
||||||
|
convert=convert,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -342,7 +351,8 @@ def _add_init(cl):
|
||||||
attr_dict = dict((a.name, a) for a in attrs)
|
attr_dict = dict((a.name, a) for a in attrs)
|
||||||
exec_(bytecode, {"NOTHING": NOTHING,
|
exec_(bytecode, {"NOTHING": NOTHING,
|
||||||
"attr_dict": attr_dict,
|
"attr_dict": attr_dict,
|
||||||
"validate": validate}, locs)
|
"validate": validate,
|
||||||
|
"_convert": _convert}, locs)
|
||||||
init = locs["__init__"]
|
init = locs["__init__"]
|
||||||
|
|
||||||
# In order of debuggers like PDB being able to step through the code,
|
# In order of debuggers like PDB being able to step through the code,
|
||||||
|
@ -395,6 +405,19 @@ def validate(inst):
|
||||||
a.validator(inst, a, getattr(inst, a.name))
|
a.validator(inst, a, getattr(inst, a.name))
|
||||||
|
|
||||||
|
|
||||||
|
def _convert(inst):
|
||||||
|
"""
|
||||||
|
Convert all attributes on *inst* that have a converter.
|
||||||
|
|
||||||
|
Leaves all exceptions through.
|
||||||
|
|
||||||
|
:param inst: Instance of a class with ``attrs`` attributes.
|
||||||
|
"""
|
||||||
|
for a in fields(inst.__class__):
|
||||||
|
if a.convert is not None:
|
||||||
|
setattr(inst, a.name, a.convert(getattr(inst, a.name)))
|
||||||
|
|
||||||
|
|
||||||
def _attrs_to_script(attrs):
|
def _attrs_to_script(attrs):
|
||||||
"""
|
"""
|
||||||
Return a valid Python script of an initializer for *attrs*.
|
Return a valid Python script of an initializer for *attrs*.
|
||||||
|
@ -402,9 +425,12 @@ def _attrs_to_script(attrs):
|
||||||
lines = []
|
lines = []
|
||||||
args = []
|
args = []
|
||||||
has_validator = False
|
has_validator = False
|
||||||
|
has_convert = False
|
||||||
for a in attrs:
|
for a in attrs:
|
||||||
if a.validator is not None:
|
if a.validator is not None:
|
||||||
has_validator = True
|
has_validator = True
|
||||||
|
if a.convert is not None:
|
||||||
|
has_convert = True
|
||||||
attr_name = a.name
|
attr_name = a.name
|
||||||
arg_name = a.name.lstrip("_")
|
arg_name = a.name.lstrip("_")
|
||||||
if a.default is not NOTHING and not isinstance(a.default, Factory):
|
if a.default is not NOTHING and not isinstance(a.default, Factory):
|
||||||
|
@ -437,6 +463,8 @@ else:
|
||||||
|
|
||||||
if has_validator:
|
if has_validator:
|
||||||
lines.append("validate(self)")
|
lines.append("validate(self)")
|
||||||
|
if has_convert:
|
||||||
|
lines.append("_convert(self)")
|
||||||
|
|
||||||
return """\
|
return """\
|
||||||
def __init__(self, {args}):
|
def __init__(self, {args}):
|
||||||
|
@ -456,17 +484,22 @@ class Attribute(object):
|
||||||
Plus *all* arguments of :func:`attr.ib`.
|
Plus *all* arguments of :func:`attr.ib`.
|
||||||
"""
|
"""
|
||||||
_attributes = [
|
_attributes = [
|
||||||
"name", "default", "validator", "repr", "cmp", "hash", "init",
|
"name", "default", "validator", "repr", "cmp", "hash", "init", "convert"
|
||||||
] # we can't use ``attrs`` so we have to cheat a little.
|
] # we can't use ``attrs`` so we have to cheat a little.
|
||||||
|
|
||||||
|
_optional = {"convert": None}
|
||||||
|
|
||||||
def __init__(self, **kw):
|
def __init__(self, **kw):
|
||||||
if len(kw) > len(Attribute._attributes):
|
if len(kw) > len(Attribute._attributes):
|
||||||
raise TypeError("Too many arguments.")
|
raise TypeError("Too many arguments.")
|
||||||
try:
|
for a in Attribute._attributes:
|
||||||
for a in Attribute._attributes:
|
try:
|
||||||
setattr(self, a, kw[a])
|
setattr(self, a, kw[a])
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise TypeError("Missing argument '{arg}'.".format(arg=a))
|
if a in Attribute._optional:
|
||||||
|
setattr(self, a, self._optional[a])
|
||||||
|
else:
|
||||||
|
raise TypeError("Missing argument '{arg}'.".format(arg=a))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_counting_attr(cl, name, ca):
|
def from_counting_attr(cl, name, ca):
|
||||||
|
@ -498,7 +531,7 @@ class _CountingAttr(object):
|
||||||
]
|
]
|
||||||
counter = 0
|
counter = 0
|
||||||
|
|
||||||
def __init__(self, default, validator, repr, cmp, hash, init):
|
def __init__(self, default, validator, repr, cmp, hash, init, convert):
|
||||||
_CountingAttr.counter += 1
|
_CountingAttr.counter += 1
|
||||||
self.counter = _CountingAttr.counter
|
self.counter = _CountingAttr.counter
|
||||||
self.default = default
|
self.default = default
|
||||||
|
@ -507,6 +540,7 @@ class _CountingAttr(object):
|
||||||
self.cmp = cmp
|
self.cmp = cmp
|
||||||
self.hash = hash
|
self.hash = hash
|
||||||
self.init = init
|
self.init = init
|
||||||
|
self.convert = convert
|
||||||
|
|
||||||
|
|
||||||
_CountingAttr = _add_cmp(_add_repr(_CountingAttr))
|
_CountingAttr = _add_cmp(_add_repr(_CountingAttr))
|
||||||
|
|
|
@ -96,7 +96,7 @@ class TestTransformAttrs(object):
|
||||||
"No mandatory attributes allowed after an attribute with a "
|
"No mandatory attributes allowed after an attribute with a "
|
||||||
"default value or factory. Attribute in question: Attribute"
|
"default value or factory. Attribute in question: Attribute"
|
||||||
"(name='y', default=NOTHING, validator=None, repr=True, "
|
"(name='y', default=NOTHING, validator=None, repr=True, "
|
||||||
"cmp=True, hash=True, init=True)",
|
"cmp=True, hash=True, init=True, convert=None)",
|
||||||
) == e.value.args
|
) == e.value.args
|
||||||
|
|
||||||
def test_these(self):
|
def test_these(self):
|
||||||
|
@ -297,7 +297,7 @@ class TestAttribute(object):
|
||||||
with pytest.raises(TypeError) as e:
|
with pytest.raises(TypeError) as e:
|
||||||
Attribute(name="foo", default=NOTHING,
|
Attribute(name="foo", default=NOTHING,
|
||||||
factory=NOTHING, validator=None,
|
factory=NOTHING, validator=None,
|
||||||
repr=True, cmp=True, hash=True, init=True)
|
repr=True, cmp=True, hash=True, init=True, convert=None)
|
||||||
assert ("Too many arguments.",) == e.value.args
|
assert ("Too many arguments.",) == e.value.args
|
||||||
|
|
||||||
|
|
||||||
|
@ -392,6 +392,32 @@ class TestFields(object):
|
||||||
in zip(fields(C), C.__attrs_attrs__))
|
in zip(fields(C), C.__attrs_attrs__))
|
||||||
|
|
||||||
|
|
||||||
|
class TestConvert(object):
|
||||||
|
"""
|
||||||
|
Tests for attribute conversion.
|
||||||
|
"""
|
||||||
|
def test_convert(self):
|
||||||
|
"""
|
||||||
|
Return value of convert is used as the attribute's value.
|
||||||
|
"""
|
||||||
|
C = make_class("C", {"x": attr(convert=lambda v: v + 1),
|
||||||
|
"y": attr()})
|
||||||
|
c = C(1, 2)
|
||||||
|
assert c.x == 2
|
||||||
|
assert c.y == 2
|
||||||
|
|
||||||
|
def test_convert_after_validate(self):
|
||||||
|
"""
|
||||||
|
Validation happens before conversion.
|
||||||
|
"""
|
||||||
|
def validator(inst, attr, val):
|
||||||
|
raise RuntimeError("foo")
|
||||||
|
C = make_class("C", {"x": attr(validator=validator, convert=lambda v: 1 / 0),
|
||||||
|
"y": attr()})
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
C(1, 2)
|
||||||
|
|
||||||
|
|
||||||
class TestValidate(object):
|
class TestValidate(object):
|
||||||
"""
|
"""
|
||||||
Tests for `validate`.
|
Tests for `validate`.
|
||||||
|
|
Loading…
Reference in New Issue