From 438f29f07a803309b2c83801cda1625d3bc8eb64 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 27 Mar 2024 23:45:50 +0100 Subject: [PATCH] Relax restrictions on wrapping a custom batch sampler in predict (#19678) --- src/lightning/pytorch/CHANGELOG.md | 2 +- src/lightning/pytorch/utilities/data.py | 9 +++++++++ tests/tests_pytorch/utilities/test_data.py | 13 ++++++++++--- 3 files changed, 20 insertions(+), 4 deletions(-) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index f400e21536..be7b66a27c 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -22,7 +22,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - The `prepare_data()` hook in `LightningModule` and `LightningDataModule` is now subject to a barrier without timeout to avoid long-running tasks to be interrupted ([#19448](https://github.com/Lightning-AI/lightning/pull/19448)) -- +- Relaxed the requirement for custom batch samplers to expose `drop_last` for prediction ([#19678](https://github.com/Lightning-AI/pytorch-lightning/pull/19678)) - diff --git a/src/lightning/pytorch/utilities/data.py b/src/lightning/pytorch/utilities/data.py index fb1c0a4370..41c5ea86e5 100644 --- a/src/lightning/pytorch/utilities/data.py +++ b/src/lightning/pytorch/utilities/data.py @@ -28,6 +28,7 @@ from lightning.fabric.utilities.data import ( has_iterable_dataset, sized_len, ) +from lightning.fabric.utilities.warnings import PossibleUserWarning from lightning.pytorch.overrides.distributed import _IndexBatchSamplerWrapper from lightning.pytorch.trainer.states import RunningStage from lightning.pytorch.utilities.exceptions import MisconfigurationException @@ -301,6 +302,14 @@ def _dataloader_init_kwargs_resolve_sampler( " or set `Trainer(use_distributed_sampler=False)`. If you choose the latter, you will be" " responsible for handling the distributed sampling within your batch sampler." ) from ex + elif is_predicting: + rank_zero_warn( + f"You are using a custom batch sampler `{batch_sampler_cls.__qualname__}` for prediction." + " Lightning would normally set `drop_last=False` to ensure all samples are returned, but for" + " custom samplers it can't guarantee this. Make sure your sampler is configured correctly to return" + " all indices.", + category=PossibleUserWarning, + ) else: # The sampler is not a PyTorch `BatchSampler`, we don't know how to inject a custom sampler or # how to adjust the `drop_last` value diff --git a/tests/tests_pytorch/utilities/test_data.py b/tests/tests_pytorch/utilities/test_data.py index 6348dba002..e9c80d95c5 100644 --- a/tests/tests_pytorch/utilities/test_data.py +++ b/tests/tests_pytorch/utilities/test_data.py @@ -5,6 +5,7 @@ import numpy as np import pytest import torch from lightning.fabric.utilities.data import _replace_dunder_methods +from lightning.fabric.utilities.warnings import PossibleUserWarning from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import RandomDataset, RandomIterableDataset from lightning.pytorch.overrides.distributed import _IndexBatchSamplerWrapper @@ -230,7 +231,8 @@ def test_custom_torch_batch_sampler_doppelganger(predicting): assert batch_sampler.drop_last == (not predicting) -def test_custom_batch_sampler(): +@pytest.mark.parametrize("predicting", [True, False]) +def test_custom_batch_sampler(predicting): """Test that a custom (non-PyTorch) batch sampler requires the user to set `use_distributed_sampler=False`.""" class CustomBatchSampler: # not inheriting from `BatchSampler` @@ -240,8 +242,13 @@ def test_custom_batch_sampler(): batch_sampler = CustomBatchSampler() dataloader = DataLoader(range(100), batch_sampler=batch_sampler) - with pytest.raises(TypeError, match=r"can't inject a \(distributed\) sampler into your batch sampler"): - _ = _update_dataloader(dataloader, sampler=Mock()) + + if predicting: + with pytest.warns(PossibleUserWarning, match=r"Make sure your sampler is configured correctly to return all"): + _ = _update_dataloader(dataloader, sampler=Mock(), mode=RunningStage.PREDICTING) + else: + with pytest.raises(TypeError, match=r"can't inject a \(distributed\) sampler into your batch sampler"): + _ = _update_dataloader(dataloader, sampler=Mock(), mode=None) def test_custom_batch_sampler_no_drop_last():