Replace DataLoader sampler once for IPUs (#8858)

This commit is contained in:
Carlos Mocholí 2021-08-16 11:28:05 +02:00 committed by GitHub
parent 1d2f7e20c4
commit 93ab24d1ee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 97 additions and 265 deletions

View File

@ -72,26 +72,11 @@ jobs:
python -c "import poptorch; print(poptorch.__version__)"
displayName: "Check poptorch installation"
- bash: |
wget https://pl-public-data.s3.amazonaws.com/legacy/checkpoints.zip -P legacy/
unzip -o legacy/checkpoints.zip -d legacy/
ls -l legacy/checkpoints/
displayName: 'Get legacy checkpoints'
- bash: |
source ${{ variables.poplar_sdk }}/poplar-ubuntu*/enable.sh
source ${{ variables.poplar_sdk }}/popart-ubuntu*/enable.sh
export POPTORCH_WAIT_FOR_IPU=1
python -m coverage run --source pytorch_lightning -m pytest pytorch_lightning tests -v --junitxml=$(Build.StagingDirectory)/test-results.xml --durations=50
python -m coverage run --source pytorch_lightning -m pytest tests/accelerators/test_ipu.py -v --junitxml=$(Build.StagingDirectory)/test-results.xml --durations=50
env:
MKL_THREADING_LAYER: "GNU"
displayName: 'Testing: standard'
- bash: |
source ${{ variables.poplar_sdk }}/poplar-ubuntu*/enable.sh
source ${{ variables.poplar_sdk }}/popart-ubuntu*/enable.sh
export POPTORCH_WAIT_FOR_IPU=1
bash tests/special_tests.sh
env:
MKL_THREADING_LAYER: "GNU"
displayName: 'Testing: special'

View File

@ -81,6 +81,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
(https://github.com/PyTorchLightning/pytorch-lightning/pull/8608))
- `Trainer.request_dataloader` now takes a `RunningStage` enum instance ([#8858](https://github.com/PyTorchLightning/pytorch-lightning/pull/8858))
### Deprecated
- Deprecated `LightningModule.summarize()` in favor of `pytorch_lightning.utilities.model_summary.summarize()`
@ -132,6 +134,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed `LightningModule.write_predictions` and `LightningModule.write_predictions_dict` ([#](https://github.com/PyTorchLightning/pytorch-lightning/pull/8850))
- Removed reset dataloader hooks to Training Plugins and Accelerators ([#8858](https://github.com/PyTorchLightning/pytorch-lightning/pull/8858))
### Fixed
@ -176,6 +181,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed infinite loop with CycleIterator and multiple loaders ([#8889](https://github.com/PyTorchLightning/pytorch-lightning/pull/8889))
- Fixed bug where data-loading functions where not getting the correct running stage passed ([#8858](https://github.com/PyTorchLightning/pytorch-lightning/pull/8858))
## [1.4.0] - 2021-07-27
### Added

View File

@ -410,22 +410,6 @@ class Accelerator:
"""
return self.training_type_plugin.process_dataloader(dataloader)
def on_reset_train_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]:
"""Called before resetting the train dataloader."""
return self.training_type_plugin.on_reset_train_dataloader(dataloader)
def on_reset_val_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]:
"""Called before resetting the val dataloader."""
return self.training_type_plugin.on_reset_val_dataloader(dataloader)
def on_reset_test_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]:
"""Called before resetting the test dataloader."""
return self.training_type_plugin.on_reset_test_dataloader(dataloader)
def on_reset_predict_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]:
"""Called before resetting the predict dataloader."""
return self.training_type_plugin.on_reset_predict_dataloader(dataloader)
@property
def results(self) -> Any:
"""

View File

@ -11,10 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import json
import os
from typing import Any, Iterable, List, Optional, Union
from typing import Any, List, Optional, Union
import torch
from torch.utils.data import DataLoader
@ -26,7 +25,6 @@ from pytorch_lightning.plugins.environments.cluster_environment import ClusterEn
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.trainer.supporters import CombinedLoader
from pytorch_lightning.utilities import _POPTORCH_AVAILABLE
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.cloud_io import get_filesystem
@ -112,6 +110,12 @@ class IPUPlugin(ParallelPlugin):
options["autoReport.directory"] = self.autoreport_dir
os.environ["POPLAR_ENGINE_OPTIONS"] = json.dumps(options)
def setup(self) -> None:
# patch the dataloader creation function with the custom `poptorch.DataLoader`.
# this violates the intended control flow for the plugins, but since this is experimental, we have chosen
# to use the simpler solution before adding abstractions to override the `DataLoader` class
self.lightning_module.trainer.replace_sampler = self._convert_to_poptorch_loader
def pre_dispatch(self) -> None:
precision = self.lightning_module.trainer.precision
model = LightningIPUModule(self.lightning_module, precision)
@ -169,59 +173,16 @@ class IPUPlugin(ParallelPlugin):
def lightning_module(self) -> Optional["pl.LightningModule"]:
return self.model.module if isinstance(self.model, LightningIPUModule) else self.model
def on_reset_train_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]:
return self._process_dataloader(dataloader, is_training=True)
def on_reset_val_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]:
return self._process_dataloader(dataloader, is_training=False)
def on_reset_test_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]:
return self._process_dataloader(dataloader, is_training=False)
def on_reset_predict_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]:
return self._process_dataloader(dataloader, is_training=False)
def _process_dataloader(
self, dataloader: Union[Iterable, DataLoader], is_training: bool
) -> Union[Iterable, DataLoader]:
if isinstance(dataloader, CombinedLoader):
dataloader.loaders = apply_to_collection(
dataloader.loaders, DataLoader, self._process_dataloader, is_training
)
return dataloader
if isinstance(dataloader, list):
dataloader = apply_to_collection(dataloader, DataLoader, self._process_dataloader, is_training)
return dataloader
if not isinstance(dataloader, poptorch.DataLoader):
opts = self.training_opts if is_training else self.inference_opts
dataloader = self._convert_to_poptorch_loader(dataloader=dataloader, opts=opts)
return dataloader
def _convert_to_poptorch_loader(
self, dataloader: Union[Iterable, DataLoader], opts: "poptorch.Options"
) -> Union[Iterable, DataLoader]:
skip_keys = ("sampler", "batch_sampler", "dataset_kind")
attrs = {k: v for k, v in vars(dataloader).items() if not k.startswith("_")}
params = set(inspect.signature(dataloader.__init__).parameters)
contains_dataset = True
if type(dataloader) is not DataLoader:
contains_dataset = "dataset" in params
params.update(inspect.signature(DataLoader.__init__).parameters)
dl_args = {name: attrs[name] for name in params if name in attrs and name not in skip_keys}
multiprocessing_context = dataloader.multiprocessing_context
dl_args["multiprocessing_context"] = multiprocessing_context
if not contains_dataset:
dl_args.pop("dataset")
self, dataloader: DataLoader, sampler, mode: Optional[RunningStage] = None
) -> "poptorch.DataLoader":
# use full path to avoid circular imports
dl_kwargs = pl.trainer.trainer.TrainerDataLoadingMixin._get_dataloader_init_kwargs(dataloader, sampler)
# Override to drop last uneven batch, as IPUs does not support uneven inputs.
dl_args["drop_last"] = True
dl_kwargs["drop_last"] = True
dataloader = poptorch.DataLoader(**dl_args, options=opts)
dataloader.multiprocessing_context = multiprocessing_context
opts = self.training_opts if mode == RunningStage.TRAINING else self.inference_opts
dataloader = poptorch.DataLoader(**dl_kwargs, options=opts)
return dataloader
@property
@ -291,6 +252,8 @@ class IPUPlugin(ParallelPlugin):
return self.poptorch_models[RunningStage.PREDICTING](*args, **kwargs)
def teardown(self) -> None:
# undo dataloader patching
self.lightning_module.trainer.replace_sampler = pl.trainer.trainer.TrainerDataLoadingMixin.replace_sampler
for model in self.poptorch_models.values():
model.destroy()

View File

@ -212,22 +212,6 @@ class TrainingTypePlugin(Plugin, ABC):
"""
return dataloader
def on_reset_train_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]:
"""Called before resetting the train dataloader."""
return dataloader
def on_reset_val_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]:
"""Called before resetting the val dataloader."""
return dataloader
def on_reset_test_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]:
"""Called before resetting the test dataloader."""
return dataloader
def on_reset_predict_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]:
"""Called before resetting the predict dataloader."""
return dataloader
def init_optimizers(self, trainer: "pl.Trainer", model: "pl.LightningModule"):
return trainer.init_optimizers(model)

View File

@ -19,7 +19,7 @@ from copy import deepcopy
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, Sampler, SequentialSampler
from torch.utils.data.dataset import IterableDataset
from torch.utils.data.distributed import DistributedSampler
@ -121,7 +121,9 @@ class TrainerDataLoadingMixin(ABC):
def auto_add_sampler(self, dataloader: Any, shuffle: bool, mode: Optional[RunningStage] = None) -> Any:
if isinstance(dataloader, CombinedLoader):
# apply `auto_add_sampler` on all the collection of loaders
dataloader.loaders = apply_to_collection(dataloader.loaders, DataLoader, self.auto_add_sampler, shuffle)
dataloader.loaders = apply_to_collection(
dataloader.loaders, DataLoader, self.auto_add_sampler, shuffle, mode=mode
)
return dataloader
# don't do anything if it's not a dataloader
@ -151,7 +153,9 @@ class TrainerDataLoadingMixin(ABC):
return dataloader
@staticmethod
def _resolve_batch_sampler(dataloader, sampler, mode: Optional[RunningStage] = None) -> Dict[str, Any]:
def _resolve_batch_sampler(
dataloader: DataLoader, sampler: Optional[Sampler], mode: Optional[RunningStage] = None
) -> Dict[str, Any]:
batch_sampler = getattr(dataloader, "batch_sampler")
is_predicting = mode == RunningStage.PREDICTING
# checking the batch sampler type is different than PyTorch default.
@ -182,7 +186,10 @@ class TrainerDataLoadingMixin(ABC):
return {"sampler": sampler, "shuffle": False, "batch_sampler": None}
def replace_sampler(self, dataloader: DataLoader, sampler, mode: Optional[RunningStage] = None) -> DataLoader:
@staticmethod
def _get_dataloader_init_kwargs(
dataloader: DataLoader, sampler: Optional[Sampler], mode: Optional[RunningStage] = None
) -> Dict[str, Any]:
if not isinstance(dataloader, DataLoader):
raise ValueError(f"The dataloader {dataloader} needs to subclass `torch.utils.data.DataLoader`")
@ -201,7 +208,7 @@ class TrainerDataLoadingMixin(ABC):
# kwargs to re-construct the dataloader
dl_kwargs = {k: v for k, v in attrs.items() if k in non_defaults}
dl_kwargs.update(self._resolve_batch_sampler(dataloader, sampler, mode=mode))
dl_kwargs.update(TrainerDataLoadingMixin._resolve_batch_sampler(dataloader, sampler, mode=mode))
required_args = {
p.name
@ -248,6 +255,11 @@ class TrainerDataLoadingMixin(ABC):
del dl_kwargs["sampler"]
del dl_kwargs["batch_sampler"]
return dl_kwargs
@staticmethod
def replace_sampler(dataloader: DataLoader, sampler, mode: Optional[RunningStage] = None) -> DataLoader:
dl_kwargs = TrainerDataLoadingMixin._get_dataloader_init_kwargs(dataloader, sampler, mode=mode)
dl_cls = type(dataloader)
dataloader = dl_cls(**dl_kwargs)
return dataloader
@ -269,7 +281,7 @@ class TrainerDataLoadingMixin(ABC):
Args:
model: The `LightningModule` if calling this outside of the trainer scope.
"""
self.train_dataloader = self.request_dataloader("train", model=model)
self.train_dataloader = self.request_dataloader(RunningStage.TRAINING, model=model)
if self.overfit_batches > 0:
if hasattr(self.train_dataloader, "sampler") and isinstance(self.train_dataloader.sampler, RandomSampler):
@ -278,7 +290,7 @@ class TrainerDataLoadingMixin(ABC):
" We are turning off the training dataloader shuffling for you."
)
self.train_dataloader = self.replace_sampler(
self.train_dataloader, SequentialSampler(self.train_dataloader.dataset)
self.train_dataloader, SequentialSampler(self.train_dataloader.dataset), mode=RunningStage.TRAINING
)
# debugging
@ -286,11 +298,11 @@ class TrainerDataLoadingMixin(ABC):
# automatically add samplers
self.train_dataloader = apply_to_collection(
self.train_dataloader, DataLoader, self.auto_add_sampler, shuffle=True
self.train_dataloader, DataLoader, self.auto_add_sampler, shuffle=True, mode=RunningStage.TRAINING
)
# check the workers recursively
apply_to_collection(self.train_dataloader, DataLoader, self._worker_check, "train dataloader")
apply_to_collection(self.train_dataloader, DataLoader, self._worker_check, "train_dataloader")
# add worker_init_fn for correct seeding in worker processes
apply_to_collection(self.train_dataloader, DataLoader, self.auto_add_worker_init_fn)
@ -302,9 +314,6 @@ class TrainerDataLoadingMixin(ABC):
# wrap the sequence of train loaders to a CombinedLoader object for computing the num_training_batches
self.train_dataloader = CombinedLoader(self.train_dataloader, self.data_connector.multiple_trainloader_mode)
# allow accelerator to modify dataloader
self.train_dataloader = self.accelerator.on_reset_train_dataloader(self.train_dataloader)
self.num_training_batches = len(self.train_dataloader) if has_len(self.train_dataloader) else float("inf")
if isinstance(self.limit_train_batches, int) or self.limit_train_batches == 0.0:
@ -351,69 +360,61 @@ class TrainerDataLoadingMixin(ABC):
)
def _reset_eval_dataloader(
self, mode: str, model: Optional["pl.LightningModule"] = None
self, mode: RunningStage, model: Optional["pl.LightningModule"] = None
) -> Tuple[List[Union[int, float]], List[DataLoader]]:
"""Generic method to reset a dataloader for evaluation.
Args:
mode: Either `'val'`, `'test'` or `'predict'`
model: The `LightningModule` if calling this outside of the trainer scope.
mode: The running stage of the ``Trainer``
model: The ``LightningModule`` if calling this outside of the trainer scope.
Returns:
Tuple (num_batches, dataloaders)
"""
assert mode.evaluating or mode == RunningStage.PREDICTING
# always get the loaders first so we can count how many there are
loader_name = f"{mode}_dataloader"
loader_name = f"{mode.dataloader_prefix}_dataloader"
dataloaders = self.request_dataloader(mode, model=model)
if not isinstance(dataloaders, list):
dataloaders = [dataloaders]
# when overfitting use the training loader as val and test
# when overfitting, use the training loader as val and test
# duplicate it the numb of times needed to match the train loaders
if self.overfit_batches > 0:
num_loaders = len(dataloaders)
train_dataloader = self.request_dataloader("train", model=model)
dataloaders = [deepcopy(train_dataloader) for _ in range(num_loaders)]
train_dataloader = self.request_dataloader(RunningStage.TRAINING, model=model)
dataloaders = [deepcopy(train_dataloader) for _ in range(len(dataloaders))]
self.dev_debugger.track_load_dataloader_call(loader_name, dataloaders=dataloaders)
for loader_i in range(len(dataloaders)):
loader = dataloaders[loader_i]
# shuffling in val and test set is bad practice
modes = ("val", "test", "predict")
if mode in modes and hasattr(loader, "sampler") and isinstance(loader.sampler, RandomSampler):
if hasattr(loader, "sampler") and isinstance(loader.sampler, RandomSampler):
# when overfitting, the dataloader should not have sampler
if self.overfit_batches > 0 and mode != "predict":
if self.overfit_batches > 0 and mode.evaluating:
rank_zero_warn(
"You requested to overfit but enabled val/test dataloader shuffling."
" We are turning it off for you."
)
dataloaders[loader_i] = self.replace_sampler(loader, SequentialSampler(loader.dataset))
dataloaders[loader_i] = self.replace_sampler(loader, SequentialSampler(loader.dataset), mode=mode)
else:
rank_zero_warn(
f"Your {mode}_dataloader has `shuffle=True`, it is best practice to turn"
" this off for val/test/predict dataloaders."
f"Your `{mode.dataloader_prefix}_dataloader` has `shuffle=True`,"
"it is strongly recommended that you turn this off for val/test/predict dataloaders."
)
if any(dl is None for dl in dataloaders):
rank_zero_warn("One of given dataloaders is None and it will be skipped.")
# add samplers
dataloaders = [
self.auto_add_sampler(dl, shuffle=False, mode=self.state.stage) for dl in dataloaders if dl is not None
]
dataloaders = [self.auto_add_sampler(dl, False, mode=mode) for dl in dataloaders if dl is not None]
# add worker_init_fn for correct seeding in worker processes
apply_to_collection(dataloaders, dtype=DataLoader, function=self.auto_add_worker_init_fn)
# allow accelerator to modify dataloader
hook_name = f"on_reset_{mode}_dataloader"
dataloaders = getattr(self.accelerator, hook_name)(dataloaders)
loader_num_batches = []
# determine number of batches
@ -421,10 +422,10 @@ class TrainerDataLoadingMixin(ABC):
if len(dataloaders) != 0:
for i, dataloader in enumerate(dataloaders):
num_batches = len(dataloader) if has_len(dataloader) else float("inf")
self._worker_check(dataloader, f"{mode} dataloader {i}")
self._worker_check(dataloader, f"{mode.dataloader_prefix}_dataloader {i}")
# percent or num_steps
limit_eval_batches = getattr(self, f"limit_{mode}_batches")
limit_eval_batches = getattr(self, f"limit_{mode.dataloader_prefix}_batches")
# limit num batches either as a percent or num steps
if isinstance(limit_eval_batches, int) or limit_eval_batches == 0.0:
@ -433,17 +434,18 @@ class TrainerDataLoadingMixin(ABC):
num_batches = int(num_batches * limit_eval_batches)
elif limit_eval_batches != 1.0:
raise MisconfigurationException(
"When using an IterableDataset for `limit_{mode}_batches`,"
f" `Trainer(limit_{mode}_batches)` must be `0.0`, `1.0` or an int. An int k specifies"
f" `num_{mode}_batches` to use."
f"When using an IterableDataset for `limit_{mode}_batches`,"
f" `Trainer(limit_{mode.dataloader_prefix}_batches)` must be `0.0`, `1.0` or an int. An int k"
f" specifies `num_{mode.dataloader_prefix}_batches` to use."
)
if num_batches == 0 and limit_eval_batches > 0.0 and isinstance(limit_eval_batches, float):
min_pct = 1.0 / len(dataloader)
raise MisconfigurationException(
f"you requested to check {limit_eval_batches} of the {mode} dataloader but"
f" {limit_eval_batches}*{num_batches} < 1. Please increase the limit_{mode}_batches."
f" Try at least limit_{mode}_batches={min_pct}"
f"you requested to check {limit_eval_batches} of the `{mode.dataloader_prefix}_dataloader` but"
f" {limit_eval_batches}*{num_batches} < 1. Please increase the"
f" `limit_{mode.dataloader_prefix}_batches` flag. Try at least"
f" `limit_{mode.dataloader_prefix}_batches={min_pct}`"
)
loader_num_batches.append(num_batches)
@ -460,7 +462,9 @@ class TrainerDataLoadingMixin(ABC):
has_loader = is_overridden("val_dataloader", pl_module)
has_step = is_overridden("validation_step", pl_module)
if has_loader and has_step:
self.num_val_batches, self.val_dataloaders = self._reset_eval_dataloader("val", model=pl_module)
self.num_val_batches, self.val_dataloaders = self._reset_eval_dataloader(
RunningStage.VALIDATING, model=pl_module
)
def reset_test_dataloader(self, model: Optional["pl.LightningModule"] = None) -> None:
"""Resets the test dataloader and determines the number of batches.
@ -472,7 +476,9 @@ class TrainerDataLoadingMixin(ABC):
has_loader = is_overridden("test_dataloader", pl_module)
has_step = is_overridden("test_step", pl_module)
if has_loader and has_step:
self.num_test_batches, self.test_dataloaders = self._reset_eval_dataloader("test", model=pl_module)
self.num_test_batches, self.test_dataloaders = self._reset_eval_dataloader(
RunningStage.TESTING, model=pl_module
)
def reset_predict_dataloader(self, model: Optional["pl.LightningModule"] = None) -> None:
"""Resets the predict dataloader and determines the number of batches.
@ -483,7 +489,9 @@ class TrainerDataLoadingMixin(ABC):
pl_module = self.lightning_module or model
has_loader = is_overridden("predict_dataloader", pl_module)
if has_loader:
self.num_predict_batches, self.predict_dataloaders = self._reset_eval_dataloader("predict", model=pl_module)
self.num_predict_batches, self.predict_dataloaders = self._reset_eval_dataloader(
RunningStage.PREDICTING, model=pl_module
)
def reset_train_val_dataloaders(self, model: Optional["pl.LightningModule"] = None) -> None:
"""
@ -501,15 +509,15 @@ class TrainerDataLoadingMixin(ABC):
self.reset_val_dataloader(model=model)
def request_dataloader(
self, stage: str, model: Optional["pl.LightningModule"] = None
self, stage: RunningStage, model: Optional["pl.LightningModule"] = None
) -> Union[DataLoader, List[DataLoader]]:
"""Handles downloading data in the GPU or TPU case.
Returns:
The dataloader
"""
self.call_hook(f"on_{stage}_dataloader")
dataloader = getattr(model, f"{stage}_dataloader")()
self.call_hook(f"on_{stage.dataloader_prefix}_dataloader")
dataloader = getattr(model, f"{stage.dataloader_prefix}_dataloader")()
if isinstance(dataloader, tuple):
dataloader = list(dataloader)
self.accelerator.barrier("get_dataloaders")

View File

@ -78,6 +78,14 @@ class RunningStage(LightningEnum):
def evaluating(self) -> bool:
return self in (self.VALIDATING, self.TESTING)
@property
def dataloader_prefix(self) -> Optional[str]:
if self in (self.SANITY_CHECKING, self.TUNING):
return None
if self == self.VALIDATING:
return "val"
return self.value
@dataclass
class TrainerState:

View File

@ -54,115 +54,6 @@ def test_plugin_setup_optimizers_in_pre_dispatch(tmpdir, delay_dispatch):
trainer.fit(model)
def test_accelerator_on_reset_dataloader_hooks(tmpdir):
"""
Ensure data-loader hooks are called using an Accelerator.
"""
class CustomAccelerator(CPUAccelerator):
train_count: int = 0
val_count: int = 0
test_count: int = 0
predict_count: int = 0
def on_reset_train_dataloader(self, dataloader):
self.train_count += 1
assert self.lightning_module.trainer.training
return super().on_reset_train_dataloader(dataloader)
def on_reset_val_dataloader(self, dataloader):
self.val_count += 1
assert self.lightning_module.trainer.training or self.lightning_module.trainer.validating
return super().on_reset_val_dataloader(dataloader)
def on_reset_test_dataloader(self, dataloader):
self.test_count += 1
assert self.lightning_module.trainer.testing
return super().on_reset_test_dataloader(dataloader)
def on_reset_predict_dataloader(self, dataloader):
self.predict_count += 1
assert self.lightning_module.trainer.predicting
return super().on_reset_predict_dataloader(dataloader)
model = BoringModel()
accelerator = CustomAccelerator(PrecisionPlugin(), SingleDevicePlugin(device=torch.device("cpu")))
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, accelerator=accelerator)
trainer.fit(model)
trainer.validate(model)
trainer.test(model)
trainer.predict(model, dataloaders=model.test_dataloader())
# assert that all loader hooks were called
assert accelerator.train_count == 1
assert accelerator.val_count == 1 # only called once during the entire session
assert accelerator.test_count == 1
assert accelerator.predict_count == 1
accelerator = CustomAccelerator(PrecisionPlugin(), SingleDevicePlugin(device=torch.device("cpu")))
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, accelerator=accelerator)
trainer.validate(model)
trainer.test(model)
trainer.predict(model)
# assert val/test/predict loader hooks were called
assert accelerator.val_count == 1
assert accelerator.test_count == 1
assert accelerator.predict_count == 1
def test_plugin_on_reset_dataloader_hooks(tmpdir):
"""
Ensure data-loader hooks are called using a Plugin.
"""
class CustomPlugin(SingleDevicePlugin):
train_count: int = 0
val_count: int = 0
test_count: int = 0
predict_count: int = 0
def on_reset_train_dataloader(self, dataloader):
self.train_count += 1
assert self.lightning_module.trainer.training
return super().on_reset_train_dataloader(dataloader)
def on_reset_val_dataloader(self, dataloader):
self.val_count += 1
assert self.lightning_module.trainer.training or self.lightning_module.trainer.validating
return super().on_reset_val_dataloader(dataloader)
def on_reset_test_dataloader(self, dataloader):
self.test_count += 1
assert self.lightning_module.trainer.testing
return super().on_reset_test_dataloader(dataloader)
def on_reset_predict_dataloader(self, dataloader):
self.predict_count += 1
assert self.lightning_module.trainer.predicting
return super().on_reset_predict_dataloader(dataloader)
plugin = CustomPlugin(device=torch.device("cpu"))
model = BoringModel()
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, plugins=plugin)
trainer.fit(model)
trainer.validate(model)
trainer.test(model)
trainer.predict(model, dataloaders=model.test_dataloader())
# assert that all loader hooks were called
assert plugin.train_count == 1
assert plugin.val_count == 1 # only called once during the entire session
assert plugin.test_count == 1
assert plugin.predict_count == 1
plugin = CustomPlugin(device=torch.device("cpu"))
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, plugins=plugin)
trainer.validate(model)
trainer.test(model)
trainer.predict(model)
# assert val/test/predict loader hooks were called
assert plugin.val_count == 1
assert plugin.test_count == 1
assert plugin.predict_count == 1
def test_restore_checkpoint_after_pre_dispatch_default():
"""
Assert default for restore_checkpoint_after_pre_dispatch is False.

View File

@ -132,7 +132,7 @@ def test_all_stages(tmpdir, ipus):
trainer.fit(model)
trainer.validate(model)
trainer.test(model)
trainer.predict(model, model.val_dataloader())
trainer.predict(model)
@RunIf(ipu=True)
@ -143,7 +143,7 @@ def test_inference_only(tmpdir, ipus):
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, ipus=ipus)
trainer.validate(model)
trainer.test(model)
trainer.predict(model, model.val_dataloader())
trainer.predict(model)
@RunIf(ipu=True)

View File

@ -681,7 +681,7 @@ def test_warning_with_few_workers(_, tmpdir, ckpt_path, stage):
with pytest.warns(
UserWarning,
match=f'The dataloader, {stage} dataloader{" 0" if stage != "train" else ""}, does not have many workers',
match=f'The dataloader, {stage}_dataloader{" 0" if stage != "train" else ""}, does not have many workers',
):
if stage == "test":
if ckpt_path in ("specific", "best"):
@ -720,7 +720,7 @@ def test_warning_with_few_workers_multi_loader(_, tmpdir, ckpt_path, stage):
with pytest.warns(
UserWarning,
match=f'The dataloader, {stage} dataloader{" 0" if stage != "train" else ""}, does not have many workers',
match=f'The dataloader, {stage}_dataloader{" 0" if stage != "train" else ""}, does not have many workers',
):
if stage == "test":
if ckpt_path in ("specific", "best"):

View File

@ -15,6 +15,7 @@ import torch
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from pytorch_lightning import Trainer
from pytorch_lightning.trainer.states import RunningStage
from tests.base import EvalModelTemplate
@ -107,7 +108,7 @@ def test_overfit_batch_limits(tmpdir):
# ------------------------------------------------------
# run tests for both val and test
# ------------------------------------------------------
for split in ["val", "test"]:
for split in (RunningStage.VALIDATING, RunningStage.TESTING):
# ------------------------------------------------------
# test overfit_batches as percent
@ -134,7 +135,7 @@ def test_overfit_batch_limits(tmpdir):
# ------------------------------------------------------
# test limit_xxx_batches as percent AND int
# ------------------------------------------------------
if split == "val":
if split == RunningStage.VALIDATING:
loader_num_batches, dataloaders = Trainer(limit_val_batches=0.1)._reset_eval_dataloader(split, model=model)
assert loader_num_batches[0] == int(0.1 * len(val_loader))