diff --git a/CHANGELOG.md b/CHANGELOG.md index 7ef319d5f9..d1fc665048 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -96,6 +96,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Changed the name of the temporary checkpoint that the `DDPSpawnPlugin` and related plugins save ([#10934](https://github.com/PyTorchLightning/pytorch-lightning/pull/10934)) +- Redesigned process creation for spawn-based plugins (`DDPSpawnPlugin`, `TPUSpawnPlugin`, etc.) ([#10896](https://github.com/PyTorchLightning/pytorch-lightning/pull/10896)) + * All spawn-based plugins now spawn processes immediately upon calling `Trainer.{fit,validate,test,predict}` + * The hooks/callbacks `prepare_data`, `setup`, `configure_sharded_model` and `teardown` now run under initialized process group for spawn-based plugins just like their non-spawn counterparts + * Some configuration errors that were previously raised as `MisconfigurationException`s will now be raised as `ProcessRaisedException` (torch>=1.8) or as `Exception` (torch<1.8) + + ### Deprecated - Deprecated `ClusterEnvironment.master_{address,port}` in favor of `ClusterEnvironment.main_{address,port}` ([#10103](https://github.com/PyTorchLightning/pytorch-lightning/issues/10103)) @@ -239,7 +245,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed method `training_step`, `test_step`, `validation_step` and `predict_step` from the `Accelerator` ([#10890](https://github.com/PyTorchLightning/pytorch-lightning/pull/10890)) -- Removed `HorovodPlugin.start_{training,evaluating,predicting}` hooks ([#10989](https://github.com/PyTorchLightning/pytorch-lightning/pull/10989)) +- Removed `TrainingTypePlugin.start_{training,evaluating,predicting}` hooks and the same in all subclasses ([#10989](https://github.com/PyTorchLightning/pytorch-lightning/pull/10989), [#10896](https://github.com/PyTorchLightning/pytorch-lightning/pull/10896)) - Removed `Accelerator.on_train_start` ([#10999](https://github.com/PyTorchLightning/pytorch-lightning/pull/10999)) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 661ba20f2e..b5373e2202 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -132,23 +132,6 @@ class DDPSpawnPlugin(ParallelPlugin): def get_mp_spawn_kwargs(self, trainer: Optional["pl.Trainer"] = None) -> Dict[str, Any]: return {"nprocs": self.num_processes} - def start_training(self, trainer: "pl.Trainer") -> Any: - spawn_output: _SpawnOutput = self.spawn(self.new_process, trainer) - self._recover_results_in_main_process(spawn_output, trainer) - # reset optimizers, since main process is never used for training and thus does not have a valid optim state - trainer.optimizers = [] - return spawn_output.trainer_results - - def start_evaluating(self, trainer: "pl.Trainer") -> Any: - spawn_output: _SpawnOutput = self.spawn(self.new_process, trainer) - self._recover_results_in_main_process(spawn_output, trainer) - return spawn_output.trainer_results - - def start_predicting(self, trainer: "pl.Trainer") -> Any: - spawn_output: _SpawnOutput = self.spawn(self.new_process, trainer) - self._recover_results_in_main_process(spawn_output, trainer) - return spawn_output.trainer_results - def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> Optional[Union[Any, "_SpawnOutput"]]: """Spawn processes that run the given function. @@ -184,7 +167,9 @@ class DDPSpawnPlugin(ParallelPlugin): self.cluster_environment, self.torch_distributed_backend, self.global_rank, self.world_size ) - def new_process(self, trainer: "pl.Trainer") -> Optional["_SpawnOutput"]: + def pre_dispatch(self, trainer: "pl.Trainer") -> None: + super().pre_dispatch(trainer) + # move the model to the correct device self.model_to_device() @@ -196,15 +181,6 @@ class DDPSpawnPlugin(ParallelPlugin): if trainer_fn == TrainerFn.FITTING: self.configure_ddp() - self.barrier() - - results = trainer.run_stage() - outputs = self._collect_rank_zero_results(trainer, results) - - # ensure that spawned processes go through teardown before joining - trainer._call_teardown_hook() - return outputs - def pre_configure_ddp(self): # if unset, default `find_unused_parameters` `True` # Many models require setting this parameter to True, as there are corner cases @@ -268,7 +244,7 @@ class DDPSpawnPlugin(ParallelPlugin): return _SpawnOutput(best_model_path, weights_path, trainer.state, results, extra) - def _recover_results_in_main_process(self, spawn_output: "_SpawnOutput", trainer) -> None: + def _recover_results_in_main_process(self, spawn_output: "_SpawnOutput", trainer: "pl.Trainer") -> None: # transfer back the best path to the trainer if trainer.checkpoint_callback: trainer.checkpoint_callback.best_model_path = spawn_output.best_model_path diff --git a/pytorch_lightning/plugins/training_type/sharded_spawn.py b/pytorch_lightning/plugins/training_type/sharded_spawn.py index 0483512592..951a0be78e 100644 --- a/pytorch_lightning/plugins/training_type/sharded_spawn.py +++ b/pytorch_lightning/plugins/training_type/sharded_spawn.py @@ -20,7 +20,7 @@ from torch.optim import Optimizer import pytorch_lightning as pl from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin -from pytorch_lightning.plugins.training_type.ddp_spawn import _SpawnOutput, DDPSpawnPlugin +from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, rank_zero_only from pytorch_lightning.utilities.enums import _StrategyType @@ -114,12 +114,12 @@ class DDPSpawnShardedPlugin(DDPSpawnPlugin): def post_training_step(self): pass - def new_process(self, trainer: "pl.Trainer") -> Optional["_SpawnOutput"]: + def pre_dispatch(self, trainer: "pl.Trainer") -> None: # Ensure that the scaler points to the correct process group # which is re-initialized in a new process if isinstance(self.precision_plugin, ShardedNativeMixedPrecisionPlugin): self._precision_plugin.scaler = ShardedGradScaler() - return super().new_process(trainer) + return super().pre_dispatch(trainer) @classmethod def register_plugins(cls, plugin_registry: Dict) -> None: diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 379d0ffe7b..4050f71fc0 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -23,7 +23,6 @@ from torch.nn import Module from torch.utils.data import DataLoader import pytorch_lightning as pl -from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger from pytorch_lightning.overrides import LightningDistributedModule from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO @@ -118,10 +117,23 @@ class TPUSpawnPlugin(DDPSpawnPlugin): return super().connect(model) def pre_dispatch(self, trainer: "pl.Trainer") -> None: - super().pre_dispatch(trainer) + self._move_optimizer_state() if self.debug: os.environ["PT_XLA_DEBUG"] = str(1) + if self.tpu_global_core_rank != 0 and trainer.progress_bar_callback is not None: + trainer.progress_bar_callback.disable() + + shared_params = find_shared_parameters(self.model) + self.model_to_device() + if is_overridden("on_post_move_to_device", self.lightning_module): + self.model.module.on_post_move_to_device() + else: + set_shared_parameters(self.model.module, shared_params) + + self.setup_optimizers(trainer) + self.precision_plugin.connect(self._model, None, None) + def setup(self, trainer: "pl.Trainer") -> None: self.start_method = "fork" super().setup(trainer) @@ -154,37 +166,6 @@ class TPUSpawnPlugin(DDPSpawnPlugin): def set_world_ranks(self, process_idx: int = 0) -> None: pass - def new_process(self, trainer: "pl.Trainer") -> Optional["_SpawnOutput"]: - if self.tpu_global_core_rank != 0 and trainer.progress_bar_callback is not None: - trainer.progress_bar_callback.disable() - - shared_params = find_shared_parameters(self.model) - self.model_to_device() - if is_overridden("on_post_move_to_device", self.lightning_module): - self.model.module.on_post_move_to_device() - else: - set_shared_parameters(self.model.module, shared_params) - - trainer.training_type_plugin.setup_optimizers(trainer) - trainer.precision_plugin.connect(self._model, None, None) - - self.barrier("pre-run-stage") - - results = trainer.run_stage() - - outputs = self._collect_rank_zero_results(trainer, results) - - # https://github.com/pytorch/xla/issues/1801#issuecomment-602799542 - self.barrier("end-process") - - # https://github.com/pytorch/xla/issues/2190#issuecomment-641665358 - if self.local_rank == 0: - time.sleep(2) - - # ensure that spawned processes go through teardown before joining - trainer._call_teardown_hook() - return outputs - def model_to_device(self) -> None: self.model = self.wrapped_model.to(self.root_device) @@ -215,8 +196,7 @@ class TPUSpawnPlugin(DDPSpawnPlugin): if is_overridden("add_to_queue", self.lightning_module): # TODO: Remove the if in v1.7 self.lightning_module.add_to_queue(extra) - else: - self.add_to_queue(trainer, extra) + self.add_to_queue(trainer, extra) return _SpawnOutput(best_model_path, weights_path, trainer.state, results, extra) @@ -263,6 +243,9 @@ class TPUSpawnPlugin(DDPSpawnPlugin): } def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> Optional[Union[Any, "_SpawnOutput"]]: + # todo: precision pluging is call in accelerator setup and should be moved + if "XLA_USE_BF16" in os.environ: + del os.environ["XLA_USE_BF16"] context = mp.get_context(self.start_method or "fork") return_queue = context.SimpleQueue() xmp.spawn(self._wrapped_function, args=(function, args, kwargs, return_queue), **self.get_mp_spawn_kwargs()) @@ -276,7 +259,10 @@ class TPUSpawnPlugin(DDPSpawnPlugin): if self.local_rank == 0: return_queue.put(move_data_to_device(result, "cpu")) + # https://github.com/pytorch/xla/issues/1801#issuecomment-602799542 self.barrier("end-process") + + # Ensure that the rank 0 process is the one exiting last # https://github.com/pytorch/xla/issues/2190#issuecomment-641665358 if self.local_rank == 0: time.sleep(2) @@ -287,21 +273,6 @@ class TPUSpawnPlugin(DDPSpawnPlugin): self.tpu_global_core_rank = xm.get_ordinal() rank_zero_only.rank = self.global_rank - def start_training(self, trainer: "pl.Trainer") -> Any: - # todo: precision pluging is call in accelerator setup and should be moved - if "XLA_USE_BF16" in os.environ: - del os.environ["XLA_USE_BF16"] - self._clean_logger(trainer) - return super().start_training(trainer) - - def start_evaluating(self, trainer: "pl.Trainer") -> Any: - self._clean_logger(trainer) - return super().start_evaluating(trainer) - - def start_predicting(self, trainer: "pl.Trainer") -> Any: - self._clean_logger(trainer) - return super().start_predicting(trainer) - def validation_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]: with self.precision_plugin.val_step_context(): return self.model(*args, **kwargs) @@ -358,9 +329,7 @@ class TPUSpawnPlugin(DDPSpawnPlugin): return xm.all_gather(tensor) def teardown(self) -> None: - # TPU teardown os.environ.pop("PT_XLA_DEBUG", None) - self.barrier("teardown") @property def should_rank_save_checkpoint(self) -> bool: @@ -377,13 +346,3 @@ class TPUSpawnPlugin(DDPSpawnPlugin): @checkpoint_io.setter def checkpoint_io(self, plugin: CheckpointIO) -> None: raise MisconfigurationException("TPU Spawn Plugin currently does not support custom checkpoint plugins.") - - @staticmethod - def _clean_logger(trainer: "pl.Trainer") -> None: - loggers = trainer.logger._logger_iterable if isinstance(trainer.logger, LoggerCollection) else [trainer.logger] - for logger in loggers: - if isinstance(logger, TensorBoardLogger) and logger._experiment is not None: - # the experiment class of `TensorBoard` holds a multiprocessing queue which can make ours hang. - # we want to make sure these are closed before we spawn our own threads. - # assuming nothing else references the experiment object, python should instantly `__del__` it. - logger._experiment = None diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index ee1c41764a..171ce23f2f 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -307,18 +307,6 @@ class TrainingTypePlugin(ABC): for optimizer, opt_state in zip(self.optimizers, optimizer_states): optimizer.load_state_dict(opt_state) - def start_training(self, trainer: "pl.Trainer") -> Any: - # double dispatch to initiate the training loop - return trainer.run_stage() - - def start_evaluating(self, trainer: "pl.Trainer") -> Any: - # double dispatch to initiate the test loop - return trainer.run_stage() - - def start_predicting(self, trainer: "pl.Trainer") -> Any: - # double dispatch to initiate the predicting loop - return trainer.run_stage() - def training_step(self, *args, **kwargs) -> STEP_OUTPUT: """The actual training step. diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index dde8edce21..bacb4cbbd5 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -49,6 +49,7 @@ from pytorch_lightning.plugins import ( TrainingTypePlugin, ) from pytorch_lightning.plugins.environments.slurm_environment import SLURMEnvironment +from pytorch_lightning.plugins.training_type.ddp_spawn import _SpawnOutput from pytorch_lightning.profiler import ( AdvancedProfiler, BaseProfiler, @@ -673,7 +674,12 @@ class Trainer( **kwargs: keyword arguments to be passed to `trainer_fn` """ try: - return trainer_fn(*args, **kwargs) + if isinstance(self.training_type_plugin, DDPSpawnPlugin): + spawn_output: _SpawnOutput = self.training_type_plugin.spawn(trainer_fn, *args, **kwargs) + self.training_type_plugin._recover_results_in_main_process(spawn_output, self) + return spawn_output.trainer_results + else: + return trainer_fn(*args, **kwargs) # TODO: treat KeyboardInterrupt as BaseException (delete the code below) in v1.7 except KeyboardInterrupt as exception: rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...") @@ -721,6 +727,7 @@ class Trainer( datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`. """ + self.training_type_plugin.model = model self._call_and_handle_interrupt( self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path ) @@ -756,10 +763,11 @@ class Trainer( # TODO: ckpt_path only in v1.7 ckpt_path = ckpt_path or self.resume_from_checkpoint - self._run(model, ckpt_path=ckpt_path) + results = self._run(model, ckpt_path=ckpt_path) assert self.state.stopped self.training = False + return results def validate( self, @@ -793,6 +801,7 @@ class Trainer( :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_epoch_end`, etc. The length of the list corresponds to the number of validation dataloaders used. """ + self.training_type_plugin.model = model or self.lightning_module return self._call_and_handle_interrupt(self._validate_impl, model, dataloaders, ckpt_path, verbose, datamodule) def _validate_impl( @@ -876,6 +885,7 @@ class Trainer( :meth:`~pytorch_lightning.core.lightning.LightningModule.test_epoch_end`, etc. The length of the list corresponds to the number of test dataloaders used. """ + self.training_type_plugin.model = model or self.lightning_module return self._call_and_handle_interrupt(self._test_impl, model, dataloaders, ckpt_path, verbose, datamodule) def _test_impl( @@ -958,6 +968,7 @@ class Trainer( Returns: Returns a list of dictionaries, one for each provided dataloader containing their respective predictions. """ + self.training_type_plugin.model = model or self.lightning_module return self._call_and_handle_interrupt( self._predict_impl, model, dataloaders, datamodule, return_predictions, ckpt_path ) @@ -1113,25 +1124,21 @@ class Trainer( Lightning internal flow looks like this: {Trainer.fit} or {Trainer.test} or {Trainer.predict} || | || - create accelerator || + spawn processes || + {self.accelerator.setup_environment} || | || - {self._dispatch} || - | || LIGHTNING - {self.training_type_plugin.start_training} || - or {self.training_type_plugin.start_evaluating} || - or {self.training_type_plugin.start_predicting} || FLOW + setup accelerator || + and strategy || LIGHTNING | || - {self.run_stage} || - | || DIRECTION - {self._run_train} || + {self.run_stage} || FLOW + | || + {self._run_train} || DIRECTION or {self._run_evaluate} || or {self._run_predict} || | || results \/ This is used to guide readers to the core loops: train, test, predict. {self._run_predict} is the simplest to understand, use `Go to Definition` to read it :) - Search for `start_training` or `start_evaluating` or `start_predicting` in - `pytorch_lightning/plugins/training_type_plugin` to find accelerator dispatch functions. """ # ---------------------------- @@ -1147,7 +1154,7 @@ class Trainer( self._call_callback_hooks("on_fit_start") self._call_lightning_module_hook("on_fit_start") - # plugin will setup fitting (e.g. ddp will launch child processes) + # plugin will move model to device self._pre_dispatch() if self.training_type_plugin.restore_checkpoint_after_pre_dispatch: @@ -1158,10 +1165,8 @@ class Trainer( self.checkpoint_connector.resume_end() - # dispatch `start_training` or `start_evaluating` or `start_predicting` - results = self._dispatch() - - self._post_dispatch() + results = self.run_stage() + self._teardown() # ---------------------------- # POST-Training CLEAN UP @@ -1171,15 +1176,15 @@ class Trainer( self._call_callback_hooks("on_fit_end") self._call_lightning_module_hook("on_fit_end") - # teardown if necessary (similar calls for spawn plugins are excluded as they have - # been included at the end of `new_process` functions) - if not isinstance(self.training_type_plugin, DDPSpawnPlugin): - self._call_teardown_hook() + self._call_teardown_hook() if self.state.status != TrainerStatus.INTERRUPTED: self.state.status = TrainerStatus.FINISHED self.state.stage = None + if isinstance(self.training_type_plugin, DDPSpawnPlugin): + results = self.training_type_plugin._collect_rank_zero_results(self, results) + return results def _pre_dispatch(self): @@ -1223,9 +1228,9 @@ class Trainer( self.logger.log_graph(self.lightning_module) self.logger.save() - def _post_dispatch(self): - # these `teardown` calls are here instead of in `_call_teardown_hook` since they are internal teardowns - # which need to happen before. + def _teardown(self): + """This is the Trainer's internal teardown, unrelated to the `teardown` hooks in LightningModule and + Callback; those are handled by :meth:`_call_teardown_hook`.""" self.training_type_plugin.post_dispatch(self) self.accelerator.teardown() self._data_connector.teardown() @@ -1233,15 +1238,8 @@ class Trainer( self.logger_connector.teardown() self.signal_connector.teardown() - def _dispatch(self) -> Any: - if self.evaluating: - return self.training_type_plugin.start_evaluating(self) - elif self.predicting: - return self.training_type_plugin.start_predicting(self) - else: - return self.training_type_plugin.start_training(self) - def run_stage(self): + self.training_type_plugin.barrier("run-stage") self.training_type_plugin.dispatch(self) self.__setup_profiler() diff --git a/tests/plugins/test_ddp_spawn_plugin.py b/tests/plugins/test_ddp_spawn_plugin.py index 44a97eaf4f..c8c861050d 100644 --- a/tests/plugins/test_ddp_spawn_plugin.py +++ b/tests/plugins/test_ddp_spawn_plugin.py @@ -80,8 +80,7 @@ def test_ddp_spawn_extra_parameters(tmpdir): val_name: str = "val_acc" model = BoringCallbackDDPSpawnModel(val_name, val) dm = BoringDataModule() - with pytest.deprecated_call(match="add_to_queue` method was deprecated in v1.5"): - trainer.fit(model, datamodule=dm) + trainer.fit(model, datamodule=dm) assert trainer.callback_metrics[val_name] == torch.tensor(val) assert model.test_val == "test_val" @@ -107,8 +106,7 @@ def test_ddp_spawn_add_get_queue(tmpdir): val_name: str = "val_acc" model = BoringCallbackDDPSpawnModel(val_name, val) dm = BoringDataModule() - with pytest.deprecated_call(match="add_to_queue` method was deprecated in v1.5"): - trainer.fit(model, datamodule=dm) + trainer.fit(model, datamodule=dm) assert trainer.callback_metrics[val_name] == torch.tensor(val) assert ddp_spawn_plugin.new_test_val == "new_test_val" diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index baece7371d..bdcd2bbb10 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -50,7 +50,7 @@ from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _AcceleratorType, _StrategyType from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.exceptions import DeadlockDetectedException, MisconfigurationException -from pytorch_lightning.utilities.imports import _OMEGACONF_AVAILABLE +from pytorch_lightning.utilities.imports import _IS_WINDOWS, _OMEGACONF_AVAILABLE, _TORCH_GREATER_EQUAL_1_8 from pytorch_lightning.utilities.seed import seed_everything from tests.helpers import BoringModel, RandomDataset from tests.helpers.boring_model import RandomIterableDataset, RandomIterableDatasetWithLen @@ -61,6 +61,11 @@ from tests.helpers.simple_models import ClassificationModel if _OMEGACONF_AVAILABLE: from omegaconf import OmegaConf +if _TORCH_GREATER_EQUAL_1_8: + from torch.multiprocessing import ProcessRaisedException +else: + ProcessRaisedException = Exception + @pytest.mark.parametrize("url_ckpt", [True, False]) def test_no_val_module(monkeypatch, tmpdir, tmpdir_server, url_ckpt): @@ -1419,7 +1424,7 @@ def predict( callbacks=[cb, cb_1] if use_callbacks else [], ) if strategy == "ddp_spawn": - with pytest.raises(MisconfigurationException): + with pytest.raises(ProcessRaisedException, match="`return_predictions` should be set to `False`"): trainer.predict(model, datamodule=dm, return_predictions=True) if datamodule: @@ -1517,6 +1522,7 @@ def test_index_batch_sampler_wrapper_with_iterable_dataset(dataset_cls, tmpdir): assert len(predictions) == 8 +@pytest.mark.skipif(_IS_WINDOWS and not _TORCH_GREATER_EQUAL_1_8, reason="torch.distributed support required") @patch("torch.cuda.device_count", return_value=2) @patch("torch.cuda.is_available", return_value=True) @pytest.mark.parametrize("accelerator", ("cpu", "gpu")) @@ -1525,7 +1531,7 @@ def test_spawn_predict_return_predictions(_, __, accelerator): model = BoringModel() trainer = Trainer(accelerator=accelerator, strategy="ddp_spawn", devices=2, fast_dev_run=True) assert isinstance(trainer.training_type_plugin, DDPSpawnPlugin) - with pytest.raises(MisconfigurationException, match="`return_predictions` should be set to `False`"): + with pytest.raises(ProcessRaisedException, match="`return_predictions` should be set to `False`"): trainer.predict(model, dataloaders=model.train_dataloader(), return_predictions=True)