diff --git a/src/attr/__init__.pyi b/src/attr/__init__.pyi index 6e470782..03cc4c82 100644 --- a/src/attr/__init__.pyi +++ b/src/attr/__init__.pyi @@ -3,11 +3,13 @@ import sys from typing import ( Any, Callable, + ClassVar, Dict, Generic, List, Mapping, Optional, + Protocol, Sequence, Tuple, Type, @@ -22,8 +24,8 @@ from . import exceptions as exceptions from . import filters as filters from . import setters as setters from . import validators as validators -from ._version_info import VersionInfo from ._cmp import cmp_using as cmp_using +from ._version_info import VersionInfo __version__: str __version_info__: VersionInfo @@ -57,6 +59,10 @@ _FieldTransformer = Callable[ # _ValidatorType from working when passed in a list or tuple. _ValidatorArgType = Union[_ValidatorType[_T], Sequence[_ValidatorType[_T]]] +# A protocol to be able to statically accept an attrs class. +class AttrsInstance(Protocol): + __attrs_attrs__: ClassVar[Any] + # _make -- NOTHING: object @@ -399,13 +405,9 @@ def define( mutable = define frozen = define # they differ only in their defaults -# TODO: add support for returning NamedTuple from the mypy plugin -class _Fields(Tuple[Attribute[Any], ...]): - def __getattr__(self, name: str) -> Attribute[Any]: ... - -def fields(cls: type) -> _Fields: ... -def fields_dict(cls: type) -> Dict[str, Attribute[Any]]: ... -def validate(inst: Any) -> None: ... +def fields(cls: Type[AttrsInstance]) -> Any: ... +def fields_dict(cls: Type[AttrsInstance]) -> Dict[str, Attribute[Any]]: ... +def validate(inst: AttrsInstance) -> None: ... def resolve_types( cls: _C, globalns: Optional[Dict[str, Any]] = ..., @@ -449,7 +451,7 @@ def make_class( # https://github.com/python/typing/issues/253 # XXX: remember to fix attrs.asdict/astuple too! def asdict( - inst: Any, + inst: AttrsInstance, recurse: bool = ..., filter: Optional[_FilterType[Any]] = ..., dict_factory: Type[Mapping[Any, Any]] = ..., @@ -462,7 +464,7 @@ def asdict( # TODO: add support for returning NamedTuple from the mypy plugin def astuple( - inst: Any, + inst: AttrsInstance, recurse: bool = ..., filter: Optional[_FilterType[Any]] = ..., tuple_factory: Type[Sequence[Any]] = ..., diff --git a/tests/test_mypy.yml b/tests/test_mypy.yml index 61a8df5c..fd09ca7e 100644 --- a/tests/test_mypy.yml +++ b/tests/test_mypy.yml @@ -1,7 +1,7 @@ - case: attr_s_with_type_argument parametrized: - - val: 'a = attr.ib(type=int)' - - val: 'a: int = attr.ib()' + - val: "a = attr.ib(type=int)" + - val: "a: int = attr.ib()" main: | import attr @attr.s @@ -12,7 +12,7 @@ C(a=1) C(a="hi") # E: Argument "a" to "C" has incompatible type "str"; expected "int" - case: attr_s_with_type_annotations - main : | + main: | import attr @attr.s class C: @@ -97,7 +97,7 @@ - case: testAttrsUntypedNoUntypedDefs mypy_config: | - disallow_untyped_defs = True + disallow_untyped_defs = True main: | import attr @attr.s @@ -316,7 +316,6 @@ class Confused: ... - - case: testAttrsInheritance main: | import attr @@ -469,14 +468,13 @@ return self.x # E: Incompatible return value type (got "List[T]", expected "T") reveal_type(A) # N: Revealed type is "def [T] (x: builtins.list[T`1], y: T`1) -> main.A[T`1]" a = A([1], 2) - reveal_type(a) # N: Revealed type is "main.A[builtins.int*]" - reveal_type(a.x) # N: Revealed type is "builtins.list[builtins.int*]" - reveal_type(a.y) # N: Revealed type is "builtins.int*" + reveal_type(a) # N: Revealed type is "main.A[builtins.int]" + reveal_type(a.x) # N: Revealed type is "builtins.list[builtins.int]" + reveal_type(a.y) # N: Revealed type is "builtins.int" A(['str'], 7) # E: Cannot infer type argument 1 of "A" A([1], '2') # E: Cannot infer type argument 1 of "A" - - case: testAttrsUntypedGenericInheritance main: | from typing import Generic, TypeVar @@ -514,12 +512,12 @@ pass sub_int = Sub[int](attr=1) - reveal_type(sub_int) # N: Revealed type is "main.Sub[builtins.int*]" - reveal_type(sub_int.attr) # N: Revealed type is "builtins.int*" + reveal_type(sub_int) # N: Revealed type is "main.Sub[builtins.int]" + reveal_type(sub_int.attr) # N: Revealed type is "builtins.int" sub_str = Sub[str](attr='ok') - reveal_type(sub_str) # N: Revealed type is "main.Sub[builtins.str*]" - reveal_type(sub_str.attr) # N: Revealed type is "builtins.str*" + reveal_type(sub_str) # N: Revealed type is "main.Sub[builtins.str]" + reveal_type(sub_str.attr) # N: Revealed type is "builtins.str" - case: testAttrsGenericInheritance2 main: | @@ -764,8 +762,7 @@ return 'hello' - case: testAttrsUsingBadConverter - mypy_config: - strict_optional = False + mypy_config: strict_optional = False main: | import attr from typing import overload @@ -792,8 +789,7 @@ main:17: note: Revealed type is "def (bad: Any, bad_overloaded: Any) -> main.A" - case: testAttrsUsingBadConverterReprocess - mypy_config: - strict_optional = False + mypy_config: strict_optional = False main: | import attr from typing import overload @@ -1169,7 +1165,6 @@ c = attr.ib(15) D(b=17) - - case: testAttrsKwOnlySubclass main: | import attr @@ -1211,8 +1206,7 @@ reveal_type(B) # N: Revealed type is "def (x: main.C) -> main.B" - case: testDisallowUntypedWorksForwardBad - mypy_config: - disallow_untyped_defs = True + mypy_config: disallow_untyped_defs = True main: | import attr @@ -1357,3 +1351,24 @@ foo = x reveal_type(B) # N: Revealed type is "def (foo: builtins.int) -> main.B" + +- case: testFields + main: | + from attrs import define, fields + + @define + class A: + a: int + b: str + + reveal_type(fields(A)) # N: Revealed type is "Any" + +- case: testFieldsError + main: | + from attrs import fields + + class A: + a: int + b: str + + fields(A) # E: Argument 1 to "fields" has incompatible type "Type[A]"; expected "Type[AttrsInstance]"