From c3ad7568e114c4cd357274f8f86c563005fb850c Mon Sep 17 00:00:00 2001 From: Ryan Mukherjee Date: Tue, 30 May 2023 00:02:50 -0400 Subject: [PATCH] avoid unnecessary workers with sequential `CombinedLoader ` (#17639) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos Mocholí Co-authored-by: Adrian Wälchli --- src/lightning/pytorch/CHANGELOG.md | 3 ++ .../pytorch/utilities/combined_loader.py | 15 ++++++-- tests/tests_pytorch/loops/test_loops.py | 6 ++-- .../utilities/test_combined_loader.py | 34 +++++++++++++++++++ 4 files changed, 51 insertions(+), 7 deletions(-) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 1e3c905b7e..67badc0d4b 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -108,6 +108,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- `CombinedLoader` only starts DataLoader workers when necessary when operating in sequential mode ([#17639](https://github.com/Lightning-AI/lightning/pull/17639)) + + - Fixed issue where `Model.load_from_checkpoint("checkpoint.ckpt", map_location=map_location)` would always return model on CPU ([#17308](https://github.com/Lightning-AI/lightning/pull/17308)) diff --git a/src/lightning/pytorch/utilities/combined_loader.py b/src/lightning/pytorch/utilities/combined_loader.py index 96299126b4..0e012dbae1 100644 --- a/src/lightning/pytorch/utilities/combined_loader.py +++ b/src/lightning/pytorch/utilities/combined_loader.py @@ -108,7 +108,7 @@ class _Sequential(_ModeIterator[Tuple[Any, int, int]]): self._limits = limits def __next__(self) -> Tuple[Any, int, int]: - n = len(self.iterators) + n = len(self.iterables) if n == 0 or self._iterator_idx >= n: raise StopIteration @@ -120,7 +120,7 @@ class _Sequential(_ModeIterator[Tuple[Any, int, int]]): raise StopIteration try: - out = next(self.iterators[self._iterator_idx]) + out = next(self.iterators[0]) index = self._idx self._idx += 1 # batch, batch_idx, dataloader_idx @@ -131,9 +131,9 @@ class _Sequential(_ModeIterator[Tuple[Any, int, int]]): return self.__next__() def __iter__(self) -> Self: - super().__iter__() self._iterator_idx = 0 self._idx = 0 + self._load_current_iterator() return self def reset(self) -> None: @@ -141,9 +141,18 @@ class _Sequential(_ModeIterator[Tuple[Any, int, int]]): self._iterator_idx = 0 self._idx = 0 + def _load_current_iterator(self) -> None: + # Load a single DataLoader, prevents multiple sets of workers from starting unnecessarily + if self._iterator_idx < len(self.iterables): + self.iterators = [iter(self.iterables[self._iterator_idx])] + else: + # No more iterables to step through, return an empty list + self.iterators = [] + def _use_next_iterator(self) -> None: self._iterator_idx += 1 self._idx = 0 + self._load_current_iterator() class _MaxSize(_ModeIterator[List]): diff --git a/tests/tests_pytorch/loops/test_loops.py b/tests/tests_pytorch/loops/test_loops.py index 15efb392a2..e38d4459e2 100644 --- a/tests/tests_pytorch/loops/test_loops.py +++ b/tests/tests_pytorch/loops/test_loops.py @@ -844,8 +844,7 @@ def test_workers_are_shutdown(tmpdir, should_fail, persistent_workers): # iterable check 0, # epoch ends - 1, - # teardown + 0, 1, ] else: @@ -855,9 +854,8 @@ def test_workers_are_shutdown(tmpdir, should_fail, persistent_workers): # iterable check 0, # epoch ends + 0, 1, 2, - # teardown - 3, ] assert val_dataloader.shutdown_workers_epochs == expected diff --git a/tests/tests_pytorch/utilities/test_combined_loader.py b/tests/tests_pytorch/utilities/test_combined_loader.py index fa6d33120a..7109523b37 100644 --- a/tests/tests_pytorch/utilities/test_combined_loader.py +++ b/tests/tests_pytorch/utilities/test_combined_loader.py @@ -305,6 +305,40 @@ def test_combined_loader_sequence_iterable_dataset(mode, use_multiple_dataloader assert idx == expected - 1 +@pytest.mark.parametrize("mode", ["min_size", "max_size_cycle", "max_size", "sequential"]) +def test_combined_loader_simultaneous_workers(mode): + """Test `CombinedLoader` to check how it initializes dataloader workers.""" + + class TestDataLoader(DataLoader): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.workers_active = False + + def _get_iterator(self): + self.workers_active = True + return super()._get_iterator() + + def _shutdown_workers(self): + self.workers_active = False + super()._shutdown_workers() + + loaders = [ + TestDataLoader(range(10), batch_size=2, num_workers=0), + TestDataLoader(range(20), batch_size=2, num_workers=0), + ] + combined_loader = CombinedLoader(loaders, mode) + # Start the dataloader + _ = iter(combined_loader) + + workers_active = [] + for loader in loaders: + workers_active.append(loader.workers_active) + + # Sequential only starts the first dataloader, other modes start both + expected = [True, False] if mode == "sequential" else [True, True] + assert workers_active == expected + + @pytest.mark.parametrize( ("limits", "expected"), [