fix recursive call for `apply_to_collection(include_none=False)` (#8719)
This commit is contained in:
parent
ed13040729
commit
963c267646
|
@ -114,6 +114,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
[#8627](https://github.com/PyTorchLightning/pytorch-lightning/pull/8627))
|
||||
|
||||
|
||||
- Fixed recursive call for `apply_to_collection(include_none=False)` ([#8719](https://github.com/PyTorchLightning/pytorch-lightning/pull/8719))
|
||||
|
||||
|
||||
|
||||
## [1.4.0] - 2021-07-27
|
||||
|
||||
### Added
|
||||
|
|
|
@ -101,7 +101,9 @@ def apply_to_collection(
|
|||
if isinstance(data, Mapping):
|
||||
out = []
|
||||
for k, v in data.items():
|
||||
v = apply_to_collection(v, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs)
|
||||
v = apply_to_collection(
|
||||
v, dtype, function, *args, wrong_dtype=wrong_dtype, include_none=include_none, **kwargs
|
||||
)
|
||||
if include_none or v is not None:
|
||||
out.append((k, v))
|
||||
return elem_type(OrderedDict(out))
|
||||
|
@ -111,7 +113,9 @@ def apply_to_collection(
|
|||
if is_namedtuple or is_sequence:
|
||||
out = []
|
||||
for d in data:
|
||||
v = apply_to_collection(d, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs)
|
||||
v = apply_to_collection(
|
||||
d, dtype, function, *args, wrong_dtype=wrong_dtype, include_none=include_none, **kwargs
|
||||
)
|
||||
if include_none or v is not None:
|
||||
out.append(v)
|
||||
return elem_type(*out) if is_namedtuple else elem_type(out)
|
||||
|
@ -119,7 +123,15 @@ def apply_to_collection(
|
|||
if _is_dataclass_instance(data):
|
||||
out = {}
|
||||
for field in data.__dataclass_fields__:
|
||||
v = apply_to_collection(getattr(data, field), dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs)
|
||||
v = apply_to_collection(
|
||||
getattr(data, field),
|
||||
dtype,
|
||||
function,
|
||||
*args,
|
||||
wrong_dtype=wrong_dtype,
|
||||
include_none=include_none,
|
||||
**kwargs
|
||||
)
|
||||
if include_none or v is not None:
|
||||
out[field] = v
|
||||
return elem_type(**out)
|
||||
|
|
|
@ -151,17 +151,17 @@ def test_recursive_application_to_collection():
|
|||
|
||||
|
||||
def test_apply_to_collection_include_none():
|
||||
to_reduce = [1, 2, 3.4, 5.6, 7]
|
||||
to_reduce = [1, 2, 3.4, 5.6, 7, (8, 9.1, {10: 10})]
|
||||
|
||||
def fn(x):
|
||||
if isinstance(x, float):
|
||||
return x
|
||||
|
||||
reduced = apply_to_collection(to_reduce, (int, float), fn)
|
||||
assert reduced == [None, None, 3.4, 5.6, None]
|
||||
assert reduced == [None, None, 3.4, 5.6, None, (None, 9.1, {10: None})]
|
||||
|
||||
reduced = apply_to_collection(to_reduce, (int, float), fn, include_none=False)
|
||||
assert reduced == [3.4, 5.6]
|
||||
assert reduced == [3.4, 5.6, (9.1, {})]
|
||||
|
||||
|
||||
def test_apply_to_collections():
|
||||
|
|
Loading…
Reference in New Issue