Add dataclass support to `apply_to_collection` (#7935)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
This commit is contained in:
parent
cdd01f32da
commit
59d0c65613
|
@ -9,6 +9,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
### Added
|
||||
|
||||
- Add `dataclass` support for `pytorch_lightning.utilities.apply_to_collection` ([#7935](https://github.com/PyTorchLightning/pytorch-lightning/pull/7935))
|
||||
|
||||
|
||||
- Added support to `LightningModule.to_torchscript` for saving to custom filesystems with fsspec ([#7617](https://github.com/PyTorchLightning/pytorch-lightning/pull/7617))
|
||||
|
||||
|
||||
|
|
|
@ -11,6 +11,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import dataclasses
|
||||
import operator
|
||||
from abc import ABC
|
||||
from collections import OrderedDict
|
||||
|
@ -60,6 +61,11 @@ def _is_namedtuple(obj: object) -> bool:
|
|||
return isinstance(obj, tuple) and hasattr(obj, "_asdict") and hasattr(obj, "_fields")
|
||||
|
||||
|
||||
def _is_dataclass_instance(obj):
|
||||
# https://docs.python.org/3/library/dataclasses.html#module-level-decorators-classes-and-functions
|
||||
return dataclasses.is_dataclass(obj) and not isinstance(obj, type)
|
||||
|
||||
|
||||
def apply_to_collection(
|
||||
data: Any,
|
||||
dtype: Union[type, tuple],
|
||||
|
@ -110,6 +116,14 @@ def apply_to_collection(
|
|||
out.append(v)
|
||||
return elem_type(*out) if is_namedtuple else elem_type(out)
|
||||
|
||||
if _is_dataclass_instance(data):
|
||||
out = dict()
|
||||
for field in data.__dataclass_fields__:
|
||||
v = apply_to_collection(getattr(data, field), dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs)
|
||||
if include_none or v is not None:
|
||||
out[field] = v
|
||||
return elem_type(**out)
|
||||
|
||||
# data is neither of dtype, nor a collection
|
||||
return data
|
||||
|
||||
|
|
|
@ -11,8 +11,10 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import dataclasses
|
||||
import numbers
|
||||
from collections import namedtuple, OrderedDict
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
@ -24,6 +26,17 @@ from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to
|
|||
def test_recursive_application_to_collection():
|
||||
ntc = namedtuple('Foo', ['bar'])
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Feature:
|
||||
input_ids: torch.Tensor
|
||||
segment_ids: np.ndarray
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ModelExample:
|
||||
example_ids: List[str]
|
||||
feature: Feature
|
||||
label: torch.Tensor
|
||||
|
||||
to_reduce = {
|
||||
'a': torch.tensor([1.]), # Tensor
|
||||
'b': [torch.tensor([2.])], # list
|
||||
|
@ -32,6 +45,12 @@ def test_recursive_application_to_collection():
|
|||
'e': np.array([10.]), # numpy array
|
||||
'f': 'this_is_a_dummy_str', # string
|
||||
'g': 12., # number
|
||||
'h': Feature(input_ids=torch.tensor([1., 2., 3.]), segment_ids=np.array([4., 5., 6.])), # dataclass
|
||||
'i': ModelExample(
|
||||
example_ids=['i-1', 'i-2', 'i-3'],
|
||||
feature=Feature(input_ids=torch.tensor([1., 2., 3.]), segment_ids=np.array([4., 5., 6.])),
|
||||
label=torch.tensor([7., 8., 9.])
|
||||
) # nested dataclass
|
||||
}
|
||||
|
||||
expected_result = {
|
||||
|
@ -42,6 +61,12 @@ def test_recursive_application_to_collection():
|
|||
'e': np.array([20.]),
|
||||
'f': 'this_is_a_dummy_str',
|
||||
'g': 24.,
|
||||
'h': Feature(input_ids=torch.tensor([2., 4., 6.]), segment_ids=np.array([8., 10., 12.])),
|
||||
'i': ModelExample(
|
||||
example_ids=['i-1', 'i-2', 'i-3'],
|
||||
feature=Feature(input_ids=torch.tensor([2., 4., 6.]), segment_ids=np.array([8., 10., 12.])),
|
||||
label=torch.tensor([14., 16., 18.])
|
||||
)
|
||||
}
|
||||
|
||||
reduced = apply_to_collection(to_reduce, (torch.Tensor, numbers.Number, np.ndarray), lambda x: x * 2)
|
||||
|
@ -78,6 +103,26 @@ def test_recursive_application_to_collection():
|
|||
assert isinstance(reduced['g'], numbers.Number), 'Reduction of a number should result in a number'
|
||||
assert reduced['g'] == expected_result['g'], 'Reduction of a number did not yield the desired result'
|
||||
|
||||
assert dataclasses.is_dataclass(reduced['h']) and not isinstance(reduced['h'], type), \
|
||||
'Reduction of a dataclass should result in a dataclass'
|
||||
assert torch.allclose(reduced['h'].input_ids, expected_result['h'].input_ids), \
|
||||
'Reduction of a dataclass did not yield the desired result'
|
||||
assert np.allclose(reduced['h'].segment_ids, expected_result['h'].segment_ids), \
|
||||
'Reduction of a dataclass did not yield the desired result'
|
||||
|
||||
assert dataclasses.is_dataclass(reduced['i']) and not isinstance(reduced['i'], type), \
|
||||
'Reduction of a dataclass should result in a dataclass'
|
||||
assert dataclasses.is_dataclass(reduced['i'].feature) and not isinstance(reduced['i'].feature, type), \
|
||||
'Reduction of a nested dataclass should result in a nested dataclass'
|
||||
assert reduced['i'].example_ids == expected_result['i'].example_ids, \
|
||||
'Reduction of a nested dataclass did not yield the desired result'
|
||||
assert torch.allclose(reduced['i'].label, expected_result['i'].label), \
|
||||
'Reduction of a nested dataclass did not yield the desired result'
|
||||
assert torch.allclose(reduced['i'].feature.input_ids, expected_result['i'].feature.input_ids), \
|
||||
'Reduction of a nested dataclass did not yield the desired result'
|
||||
assert np.allclose(reduced['i'].feature.segment_ids, expected_result['i'].feature.segment_ids), \
|
||||
'Reduction of a nested dataclass did not yield the desired result'
|
||||
|
||||
# mapping support
|
||||
reduced = apply_to_collection({'a': 1, 'b': 2}, int, lambda x: str(x))
|
||||
assert reduced == {'a': '1', 'b': '2'}
|
||||
|
|
Loading…
Reference in New Issue