avoid unnecessary workers with sequential `CombinedLoader ` (#17639)

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
Ryan Mukherjee 2023-05-30 00:02:50 -04:00 committed by GitHub
parent cf14d624ae
commit c3ad7568e1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 51 additions and 7 deletions

View File

@ -108,6 +108,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Fixed
- `CombinedLoader` only starts DataLoader workers when necessary when operating in sequential mode ([#17639](https://github.com/Lightning-AI/lightning/pull/17639))
- Fixed issue where `Model.load_from_checkpoint("checkpoint.ckpt", map_location=map_location)` would always return model on CPU ([#17308](https://github.com/Lightning-AI/lightning/pull/17308))

View File

@ -108,7 +108,7 @@ class _Sequential(_ModeIterator[Tuple[Any, int, int]]):
self._limits = limits
def __next__(self) -> Tuple[Any, int, int]:
n = len(self.iterators)
n = len(self.iterables)
if n == 0 or self._iterator_idx >= n:
raise StopIteration
@ -120,7 +120,7 @@ class _Sequential(_ModeIterator[Tuple[Any, int, int]]):
raise StopIteration
try:
out = next(self.iterators[self._iterator_idx])
out = next(self.iterators[0])
index = self._idx
self._idx += 1
# batch, batch_idx, dataloader_idx
@ -131,9 +131,9 @@ class _Sequential(_ModeIterator[Tuple[Any, int, int]]):
return self.__next__()
def __iter__(self) -> Self:
super().__iter__()
self._iterator_idx = 0
self._idx = 0
self._load_current_iterator()
return self
def reset(self) -> None:
@ -141,9 +141,18 @@ class _Sequential(_ModeIterator[Tuple[Any, int, int]]):
self._iterator_idx = 0
self._idx = 0
def _load_current_iterator(self) -> None:
# Load a single DataLoader, prevents multiple sets of workers from starting unnecessarily
if self._iterator_idx < len(self.iterables):
self.iterators = [iter(self.iterables[self._iterator_idx])]
else:
# No more iterables to step through, return an empty list
self.iterators = []
def _use_next_iterator(self) -> None:
self._iterator_idx += 1
self._idx = 0
self._load_current_iterator()
class _MaxSize(_ModeIterator[List]):

View File

@ -844,8 +844,7 @@ def test_workers_are_shutdown(tmpdir, should_fail, persistent_workers):
# iterable check
0,
# epoch ends
1,
# teardown
0,
1,
]
else:
@ -855,9 +854,8 @@ def test_workers_are_shutdown(tmpdir, should_fail, persistent_workers):
# iterable check
0,
# epoch ends
0,
1,
2,
# teardown
3,
]
assert val_dataloader.shutdown_workers_epochs == expected

View File

@ -305,6 +305,40 @@ def test_combined_loader_sequence_iterable_dataset(mode, use_multiple_dataloader
assert idx == expected - 1
@pytest.mark.parametrize("mode", ["min_size", "max_size_cycle", "max_size", "sequential"])
def test_combined_loader_simultaneous_workers(mode):
"""Test `CombinedLoader` to check how it initializes dataloader workers."""
class TestDataLoader(DataLoader):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.workers_active = False
def _get_iterator(self):
self.workers_active = True
return super()._get_iterator()
def _shutdown_workers(self):
self.workers_active = False
super()._shutdown_workers()
loaders = [
TestDataLoader(range(10), batch_size=2, num_workers=0),
TestDataLoader(range(20), batch_size=2, num_workers=0),
]
combined_loader = CombinedLoader(loaders, mode)
# Start the dataloader
_ = iter(combined_loader)
workers_active = []
for loader in loaders:
workers_active.append(loader.workers_active)
# Sequential only starts the first dataloader, other modes start both
expected = [True, False] if mode == "sequential" else [True, True]
assert workers_active == expected
@pytest.mark.parametrize(
("limits", "expected"),
[