Relax restrictions on wrapping a custom batch sampler in predict (#19678)
This commit is contained in:
parent
94167d6e65
commit
438f29f07a
|
@ -22,7 +22,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
- The `prepare_data()` hook in `LightningModule` and `LightningDataModule` is now subject to a barrier without timeout to avoid long-running tasks to be interrupted ([#19448](https://github.com/Lightning-AI/lightning/pull/19448))
|
||||
|
||||
-
|
||||
- Relaxed the requirement for custom batch samplers to expose `drop_last` for prediction ([#19678](https://github.com/Lightning-AI/pytorch-lightning/pull/19678))
|
||||
|
||||
-
|
||||
|
||||
|
|
|
@ -28,6 +28,7 @@ from lightning.fabric.utilities.data import (
|
|||
has_iterable_dataset,
|
||||
sized_len,
|
||||
)
|
||||
from lightning.fabric.utilities.warnings import PossibleUserWarning
|
||||
from lightning.pytorch.overrides.distributed import _IndexBatchSamplerWrapper
|
||||
from lightning.pytorch.trainer.states import RunningStage
|
||||
from lightning.pytorch.utilities.exceptions import MisconfigurationException
|
||||
|
@ -301,6 +302,14 @@ def _dataloader_init_kwargs_resolve_sampler(
|
|||
" or set `Trainer(use_distributed_sampler=False)`. If you choose the latter, you will be"
|
||||
" responsible for handling the distributed sampling within your batch sampler."
|
||||
) from ex
|
||||
elif is_predicting:
|
||||
rank_zero_warn(
|
||||
f"You are using a custom batch sampler `{batch_sampler_cls.__qualname__}` for prediction."
|
||||
" Lightning would normally set `drop_last=False` to ensure all samples are returned, but for"
|
||||
" custom samplers it can't guarantee this. Make sure your sampler is configured correctly to return"
|
||||
" all indices.",
|
||||
category=PossibleUserWarning,
|
||||
)
|
||||
else:
|
||||
# The sampler is not a PyTorch `BatchSampler`, we don't know how to inject a custom sampler or
|
||||
# how to adjust the `drop_last` value
|
||||
|
|
|
@ -5,6 +5,7 @@ import numpy as np
|
|||
import pytest
|
||||
import torch
|
||||
from lightning.fabric.utilities.data import _replace_dunder_methods
|
||||
from lightning.fabric.utilities.warnings import PossibleUserWarning
|
||||
from lightning.pytorch import Trainer
|
||||
from lightning.pytorch.demos.boring_classes import RandomDataset, RandomIterableDataset
|
||||
from lightning.pytorch.overrides.distributed import _IndexBatchSamplerWrapper
|
||||
|
@ -230,7 +231,8 @@ def test_custom_torch_batch_sampler_doppelganger(predicting):
|
|||
assert batch_sampler.drop_last == (not predicting)
|
||||
|
||||
|
||||
def test_custom_batch_sampler():
|
||||
@pytest.mark.parametrize("predicting", [True, False])
|
||||
def test_custom_batch_sampler(predicting):
|
||||
"""Test that a custom (non-PyTorch) batch sampler requires the user to set `use_distributed_sampler=False`."""
|
||||
|
||||
class CustomBatchSampler: # not inheriting from `BatchSampler`
|
||||
|
@ -240,8 +242,13 @@ def test_custom_batch_sampler():
|
|||
|
||||
batch_sampler = CustomBatchSampler()
|
||||
dataloader = DataLoader(range(100), batch_sampler=batch_sampler)
|
||||
|
||||
if predicting:
|
||||
with pytest.warns(PossibleUserWarning, match=r"Make sure your sampler is configured correctly to return all"):
|
||||
_ = _update_dataloader(dataloader, sampler=Mock(), mode=RunningStage.PREDICTING)
|
||||
else:
|
||||
with pytest.raises(TypeError, match=r"can't inject a \(distributed\) sampler into your batch sampler"):
|
||||
_ = _update_dataloader(dataloader, sampler=Mock())
|
||||
_ = _update_dataloader(dataloader, sampler=Mock(), mode=None)
|
||||
|
||||
|
||||
def test_custom_batch_sampler_no_drop_last():
|
||||
|
|
Loading…
Reference in New Issue