Add dataclass support to _extract_batch_size (#12573)
Co-authored-by: Akihiro Nitta <nitta@akihironitta.com>
This commit is contained in:
parent
b8d4b81221
commit
ae3226ced9
|
@ -31,6 +31,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Support `strategy` argument being case insensitive ([#12528](https://github.com/PyTorchLightning/pytorch-lightning/pull/12528))
|
||||
|
||||
|
||||
- Added dataclass support to `extract_batch_size` ([#12573](https://github.com/PyTorchLightning/pytorch-lightning/pull/12573))
|
||||
|
||||
|
||||
- Changed checkpoints save path in the case of one logger and user-provided weights_save_path from `weights_save_path/name/version/checkpoints` to `weights_save_path/checkpoints` ([#12372](https://github.com/PyTorchLightning/pytorch-lightning/pull/12372))
|
||||
|
||||
|
||||
|
|
|
@ -15,6 +15,7 @@ import functools
|
|||
import inspect
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import fields
|
||||
from functools import partial
|
||||
from itertools import chain
|
||||
from typing import Any, Callable, Dict, Generator, Iterable, Mapping, Optional, Set, Type, Union
|
||||
|
@ -25,6 +26,7 @@ from torch.utils.data import BatchSampler, DataLoader, IterableDataset, Sampler,
|
|||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper
|
||||
from pytorch_lightning.trainer.states import RunningStage
|
||||
from pytorch_lightning.utilities.apply_func import _is_dataclass_instance
|
||||
from pytorch_lightning.utilities.auto_restart import CaptureIterableDataset, CaptureMapDataset, FastForwardSampler
|
||||
from pytorch_lightning.utilities.enums import _FaultTolerantMode
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
@ -49,6 +51,9 @@ def _extract_batch_size(batch: BType) -> Generator[int, None, None]:
|
|||
|
||||
for sample in batch:
|
||||
yield from _extract_batch_size(sample)
|
||||
elif _is_dataclass_instance(batch):
|
||||
for field in fields(batch):
|
||||
yield from _extract_batch_size(getattr(batch, field.name))
|
||||
else:
|
||||
yield None
|
||||
|
||||
|
|
|
@ -1,5 +1,8 @@
|
|||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.utils.data.dataloader import DataLoader
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
|
@ -36,6 +39,11 @@ def test_extract_batch_size():
|
|||
with pytest.raises(MisconfigurationException, match="We could not infer the batch_size"):
|
||||
extract_batch_size(batch)
|
||||
|
||||
@dataclass
|
||||
class CustomDataclass:
|
||||
a: Tensor
|
||||
b: Tensor
|
||||
|
||||
# Warning not raised
|
||||
batch = torch.zeros(11, 10, 9, 8)
|
||||
_check_warning_not_raised(batch, 11)
|
||||
|
@ -46,6 +54,9 @@ def test_extract_batch_size():
|
|||
batch = [torch.zeros(11, 10)]
|
||||
_check_warning_not_raised(batch, 11)
|
||||
|
||||
batch = CustomDataclass(torch.zeros(11, 10), torch.zeros(11, 10))
|
||||
_check_warning_not_raised(batch, 11)
|
||||
|
||||
batch = {"test": [{"test": [torch.zeros(11, 10)]}]}
|
||||
_check_warning_not_raised(batch, 11)
|
||||
|
||||
|
@ -53,6 +64,9 @@ def test_extract_batch_size():
|
|||
batch = {"a": [torch.tensor(1), torch.tensor(2)], "b": torch.tensor([1, 2, 3, 4])}
|
||||
_check_warning_raised(batch, 1)
|
||||
|
||||
batch = CustomDataclass(torch.zeros(11, 10), torch.zeros(1))
|
||||
_check_warning_raised(batch, 11)
|
||||
|
||||
batch = {"test": [{"test": [torch.zeros(11, 10), torch.zeros(10, 10)]}]}
|
||||
_check_warning_raised(batch, 11)
|
||||
|
||||
|
|
Loading…
Reference in New Issue