avoid unnecessary workers with sequential `CombinedLoader ` (#17639)
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
parent
cf14d624ae
commit
c3ad7568e1
|
@ -108,6 +108,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
### Fixed
|
||||
|
||||
- `CombinedLoader` only starts DataLoader workers when necessary when operating in sequential mode ([#17639](https://github.com/Lightning-AI/lightning/pull/17639))
|
||||
|
||||
|
||||
- Fixed issue where `Model.load_from_checkpoint("checkpoint.ckpt", map_location=map_location)` would always return model on CPU ([#17308](https://github.com/Lightning-AI/lightning/pull/17308))
|
||||
|
||||
|
||||
|
|
|
@ -108,7 +108,7 @@ class _Sequential(_ModeIterator[Tuple[Any, int, int]]):
|
|||
self._limits = limits
|
||||
|
||||
def __next__(self) -> Tuple[Any, int, int]:
|
||||
n = len(self.iterators)
|
||||
n = len(self.iterables)
|
||||
if n == 0 or self._iterator_idx >= n:
|
||||
raise StopIteration
|
||||
|
||||
|
@ -120,7 +120,7 @@ class _Sequential(_ModeIterator[Tuple[Any, int, int]]):
|
|||
raise StopIteration
|
||||
|
||||
try:
|
||||
out = next(self.iterators[self._iterator_idx])
|
||||
out = next(self.iterators[0])
|
||||
index = self._idx
|
||||
self._idx += 1
|
||||
# batch, batch_idx, dataloader_idx
|
||||
|
@ -131,9 +131,9 @@ class _Sequential(_ModeIterator[Tuple[Any, int, int]]):
|
|||
return self.__next__()
|
||||
|
||||
def __iter__(self) -> Self:
|
||||
super().__iter__()
|
||||
self._iterator_idx = 0
|
||||
self._idx = 0
|
||||
self._load_current_iterator()
|
||||
return self
|
||||
|
||||
def reset(self) -> None:
|
||||
|
@ -141,9 +141,18 @@ class _Sequential(_ModeIterator[Tuple[Any, int, int]]):
|
|||
self._iterator_idx = 0
|
||||
self._idx = 0
|
||||
|
||||
def _load_current_iterator(self) -> None:
|
||||
# Load a single DataLoader, prevents multiple sets of workers from starting unnecessarily
|
||||
if self._iterator_idx < len(self.iterables):
|
||||
self.iterators = [iter(self.iterables[self._iterator_idx])]
|
||||
else:
|
||||
# No more iterables to step through, return an empty list
|
||||
self.iterators = []
|
||||
|
||||
def _use_next_iterator(self) -> None:
|
||||
self._iterator_idx += 1
|
||||
self._idx = 0
|
||||
self._load_current_iterator()
|
||||
|
||||
|
||||
class _MaxSize(_ModeIterator[List]):
|
||||
|
|
|
@ -844,8 +844,7 @@ def test_workers_are_shutdown(tmpdir, should_fail, persistent_workers):
|
|||
# iterable check
|
||||
0,
|
||||
# epoch ends
|
||||
1,
|
||||
# teardown
|
||||
0,
|
||||
1,
|
||||
]
|
||||
else:
|
||||
|
@ -855,9 +854,8 @@ def test_workers_are_shutdown(tmpdir, should_fail, persistent_workers):
|
|||
# iterable check
|
||||
0,
|
||||
# epoch ends
|
||||
0,
|
||||
1,
|
||||
2,
|
||||
# teardown
|
||||
3,
|
||||
]
|
||||
assert val_dataloader.shutdown_workers_epochs == expected
|
||||
|
|
|
@ -305,6 +305,40 @@ def test_combined_loader_sequence_iterable_dataset(mode, use_multiple_dataloader
|
|||
assert idx == expected - 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("mode", ["min_size", "max_size_cycle", "max_size", "sequential"])
|
||||
def test_combined_loader_simultaneous_workers(mode):
|
||||
"""Test `CombinedLoader` to check how it initializes dataloader workers."""
|
||||
|
||||
class TestDataLoader(DataLoader):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.workers_active = False
|
||||
|
||||
def _get_iterator(self):
|
||||
self.workers_active = True
|
||||
return super()._get_iterator()
|
||||
|
||||
def _shutdown_workers(self):
|
||||
self.workers_active = False
|
||||
super()._shutdown_workers()
|
||||
|
||||
loaders = [
|
||||
TestDataLoader(range(10), batch_size=2, num_workers=0),
|
||||
TestDataLoader(range(20), batch_size=2, num_workers=0),
|
||||
]
|
||||
combined_loader = CombinedLoader(loaders, mode)
|
||||
# Start the dataloader
|
||||
_ = iter(combined_loader)
|
||||
|
||||
workers_active = []
|
||||
for loader in loaders:
|
||||
workers_active.append(loader.workers_active)
|
||||
|
||||
# Sequential only starts the first dataloader, other modes start both
|
||||
expected = [True, False] if mode == "sequential" else [True, True]
|
||||
assert workers_active == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("limits", "expected"),
|
||||
[
|
||||
|
|
Loading…
Reference in New Issue