From 59d0c65613d7543edee51cd05f45e21b507ed9c1 Mon Sep 17 00:00:00 2001 From: Yuanzheng Wang <31960962+dalek-who@users.noreply.github.com> Date: Sat, 12 Jun 2021 19:42:49 +0800 Subject: [PATCH] Add dataclass support to `apply_to_collection` (#7935) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Carlos MocholĂ­ Co-authored-by: Jirka Borovec --- CHANGELOG.md | 3 ++ pytorch_lightning/utilities/apply_func.py | 14 +++++++ tests/utilities/test_apply_func.py | 45 +++++++++++++++++++++++ 3 files changed, 62 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7a9ec6564c..e9bb747b70 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Add `dataclass` support for `pytorch_lightning.utilities.apply_to_collection` ([#7935](https://github.com/PyTorchLightning/pytorch-lightning/pull/7935)) + + - Added support to `LightningModule.to_torchscript` for saving to custom filesystems with fsspec ([#7617](https://github.com/PyTorchLightning/pytorch-lightning/pull/7617)) diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index 2f46ff0569..42a694ebad 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import dataclasses import operator from abc import ABC from collections import OrderedDict @@ -60,6 +61,11 @@ def _is_namedtuple(obj: object) -> bool: return isinstance(obj, tuple) and hasattr(obj, "_asdict") and hasattr(obj, "_fields") +def _is_dataclass_instance(obj): + # https://docs.python.org/3/library/dataclasses.html#module-level-decorators-classes-and-functions + return dataclasses.is_dataclass(obj) and not isinstance(obj, type) + + def apply_to_collection( data: Any, dtype: Union[type, tuple], @@ -110,6 +116,14 @@ def apply_to_collection( out.append(v) return elem_type(*out) if is_namedtuple else elem_type(out) + if _is_dataclass_instance(data): + out = dict() + for field in data.__dataclass_fields__: + v = apply_to_collection(getattr(data, field), dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs) + if include_none or v is not None: + out[field] = v + return elem_type(**out) + # data is neither of dtype, nor a collection return data diff --git a/tests/utilities/test_apply_func.py b/tests/utilities/test_apply_func.py index 2457cf998c..8959a3283d 100644 --- a/tests/utilities/test_apply_func.py +++ b/tests/utilities/test_apply_func.py @@ -11,8 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import dataclasses import numbers from collections import namedtuple, OrderedDict +from typing import List import numpy as np import pytest @@ -24,6 +26,17 @@ from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to def test_recursive_application_to_collection(): ntc = namedtuple('Foo', ['bar']) + @dataclasses.dataclass + class Feature: + input_ids: torch.Tensor + segment_ids: np.ndarray + + @dataclasses.dataclass + class ModelExample: + example_ids: List[str] + feature: Feature + label: torch.Tensor + to_reduce = { 'a': torch.tensor([1.]), # Tensor 'b': [torch.tensor([2.])], # list @@ -32,6 +45,12 @@ def test_recursive_application_to_collection(): 'e': np.array([10.]), # numpy array 'f': 'this_is_a_dummy_str', # string 'g': 12., # number + 'h': Feature(input_ids=torch.tensor([1., 2., 3.]), segment_ids=np.array([4., 5., 6.])), # dataclass + 'i': ModelExample( + example_ids=['i-1', 'i-2', 'i-3'], + feature=Feature(input_ids=torch.tensor([1., 2., 3.]), segment_ids=np.array([4., 5., 6.])), + label=torch.tensor([7., 8., 9.]) + ) # nested dataclass } expected_result = { @@ -42,6 +61,12 @@ def test_recursive_application_to_collection(): 'e': np.array([20.]), 'f': 'this_is_a_dummy_str', 'g': 24., + 'h': Feature(input_ids=torch.tensor([2., 4., 6.]), segment_ids=np.array([8., 10., 12.])), + 'i': ModelExample( + example_ids=['i-1', 'i-2', 'i-3'], + feature=Feature(input_ids=torch.tensor([2., 4., 6.]), segment_ids=np.array([8., 10., 12.])), + label=torch.tensor([14., 16., 18.]) + ) } reduced = apply_to_collection(to_reduce, (torch.Tensor, numbers.Number, np.ndarray), lambda x: x * 2) @@ -78,6 +103,26 @@ def test_recursive_application_to_collection(): assert isinstance(reduced['g'], numbers.Number), 'Reduction of a number should result in a number' assert reduced['g'] == expected_result['g'], 'Reduction of a number did not yield the desired result' + assert dataclasses.is_dataclass(reduced['h']) and not isinstance(reduced['h'], type), \ + 'Reduction of a dataclass should result in a dataclass' + assert torch.allclose(reduced['h'].input_ids, expected_result['h'].input_ids), \ + 'Reduction of a dataclass did not yield the desired result' + assert np.allclose(reduced['h'].segment_ids, expected_result['h'].segment_ids), \ + 'Reduction of a dataclass did not yield the desired result' + + assert dataclasses.is_dataclass(reduced['i']) and not isinstance(reduced['i'], type), \ + 'Reduction of a dataclass should result in a dataclass' + assert dataclasses.is_dataclass(reduced['i'].feature) and not isinstance(reduced['i'].feature, type), \ + 'Reduction of a nested dataclass should result in a nested dataclass' + assert reduced['i'].example_ids == expected_result['i'].example_ids, \ + 'Reduction of a nested dataclass did not yield the desired result' + assert torch.allclose(reduced['i'].label, expected_result['i'].label), \ + 'Reduction of a nested dataclass did not yield the desired result' + assert torch.allclose(reduced['i'].feature.input_ids, expected_result['i'].feature.input_ids), \ + 'Reduction of a nested dataclass did not yield the desired result' + assert np.allclose(reduced['i'].feature.segment_ids, expected_result['i'].feature.segment_ids), \ + 'Reduction of a nested dataclass did not yield the desired result' + # mapping support reduced = apply_to_collection({'a': 1, 'b': 2}, int, lambda x: str(x)) assert reduced == {'a': '1', 'b': '2'}