fix recursive call for `apply_to_collection(include_none=False)` (#8719)

This commit is contained in:
Adrian Wälchli 2021-08-04 20:31:35 +02:00 committed by GitHub
parent ed13040729
commit 963c267646
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 22 additions and 6 deletions

View File

@ -114,6 +114,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
[#8627](https://github.com/PyTorchLightning/pytorch-lightning/pull/8627))
- Fixed recursive call for `apply_to_collection(include_none=False)` ([#8719](https://github.com/PyTorchLightning/pytorch-lightning/pull/8719))
## [1.4.0] - 2021-07-27
### Added

View File

@ -101,7 +101,9 @@ def apply_to_collection(
if isinstance(data, Mapping):
out = []
for k, v in data.items():
v = apply_to_collection(v, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs)
v = apply_to_collection(
v, dtype, function, *args, wrong_dtype=wrong_dtype, include_none=include_none, **kwargs
)
if include_none or v is not None:
out.append((k, v))
return elem_type(OrderedDict(out))
@ -111,7 +113,9 @@ def apply_to_collection(
if is_namedtuple or is_sequence:
out = []
for d in data:
v = apply_to_collection(d, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs)
v = apply_to_collection(
d, dtype, function, *args, wrong_dtype=wrong_dtype, include_none=include_none, **kwargs
)
if include_none or v is not None:
out.append(v)
return elem_type(*out) if is_namedtuple else elem_type(out)
@ -119,7 +123,15 @@ def apply_to_collection(
if _is_dataclass_instance(data):
out = {}
for field in data.__dataclass_fields__:
v = apply_to_collection(getattr(data, field), dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs)
v = apply_to_collection(
getattr(data, field),
dtype,
function,
*args,
wrong_dtype=wrong_dtype,
include_none=include_none,
**kwargs
)
if include_none or v is not None:
out[field] = v
return elem_type(**out)

View File

@ -151,17 +151,17 @@ def test_recursive_application_to_collection():
def test_apply_to_collection_include_none():
to_reduce = [1, 2, 3.4, 5.6, 7]
to_reduce = [1, 2, 3.4, 5.6, 7, (8, 9.1, {10: 10})]
def fn(x):
if isinstance(x, float):
return x
reduced = apply_to_collection(to_reduce, (int, float), fn)
assert reduced == [None, None, 3.4, 5.6, None]
assert reduced == [None, None, 3.4, 5.6, None, (None, 9.1, {10: None})]
reduced = apply_to_collection(to_reduce, (int, float), fn, include_none=False)
assert reduced == [3.4, 5.6]
assert reduced == [3.4, 5.6, (9.1, {})]
def test_apply_to_collections():