From ae3226ced96e2bc7e62f298d532aaf2290e6ef34 Mon Sep 17 00:00:00 2001 From: twsl <45483159+twsl@users.noreply.github.com> Date: Fri, 15 Apr 2022 14:13:33 +0200 Subject: [PATCH] Add dataclass support to _extract_batch_size (#12573) Co-authored-by: Akihiro Nitta --- CHANGELOG.md | 3 +++ pytorch_lightning/utilities/data.py | 5 +++++ tests/utilities/test_data.py | 14 ++++++++++++++ 3 files changed, 22 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 46722e5b6e..d2dbbf5039 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -31,6 +31,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Support `strategy` argument being case insensitive ([#12528](https://github.com/PyTorchLightning/pytorch-lightning/pull/12528)) +- Added dataclass support to `extract_batch_size` ([#12573](https://github.com/PyTorchLightning/pytorch-lightning/pull/12573)) + + - Changed checkpoints save path in the case of one logger and user-provided weights_save_path from `weights_save_path/name/version/checkpoints` to `weights_save_path/checkpoints` ([#12372](https://github.com/PyTorchLightning/pytorch-lightning/pull/12372)) diff --git a/pytorch_lightning/utilities/data.py b/pytorch_lightning/utilities/data.py index 5d54c8e53f..80bb7747c5 100644 --- a/pytorch_lightning/utilities/data.py +++ b/pytorch_lightning/utilities/data.py @@ -15,6 +15,7 @@ import functools import inspect import os from contextlib import contextmanager +from dataclasses import fields from functools import partial from itertools import chain from typing import Any, Callable, Dict, Generator, Iterable, Mapping, Optional, Set, Type, Union @@ -25,6 +26,7 @@ from torch.utils.data import BatchSampler, DataLoader, IterableDataset, Sampler, import pytorch_lightning as pl from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper from pytorch_lightning.trainer.states import RunningStage +from pytorch_lightning.utilities.apply_func import _is_dataclass_instance from pytorch_lightning.utilities.auto_restart import CaptureIterableDataset, CaptureMapDataset, FastForwardSampler from pytorch_lightning.utilities.enums import _FaultTolerantMode from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -49,6 +51,9 @@ def _extract_batch_size(batch: BType) -> Generator[int, None, None]: for sample in batch: yield from _extract_batch_size(sample) + elif _is_dataclass_instance(batch): + for field in fields(batch): + yield from _extract_batch_size(getattr(batch, field.name)) else: yield None diff --git a/tests/utilities/test_data.py b/tests/utilities/test_data.py index f056fe99cd..ced076540b 100644 --- a/tests/utilities/test_data.py +++ b/tests/utilities/test_data.py @@ -1,5 +1,8 @@ +from dataclasses import dataclass + import pytest import torch +from torch import Tensor from torch.utils.data.dataloader import DataLoader from pytorch_lightning import Trainer @@ -36,6 +39,11 @@ def test_extract_batch_size(): with pytest.raises(MisconfigurationException, match="We could not infer the batch_size"): extract_batch_size(batch) + @dataclass + class CustomDataclass: + a: Tensor + b: Tensor + # Warning not raised batch = torch.zeros(11, 10, 9, 8) _check_warning_not_raised(batch, 11) @@ -46,6 +54,9 @@ def test_extract_batch_size(): batch = [torch.zeros(11, 10)] _check_warning_not_raised(batch, 11) + batch = CustomDataclass(torch.zeros(11, 10), torch.zeros(11, 10)) + _check_warning_not_raised(batch, 11) + batch = {"test": [{"test": [torch.zeros(11, 10)]}]} _check_warning_not_raised(batch, 11) @@ -53,6 +64,9 @@ def test_extract_batch_size(): batch = {"a": [torch.tensor(1), torch.tensor(2)], "b": torch.tensor([1, 2, 3, 4])} _check_warning_raised(batch, 1) + batch = CustomDataclass(torch.zeros(11, 10), torch.zeros(1)) + _check_warning_raised(batch, 11) + batch = {"test": [{"test": [torch.zeros(11, 10), torch.zeros(10, 10)]}]} _check_warning_raised(batch, 11)