Add dataclass support to _extract_batch_size (#12573)

Co-authored-by: Akihiro Nitta <nitta@akihironitta.com>
This commit is contained in:
twsl 2022-04-15 14:13:33 +02:00 committed by GitHub
parent b8d4b81221
commit ae3226ced9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 22 additions and 0 deletions

View File

@ -31,6 +31,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Support `strategy` argument being case insensitive ([#12528](https://github.com/PyTorchLightning/pytorch-lightning/pull/12528))
- Added dataclass support to `extract_batch_size` ([#12573](https://github.com/PyTorchLightning/pytorch-lightning/pull/12573))
- Changed checkpoints save path in the case of one logger and user-provided weights_save_path from `weights_save_path/name/version/checkpoints` to `weights_save_path/checkpoints` ([#12372](https://github.com/PyTorchLightning/pytorch-lightning/pull/12372))

View File

@ -15,6 +15,7 @@ import functools
import inspect
import os
from contextlib import contextmanager
from dataclasses import fields
from functools import partial
from itertools import chain
from typing import Any, Callable, Dict, Generator, Iterable, Mapping, Optional, Set, Type, Union
@ -25,6 +26,7 @@ from torch.utils.data import BatchSampler, DataLoader, IterableDataset, Sampler,
import pytorch_lightning as pl
from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities.apply_func import _is_dataclass_instance
from pytorch_lightning.utilities.auto_restart import CaptureIterableDataset, CaptureMapDataset, FastForwardSampler
from pytorch_lightning.utilities.enums import _FaultTolerantMode
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@ -49,6 +51,9 @@ def _extract_batch_size(batch: BType) -> Generator[int, None, None]:
for sample in batch:
yield from _extract_batch_size(sample)
elif _is_dataclass_instance(batch):
for field in fields(batch):
yield from _extract_batch_size(getattr(batch, field.name))
else:
yield None

View File

@ -1,5 +1,8 @@
from dataclasses import dataclass
import pytest
import torch
from torch import Tensor
from torch.utils.data.dataloader import DataLoader
from pytorch_lightning import Trainer
@ -36,6 +39,11 @@ def test_extract_batch_size():
with pytest.raises(MisconfigurationException, match="We could not infer the batch_size"):
extract_batch_size(batch)
@dataclass
class CustomDataclass:
a: Tensor
b: Tensor
# Warning not raised
batch = torch.zeros(11, 10, 9, 8)
_check_warning_not_raised(batch, 11)
@ -46,6 +54,9 @@ def test_extract_batch_size():
batch = [torch.zeros(11, 10)]
_check_warning_not_raised(batch, 11)
batch = CustomDataclass(torch.zeros(11, 10), torch.zeros(11, 10))
_check_warning_not_raised(batch, 11)
batch = {"test": [{"test": [torch.zeros(11, 10)]}]}
_check_warning_not_raised(batch, 11)
@ -53,6 +64,9 @@ def test_extract_batch_size():
batch = {"a": [torch.tensor(1), torch.tensor(2)], "b": torch.tensor([1, 2, 3, 4])}
_check_warning_raised(batch, 1)
batch = CustomDataclass(torch.zeros(11, 10), torch.zeros(1))
_check_warning_raised(batch, 11)
batch = {"test": [{"test": [torch.zeros(11, 10), torch.zeros(10, 10)]}]}
_check_warning_raised(batch, 11)