diff --git a/CHANGELOG.md b/CHANGELOG.md index 287476cfc1..3794530077 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -544,6 +544,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed use of `LightningCLI` in computer_vision_fine_tuning.py example ([#9934](https://github.com/PyTorchLightning/pytorch-lightning/pull/9934)) +- Fixed issue with non-init dataclass fields in `apply_to_collection` ([#9963](https://github.com/PyTorchLightning/pytorch-lightning/issues/9963)) + + ## [1.4.9] - 2021-09-30 - Fixed `lr_find` to generate same results on multiple calls ([#9704](https://github.com/PyTorchLightning/pytorch-lightning/pull/9704)) diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index 2758262653..3bd920c2e3 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -118,18 +118,19 @@ def apply_to_collection( if _is_dataclass_instance(data): out_dict = {} - for field in data.__dataclass_fields__: - v = apply_to_collection( - getattr(data, field), - dtype, - function, - *args, - wrong_dtype=wrong_dtype, - include_none=include_none, - **kwargs, - ) - if include_none or v is not None: - out_dict[field] = v + for field in dataclasses.fields(data): + if field.init: + v = apply_to_collection( + getattr(data, field.name), + dtype, + function, + *args, + wrong_dtype=wrong_dtype, + include_none=include_none, + **kwargs, + ) + if include_none or v is not None: + out_dict[field.name] = v return elem_type(**out_dict) # data is neither of dtype, nor a collection diff --git a/tests/utilities/test_apply_func.py b/tests/utilities/test_apply_func.py index 9862da05bf..2c131f96ec 100644 --- a/tests/utilities/test_apply_func.py +++ b/tests/utilities/test_apply_func.py @@ -36,6 +36,10 @@ def test_recursive_application_to_collection(): example_ids: List[str] feature: Feature label: torch.Tensor + some_constant: int = dataclasses.field(init=False) + + def __post_init__(self): + self.some_constant = 7 to_reduce = { "a": torch.tensor([1.0]), # Tensor