diff --git a/CHANGELOG.md b/CHANGELOG.md index 5c1fab4259..d7b6ce0dbf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -114,6 +114,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). [#8627](https://github.com/PyTorchLightning/pytorch-lightning/pull/8627)) +- Fixed recursive call for `apply_to_collection(include_none=False)` ([#8719](https://github.com/PyTorchLightning/pytorch-lightning/pull/8719)) + + + ## [1.4.0] - 2021-07-27 ### Added diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index 028a7791b3..b96a0110e5 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -101,7 +101,9 @@ def apply_to_collection( if isinstance(data, Mapping): out = [] for k, v in data.items(): - v = apply_to_collection(v, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs) + v = apply_to_collection( + v, dtype, function, *args, wrong_dtype=wrong_dtype, include_none=include_none, **kwargs + ) if include_none or v is not None: out.append((k, v)) return elem_type(OrderedDict(out)) @@ -111,7 +113,9 @@ def apply_to_collection( if is_namedtuple or is_sequence: out = [] for d in data: - v = apply_to_collection(d, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs) + v = apply_to_collection( + d, dtype, function, *args, wrong_dtype=wrong_dtype, include_none=include_none, **kwargs + ) if include_none or v is not None: out.append(v) return elem_type(*out) if is_namedtuple else elem_type(out) @@ -119,7 +123,15 @@ def apply_to_collection( if _is_dataclass_instance(data): out = {} for field in data.__dataclass_fields__: - v = apply_to_collection(getattr(data, field), dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs) + v = apply_to_collection( + getattr(data, field), + dtype, + function, + *args, + wrong_dtype=wrong_dtype, + include_none=include_none, + **kwargs + ) if include_none or v is not None: out[field] = v return elem_type(**out) diff --git a/tests/utilities/test_apply_func.py b/tests/utilities/test_apply_func.py index 6c0dc9c9aa..9862da05bf 100644 --- a/tests/utilities/test_apply_func.py +++ b/tests/utilities/test_apply_func.py @@ -151,17 +151,17 @@ def test_recursive_application_to_collection(): def test_apply_to_collection_include_none(): - to_reduce = [1, 2, 3.4, 5.6, 7] + to_reduce = [1, 2, 3.4, 5.6, 7, (8, 9.1, {10: 10})] def fn(x): if isinstance(x, float): return x reduced = apply_to_collection(to_reduce, (int, float), fn) - assert reduced == [None, None, 3.4, 5.6, None] + assert reduced == [None, None, 3.4, 5.6, None, (None, 9.1, {10: None})] reduced = apply_to_collection(to_reduce, (int, float), fn, include_none=False) - assert reduced == [3.4, 5.6] + assert reduced == [3.4, 5.6, (9.1, {})] def test_apply_to_collections():