Add dataclass support to `apply_to_collection` (#7935)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
This commit is contained in:
Yuanzheng Wang 2021-06-12 19:42:49 +08:00 committed by GitHub
parent cdd01f32da
commit 59d0c65613
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 62 additions and 0 deletions

View File

@ -9,6 +9,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added
- Add `dataclass` support for `pytorch_lightning.utilities.apply_to_collection` ([#7935](https://github.com/PyTorchLightning/pytorch-lightning/pull/7935))
- Added support to `LightningModule.to_torchscript` for saving to custom filesystems with fsspec ([#7617](https://github.com/PyTorchLightning/pytorch-lightning/pull/7617))

View File

@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
import operator
from abc import ABC
from collections import OrderedDict
@ -60,6 +61,11 @@ def _is_namedtuple(obj: object) -> bool:
return isinstance(obj, tuple) and hasattr(obj, "_asdict") and hasattr(obj, "_fields")
def _is_dataclass_instance(obj):
# https://docs.python.org/3/library/dataclasses.html#module-level-decorators-classes-and-functions
return dataclasses.is_dataclass(obj) and not isinstance(obj, type)
def apply_to_collection(
data: Any,
dtype: Union[type, tuple],
@ -110,6 +116,14 @@ def apply_to_collection(
out.append(v)
return elem_type(*out) if is_namedtuple else elem_type(out)
if _is_dataclass_instance(data):
out = dict()
for field in data.__dataclass_fields__:
v = apply_to_collection(getattr(data, field), dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs)
if include_none or v is not None:
out[field] = v
return elem_type(**out)
# data is neither of dtype, nor a collection
return data

View File

@ -11,8 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
import numbers
from collections import namedtuple, OrderedDict
from typing import List
import numpy as np
import pytest
@ -24,6 +26,17 @@ from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to
def test_recursive_application_to_collection():
ntc = namedtuple('Foo', ['bar'])
@dataclasses.dataclass
class Feature:
input_ids: torch.Tensor
segment_ids: np.ndarray
@dataclasses.dataclass
class ModelExample:
example_ids: List[str]
feature: Feature
label: torch.Tensor
to_reduce = {
'a': torch.tensor([1.]), # Tensor
'b': [torch.tensor([2.])], # list
@ -32,6 +45,12 @@ def test_recursive_application_to_collection():
'e': np.array([10.]), # numpy array
'f': 'this_is_a_dummy_str', # string
'g': 12., # number
'h': Feature(input_ids=torch.tensor([1., 2., 3.]), segment_ids=np.array([4., 5., 6.])), # dataclass
'i': ModelExample(
example_ids=['i-1', 'i-2', 'i-3'],
feature=Feature(input_ids=torch.tensor([1., 2., 3.]), segment_ids=np.array([4., 5., 6.])),
label=torch.tensor([7., 8., 9.])
) # nested dataclass
}
expected_result = {
@ -42,6 +61,12 @@ def test_recursive_application_to_collection():
'e': np.array([20.]),
'f': 'this_is_a_dummy_str',
'g': 24.,
'h': Feature(input_ids=torch.tensor([2., 4., 6.]), segment_ids=np.array([8., 10., 12.])),
'i': ModelExample(
example_ids=['i-1', 'i-2', 'i-3'],
feature=Feature(input_ids=torch.tensor([2., 4., 6.]), segment_ids=np.array([8., 10., 12.])),
label=torch.tensor([14., 16., 18.])
)
}
reduced = apply_to_collection(to_reduce, (torch.Tensor, numbers.Number, np.ndarray), lambda x: x * 2)
@ -78,6 +103,26 @@ def test_recursive_application_to_collection():
assert isinstance(reduced['g'], numbers.Number), 'Reduction of a number should result in a number'
assert reduced['g'] == expected_result['g'], 'Reduction of a number did not yield the desired result'
assert dataclasses.is_dataclass(reduced['h']) and not isinstance(reduced['h'], type), \
'Reduction of a dataclass should result in a dataclass'
assert torch.allclose(reduced['h'].input_ids, expected_result['h'].input_ids), \
'Reduction of a dataclass did not yield the desired result'
assert np.allclose(reduced['h'].segment_ids, expected_result['h'].segment_ids), \
'Reduction of a dataclass did not yield the desired result'
assert dataclasses.is_dataclass(reduced['i']) and not isinstance(reduced['i'], type), \
'Reduction of a dataclass should result in a dataclass'
assert dataclasses.is_dataclass(reduced['i'].feature) and not isinstance(reduced['i'].feature, type), \
'Reduction of a nested dataclass should result in a nested dataclass'
assert reduced['i'].example_ids == expected_result['i'].example_ids, \
'Reduction of a nested dataclass did not yield the desired result'
assert torch.allclose(reduced['i'].label, expected_result['i'].label), \
'Reduction of a nested dataclass did not yield the desired result'
assert torch.allclose(reduced['i'].feature.input_ids, expected_result['i'].feature.input_ids), \
'Reduction of a nested dataclass did not yield the desired result'
assert np.allclose(reduced['i'].feature.segment_ids, expected_result['i'].feature.segment_ids), \
'Reduction of a nested dataclass did not yield the desired result'
# mapping support
reduced = apply_to_collection({'a': 1, 'b': 2}, int, lambda x: str(x))
assert reduced == {'a': '1', 'b': '2'}