Fix issue with no-init dataclass fields in move_to_device (#9963)
Co-authored-by: ronif <ronif@users.noreply.github.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
parent
e5dfdf34f9
commit
7b4df7bf91
|
@ -544,6 +544,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Fixed use of `LightningCLI` in computer_vision_fine_tuning.py example ([#9934](https://github.com/PyTorchLightning/pytorch-lightning/pull/9934))
|
||||
|
||||
|
||||
- Fixed issue with non-init dataclass fields in `apply_to_collection` ([#9963](https://github.com/PyTorchLightning/pytorch-lightning/issues/9963))
|
||||
|
||||
|
||||
## [1.4.9] - 2021-09-30
|
||||
|
||||
- Fixed `lr_find` to generate same results on multiple calls ([#9704](https://github.com/PyTorchLightning/pytorch-lightning/pull/9704))
|
||||
|
|
|
@ -118,18 +118,19 @@ def apply_to_collection(
|
|||
|
||||
if _is_dataclass_instance(data):
|
||||
out_dict = {}
|
||||
for field in data.__dataclass_fields__:
|
||||
v = apply_to_collection(
|
||||
getattr(data, field),
|
||||
dtype,
|
||||
function,
|
||||
*args,
|
||||
wrong_dtype=wrong_dtype,
|
||||
include_none=include_none,
|
||||
**kwargs,
|
||||
)
|
||||
if include_none or v is not None:
|
||||
out_dict[field] = v
|
||||
for field in dataclasses.fields(data):
|
||||
if field.init:
|
||||
v = apply_to_collection(
|
||||
getattr(data, field.name),
|
||||
dtype,
|
||||
function,
|
||||
*args,
|
||||
wrong_dtype=wrong_dtype,
|
||||
include_none=include_none,
|
||||
**kwargs,
|
||||
)
|
||||
if include_none or v is not None:
|
||||
out_dict[field.name] = v
|
||||
return elem_type(**out_dict)
|
||||
|
||||
# data is neither of dtype, nor a collection
|
||||
|
|
|
@ -36,6 +36,10 @@ def test_recursive_application_to_collection():
|
|||
example_ids: List[str]
|
||||
feature: Feature
|
||||
label: torch.Tensor
|
||||
some_constant: int = dataclasses.field(init=False)
|
||||
|
||||
def __post_init__(self):
|
||||
self.some_constant = 7
|
||||
|
||||
to_reduce = {
|
||||
"a": torch.tensor([1.0]), # Tensor
|
||||
|
|
Loading…
Reference in New Issue