Replace DataLoader sampler once for IPUs (#8858)
This commit is contained in:
parent
1d2f7e20c4
commit
93ab24d1ee
|
@ -72,26 +72,11 @@ jobs:
|
|||
python -c "import poptorch; print(poptorch.__version__)"
|
||||
displayName: "Check poptorch installation"
|
||||
|
||||
- bash: |
|
||||
wget https://pl-public-data.s3.amazonaws.com/legacy/checkpoints.zip -P legacy/
|
||||
unzip -o legacy/checkpoints.zip -d legacy/
|
||||
ls -l legacy/checkpoints/
|
||||
displayName: 'Get legacy checkpoints'
|
||||
|
||||
- bash: |
|
||||
source ${{ variables.poplar_sdk }}/poplar-ubuntu*/enable.sh
|
||||
source ${{ variables.poplar_sdk }}/popart-ubuntu*/enable.sh
|
||||
export POPTORCH_WAIT_FOR_IPU=1
|
||||
python -m coverage run --source pytorch_lightning -m pytest pytorch_lightning tests -v --junitxml=$(Build.StagingDirectory)/test-results.xml --durations=50
|
||||
python -m coverage run --source pytorch_lightning -m pytest tests/accelerators/test_ipu.py -v --junitxml=$(Build.StagingDirectory)/test-results.xml --durations=50
|
||||
env:
|
||||
MKL_THREADING_LAYER: "GNU"
|
||||
displayName: 'Testing: standard'
|
||||
|
||||
- bash: |
|
||||
source ${{ variables.poplar_sdk }}/poplar-ubuntu*/enable.sh
|
||||
source ${{ variables.poplar_sdk }}/popart-ubuntu*/enable.sh
|
||||
export POPTORCH_WAIT_FOR_IPU=1
|
||||
bash tests/special_tests.sh
|
||||
env:
|
||||
MKL_THREADING_LAYER: "GNU"
|
||||
displayName: 'Testing: special'
|
||||
|
|
|
@ -81,6 +81,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
(https://github.com/PyTorchLightning/pytorch-lightning/pull/8608))
|
||||
|
||||
|
||||
- `Trainer.request_dataloader` now takes a `RunningStage` enum instance ([#8858](https://github.com/PyTorchLightning/pytorch-lightning/pull/8858))
|
||||
|
||||
### Deprecated
|
||||
|
||||
- Deprecated `LightningModule.summarize()` in favor of `pytorch_lightning.utilities.model_summary.summarize()`
|
||||
|
@ -132,6 +134,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Removed `LightningModule.write_predictions` and `LightningModule.write_predictions_dict` ([#](https://github.com/PyTorchLightning/pytorch-lightning/pull/8850))
|
||||
|
||||
|
||||
- Removed reset dataloader hooks to Training Plugins and Accelerators ([#8858](https://github.com/PyTorchLightning/pytorch-lightning/pull/8858))
|
||||
|
||||
|
||||
|
||||
### Fixed
|
||||
|
||||
|
@ -176,6 +181,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Fixed infinite loop with CycleIterator and multiple loaders ([#8889](https://github.com/PyTorchLightning/pytorch-lightning/pull/8889))
|
||||
|
||||
|
||||
- Fixed bug where data-loading functions where not getting the correct running stage passed ([#8858](https://github.com/PyTorchLightning/pytorch-lightning/pull/8858))
|
||||
|
||||
|
||||
## [1.4.0] - 2021-07-27
|
||||
|
||||
### Added
|
||||
|
|
|
@ -410,22 +410,6 @@ class Accelerator:
|
|||
"""
|
||||
return self.training_type_plugin.process_dataloader(dataloader)
|
||||
|
||||
def on_reset_train_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]:
|
||||
"""Called before resetting the train dataloader."""
|
||||
return self.training_type_plugin.on_reset_train_dataloader(dataloader)
|
||||
|
||||
def on_reset_val_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]:
|
||||
"""Called before resetting the val dataloader."""
|
||||
return self.training_type_plugin.on_reset_val_dataloader(dataloader)
|
||||
|
||||
def on_reset_test_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]:
|
||||
"""Called before resetting the test dataloader."""
|
||||
return self.training_type_plugin.on_reset_test_dataloader(dataloader)
|
||||
|
||||
def on_reset_predict_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]:
|
||||
"""Called before resetting the predict dataloader."""
|
||||
return self.training_type_plugin.on_reset_predict_dataloader(dataloader)
|
||||
|
||||
@property
|
||||
def results(self) -> Any:
|
||||
"""
|
||||
|
|
|
@ -11,10 +11,9 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Iterable, List, Optional, Union
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
@ -26,7 +25,6 @@ from pytorch_lightning.plugins.environments.cluster_environment import ClusterEn
|
|||
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
|
||||
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
|
||||
from pytorch_lightning.trainer.states import RunningStage
|
||||
from pytorch_lightning.trainer.supporters import CombinedLoader
|
||||
from pytorch_lightning.utilities import _POPTORCH_AVAILABLE
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
||||
from pytorch_lightning.utilities.cloud_io import get_filesystem
|
||||
|
@ -112,6 +110,12 @@ class IPUPlugin(ParallelPlugin):
|
|||
options["autoReport.directory"] = self.autoreport_dir
|
||||
os.environ["POPLAR_ENGINE_OPTIONS"] = json.dumps(options)
|
||||
|
||||
def setup(self) -> None:
|
||||
# patch the dataloader creation function with the custom `poptorch.DataLoader`.
|
||||
# this violates the intended control flow for the plugins, but since this is experimental, we have chosen
|
||||
# to use the simpler solution before adding abstractions to override the `DataLoader` class
|
||||
self.lightning_module.trainer.replace_sampler = self._convert_to_poptorch_loader
|
||||
|
||||
def pre_dispatch(self) -> None:
|
||||
precision = self.lightning_module.trainer.precision
|
||||
model = LightningIPUModule(self.lightning_module, precision)
|
||||
|
@ -169,59 +173,16 @@ class IPUPlugin(ParallelPlugin):
|
|||
def lightning_module(self) -> Optional["pl.LightningModule"]:
|
||||
return self.model.module if isinstance(self.model, LightningIPUModule) else self.model
|
||||
|
||||
def on_reset_train_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]:
|
||||
return self._process_dataloader(dataloader, is_training=True)
|
||||
|
||||
def on_reset_val_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]:
|
||||
return self._process_dataloader(dataloader, is_training=False)
|
||||
|
||||
def on_reset_test_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]:
|
||||
return self._process_dataloader(dataloader, is_training=False)
|
||||
|
||||
def on_reset_predict_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]:
|
||||
return self._process_dataloader(dataloader, is_training=False)
|
||||
|
||||
def _process_dataloader(
|
||||
self, dataloader: Union[Iterable, DataLoader], is_training: bool
|
||||
) -> Union[Iterable, DataLoader]:
|
||||
if isinstance(dataloader, CombinedLoader):
|
||||
dataloader.loaders = apply_to_collection(
|
||||
dataloader.loaders, DataLoader, self._process_dataloader, is_training
|
||||
)
|
||||
return dataloader
|
||||
if isinstance(dataloader, list):
|
||||
dataloader = apply_to_collection(dataloader, DataLoader, self._process_dataloader, is_training)
|
||||
return dataloader
|
||||
if not isinstance(dataloader, poptorch.DataLoader):
|
||||
opts = self.training_opts if is_training else self.inference_opts
|
||||
dataloader = self._convert_to_poptorch_loader(dataloader=dataloader, opts=opts)
|
||||
return dataloader
|
||||
|
||||
def _convert_to_poptorch_loader(
|
||||
self, dataloader: Union[Iterable, DataLoader], opts: "poptorch.Options"
|
||||
) -> Union[Iterable, DataLoader]:
|
||||
skip_keys = ("sampler", "batch_sampler", "dataset_kind")
|
||||
|
||||
attrs = {k: v for k, v in vars(dataloader).items() if not k.startswith("_")}
|
||||
|
||||
params = set(inspect.signature(dataloader.__init__).parameters)
|
||||
contains_dataset = True
|
||||
|
||||
if type(dataloader) is not DataLoader:
|
||||
contains_dataset = "dataset" in params
|
||||
params.update(inspect.signature(DataLoader.__init__).parameters)
|
||||
|
||||
dl_args = {name: attrs[name] for name in params if name in attrs and name not in skip_keys}
|
||||
|
||||
multiprocessing_context = dataloader.multiprocessing_context
|
||||
dl_args["multiprocessing_context"] = multiprocessing_context
|
||||
if not contains_dataset:
|
||||
dl_args.pop("dataset")
|
||||
self, dataloader: DataLoader, sampler, mode: Optional[RunningStage] = None
|
||||
) -> "poptorch.DataLoader":
|
||||
# use full path to avoid circular imports
|
||||
dl_kwargs = pl.trainer.trainer.TrainerDataLoadingMixin._get_dataloader_init_kwargs(dataloader, sampler)
|
||||
# Override to drop last uneven batch, as IPUs does not support uneven inputs.
|
||||
dl_args["drop_last"] = True
|
||||
dl_kwargs["drop_last"] = True
|
||||
|
||||
dataloader = poptorch.DataLoader(**dl_args, options=opts)
|
||||
dataloader.multiprocessing_context = multiprocessing_context
|
||||
opts = self.training_opts if mode == RunningStage.TRAINING else self.inference_opts
|
||||
dataloader = poptorch.DataLoader(**dl_kwargs, options=opts)
|
||||
return dataloader
|
||||
|
||||
@property
|
||||
|
@ -291,6 +252,8 @@ class IPUPlugin(ParallelPlugin):
|
|||
return self.poptorch_models[RunningStage.PREDICTING](*args, **kwargs)
|
||||
|
||||
def teardown(self) -> None:
|
||||
# undo dataloader patching
|
||||
self.lightning_module.trainer.replace_sampler = pl.trainer.trainer.TrainerDataLoadingMixin.replace_sampler
|
||||
for model in self.poptorch_models.values():
|
||||
model.destroy()
|
||||
|
||||
|
|
|
@ -212,22 +212,6 @@ class TrainingTypePlugin(Plugin, ABC):
|
|||
"""
|
||||
return dataloader
|
||||
|
||||
def on_reset_train_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]:
|
||||
"""Called before resetting the train dataloader."""
|
||||
return dataloader
|
||||
|
||||
def on_reset_val_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]:
|
||||
"""Called before resetting the val dataloader."""
|
||||
return dataloader
|
||||
|
||||
def on_reset_test_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]:
|
||||
"""Called before resetting the test dataloader."""
|
||||
return dataloader
|
||||
|
||||
def on_reset_predict_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]:
|
||||
"""Called before resetting the predict dataloader."""
|
||||
return dataloader
|
||||
|
||||
def init_optimizers(self, trainer: "pl.Trainer", model: "pl.LightningModule"):
|
||||
return trainer.init_optimizers(model)
|
||||
|
||||
|
|
|
@ -19,7 +19,7 @@ from copy import deepcopy
|
|||
from functools import partial
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
||||
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, Sampler, SequentialSampler
|
||||
from torch.utils.data.dataset import IterableDataset
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
|
@ -121,7 +121,9 @@ class TrainerDataLoadingMixin(ABC):
|
|||
def auto_add_sampler(self, dataloader: Any, shuffle: bool, mode: Optional[RunningStage] = None) -> Any:
|
||||
if isinstance(dataloader, CombinedLoader):
|
||||
# apply `auto_add_sampler` on all the collection of loaders
|
||||
dataloader.loaders = apply_to_collection(dataloader.loaders, DataLoader, self.auto_add_sampler, shuffle)
|
||||
dataloader.loaders = apply_to_collection(
|
||||
dataloader.loaders, DataLoader, self.auto_add_sampler, shuffle, mode=mode
|
||||
)
|
||||
return dataloader
|
||||
|
||||
# don't do anything if it's not a dataloader
|
||||
|
@ -151,7 +153,9 @@ class TrainerDataLoadingMixin(ABC):
|
|||
return dataloader
|
||||
|
||||
@staticmethod
|
||||
def _resolve_batch_sampler(dataloader, sampler, mode: Optional[RunningStage] = None) -> Dict[str, Any]:
|
||||
def _resolve_batch_sampler(
|
||||
dataloader: DataLoader, sampler: Optional[Sampler], mode: Optional[RunningStage] = None
|
||||
) -> Dict[str, Any]:
|
||||
batch_sampler = getattr(dataloader, "batch_sampler")
|
||||
is_predicting = mode == RunningStage.PREDICTING
|
||||
# checking the batch sampler type is different than PyTorch default.
|
||||
|
@ -182,7 +186,10 @@ class TrainerDataLoadingMixin(ABC):
|
|||
|
||||
return {"sampler": sampler, "shuffle": False, "batch_sampler": None}
|
||||
|
||||
def replace_sampler(self, dataloader: DataLoader, sampler, mode: Optional[RunningStage] = None) -> DataLoader:
|
||||
@staticmethod
|
||||
def _get_dataloader_init_kwargs(
|
||||
dataloader: DataLoader, sampler: Optional[Sampler], mode: Optional[RunningStage] = None
|
||||
) -> Dict[str, Any]:
|
||||
if not isinstance(dataloader, DataLoader):
|
||||
raise ValueError(f"The dataloader {dataloader} needs to subclass `torch.utils.data.DataLoader`")
|
||||
|
||||
|
@ -201,7 +208,7 @@ class TrainerDataLoadingMixin(ABC):
|
|||
|
||||
# kwargs to re-construct the dataloader
|
||||
dl_kwargs = {k: v for k, v in attrs.items() if k in non_defaults}
|
||||
dl_kwargs.update(self._resolve_batch_sampler(dataloader, sampler, mode=mode))
|
||||
dl_kwargs.update(TrainerDataLoadingMixin._resolve_batch_sampler(dataloader, sampler, mode=mode))
|
||||
|
||||
required_args = {
|
||||
p.name
|
||||
|
@ -248,6 +255,11 @@ class TrainerDataLoadingMixin(ABC):
|
|||
del dl_kwargs["sampler"]
|
||||
del dl_kwargs["batch_sampler"]
|
||||
|
||||
return dl_kwargs
|
||||
|
||||
@staticmethod
|
||||
def replace_sampler(dataloader: DataLoader, sampler, mode: Optional[RunningStage] = None) -> DataLoader:
|
||||
dl_kwargs = TrainerDataLoadingMixin._get_dataloader_init_kwargs(dataloader, sampler, mode=mode)
|
||||
dl_cls = type(dataloader)
|
||||
dataloader = dl_cls(**dl_kwargs)
|
||||
return dataloader
|
||||
|
@ -269,7 +281,7 @@ class TrainerDataLoadingMixin(ABC):
|
|||
Args:
|
||||
model: The `LightningModule` if calling this outside of the trainer scope.
|
||||
"""
|
||||
self.train_dataloader = self.request_dataloader("train", model=model)
|
||||
self.train_dataloader = self.request_dataloader(RunningStage.TRAINING, model=model)
|
||||
|
||||
if self.overfit_batches > 0:
|
||||
if hasattr(self.train_dataloader, "sampler") and isinstance(self.train_dataloader.sampler, RandomSampler):
|
||||
|
@ -278,7 +290,7 @@ class TrainerDataLoadingMixin(ABC):
|
|||
" We are turning off the training dataloader shuffling for you."
|
||||
)
|
||||
self.train_dataloader = self.replace_sampler(
|
||||
self.train_dataloader, SequentialSampler(self.train_dataloader.dataset)
|
||||
self.train_dataloader, SequentialSampler(self.train_dataloader.dataset), mode=RunningStage.TRAINING
|
||||
)
|
||||
|
||||
# debugging
|
||||
|
@ -286,11 +298,11 @@ class TrainerDataLoadingMixin(ABC):
|
|||
|
||||
# automatically add samplers
|
||||
self.train_dataloader = apply_to_collection(
|
||||
self.train_dataloader, DataLoader, self.auto_add_sampler, shuffle=True
|
||||
self.train_dataloader, DataLoader, self.auto_add_sampler, shuffle=True, mode=RunningStage.TRAINING
|
||||
)
|
||||
|
||||
# check the workers recursively
|
||||
apply_to_collection(self.train_dataloader, DataLoader, self._worker_check, "train dataloader")
|
||||
apply_to_collection(self.train_dataloader, DataLoader, self._worker_check, "train_dataloader")
|
||||
|
||||
# add worker_init_fn for correct seeding in worker processes
|
||||
apply_to_collection(self.train_dataloader, DataLoader, self.auto_add_worker_init_fn)
|
||||
|
@ -302,9 +314,6 @@ class TrainerDataLoadingMixin(ABC):
|
|||
# wrap the sequence of train loaders to a CombinedLoader object for computing the num_training_batches
|
||||
self.train_dataloader = CombinedLoader(self.train_dataloader, self.data_connector.multiple_trainloader_mode)
|
||||
|
||||
# allow accelerator to modify dataloader
|
||||
self.train_dataloader = self.accelerator.on_reset_train_dataloader(self.train_dataloader)
|
||||
|
||||
self.num_training_batches = len(self.train_dataloader) if has_len(self.train_dataloader) else float("inf")
|
||||
|
||||
if isinstance(self.limit_train_batches, int) or self.limit_train_batches == 0.0:
|
||||
|
@ -351,69 +360,61 @@ class TrainerDataLoadingMixin(ABC):
|
|||
)
|
||||
|
||||
def _reset_eval_dataloader(
|
||||
self, mode: str, model: Optional["pl.LightningModule"] = None
|
||||
self, mode: RunningStage, model: Optional["pl.LightningModule"] = None
|
||||
) -> Tuple[List[Union[int, float]], List[DataLoader]]:
|
||||
"""Generic method to reset a dataloader for evaluation.
|
||||
|
||||
Args:
|
||||
mode: Either `'val'`, `'test'` or `'predict'`
|
||||
model: The `LightningModule` if calling this outside of the trainer scope.
|
||||
mode: The running stage of the ``Trainer``
|
||||
model: The ``LightningModule`` if calling this outside of the trainer scope.
|
||||
|
||||
Returns:
|
||||
Tuple (num_batches, dataloaders)
|
||||
"""
|
||||
assert mode.evaluating or mode == RunningStage.PREDICTING
|
||||
|
||||
# always get the loaders first so we can count how many there are
|
||||
loader_name = f"{mode}_dataloader"
|
||||
loader_name = f"{mode.dataloader_prefix}_dataloader"
|
||||
dataloaders = self.request_dataloader(mode, model=model)
|
||||
|
||||
if not isinstance(dataloaders, list):
|
||||
dataloaders = [dataloaders]
|
||||
|
||||
# when overfitting use the training loader as val and test
|
||||
# when overfitting, use the training loader as val and test
|
||||
# duplicate it the numb of times needed to match the train loaders
|
||||
if self.overfit_batches > 0:
|
||||
num_loaders = len(dataloaders)
|
||||
train_dataloader = self.request_dataloader("train", model=model)
|
||||
dataloaders = [deepcopy(train_dataloader) for _ in range(num_loaders)]
|
||||
train_dataloader = self.request_dataloader(RunningStage.TRAINING, model=model)
|
||||
dataloaders = [deepcopy(train_dataloader) for _ in range(len(dataloaders))]
|
||||
|
||||
self.dev_debugger.track_load_dataloader_call(loader_name, dataloaders=dataloaders)
|
||||
|
||||
for loader_i in range(len(dataloaders)):
|
||||
loader = dataloaders[loader_i]
|
||||
|
||||
# shuffling in val and test set is bad practice
|
||||
modes = ("val", "test", "predict")
|
||||
if mode in modes and hasattr(loader, "sampler") and isinstance(loader.sampler, RandomSampler):
|
||||
if hasattr(loader, "sampler") and isinstance(loader.sampler, RandomSampler):
|
||||
|
||||
# when overfitting, the dataloader should not have sampler
|
||||
if self.overfit_batches > 0 and mode != "predict":
|
||||
if self.overfit_batches > 0 and mode.evaluating:
|
||||
rank_zero_warn(
|
||||
"You requested to overfit but enabled val/test dataloader shuffling."
|
||||
" We are turning it off for you."
|
||||
)
|
||||
dataloaders[loader_i] = self.replace_sampler(loader, SequentialSampler(loader.dataset))
|
||||
|
||||
dataloaders[loader_i] = self.replace_sampler(loader, SequentialSampler(loader.dataset), mode=mode)
|
||||
else:
|
||||
rank_zero_warn(
|
||||
f"Your {mode}_dataloader has `shuffle=True`, it is best practice to turn"
|
||||
" this off for val/test/predict dataloaders."
|
||||
f"Your `{mode.dataloader_prefix}_dataloader` has `shuffle=True`,"
|
||||
"it is strongly recommended that you turn this off for val/test/predict dataloaders."
|
||||
)
|
||||
|
||||
if any(dl is None for dl in dataloaders):
|
||||
rank_zero_warn("One of given dataloaders is None and it will be skipped.")
|
||||
|
||||
# add samplers
|
||||
dataloaders = [
|
||||
self.auto_add_sampler(dl, shuffle=False, mode=self.state.stage) for dl in dataloaders if dl is not None
|
||||
]
|
||||
dataloaders = [self.auto_add_sampler(dl, False, mode=mode) for dl in dataloaders if dl is not None]
|
||||
|
||||
# add worker_init_fn for correct seeding in worker processes
|
||||
apply_to_collection(dataloaders, dtype=DataLoader, function=self.auto_add_worker_init_fn)
|
||||
|
||||
# allow accelerator to modify dataloader
|
||||
hook_name = f"on_reset_{mode}_dataloader"
|
||||
dataloaders = getattr(self.accelerator, hook_name)(dataloaders)
|
||||
|
||||
loader_num_batches = []
|
||||
|
||||
# determine number of batches
|
||||
|
@ -421,10 +422,10 @@ class TrainerDataLoadingMixin(ABC):
|
|||
if len(dataloaders) != 0:
|
||||
for i, dataloader in enumerate(dataloaders):
|
||||
num_batches = len(dataloader) if has_len(dataloader) else float("inf")
|
||||
self._worker_check(dataloader, f"{mode} dataloader {i}")
|
||||
self._worker_check(dataloader, f"{mode.dataloader_prefix}_dataloader {i}")
|
||||
|
||||
# percent or num_steps
|
||||
limit_eval_batches = getattr(self, f"limit_{mode}_batches")
|
||||
limit_eval_batches = getattr(self, f"limit_{mode.dataloader_prefix}_batches")
|
||||
|
||||
# limit num batches either as a percent or num steps
|
||||
if isinstance(limit_eval_batches, int) or limit_eval_batches == 0.0:
|
||||
|
@ -433,17 +434,18 @@ class TrainerDataLoadingMixin(ABC):
|
|||
num_batches = int(num_batches * limit_eval_batches)
|
||||
elif limit_eval_batches != 1.0:
|
||||
raise MisconfigurationException(
|
||||
"When using an IterableDataset for `limit_{mode}_batches`,"
|
||||
f" `Trainer(limit_{mode}_batches)` must be `0.0`, `1.0` or an int. An int k specifies"
|
||||
f" `num_{mode}_batches` to use."
|
||||
f"When using an IterableDataset for `limit_{mode}_batches`,"
|
||||
f" `Trainer(limit_{mode.dataloader_prefix}_batches)` must be `0.0`, `1.0` or an int. An int k"
|
||||
f" specifies `num_{mode.dataloader_prefix}_batches` to use."
|
||||
)
|
||||
|
||||
if num_batches == 0 and limit_eval_batches > 0.0 and isinstance(limit_eval_batches, float):
|
||||
min_pct = 1.0 / len(dataloader)
|
||||
raise MisconfigurationException(
|
||||
f"you requested to check {limit_eval_batches} of the {mode} dataloader but"
|
||||
f" {limit_eval_batches}*{num_batches} < 1. Please increase the limit_{mode}_batches."
|
||||
f" Try at least limit_{mode}_batches={min_pct}"
|
||||
f"you requested to check {limit_eval_batches} of the `{mode.dataloader_prefix}_dataloader` but"
|
||||
f" {limit_eval_batches}*{num_batches} < 1. Please increase the"
|
||||
f" `limit_{mode.dataloader_prefix}_batches` flag. Try at least"
|
||||
f" `limit_{mode.dataloader_prefix}_batches={min_pct}`"
|
||||
)
|
||||
|
||||
loader_num_batches.append(num_batches)
|
||||
|
@ -460,7 +462,9 @@ class TrainerDataLoadingMixin(ABC):
|
|||
has_loader = is_overridden("val_dataloader", pl_module)
|
||||
has_step = is_overridden("validation_step", pl_module)
|
||||
if has_loader and has_step:
|
||||
self.num_val_batches, self.val_dataloaders = self._reset_eval_dataloader("val", model=pl_module)
|
||||
self.num_val_batches, self.val_dataloaders = self._reset_eval_dataloader(
|
||||
RunningStage.VALIDATING, model=pl_module
|
||||
)
|
||||
|
||||
def reset_test_dataloader(self, model: Optional["pl.LightningModule"] = None) -> None:
|
||||
"""Resets the test dataloader and determines the number of batches.
|
||||
|
@ -472,7 +476,9 @@ class TrainerDataLoadingMixin(ABC):
|
|||
has_loader = is_overridden("test_dataloader", pl_module)
|
||||
has_step = is_overridden("test_step", pl_module)
|
||||
if has_loader and has_step:
|
||||
self.num_test_batches, self.test_dataloaders = self._reset_eval_dataloader("test", model=pl_module)
|
||||
self.num_test_batches, self.test_dataloaders = self._reset_eval_dataloader(
|
||||
RunningStage.TESTING, model=pl_module
|
||||
)
|
||||
|
||||
def reset_predict_dataloader(self, model: Optional["pl.LightningModule"] = None) -> None:
|
||||
"""Resets the predict dataloader and determines the number of batches.
|
||||
|
@ -483,7 +489,9 @@ class TrainerDataLoadingMixin(ABC):
|
|||
pl_module = self.lightning_module or model
|
||||
has_loader = is_overridden("predict_dataloader", pl_module)
|
||||
if has_loader:
|
||||
self.num_predict_batches, self.predict_dataloaders = self._reset_eval_dataloader("predict", model=pl_module)
|
||||
self.num_predict_batches, self.predict_dataloaders = self._reset_eval_dataloader(
|
||||
RunningStage.PREDICTING, model=pl_module
|
||||
)
|
||||
|
||||
def reset_train_val_dataloaders(self, model: Optional["pl.LightningModule"] = None) -> None:
|
||||
"""
|
||||
|
@ -501,15 +509,15 @@ class TrainerDataLoadingMixin(ABC):
|
|||
self.reset_val_dataloader(model=model)
|
||||
|
||||
def request_dataloader(
|
||||
self, stage: str, model: Optional["pl.LightningModule"] = None
|
||||
self, stage: RunningStage, model: Optional["pl.LightningModule"] = None
|
||||
) -> Union[DataLoader, List[DataLoader]]:
|
||||
"""Handles downloading data in the GPU or TPU case.
|
||||
|
||||
Returns:
|
||||
The dataloader
|
||||
"""
|
||||
self.call_hook(f"on_{stage}_dataloader")
|
||||
dataloader = getattr(model, f"{stage}_dataloader")()
|
||||
self.call_hook(f"on_{stage.dataloader_prefix}_dataloader")
|
||||
dataloader = getattr(model, f"{stage.dataloader_prefix}_dataloader")()
|
||||
if isinstance(dataloader, tuple):
|
||||
dataloader = list(dataloader)
|
||||
self.accelerator.barrier("get_dataloaders")
|
||||
|
|
|
@ -78,6 +78,14 @@ class RunningStage(LightningEnum):
|
|||
def evaluating(self) -> bool:
|
||||
return self in (self.VALIDATING, self.TESTING)
|
||||
|
||||
@property
|
||||
def dataloader_prefix(self) -> Optional[str]:
|
||||
if self in (self.SANITY_CHECKING, self.TUNING):
|
||||
return None
|
||||
if self == self.VALIDATING:
|
||||
return "val"
|
||||
return self.value
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainerState:
|
||||
|
|
|
@ -54,115 +54,6 @@ def test_plugin_setup_optimizers_in_pre_dispatch(tmpdir, delay_dispatch):
|
|||
trainer.fit(model)
|
||||
|
||||
|
||||
def test_accelerator_on_reset_dataloader_hooks(tmpdir):
|
||||
"""
|
||||
Ensure data-loader hooks are called using an Accelerator.
|
||||
"""
|
||||
|
||||
class CustomAccelerator(CPUAccelerator):
|
||||
train_count: int = 0
|
||||
val_count: int = 0
|
||||
test_count: int = 0
|
||||
predict_count: int = 0
|
||||
|
||||
def on_reset_train_dataloader(self, dataloader):
|
||||
self.train_count += 1
|
||||
assert self.lightning_module.trainer.training
|
||||
return super().on_reset_train_dataloader(dataloader)
|
||||
|
||||
def on_reset_val_dataloader(self, dataloader):
|
||||
self.val_count += 1
|
||||
assert self.lightning_module.trainer.training or self.lightning_module.trainer.validating
|
||||
return super().on_reset_val_dataloader(dataloader)
|
||||
|
||||
def on_reset_test_dataloader(self, dataloader):
|
||||
self.test_count += 1
|
||||
assert self.lightning_module.trainer.testing
|
||||
return super().on_reset_test_dataloader(dataloader)
|
||||
|
||||
def on_reset_predict_dataloader(self, dataloader):
|
||||
self.predict_count += 1
|
||||
assert self.lightning_module.trainer.predicting
|
||||
return super().on_reset_predict_dataloader(dataloader)
|
||||
|
||||
model = BoringModel()
|
||||
accelerator = CustomAccelerator(PrecisionPlugin(), SingleDevicePlugin(device=torch.device("cpu")))
|
||||
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, accelerator=accelerator)
|
||||
trainer.fit(model)
|
||||
trainer.validate(model)
|
||||
trainer.test(model)
|
||||
trainer.predict(model, dataloaders=model.test_dataloader())
|
||||
# assert that all loader hooks were called
|
||||
assert accelerator.train_count == 1
|
||||
assert accelerator.val_count == 1 # only called once during the entire session
|
||||
assert accelerator.test_count == 1
|
||||
assert accelerator.predict_count == 1
|
||||
|
||||
accelerator = CustomAccelerator(PrecisionPlugin(), SingleDevicePlugin(device=torch.device("cpu")))
|
||||
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, accelerator=accelerator)
|
||||
trainer.validate(model)
|
||||
trainer.test(model)
|
||||
trainer.predict(model)
|
||||
# assert val/test/predict loader hooks were called
|
||||
assert accelerator.val_count == 1
|
||||
assert accelerator.test_count == 1
|
||||
assert accelerator.predict_count == 1
|
||||
|
||||
|
||||
def test_plugin_on_reset_dataloader_hooks(tmpdir):
|
||||
"""
|
||||
Ensure data-loader hooks are called using a Plugin.
|
||||
"""
|
||||
|
||||
class CustomPlugin(SingleDevicePlugin):
|
||||
train_count: int = 0
|
||||
val_count: int = 0
|
||||
test_count: int = 0
|
||||
predict_count: int = 0
|
||||
|
||||
def on_reset_train_dataloader(self, dataloader):
|
||||
self.train_count += 1
|
||||
assert self.lightning_module.trainer.training
|
||||
return super().on_reset_train_dataloader(dataloader)
|
||||
|
||||
def on_reset_val_dataloader(self, dataloader):
|
||||
self.val_count += 1
|
||||
assert self.lightning_module.trainer.training or self.lightning_module.trainer.validating
|
||||
return super().on_reset_val_dataloader(dataloader)
|
||||
|
||||
def on_reset_test_dataloader(self, dataloader):
|
||||
self.test_count += 1
|
||||
assert self.lightning_module.trainer.testing
|
||||
return super().on_reset_test_dataloader(dataloader)
|
||||
|
||||
def on_reset_predict_dataloader(self, dataloader):
|
||||
self.predict_count += 1
|
||||
assert self.lightning_module.trainer.predicting
|
||||
return super().on_reset_predict_dataloader(dataloader)
|
||||
|
||||
plugin = CustomPlugin(device=torch.device("cpu"))
|
||||
model = BoringModel()
|
||||
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, plugins=plugin)
|
||||
trainer.fit(model)
|
||||
trainer.validate(model)
|
||||
trainer.test(model)
|
||||
trainer.predict(model, dataloaders=model.test_dataloader())
|
||||
# assert that all loader hooks were called
|
||||
assert plugin.train_count == 1
|
||||
assert plugin.val_count == 1 # only called once during the entire session
|
||||
assert plugin.test_count == 1
|
||||
assert plugin.predict_count == 1
|
||||
plugin = CustomPlugin(device=torch.device("cpu"))
|
||||
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, plugins=plugin)
|
||||
trainer.validate(model)
|
||||
trainer.test(model)
|
||||
trainer.predict(model)
|
||||
# assert val/test/predict loader hooks were called
|
||||
assert plugin.val_count == 1
|
||||
assert plugin.test_count == 1
|
||||
assert plugin.predict_count == 1
|
||||
|
||||
|
||||
def test_restore_checkpoint_after_pre_dispatch_default():
|
||||
"""
|
||||
Assert default for restore_checkpoint_after_pre_dispatch is False.
|
||||
|
|
|
@ -132,7 +132,7 @@ def test_all_stages(tmpdir, ipus):
|
|||
trainer.fit(model)
|
||||
trainer.validate(model)
|
||||
trainer.test(model)
|
||||
trainer.predict(model, model.val_dataloader())
|
||||
trainer.predict(model)
|
||||
|
||||
|
||||
@RunIf(ipu=True)
|
||||
|
@ -143,7 +143,7 @@ def test_inference_only(tmpdir, ipus):
|
|||
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, ipus=ipus)
|
||||
trainer.validate(model)
|
||||
trainer.test(model)
|
||||
trainer.predict(model, model.val_dataloader())
|
||||
trainer.predict(model)
|
||||
|
||||
|
||||
@RunIf(ipu=True)
|
||||
|
|
|
@ -681,7 +681,7 @@ def test_warning_with_few_workers(_, tmpdir, ckpt_path, stage):
|
|||
|
||||
with pytest.warns(
|
||||
UserWarning,
|
||||
match=f'The dataloader, {stage} dataloader{" 0" if stage != "train" else ""}, does not have many workers',
|
||||
match=f'The dataloader, {stage}_dataloader{" 0" if stage != "train" else ""}, does not have many workers',
|
||||
):
|
||||
if stage == "test":
|
||||
if ckpt_path in ("specific", "best"):
|
||||
|
@ -720,7 +720,7 @@ def test_warning_with_few_workers_multi_loader(_, tmpdir, ckpt_path, stage):
|
|||
|
||||
with pytest.warns(
|
||||
UserWarning,
|
||||
match=f'The dataloader, {stage} dataloader{" 0" if stage != "train" else ""}, does not have many workers',
|
||||
match=f'The dataloader, {stage}_dataloader{" 0" if stage != "train" else ""}, does not have many workers',
|
||||
):
|
||||
if stage == "test":
|
||||
if ckpt_path in ("specific", "best"):
|
||||
|
|
|
@ -15,6 +15,7 @@ import torch
|
|||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.trainer.states import RunningStage
|
||||
from tests.base import EvalModelTemplate
|
||||
|
||||
|
||||
|
@ -107,7 +108,7 @@ def test_overfit_batch_limits(tmpdir):
|
|||
# ------------------------------------------------------
|
||||
# run tests for both val and test
|
||||
# ------------------------------------------------------
|
||||
for split in ["val", "test"]:
|
||||
for split in (RunningStage.VALIDATING, RunningStage.TESTING):
|
||||
|
||||
# ------------------------------------------------------
|
||||
# test overfit_batches as percent
|
||||
|
@ -134,7 +135,7 @@ def test_overfit_batch_limits(tmpdir):
|
|||
# ------------------------------------------------------
|
||||
# test limit_xxx_batches as percent AND int
|
||||
# ------------------------------------------------------
|
||||
if split == "val":
|
||||
if split == RunningStage.VALIDATING:
|
||||
loader_num_batches, dataloaders = Trainer(limit_val_batches=0.1)._reset_eval_dataloader(split, model=model)
|
||||
assert loader_num_batches[0] == int(0.1 * len(val_loader))
|
||||
|
||||
|
|
Loading…
Reference in New Issue