From e982800b81dbc636cc8505eecd47b6fdea02c4dc Mon Sep 17 00:00:00 2001 From: chaton Date: Tue, 16 Feb 2021 22:11:56 +0000 Subject: [PATCH] Add PredictLoop (#5752) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * integrate distrib_type * sync changes * sync * fixes * add forgotten generators * add missing logic * update * import * missed imports * import fixes * isort * mv f * changelog * format * move helper to parallel plugin * d * add world size * clean up * duplicate * activate ddp_sharded and tpu * set nvidia flags * remove unused colab var * use_tpu <-> on_tpu attrs * make some ddp_cpu and clusterplugin tests pass * Ref/accelerator connector (#5742) * final cleanup Co-authored-by: Adrian Wälchli * connector cleanup Co-authored-by: Adrian Wälchli * trainer cleanup Co-authored-by: Adrian Wälchli * accelerator cleanup + missing logic in accelerator connector Co-authored-by: Adrian Wälchli * add missing changes to callbacks Co-authored-by: Adrian Wälchli * reflect accelerator changes to lightning module Co-authored-by: Adrian Wälchli * clean cluster envs Co-authored-by: Adrian Wälchli * cleanup plugins Co-authored-by: Adrian Wälchli * add broadcasting Co-authored-by: Adrian Wälchli * yapf * remove plugin connector Co-authored-by: Adrian Wälchli * plugins * add predict_loop * manual optimization * clean predictloop * update optimizer routing * add predict loop on new accelerator * resolve a bug * add rank to torchelastic * add predict_loop * add predict loop on new accelerator * resolve a bug * fix memory mixed precision * update * setstate on trainer for pickling in ddp spawn * add predict_loop * clean predictloop * add predict loop on new accelerator * resolve a bug * add predict_loop * add predict loop on new accelerator * resolve a bug * add predict_loop * add predict loop on new accelerator * resolve a bug * add predict_loop * add predict loop on new accelerator * resolve a bug * add predict_loop * clean predictloop * add predict loop on new accelerator * resolve a bug * add predict_loop * add predict loop on new accelerator * resolve a bug * resolve tests * add predict method * add back commented accelerator code * adapt test for sync_batch_norm to new plugin * fix deprecated tests * fix ddp cpu choice when no num_processes are given * yapf format * skip a memory test that cannot pass anymore * remove sanetize * rename train to run_train * remove useless hooks * add misconfigurationException * remove wrong naming * resolve some legacy * udpate docstring * fix pickle error in spawn plugin * x * avoid * x * fix cyclic import in docs build * add support for sharded * update typing * add sharded and sharded_spawn to distributed types * make unwrap model default * refactor LightningShardedDataParallel similar to LightningDistributedDataParallel * update sharded spawn to reflect changes * update sharded to reflect changes * Merge 1.1.5 changes * fix merge * fix merge * yapf isort * fix merge * yapf isort * fix indentation in test * copy over reinit scheduler implementation from dev1.2 * fix apex tracking calls with dev_debugger * reduce diff to dev1.2, clean up * fix trainer config test when gpus>0 and num_processes >0 and ddp_cpu * sort plugin tests legacy/new * fix error handling for amp on cpu * fix merge fix merge fix merge * [Feat] Resolve manual_backward (#5837) * resolve manual_backward * resolve flake8 * update * resolve for ddp_spawn * resolve flake8 * resolve flake8 * resolve flake8 Co-authored-by: Ubuntu * fix tests/accelerator tests on cpu * [BugFix] Resolve manual optimization (#5852) * resolve manual_optimization * update * update Co-authored-by: Ubuntu * Remove copy trainer parameters to happen earlier within the loop and add safe guard to get ref model (#5856) * resovle a bug * Accelerator refactor sharded rpc (#5854) * rpc branch * merge * update handling of rpc * make devices etc. Optional in RPC * set devices etc. later if necessary * remove devices from sequential * make devices optional in rpc * fix import * uncomment everything * fix cluster selection Co-authored-by: Ubuntu * resolve bug * fix assert in rpc test * resolve a test * fix docs compilation * accelerator refactor - fix for sharded parity test (#5866) * fix memory issue with ddp_spawn * x x x x x x x x x * x * Remove DDP2 as this does not apply * Add missing pre optimizer hook to ensure lambda closure is called * fix apex docstring * [accelerator][BugFix] Resolve some test for 1 gpu (#5863) * update * revert init * resolve a bug * update * resolve flake8 * update * update * update * revert init * resolve a bug * update * resolve flake8 * update * update * update * update * update * revert init * resolve a bug * update * resolve flake8 * update * update * update * revert init * update * resolve flake8 * update * update * update * update * update * all_gather * update * make plugins work, add misconfig for RPC * update * update * remove breaking test * resolve some tests * resolve flake8 * revert to ddp_spawn Co-authored-by: root Co-authored-by: Ubuntu Co-authored-by: Justus Schock * yapf isort * resolve flake8 * fix apex doctests * fix apex doctests 2 * resolve docs * update drone * clean env * update * update * update * update * merge * Fix RPC related tests, clean out old API, update for new accelerator API [skip ci] (#5881) * Fix RPC related tests, clean out old API, update for new accelerator API * Move tests out of legacy folder, update paths and names * Update test_remove_1-4.py * Expose properties for tpu cores/gpus/num_gpus * Add root GPU property * Move properties to properties.py * move tests that were previously in drone * Fix root GPU property (#5908) * Move root GPU to property, remove horovod set as this is handled in horovod plugin, ensure we mock correctly to set GPU accelerator * Add missing tests back * fix best model path transfer when no checkpoint callback available * Fix setup hook order [wip] (#5858) * Call trainer setup hook before accelerator setup * Add test case * add new test * typo * fix callback order in test Co-authored-by: tchaton Co-authored-by: Adrian Wälchli * rename ddp sequential -> rpc sequential for special test * revert * fix stupid merge problem * Use property in connector for sampler (#5913) * merge the import conflicts * fix spawning of processes in slurm * [wip] Fix some bugs for TPU [skip ci] (#5878) * fixed for single tpu * fixed spawn * fixed spawn * update * update * wip * resolve bugs * resolve bug * update on comment * removed decorator * resolve comments * set to 4 * update * update * need cleaning * update * update * update * resolve flake8 * resolve bugs * exclude broadcast * resolve bugs * change test * update * update * skip if meet fails * properly raise trace * update * add catch * wrap test * resolve typo * update * typo Co-authored-by: Lezwon Castelino Co-authored-by: Your Name * resolve some tests * update * fix imports * update * resolve flake8 * update azure pipeline * skip a sharded test on cpu that requires a gpu * resolve tpus * resolve bug * resolve flake8 * update * updat utils * revert permission change on files * suggestions from carlos Co-authored-by: Carlos Mocholí * remove unrelated formatting changes * remove incomplete comment * Update pytorch_lightning/accelerators/__init__.py Co-authored-by: Carlos Mocholí * remove unrelated formatting change * add types * warn 1.7 ddp manual backward only if ddp kwarg unset * yapf + isort * pep8 unused imports * fix cyclic import in docs * Apply suggestions from code review * typer in accelerator.py * typo * resolve flake8 * update code * update * Update pytorch_lightning/trainer/predict_loop.py Co-authored-by: Carlos Mocholí * Update pytorch_lightning/trainer/predict_loop.py Co-authored-by: Carlos Mocholí * fix merge * fix merge * reset legacy accelerator * add missing rename dispatch * rename post traning * update code * resolved comments * typo * typo * add flow description * resolve comments * update on comments * update flow * add backticks * resolve tpu Co-authored-by: Adrian Wälchli Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> Co-authored-by: justusschock Co-authored-by: Justus Schock Co-authored-by: Ubuntu Co-authored-by: Sean Naren Co-authored-by: SeanNaren Co-authored-by: root Co-authored-by: Lezwon Castelino Co-authored-by: Your Name Co-authored-by: Carlos Mocholí Co-authored-by: Jirka Borovec Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com> --- CHANGELOG.md | 3 + pytorch_lightning/accelerators/accelerator.py | 33 ++- pytorch_lightning/callbacks/progress.py | 63 +++++- pytorch_lightning/core/datamodule.py | 4 + pytorch_lightning/core/hooks.py | 37 +++- pytorch_lightning/overrides/base.py | 10 +- pytorch_lightning/plugins/base_plugin.py | 13 +- .../plugins/training_type/ddp.py | 6 +- .../plugins/training_type/ddp_spawn.py | 12 +- .../plugins/training_type/horovod.py | 14 +- .../plugins/training_type/rpc_sequential.py | 4 +- .../plugins/training_type/single_tpu.py | 4 +- .../plugins/training_type/tpu_spawn.py | 10 +- .../training_type/training_type_plugin.py | 6 +- .../trainer/configuration_validator.py | 4 +- .../trainer/connectors/data_connector.py | 14 +- .../trainer/connectors/debugging_connector.py | 3 + pytorch_lightning/trainer/data_loading.py | 16 +- pytorch_lightning/trainer/evaluation_loop.py | 11 +- pytorch_lightning/trainer/predict_loop.py | 97 ++++++++ pytorch_lightning/trainer/trainer.py | 209 ++++++++++++------ pytorch_lightning/trainer/training_loop.py | 4 +- tests/callbacks/test_progress_bar.py | 4 +- tests/overrides/test_data_parallel.py | 2 +- tests/trainer/test_trainer.py | 7 +- 25 files changed, 454 insertions(+), 136 deletions(-) create mode 100644 pytorch_lightning/trainer/predict_loop.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 13c11163bb..95ca3329f8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -78,6 +78,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added AUC/AUROC class interface ([#5479](https://github.com/PyTorchLightning/pytorch-lightning/pull/5479)) +- Added `PredictLoop` object ([#5752](https://github.com/PyTorchLightning/pytorch-lightning/pull/5752)) + + - Added `QuantizationAwareTraining` callback ([#5706](https://github.com/PyTorchLightning/pytorch-lightning/pull/5706)) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 893456f403..2e8e31139d 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -139,9 +139,8 @@ class Accelerator(object): args[0] = batch - with self.precision_plugin.train_step_context(): - with self.training_type_plugin.train_step_context(): - return self.training_type_plugin.training_step(*args) + with self.precision_plugin.train_step_context(), self.training_type_plugin.train_step_context(): + return self.training_type_plugin.training_step(*args) def post_training_step(self): self.training_type_plugin.post_training_step() @@ -161,9 +160,8 @@ class Accelerator(object): args[0] = batch - with self.precision_plugin.val_step_context(): - with self.training_type_plugin.val_step_context(): - return self.training_type_plugin.validation_step(*args) + with self.precision_plugin.val_step_context(), self.training_type_plugin.val_step_context(): + return self.training_type_plugin.validation_step(*args) def test_step(self, args): """The actual test step. @@ -180,9 +178,26 @@ class Accelerator(object): args[0] = batch - with self.precision_plugin.test_step_context(): - with self.training_type_plugin.test_step_context(): - return self.training_type_plugin.test_step(*args) + with self.precision_plugin.test_step_context(), self.training_type_plugin.test_step_context(): + return self.training_type_plugin.test_step(*args) + + def predict(self, args): + """The actual predict step. + + Args: + args: the arguments for the models predict step. Can consist of the following: + batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]): + The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list. + batch_idx (int): The index of this batch. + dataloader_idx (int): The index of the dataloader that produced this batch + (only if multiple predict dataloaders used). + """ + batch = self.to_device(args[0]) + + args[0] = batch + + with self.precision_plugin.predict_context(), self.training_type_plugin.predict_context(): + return self.training_type_plugin.predict(*args) def training_step_end(self, output): """A hook to do something at the end of the training step diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py index a37a979c9d..7de7982b4a 100644 --- a/pytorch_lightning/callbacks/progress.py +++ b/pytorch_lightning/callbacks/progress.py @@ -67,6 +67,7 @@ class ProgressBarBase(Callback): self._train_batch_idx = 0 self._val_batch_idx = 0 self._test_batch_idx = 0 + self._predict_batch_idx = 0 @property def trainer(self): @@ -96,6 +97,14 @@ class ProgressBarBase(Callback): """ return self._test_batch_idx + @property + def predict_batch_idx(self) -> int: + """ + The current batch index being processed during predicting. + Use this to update your progress bar. + """ + return self._predict_batch_idx + @property def total_train_batches(self) -> int: """ @@ -108,7 +117,7 @@ class ProgressBarBase(Callback): @property def total_val_batches(self) -> int: """ - The total number of training batches during validation, which may change from epoch to epoch. + The total number of validation batches during validation, which may change from epoch to epoch. Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the validation dataloader is of infinite size. """ @@ -121,12 +130,21 @@ class ProgressBarBase(Callback): @property def total_test_batches(self) -> int: """ - The total number of training batches during testing, which may change from epoch to epoch. + The total number of testing batches during testing, which may change from epoch to epoch. Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the test dataloader is of infinite size. """ return sum(self.trainer.num_test_batches) + @property + def total_predict_batches(self) -> int: + """ + The total number of predicting batches during testing, which may change from epoch to epoch. + Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the + predict dataloader is of infinite size. + """ + return sum(self.trainer.num_predict_batches) + def disable(self): """ You should provide a way to disable the progress bar. @@ -168,6 +186,12 @@ class ProgressBarBase(Callback): def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): self._test_batch_idx += 1 + def on_predict_start(self, trainer, pl_module): + self._predict_batch_idx = 0 + + def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + self._predict_batch_idx += 1 + class ProgressBar(ProgressBarBase): r""" @@ -282,6 +306,20 @@ class ProgressBar(ProgressBarBase): ) return bar + def init_predict_tqdm(self) -> tqdm: + """ Override this to customize the tqdm bar for predicting. """ + bar = tqdm( + desc='Predicting', + initial=self.train_batch_idx, + position=(2 * self.process_position), + disable=self.is_disabled, + leave=True, + dynamic_ncols=True, + file=sys.stdout, + smoothing=0, + ) + return bar + def init_validation_tqdm(self) -> tqdm: """ Override this to customize the tqdm bar for validation. """ bar = tqdm( @@ -294,12 +332,10 @@ class ProgressBar(ProgressBarBase): ) return bar - def init_test_tqdm(self, trainer=None) -> tqdm: + def init_test_tqdm(self) -> tqdm: """ Override this to customize the tqdm bar for testing. """ - desc = "Testing" - desc = "Predicting" if trainer is not None and getattr(trainer, "is_predicting", False) else "Testing" bar = tqdm( - desc=desc, + desc="Testing", position=(2 * self.process_position), disable=self.is_disabled, leave=True, @@ -365,7 +401,7 @@ class ProgressBar(ProgressBarBase): def on_test_start(self, trainer, pl_module): super().on_test_start(trainer, pl_module) - self.test_progress_bar = self.init_test_tqdm(trainer=trainer) + self.test_progress_bar = self.init_test_tqdm() self.test_progress_bar.total = convert_inf(self.total_test_batches) def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): @@ -377,6 +413,19 @@ class ProgressBar(ProgressBarBase): super().on_test_end(trainer, pl_module) self.test_progress_bar.close() + def on_predict_start(self, trainer, pl_module): + super().on_predict_start(trainer, pl_module) + self.predict_progress_bar = self.init_predict_tqdm() + self.predict_progress_bar.total = convert_inf(self.total_predict_batches) + + def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + super().on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) + if self._should_update(self.predict_batch_idx, self.total_predict_batches): + self._update_bar(self.predict_progress_bar) + + def on_predict_end(self, trainer, pl_module): + self.predict_progress_bar.close() + def _should_update(self, current, total): return self.is_enabled and (current % self.refresh_rate == 0 or current == total) diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index d0e1725b2c..3b9d8e7de4 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -260,6 +260,10 @@ class LightningDataModule(DataHooks, CheckpointHooks, metaclass=_DataModuleWrapp def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]: pass + @abstractmethod + def predict_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]: + pass + @abstractmethod def transfer_batch_to_device(self, batch: Any, device: torch.device) -> Any: pass diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 23fd5d9b58..ac7bb2a1d2 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -204,17 +204,23 @@ class ModelHooks: """ # do something when the batch ends + def on_test_model_train(self) -> None: + """ + Sets the model to train during the test loop + """ + self.train() + def on_test_model_eval(self) -> None: """ Sets the model to eval during the test loop """ self.eval() - def on_test_model_train(self) -> None: + def on_predict_model_eval(self) -> None: """ - Sets the model to train during the test loop + Sets the model to eval during the predict loop """ - self.train() + self.eval() def on_epoch_start(self) -> None: """ @@ -518,6 +524,31 @@ class DataHooks: will have an argument ``dataloader_idx`` which matches the order here. """ + def predict_dataloader(self) -> Union[DataLoader, List[DataLoader]]: + r""" + Implement one or multiple PyTorch DataLoaders for prediction. + + It's recommended that all data downloads and preparation happen in :meth:`prepare_data`. + + - :meth:`~pytorch_lightning.trainer.Trainer.fit` + - ... + - :meth:`prepare_data` + - :meth:`train_dataloader` + - :meth:`val_dataloader` + - :meth:`test_dataloader` + + Note: + Lightning adds the correct sampler for distributed and arbitrary hardware + There is no need to set it yourself. + + Return: + Single or multiple PyTorch DataLoaders. + + Note: + In the case where you return multiple prediction dataloaders, the :meth:`predict` + will have an argument ``dataloader_idx`` which matches the order here. + """ + def transfer_batch_to_device(self, batch: Any, device: Optional[torch.device] = None) -> Any: """ Override this hook if your :class:`~torch.utils.data.DataLoader` returns tensors diff --git a/pytorch_lightning/overrides/base.py b/pytorch_lightning/overrides/base.py index 1a33556991..2fcb4b11a0 100644 --- a/pytorch_lightning/overrides/base.py +++ b/pytorch_lightning/overrides/base.py @@ -54,14 +54,22 @@ class _LightningModuleWrapperBase(torch.nn.Module): if not self.module.automatic_optimization: self.module.trainer.model.require_backward_grad_sync = False warn_if_output_is_none(output, "training_step") + elif running_stage == RunningStage.TESTING: output = self.module.test_step(*inputs, **kwargs) warn_if_output_is_none(output, "test_step") + elif running_stage == RunningStage.EVALUATING: output = self.module.validation_step(*inputs, **kwargs) warn_if_output_is_none(output, "validation_step") - else: + + elif running_stage == RunningStage.PREDICTING: output = self.module.predict(*inputs, **kwargs) + warn_if_output_is_none(output, "predict") + + else: + output = self.module(*inputs, **kwargs) + return output diff --git a/pytorch_lightning/plugins/base_plugin.py b/pytorch_lightning/plugins/base_plugin.py index b8bdf38a57..e495d9ffad 100644 --- a/pytorch_lightning/plugins/base_plugin.py +++ b/pytorch_lightning/plugins/base_plugin.py @@ -33,11 +33,11 @@ class Plugin(ABC): Will be called by the accelerator. """ - def pre_training(self) -> None: - """Hook to do something before the training starts.""" + def pre_dispatch(self) -> None: + """Hook to do something before the training/evaluation/prediction starts.""" - def post_training(self) -> None: - """Hook to do something after the training finishes.""" + def post_dispatch(self) -> None: + """Hook to do something after the training/evaluation/prediction finishes.""" @contextlib.contextmanager def train_step_context(self) -> Generator: @@ -53,3 +53,8 @@ class Plugin(ABC): def test_step_context(self) -> Generator: """A contextmanager for the teststep""" yield + + @contextlib.contextmanager + def predict_context(self) -> Generator: + """A contextmanager for the predict step""" + yield diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 52a24655f0..6e6c292eec 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -215,7 +215,7 @@ class DDPPlugin(ParallelPlugin): log.info(f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}") torch_distrib.init_process_group(torch_backend, rank=global_rank, world_size=world_size) - def pre_training(self): + def pre_dispatch(self): # TODO: check if needed seed = os.environ.get("PL_GLOBAL_SEED") if seed is not None: @@ -232,7 +232,7 @@ class DDPPlugin(ParallelPlugin): # where to store ip_table self.init_ddp_connection(self.global_rank, self.world_size) - # TODO: we moved it to the trainer.fit after calling pre_training + # TODO: we moved it to the trainer.fit after calling pre_dispatch # ... need to double check that it is the correct place # self.trainer.call_setup_hook(self.model) @@ -257,7 +257,7 @@ class DDPPlugin(ParallelPlugin): self.barrier() - def post_training(self): + def post_dispatch(self): if "WORLD_SIZE" in os.environ: del os.environ["WORLD_SIZE"] diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 6b6d85ee0d..449373e2c3 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -110,6 +110,9 @@ class DDPSpawnPlugin(ParallelPlugin): def start_testing(self, trainer): mp.spawn(self.new_process, **self.mp_spawn_kwargs) + def start_predicting(self, trainer): + mp.spawn(self.new_process, **self.mp_spawn_kwargs) + def new_process(self, process_idx, trainer, mp_queue): self.mp_queue = mp_queue @@ -128,7 +131,7 @@ class DDPSpawnPlugin(ParallelPlugin): # where to store ip_table self.init_ddp_connection(self.global_rank, self.world_size) - # TODO: we moved it to the trainer.fit after calling pre_training + # TODO: we moved it to the trainer.fit after calling pre_dispatch # ... need to double check that it is the correct place # self.trainer.call_setup_hook(self.model) @@ -153,15 +156,12 @@ class DDPSpawnPlugin(ParallelPlugin): self.barrier() - if trainer.testing: - results = trainer.run_test() - else: - results = trainer.train() + results = trainer.train_or_test_or_predict() # persist info in ddp_spawn self.transfer_distrib_spawn_state_on_fit_end(results) - def post_training(self): + def post_dispatch(self): # restore main state with best weights best_path = self.mp_queue.get() last_path = self.mp_queue.get() diff --git a/pytorch_lightning/plugins/training_type/horovod.py b/pytorch_lightning/plugins/training_type/horovod.py index 995c830799..c1de2d7833 100644 --- a/pytorch_lightning/plugins/training_type/horovod.py +++ b/pytorch_lightning/plugins/training_type/horovod.py @@ -50,7 +50,7 @@ class HorovodPlugin(ParallelPlugin): self.model_to_device() - def pre_training(self): + def pre_dispatch(self): def _unpack_lightning_optimizer(opt): return opt._optimizer if isinstance(opt, LightningOptimizer) else opt @@ -95,20 +95,26 @@ class HorovodPlugin(ParallelPlugin): stack.enter_context(optimizer.skip_synchronize()) # set up training routine - self._results = trainer.train() + self._results = trainer.run_train() # Make sure all workers have finished training before returning to the user hvd.join() def start_testing(self, trainer): with ExitStack() as stack: - # set up training routine - # self.trainer.train_loop.setup_training(self.trainer.model) self._results = trainer.run_test() # Make sure all workers have finished training before returning to the user hvd.join() + def start_predicting(self, trainer): + with ExitStack() as stack: + # set up training routine + self._results = trainer.run_predict() + + # Make sure all workers have finished training before returning to the user + hvd.join() + def barrier(self, *args, **kwargs): hvd.join() diff --git a/pytorch_lightning/plugins/training_type/rpc_sequential.py b/pytorch_lightning/plugins/training_type/rpc_sequential.py index 345f208b97..fc707afb3e 100644 --- a/pytorch_lightning/plugins/training_type/rpc_sequential.py +++ b/pytorch_lightning/plugins/training_type/rpc_sequential.py @@ -325,9 +325,9 @@ class RPCSequentialPlugin(RPCPlugin): # Initialize optimizer step on main process self.worker_optimizer_step(model=self.lightning_module, opt_idx=optimizer_idx, **kwargs) - def post_training(self): + def post_training_step(self): if self.main_rpc_process: - super().post_training() + super().post_training_step() def start_training(self, trainer: 'Trainer') -> None: if self.main_rpc_process: diff --git a/pytorch_lightning/plugins/training_type/single_tpu.py b/pytorch_lightning/plugins/training_type/single_tpu.py index 46df404bdc..40fc9fba3a 100644 --- a/pytorch_lightning/plugins/training_type/single_tpu.py +++ b/pytorch_lightning/plugins/training_type/single_tpu.py @@ -36,14 +36,14 @@ class SingleTPUPlugin(SingleDevicePlugin): def model_to_device(self) -> None: self._model.to(self.root_device) - def pre_training(self) -> None: + def pre_dispatch(self) -> None: if isinstance(self.device, int): self.device = xm.xla_device(self.device) self.tpu_local_core_rank = xm.get_local_ordinal() self.tpu_global_core_rank = xm.get_ordinal() - def post_training(self) -> None: + def post_dispatch(self) -> None: model = self.lightning_module if on_colab_kaggle(): diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 4c5844da94..d4374d0ef9 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -95,10 +95,7 @@ class TPUSpawnPlugin(DDPSpawnPlugin): trainer.save_checkpoint = self.save_checkpoint self.barrier() - if trainer.testing: - results = trainer.run_test() - else: - results = trainer.train() + results = trainer.train_or_test_or_predict() self.__save_end_of_training_weights(self.lightning_module) self.transfer_distrib_spawn_state_on_fit_end(results) @@ -182,7 +179,7 @@ class TPUSpawnPlugin(DDPSpawnPlugin): should_stop = int(stop.item()) == self.world_size return should_stop - def post_training(self) -> None: + def post_dispatch(self) -> None: # TODO: Check if trainer references can be resolved otherwise model = self.lightning_module @@ -233,6 +230,9 @@ class TPUSpawnPlugin(DDPSpawnPlugin): def start_testing(self, trainer) -> None: xmp.spawn(self.new_process, **self.xmp_spawn_kwargs) + def start_predicting(self, trainer) -> None: + xmp.spawn(self.new_process, **self.xmp_spawn_kwargs) + def training_step(self, *args, **kwargs): return self.lightning_module.training_step(*args, **kwargs) diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 74f5837afc..cede3e5f98 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -112,12 +112,16 @@ class TrainingTypePlugin(Plugin, ABC): def start_training(self, trainer: 'Trainer') -> None: # double dispatch to initiate the training loop - self._results = trainer.train() + self._results = trainer.run_train() def start_testing(self, trainer: 'Trainer') -> None: # double dispatch to initiate the test loop self._results = trainer.run_test() + def start_predicting(self, trainer: 'Trainer') -> None: + # double dispatch to initiate the predicting loop + self._results = trainer.run_predict() + def training_step(self, *args, **kwargs): return self.lightning_module.training_step(*args, **kwargs) diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py index a7e13de8ed..9cb22f39b7 100644 --- a/pytorch_lightning/trainer/configuration_validator.py +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -52,7 +52,7 @@ class ConfigValidator(object): # verify model has a train dataloader # ----------------------------------- has_train_dataloader = is_overridden('train_dataloader', model) - if not has_train_dataloader and not self.trainer._predicting: + if not has_train_dataloader: raise MisconfigurationException( 'No `train_dataloader()` method defined. Lightning `Trainer` expects as minimum a' ' `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.' @@ -62,7 +62,7 @@ class ConfigValidator(object): # verify model has optimizer # ----------------------------------- has_optimizers = is_overridden('configure_optimizers', model) - if not has_optimizers and not self.trainer._predicting: + if not has_optimizers: raise MisconfigurationException( 'No `configure_optimizers()` method defined. Lightning `Trainer` expects as minimum a' ' `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.' diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index 9161f3e875..2852d9dfaf 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -90,7 +90,14 @@ class DataConnector(object): 'You cannot pass train_dataloader or val_dataloaders to trainer.fit if you supply a datamodule' ) - def attach_dataloaders(self, model, train_dataloader=None, val_dataloaders=None, test_dataloaders=None): + def attach_dataloaders( + self, + model, + train_dataloader=None, + val_dataloaders=None, + test_dataloaders=None, + predict_dataloaders=None + ): # when dataloader is passed via fit, patch the train_dataloader # functions to overwrite with these implementations if train_dataloader is not None: @@ -102,6 +109,9 @@ class DataConnector(object): if test_dataloaders is not None: model.test_dataloader = _PatchDataLoader(test_dataloaders) + if predict_dataloaders is not None: + model.predict_dataloader = _PatchDataLoader(predict_dataloaders) + def attach_datamodule(self, model, datamodule: Optional[LightningDataModule], stage: str) -> None: # Todo: required argument `stage` is not used @@ -118,6 +128,8 @@ class DataConnector(object): model.val_dataloader = datamodule.val_dataloader if is_overridden('test_dataloader', datamodule): model.test_dataloader = datamodule.test_dataloader + if is_overridden('predict_dataloader', datamodule): + model.predict_dataloader = datamodule.predict_dataloader # Override transfer_batch_to_device if dataset-specific to_device logic has been defined in datamodule if is_overridden('transfer_batch_to_device', datamodule): diff --git a/pytorch_lightning/trainer/connectors/debugging_connector.py b/pytorch_lightning/trainer/connectors/debugging_connector.py index cb2ecc20f5..28c99f8f4d 100644 --- a/pytorch_lightning/trainer/connectors/debugging_connector.py +++ b/pytorch_lightning/trainer/connectors/debugging_connector.py @@ -29,6 +29,7 @@ class DebuggingConnector: limit_train_batches, limit_val_batches, limit_test_batches, + limit_predict_batches, val_check_interval, overfit_batches, fast_dev_run, @@ -56,6 +57,7 @@ class DebuggingConnector: limit_train_batches = fast_dev_run limit_val_batches = fast_dev_run limit_test_batches = fast_dev_run + limit_predict_batches = fast_dev_run self.trainer.max_steps = fast_dev_run self.trainer.num_sanity_val_steps = 0 self.trainer.max_epochs = 1 @@ -71,6 +73,7 @@ class DebuggingConnector: self.trainer.limit_train_batches = _determine_batch_limits(limit_train_batches, 'limit_train_batches') self.trainer.limit_val_batches = _determine_batch_limits(limit_val_batches, 'limit_val_batches') self.trainer.limit_test_batches = _determine_batch_limits(limit_test_batches, 'limit_test_batches') + self.trainer.limit_predict_batches = _determine_batch_limits(limit_predict_batches, 'limit_predict_batches') self.trainer.val_check_interval = _determine_batch_limits(val_check_interval, 'val_check_interval') self.trainer.overfit_batches = _determine_batch_limits(overfit_batches, 'overfit_batches') self.determine_data_use_amount(self.trainer.overfit_batches) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index fd93c559ff..946a900644 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -293,7 +293,8 @@ class TrainerDataLoadingMixin(ABC): loader = dataloaders[loader_i] # shuffling in val and test set is bad practice - if mode in ('val', 'test') and hasattr(loader, 'sampler') and isinstance(loader.sampler, RandomSampler): + modes = ('val', 'test', 'predict') + if mode in modes and hasattr(loader, 'sampler') and isinstance(loader.sampler, RandomSampler): # when overfitting, the dataloader should not have sampler if self.overfit_batches > 0: @@ -363,7 +364,7 @@ class TrainerDataLoadingMixin(ABC): self.num_val_batches, self.val_dataloaders = self._reset_eval_dataloader(model, 'val') def reset_test_dataloader(self, model) -> None: - """Resets the validation dataloader and determines the number of batches. + """Resets the test dataloader and determines the number of batches. Args: model: The current `LightningModule` @@ -374,6 +375,17 @@ class TrainerDataLoadingMixin(ABC): self.num_test_batches, self.test_dataloaders =\ self._reset_eval_dataloader(model, 'test') + def reset_predict_dataloader(self, model) -> None: + """Resets the predict dataloader and determines the number of batches. + + Args: + model: The current `LightningModule` + """ + has_loader = is_overridden('predict_dataloader', model) + if has_loader: + self.num_predict_batches, self.predict_dataloaders =\ + self._reset_eval_dataloader(model, 'predict') + def request_dataloader(self, dataloader_fx: Callable) -> DataLoader: """Handles downloading data in the GPU or TPU case. diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 1fbcc80ca4..fe3fc62ff1 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -154,16 +154,7 @@ class EvaluationLoop(object): model_ref = self.trainer.get_model() model_ref._results = Result() - if self.trainer._predicting: - model_ref._current_fx_name = "predict" - predictions = self.trainer.accelerator_backend.predict(args) - self._predictions[dataloader_idx].append(predictions) - self.trainer._progress_bar_callback.on_test_batch_end( - self.trainer, model_ref, predictions, batch, batch_idx, dataloader_idx - ) - return - - elif self.testing: + if self.testing: model_ref._current_fx_name = "test_step" with self.trainer.profiler.profile("test_step"): output = self.trainer.accelerator_backend.test_step(args) diff --git a/pytorch_lightning/trainer/predict_loop.py b/pytorch_lightning/trainer/predict_loop.py new file mode 100644 index 0000000000..43016b8943 --- /dev/null +++ b/pytorch_lightning/trainer/predict_loop.py @@ -0,0 +1,97 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + +from pytorch_lightning.utilities.apply_func import apply_to_collection + + +class PredictLoop(object): + + def __init__(self, trainer): + self.trainer = trainer + self.max_batches = None + self.num_dataloaders = None + + def on_trainer_init(self): + self.trainer.num_predict_batches = [] + + def get_predict_dataloaders(self, max_batches): + # select dataloaders + model = self.trainer.get_model() + self.trainer.reset_predict_dataloader(model) + dataloaders = self.trainer.predict_dataloaders + if max_batches is None: + max_batches = self.trainer.num_predict_batches + + return dataloaders, max_batches + + def should_skip_predict(self, dataloaders, max_batches): + return dataloaders is None or not sum(max_batches) + + def on_predict_model_eval(self, *_, **__): + model_ref = self.trainer.get_model() + model_ref.on_predict_model_eval() + + def setup(self, model, max_batches, dataloaders): + # copy properties for forward overrides + self.trainer.model_connector.copy_trainer_model_properties(model) + + # convert max_batches to list + if isinstance(max_batches, int): + max_batches = [max_batches] * len(dataloaders) + + self.max_batches = max_batches + self.num_dataloaders = self._get_num_dataloaders(dataloaders) + self._predictions = [[] for _ in range(self.num_dataloaders)] + + self.trainer._progress_bar_callback.on_predict_start(self.trainer, self.trainer.get_model()) + + def _get_num_dataloaders(self, dataloaders): + # case where user does: + # return dl1, dl2 + length = len(dataloaders) + if len(dataloaders) > 0 and isinstance(dataloaders[0], (list, tuple)): + length = len(dataloaders[0]) + return length + + def predict(self, batch, batch_idx, dataloader_idx): + # configure args + args = [batch, batch_idx] + if self.num_dataloaders: + args.append(dataloader_idx) + + model_ref = self.trainer.get_model() + + model_ref._current_fx_name = "predict" + predictions = self.trainer.accelerator_backend.predict(args) + self._predictions[dataloader_idx].append(predictions) + self.trainer._progress_bar_callback.on_predict_batch_end( + self.trainer, model_ref, predictions, batch, batch_idx, dataloader_idx + ) + return + + def on_predict_epoch_end(self): + self.trainer._progress_bar_callback.on_predict_end(self.trainer, self.trainer.get_model()) + + results = self._predictions + + def _convert_to_numpy(v): + return v.cpu().numpy() + + results = apply_to_collection(results, torch.Tensor, _convert_to_numpy) + + if len(results) == 1: + return results[0] + + return results diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index db04734d2f..2e45f9502e 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -49,6 +49,7 @@ from pytorch_lightning.trainer.evaluation_loop import EvaluationLoop from pytorch_lightning.trainer.logging import TrainerLoggingMixin from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin +from pytorch_lightning.trainer.predict_loop import PredictLoop from pytorch_lightning.trainer.properties import TrainerProperties from pytorch_lightning.trainer.states import RunningStage, TrainerState from pytorch_lightning.trainer.training_loop import TrainLoop @@ -57,6 +58,7 @@ from pytorch_lightning.tuner.tuning import Tuner from pytorch_lightning.utilities import DeviceType, rank_zero_warn from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.debugging import InternalDebugger +from pytorch_lightning.utilities.enums import LightningEnum from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.memory import recursive_detach from pytorch_lightning.utilities.model_helpers import is_overridden @@ -107,6 +109,7 @@ class Trainer( limit_train_batches: Union[int, float] = 1.0, limit_val_batches: Union[int, float] = 1.0, limit_test_batches: Union[int, float] = 1.0, + limit_predict_batches: Union[int, float] = 1.0, val_check_interval: Union[int, float] = 1.0, flush_logs_every_n_steps: int = 100, log_every_n_steps: int = 50, @@ -296,7 +299,6 @@ class Trainer( """ super().__init__() self._running_stage = None - self._predicting = False distributed_backend = distributed_backend or accelerator @@ -319,8 +321,9 @@ class Trainer( self.checkpoint_connector = CheckpointConnector(self) self.slurm_connector = SLURMConnector(self) self.tuner = Tuner(self) - self.evaluation_loop = EvaluationLoop(self) self.train_loop = TrainLoop(self, multiple_trainloader_mode) + self.evaluation_loop = EvaluationLoop(self) + self.predict_loop = PredictLoop(self) # training state self.weights_summary = weights_summary @@ -393,6 +396,7 @@ class Trainer( limit_train_batches, limit_val_batches, limit_test_batches, + limit_predict_batches, val_check_interval, overfit_batches, fast_dev_run, @@ -440,7 +444,11 @@ class Trainer( """ # bookkeeping self._state = TrainerState.RUNNING - self._set_wide_running_stage(RunningStage.TRAINING) + + # bookkeeping + # we reuse fit in .test() and .predict(). When already set, it shouldn't be modified. + if self._running_stage is None: + self._set_running_stage(RunningStage.TRAINING, model) # set local properties on the model self.model_connector.copy_trainer_model_properties(model) @@ -463,27 +471,47 @@ class Trainer( self.accelerator_backend.setup(self, model) self.setup_trainer(model) + # ---------------------------- + # INSPECT THE CORE LOOPS + # ---------------------------- + # Lightning internal flow looks like this. + # + # trainer.fit(...) or trainer.test(...) or trainer.predict(...) || + # | || + # create accelerator || + # | || + # trainer.dispatch || LIGHTNING + # | || + # start_training or start_testing or start_predicting call || FLOW + # from `accelerator.training_type_plugin` || + # | || DIRECTION + # run_train or run_test or run_predict call || + # from `trainer` || + # | || + # results \/ + # This is used to guide readers to the core loops: train, test, predict. + # `run_predict` is the simplest to understand, use `Go to Definition` to read it :) + # Search for `start_training` or `start_testing` or `start_predicting` in + # `pytorch_lightning/plugins/training_type` folder to find accelerator dispatch functions. + self.accelerator.train_loop = self.run_train + self.accelerator.validation_loop = self.run_evaluation + self.accelerator.test_loop = self.run_evaluation + self.accelerator.predict_loop = self.run_predict + # ---------------------------- # TRAIN # ---------------------------- # hook self.call_hook("on_fit_start") - # plugin will setup training (e.g. ddp will launch child processes) - # TODO: the old setup is now called "pre_training", where should this hook be called now? - self.training_type_plugin.pre_training() - self.precision_plugin.pre_training() + # plugin will setup fitting (e.g. ddp will launch child processes) + self.pre_dispatch() - # double dispatch: let the plugin initiate the training/test loop. - if self.testing: - self.training_type_plugin.start_testing(self) - else: - self.training_type_plugin.start_training(self) + # dispath `start_training` or `start_testing` or `start_predicting` + self.dispatch() - self.precision_plugin.post_training() - self.training_type_plugin.post_training() - self.accelerator_backend.teardown() - results = self.training_type_plugin.results + # plugin will finalized fitting (e.g. ddp_spawn will load trained model) + self.post_dispatch() # ---------------------------- # POST-Training CLEAN UP @@ -501,31 +529,47 @@ class Trainer( if self._state != TrainerState.INTERRUPTED: self._state = TrainerState.FINISHED - self._set_wide_running_stage(None) + self._set_running_stage(None, model) - return results or 1 + return self.training_type_plugin.results or 1 - def _set_wide_running_stage(self, stage): - model_ref = self.get_model() + def pre_dispatch(self): + self.training_type_plugin.pre_dispatch() + self.precision_plugin.pre_dispatch() - if stage is None: - self._running_stage = stage - model_ref.running_stage = stage - return + def post_dispatch(self): + self.training_type_plugin.post_dispatch() + self.precision_plugin.post_dispatch() + self.accelerator_backend.teardown() - # todo: clean up this routing mess. - if self._running_stage == RunningStage.TESTING: - stage = RunningStage.TESTING + def dispatch(self): + if self.testing: + self.training_type_plugin.start_testing(self) - # WARNING: With predicting, - # trainer _running_state should be RunningStage.TESTING - # however, the model running_stage should be RunningStage.PREDICTING or None - if model_ref is not None: - if self._predicting: - model_ref.running_stage = RunningStage.PREDICTING - else: - model_ref.running_stage = stage + elif self.predicting: + self.training_type_plugin.start_predicting(self) + else: + self.training_type_plugin.start_training(self) + + def train_or_test_or_predict(self): + if self.testing: + results = self.run_test() + + elif self.predicting: + results = self.run_predict() + + else: + results = self.run_train() + + return results + + def _set_running_stage(self, stage: LightningEnum, model_ref: LightningModule): + """ + This function is used to set the running_state on both + the trainer and the model + """ + model_ref.running_stage = stage self._running_stage = stage def _pre_training_routine(self): @@ -560,7 +604,7 @@ class Trainer( if self.is_function_implemented("on_pretrain_routine_end"): ref_model.on_pretrain_routine_end() - def train(self): + def run_train(self): self._pre_training_routine() @@ -570,7 +614,7 @@ class Trainer( self.run_sanity_check(self.get_model()) # set stage for logging - self._set_wide_running_stage(RunningStage.TRAINING) + self._set_running_stage(RunningStage.TRAINING, self.get_model()) self.checkpoint_connector.has_trained = False @@ -634,7 +678,7 @@ class Trainer( def run_evaluation(self, max_batches=None, on_epoch=False): # used to know if we are logging for val, test + reset cached results - self._set_wide_running_stage(RunningStage.TESTING if self.testing else RunningStage.EVALUATING) + self._set_running_stage(RunningStage.TESTING if self.testing else RunningStage.EVALUATING, self.get_model()) self.logger_connector.reset() # bookkeeping @@ -647,11 +691,10 @@ class Trainer( if self.evaluation_loop.should_skip_evaluation(max_batches): return [], [] - # ref model - model = self.get_model() - # enable eval mode + no grads self.evaluation_loop.on_evaluation_model_eval() + # ref model + model = self.get_model() model.zero_grad() torch.set_grad_enabled(False) @@ -685,8 +728,6 @@ class Trainer( # lightning module methods with self.profiler.profile("evaluation_step_and_end"): output = self.evaluation_loop.evaluation_step(batch, batch_idx, dataloader_idx) - if self._predicting: - continue output = self.evaluation_loop.evaluation_step_end(output) # hook + store predictions @@ -701,9 +742,6 @@ class Trainer( # store batch level output per dataloader self.evaluation_loop.outputs.append(dl_outputs) - if self._predicting: - return self.evaluation_loop.on_predict_epoch_end() - # lightning module method deprecated_eval_results = self.evaluation_loop.evaluation_epoch_end() @@ -764,6 +802,45 @@ class Trainer( return eval_loop_results + def run_predict(self): + # prepare dataloaders + dataloaders, max_batches = self.predict_loop.get_predict_dataloaders(None) + + # check if we want to skip this evaluation + if self.predict_loop.should_skip_predict(dataloaders, max_batches): + return [] + + # ref model + model = self.get_model() + + # enable eval mode + no grads + self.predict_loop.on_predict_model_eval() + model.zero_grad() + torch.set_grad_enabled(False) + + # set up the eval loop + self.predict_loop.setup(model, max_batches, dataloaders) + + # run validation/testing + for dataloader_idx, dataloader in enumerate(dataloaders): + dataloader = self.accelerator_backend.process_dataloader(dataloader) + dl_max_batches = self.predict_loop.max_batches[dataloader_idx] + + for batch_idx, batch in enumerate(dataloader): + if batch is None: + continue + + # stop short when running on limited batches + if batch_idx >= dl_max_batches: + break + + # lightning module methods + with self.profiler.profile("predict"): + self.predict_loop.predict(batch, batch_idx, dataloader_idx) + + results = self.predict_loop.on_predict_epoch_end() + return results + def run_sanity_check(self, ref_model): using_val_step = ref_model.val_dataloader is not None and is_overridden('validation_step', ref_model) should_sanity_check = using_val_step and self.num_sanity_val_steps > 0 and self.limit_val_batches > 0 @@ -828,7 +905,7 @@ class Trainer( # -------------------- self.verbose_test = verbose - self._set_wide_running_stage(RunningStage.TESTING) + self._set_running_stage(RunningStage.TESTING, model or self.get_model()) # If you supply a datamodule you can't supply train_dataloader or val_dataloaders if test_dataloaders and datamodule: @@ -845,9 +922,7 @@ class Trainer( results = self.__test_using_best_weights(ckpt_path, test_dataloaders) self.teardown('test') - - self._set_wide_running_stage(None) - + self._set_running_stage(None, model or self.get_model()) return results def __test_using_best_weights(self, ckpt_path, test_dataloaders): @@ -935,35 +1010,28 @@ class Trainer( # -------------------- # SETUP HOOK # -------------------- - self._set_wide_running_stage(RunningStage.TESTING) - # If you supply a datamodule you can't supply dataloaders + + model = model or self.get_model() + + self._set_running_stage(RunningStage.PREDICTING, model) + if dataloaders and datamodule: raise MisconfigurationException( 'You cannot pass dataloaders to trainer.predict if you supply a datamodule.' ) - if model is None: - raise MisconfigurationException('You need to pass a model to `trainer.predict`.') - if datamodule is not None: # Attach datamodule to get setup/prepare_data added to model before the call to it below - self.data_connector.attach_datamodule(model, datamodule, 'test') + self.data_connector.attach_datamodule(model, datamodule, 'predict') # attach data if dataloaders is not None: - self.data_connector.attach_dataloaders(model, test_dataloaders=dataloaders) + self.data_connector.attach_dataloaders(model, predict_dataloaders=dataloaders) - # set path variable - self._predicting = True self.model = model - results = self.fit(model) - - # unset path variable - self.teardown('test') - self._predicting = False - self._set_wide_running_stage(None) + self._set_running_stage(None, model) return results @@ -1069,6 +1137,17 @@ class Trainer( elif self.testing: self._running_stage = None + @property + def predicting(self) -> bool: + return self._running_stage == RunningStage.PREDICTING + + @predicting.setter + def predicting(self, val: bool) -> None: + if val: + self._running_stage = RunningStage.PREDICTING + elif self.predicting: + self._running_stage = None + @property def tuning(self) -> bool: return self._running_stage == RunningStage.TUNING diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 1640afe97f..0908e96bd1 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -547,7 +547,7 @@ class TrainLoop: self.trainer.run_evaluation() # reset stage to train - self.trainer._set_wide_running_stage(RunningStage.TRAINING) + self.trainer._set_running_stage(RunningStage.TRAINING, self.trainer.lightning_module) # ----------------------------------------- # SAVE LOGGERS (ie: Tensorboard, etc...) @@ -594,7 +594,7 @@ class TrainLoop: self.trainer.run_evaluation(on_epoch=True) # reset stage to train - self.trainer._set_wide_running_stage(RunningStage.TRAINING) + self.trainer._set_running_stage(RunningStage.TRAINING, self.trainer.lightning_module) should_skip_eval = self.trainer.evaluation_loop.should_skip_evaluation(self.trainer.num_val_batches) should_train_only = self.trainer.disable_validation or should_skip_eval diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py index 4ea4d511e1..8398aec88f 100644 --- a/tests/callbacks/test_progress_bar.py +++ b/tests/callbacks/test_progress_bar.py @@ -288,8 +288,8 @@ class MockedUpdateProgressBars(ProgressBar): bar = super().init_validation_tqdm() return self._mock_bar_update(bar) - def init_test_tqdm(self, trainer=None): - bar = super().init_test_tqdm(trainer=trainer) + def init_test_tqdm(self): + bar = super().init_test_tqdm() return self._mock_bar_update(bar) diff --git a/tests/overrides/test_data_parallel.py b/tests/overrides/test_data_parallel.py index 15faf787b5..64481bd703 100644 --- a/tests/overrides/test_data_parallel.py +++ b/tests/overrides/test_data_parallel.py @@ -39,7 +39,7 @@ def test_lightning_wrapper_module_methods(wrapper_class): wrapped_module(batch, batch_idx) pl_module.validation_step.assert_called_with(batch, batch_idx) - pl_module.running_stage = None + pl_module.running_stage = RunningStage.PREDICTING wrapped_module(batch) pl_module.predict.assert_called_with(batch) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 4e85a5695b..71caaaad4d 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1498,6 +1498,9 @@ class TestLightningDataModule(LightningDataModule): def test_dataloader(self): return self._dataloaders + def predict_dataloader(self): + return self._dataloaders + def predict(tmpdir, accelerator, gpus, num_processes, plugins=None, datamodule=True): @@ -1515,7 +1518,6 @@ def predict(tmpdir, accelerator, gpus, num_processes, plugins=None, datamodule=T gpus=gpus, num_processes=num_processes, plugins=plugins, - num_sanity_val_steps=0 ) if datamodule: results = trainer.predict(model, datamodule=datamodule) @@ -1529,9 +1531,6 @@ def predict(tmpdir, accelerator, gpus, num_processes, plugins=None, datamodule=T assert results[0][0].shape == torch.Size([1, 2]) -@pytest.mark.skipif( - not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest" -) @pytest.mark.parametrize('datamodule', [False, True]) def test_trainer_predict_cpu(tmpdir, datamodule): predict(tmpdir, None, None, 1, datamodule=datamodule)