Relax restrictions on wrapping a custom batch sampler in predict (#19678)

This commit is contained in:
awaelchli 2024-03-27 23:45:50 +01:00 committed by GitHub
parent 94167d6e65
commit 438f29f07a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 20 additions and 4 deletions

View File

@ -22,7 +22,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- The `prepare_data()` hook in `LightningModule` and `LightningDataModule` is now subject to a barrier without timeout to avoid long-running tasks to be interrupted ([#19448](https://github.com/Lightning-AI/lightning/pull/19448))
-
- Relaxed the requirement for custom batch samplers to expose `drop_last` for prediction ([#19678](https://github.com/Lightning-AI/pytorch-lightning/pull/19678))
-

View File

@ -28,6 +28,7 @@ from lightning.fabric.utilities.data import (
has_iterable_dataset,
sized_len,
)
from lightning.fabric.utilities.warnings import PossibleUserWarning
from lightning.pytorch.overrides.distributed import _IndexBatchSamplerWrapper
from lightning.pytorch.trainer.states import RunningStage
from lightning.pytorch.utilities.exceptions import MisconfigurationException
@ -301,6 +302,14 @@ def _dataloader_init_kwargs_resolve_sampler(
" or set `Trainer(use_distributed_sampler=False)`. If you choose the latter, you will be"
" responsible for handling the distributed sampling within your batch sampler."
) from ex
elif is_predicting:
rank_zero_warn(
f"You are using a custom batch sampler `{batch_sampler_cls.__qualname__}` for prediction."
" Lightning would normally set `drop_last=False` to ensure all samples are returned, but for"
" custom samplers it can't guarantee this. Make sure your sampler is configured correctly to return"
" all indices.",
category=PossibleUserWarning,
)
else:
# The sampler is not a PyTorch `BatchSampler`, we don't know how to inject a custom sampler or
# how to adjust the `drop_last` value

View File

@ -5,6 +5,7 @@ import numpy as np
import pytest
import torch
from lightning.fabric.utilities.data import _replace_dunder_methods
from lightning.fabric.utilities.warnings import PossibleUserWarning
from lightning.pytorch import Trainer
from lightning.pytorch.demos.boring_classes import RandomDataset, RandomIterableDataset
from lightning.pytorch.overrides.distributed import _IndexBatchSamplerWrapper
@ -230,7 +231,8 @@ def test_custom_torch_batch_sampler_doppelganger(predicting):
assert batch_sampler.drop_last == (not predicting)
def test_custom_batch_sampler():
@pytest.mark.parametrize("predicting", [True, False])
def test_custom_batch_sampler(predicting):
"""Test that a custom (non-PyTorch) batch sampler requires the user to set `use_distributed_sampler=False`."""
class CustomBatchSampler: # not inheriting from `BatchSampler`
@ -240,8 +242,13 @@ def test_custom_batch_sampler():
batch_sampler = CustomBatchSampler()
dataloader = DataLoader(range(100), batch_sampler=batch_sampler)
if predicting:
with pytest.warns(PossibleUserWarning, match=r"Make sure your sampler is configured correctly to return all"):
_ = _update_dataloader(dataloader, sampler=Mock(), mode=RunningStage.PREDICTING)
else:
with pytest.raises(TypeError, match=r"can't inject a \(distributed\) sampler into your batch sampler"):
_ = _update_dataloader(dataloader, sampler=Mock())
_ = _update_dataloader(dataloader, sampler=Mock(), mode=None)
def test_custom_batch_sampler_no_drop_last():