Fix issue with no-init dataclass fields in move_to_device (#9963)

Co-authored-by: ronif <ronif@users.noreply.github.com>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
ronif 2021-10-17 10:10:47 +03:00 committed by GitHub
parent e5dfdf34f9
commit 7b4df7bf91
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 20 additions and 12 deletions

View File

@ -544,6 +544,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed use of `LightningCLI` in computer_vision_fine_tuning.py example ([#9934](https://github.com/PyTorchLightning/pytorch-lightning/pull/9934))
- Fixed issue with non-init dataclass fields in `apply_to_collection` ([#9963](https://github.com/PyTorchLightning/pytorch-lightning/issues/9963))
## [1.4.9] - 2021-09-30
- Fixed `lr_find` to generate same results on multiple calls ([#9704](https://github.com/PyTorchLightning/pytorch-lightning/pull/9704))

View File

@ -118,18 +118,19 @@ def apply_to_collection(
if _is_dataclass_instance(data):
out_dict = {}
for field in data.__dataclass_fields__:
v = apply_to_collection(
getattr(data, field),
dtype,
function,
*args,
wrong_dtype=wrong_dtype,
include_none=include_none,
**kwargs,
)
if include_none or v is not None:
out_dict[field] = v
for field in dataclasses.fields(data):
if field.init:
v = apply_to_collection(
getattr(data, field.name),
dtype,
function,
*args,
wrong_dtype=wrong_dtype,
include_none=include_none,
**kwargs,
)
if include_none or v is not None:
out_dict[field.name] = v
return elem_type(**out_dict)
# data is neither of dtype, nor a collection

View File

@ -36,6 +36,10 @@ def test_recursive_application_to_collection():
example_ids: List[str]
feature: Feature
label: torch.Tensor
some_constant: int = dataclasses.field(init=False)
def __post_init__(self):
self.some_constant = 7
to_reduce = {
"a": torch.tensor([1.0]), # Tensor