Add PredictLoop (#5752)

* integrate distrib_type

* sync changes

* sync

* fixes

* add forgotten generators

* add missing logic

* update

* import

* missed imports

* import fixes

* isort

* mv f

* changelog

* format

* move helper to parallel plugin

* d

* add world size

* clean up

* duplicate

* activate ddp_sharded and tpu

* set nvidia flags

* remove unused colab var

* use_tpu <-> on_tpu attrs

* make some ddp_cpu and clusterplugin tests pass

* Ref/accelerator connector (#5742)

* final cleanup

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

* connector cleanup

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

* trainer cleanup

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

* accelerator cleanup + missing logic in accelerator connector

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

* add missing changes to callbacks

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

* reflect accelerator changes to lightning module

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

* clean cluster envs

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

* cleanup plugins

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

* add broadcasting

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

* yapf

* remove plugin connector

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

* plugins

* add predict_loop

* manual optimization

* clean predictloop

* update optimizer routing

* add predict loop on new accelerator

* resolve a bug

* add rank to torchelastic

* add predict_loop

* add predict loop on new accelerator

* resolve a bug

* fix memory mixed precision

* update

* setstate on trainer for pickling in ddp spawn

* add predict_loop

* clean predictloop

* add predict loop on new accelerator

* resolve a bug

* add predict_loop

* add predict loop on new accelerator

* resolve a bug

* add predict_loop

* add predict loop on new accelerator

* resolve a bug

* add predict_loop

* add predict loop on new accelerator

* resolve a bug

* add predict_loop

* clean predictloop

* add predict loop on new accelerator

* resolve a bug

* add predict_loop

* add predict loop on new accelerator

* resolve a bug

* resolve tests

* add predict method

* add back commented accelerator code

* adapt test for sync_batch_norm to new plugin

* fix deprecated tests

* fix ddp cpu choice when no num_processes are given

* yapf format

* skip a memory test that cannot pass anymore

* remove sanetize

* rename train to run_train

* remove useless hooks

* add misconfigurationException

* remove wrong naming

* resolve some legacy

* udpate docstring

* fix pickle error in spawn plugin

* x

* avoid

* x

* fix cyclic import in docs build

* add support for sharded

* update typing

* add sharded and sharded_spawn to distributed types

* make unwrap model default

* refactor LightningShardedDataParallel similar to LightningDistributedDataParallel

* update sharded spawn to reflect changes

* update sharded to reflect changes

* Merge 1.1.5 changes

* fix merge

* fix merge

* yapf isort

* fix merge

* yapf isort

* fix indentation in test

* copy over reinit scheduler implementation from dev1.2

* fix apex tracking calls with dev_debugger

* reduce diff to dev1.2, clean up

* fix trainer config test  when gpus>0 and num_processes >0 and ddp_cpu

* sort plugin tests legacy/new

* fix error handling for amp on cpu

* fix merge


fix merge


fix merge

* [Feat] Resolve manual_backward (#5837)

* resolve manual_backward

* resolve flake8

* update

* resolve for ddp_spawn

* resolve flake8

* resolve flake8

* resolve flake8

Co-authored-by: Ubuntu <ubuntu@ip-172-31-88-60.ec2.internal>

* fix tests/accelerator tests on cpu

* [BugFix] Resolve manual optimization (#5852)

* resolve manual_optimization

* update

* update

Co-authored-by: Ubuntu <ubuntu@ip-172-31-88-60.ec2.internal>

* Remove copy trainer parameters to happen earlier within the loop and add safe guard to get ref model (#5856)

* resovle a bug

* Accelerator refactor sharded rpc (#5854)

* rpc branch

* merge

* update handling of rpc

* make devices etc. Optional in RPC

* set devices etc. later if necessary

* remove devices from sequential

* make devices optional in rpc

* fix import

* uncomment everything

* fix cluster selection

Co-authored-by: Ubuntu <ubuntu@ip-172-31-88-60.ec2.internal>

* resolve bug

* fix assert in rpc test

* resolve a test

* fix docs compilation

* accelerator refactor - fix for sharded parity test (#5866)

* fix memory issue with ddp_spawn

* x


x


x


x


x


x


x


x


x

* x

* Remove DDP2 as this does not apply

* Add missing pre optimizer hook to ensure lambda closure is called

* fix apex docstring

* [accelerator][BugFix] Resolve some test for 1 gpu (#5863)

* update

* revert init

* resolve a bug

* update

* resolve flake8

* update

* update

* update

* revert init

* resolve a bug

* update

* resolve flake8

* update

* update

* update

* update

* update

* revert init

* resolve a bug

* update

* resolve flake8

* update

* update

* update

* revert init

* update

* resolve flake8

* update

* update

* update

* update

* update

* all_gather

* update

* make plugins work, add misconfig for RPC

* update

* update

* remove breaking test

* resolve some tests

* resolve flake8

* revert to ddp_spawn

Co-authored-by: root <root@ip-172-31-88-60.ec2.internal>
Co-authored-by: Ubuntu <ubuntu@ip-172-31-88-60.ec2.internal>
Co-authored-by: Justus Schock <justus.schock@rwth-aachen.de>

* yapf isort

* resolve flake8

* fix apex doctests

* fix apex doctests 2

* resolve docs

* update drone

* clean env

* update

* update

* update

* update

* merge

* Fix RPC related tests, clean out old API, update for new accelerator API [skip ci] (#5881)

* Fix RPC related tests, clean out old API, update for new accelerator API

* Move tests out of legacy folder, update paths and names

* Update test_remove_1-4.py

* Expose properties for tpu cores/gpus/num_gpus

* Add root GPU property

* Move properties to properties.py

* move tests that were previously in drone

* Fix root GPU property (#5908)

* Move root GPU to property, remove horovod set as this is handled in horovod plugin, ensure we mock correctly to set GPU accelerator

* Add missing tests back

* fix best model path transfer when no checkpoint callback available

* Fix setup hook order [wip] (#5858)

* Call trainer setup hook before accelerator setup

* Add test case

* add new test

* typo

* fix callback order in test

Co-authored-by: tchaton <thomas@grid.ai>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

* rename ddp sequential -> rpc sequential for special test

* revert

* fix stupid merge problem

* Use property in connector for sampler (#5913)

* merge the import conflicts

* fix spawning of processes in slurm

* [wip] Fix some bugs for TPU [skip ci] (#5878)

* fixed for single tpu

* fixed spawn

* fixed spawn

* update

* update

* wip

* resolve bugs

* resolve bug

* update on comment

* removed decorator

* resolve comments

* set to 4

* update

* update

* need cleaning

* update

* update

* update

* resolve flake8

* resolve bugs

* exclude broadcast

* resolve bugs

* change test

* update

* update

* skip if meet fails

* properly raise trace

* update

* add catch

* wrap test

* resolve typo

* update

* typo

Co-authored-by: Lezwon Castelino <lezwon@gmail.com>
Co-authored-by: Your Name <you@example.com>

* resolve some tests

* update

* fix imports

* update

* resolve flake8

* update azure pipeline

* skip a sharded test on cpu that requires a gpu

* resolve tpus

* resolve bug

* resolve flake8

* update

* updat utils

* revert permission change on files

* suggestions from carlos

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

* remove unrelated formatting changes

* remove incomplete comment

* Update pytorch_lightning/accelerators/__init__.py

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

* remove unrelated formatting change

* add types

* warn 1.7 ddp manual backward only if ddp kwarg unset

* yapf + isort

* pep8 unused imports

* fix cyclic import in docs

* Apply suggestions from code review

* typer in accelerator.py

* typo

* resolve flake8

* update code

* update

* Update pytorch_lightning/trainer/predict_loop.py

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

* Update pytorch_lightning/trainer/predict_loop.py

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

* fix merge

* fix merge

* reset legacy accelerator

* add missing rename dispatch

* rename post traning

* update code

* resolved comments

* typo

* typo

* add flow description

* resolve comments

* update on comments

* update flow

* add backticks

* resolve tpu

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
Co-authored-by: justusschock <justus.schock@posteo.de>
Co-authored-by: Justus Schock <justus.schock@rwth-aachen.de>
Co-authored-by: Ubuntu <ubuntu@ip-172-31-88-60.ec2.internal>
Co-authored-by: Sean Naren <sean.narenthiran@gmail.com>
Co-authored-by: SeanNaren <sean@grid.ai>
Co-authored-by: root <root@ip-172-31-88-60.ec2.internal>
Co-authored-by: Lezwon Castelino <lezwon@gmail.com>
Co-authored-by: Your Name <you@example.com>
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
This commit is contained in:
chaton 2021-02-16 22:11:56 +00:00 committed by GitHub
parent a52be5bb07
commit e982800b81
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 454 additions and 136 deletions

View File

@ -78,6 +78,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added AUC/AUROC class interface ([#5479](https://github.com/PyTorchLightning/pytorch-lightning/pull/5479))
- Added `PredictLoop` object ([#5752](https://github.com/PyTorchLightning/pytorch-lightning/pull/5752))
- Added `QuantizationAwareTraining` callback ([#5706](https://github.com/PyTorchLightning/pytorch-lightning/pull/5706))

View File

@ -139,9 +139,8 @@ class Accelerator(object):
args[0] = batch
with self.precision_plugin.train_step_context():
with self.training_type_plugin.train_step_context():
return self.training_type_plugin.training_step(*args)
with self.precision_plugin.train_step_context(), self.training_type_plugin.train_step_context():
return self.training_type_plugin.training_step(*args)
def post_training_step(self):
self.training_type_plugin.post_training_step()
@ -161,9 +160,8 @@ class Accelerator(object):
args[0] = batch
with self.precision_plugin.val_step_context():
with self.training_type_plugin.val_step_context():
return self.training_type_plugin.validation_step(*args)
with self.precision_plugin.val_step_context(), self.training_type_plugin.val_step_context():
return self.training_type_plugin.validation_step(*args)
def test_step(self, args):
"""The actual test step.
@ -180,9 +178,26 @@ class Accelerator(object):
args[0] = batch
with self.precision_plugin.test_step_context():
with self.training_type_plugin.test_step_context():
return self.training_type_plugin.test_step(*args)
with self.precision_plugin.test_step_context(), self.training_type_plugin.test_step_context():
return self.training_type_plugin.test_step(*args)
def predict(self, args):
"""The actual predict step.
Args:
args: the arguments for the models predict step. Can consist of the following:
batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]):
The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
batch_idx (int): The index of this batch.
dataloader_idx (int): The index of the dataloader that produced this batch
(only if multiple predict dataloaders used).
"""
batch = self.to_device(args[0])
args[0] = batch
with self.precision_plugin.predict_context(), self.training_type_plugin.predict_context():
return self.training_type_plugin.predict(*args)
def training_step_end(self, output):
"""A hook to do something at the end of the training step

View File

@ -67,6 +67,7 @@ class ProgressBarBase(Callback):
self._train_batch_idx = 0
self._val_batch_idx = 0
self._test_batch_idx = 0
self._predict_batch_idx = 0
@property
def trainer(self):
@ -96,6 +97,14 @@ class ProgressBarBase(Callback):
"""
return self._test_batch_idx
@property
def predict_batch_idx(self) -> int:
"""
The current batch index being processed during predicting.
Use this to update your progress bar.
"""
return self._predict_batch_idx
@property
def total_train_batches(self) -> int:
"""
@ -108,7 +117,7 @@ class ProgressBarBase(Callback):
@property
def total_val_batches(self) -> int:
"""
The total number of training batches during validation, which may change from epoch to epoch.
The total number of validation batches during validation, which may change from epoch to epoch.
Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the
validation dataloader is of infinite size.
"""
@ -121,12 +130,21 @@ class ProgressBarBase(Callback):
@property
def total_test_batches(self) -> int:
"""
The total number of training batches during testing, which may change from epoch to epoch.
The total number of testing batches during testing, which may change from epoch to epoch.
Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the
test dataloader is of infinite size.
"""
return sum(self.trainer.num_test_batches)
@property
def total_predict_batches(self) -> int:
"""
The total number of predicting batches during testing, which may change from epoch to epoch.
Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the
predict dataloader is of infinite size.
"""
return sum(self.trainer.num_predict_batches)
def disable(self):
"""
You should provide a way to disable the progress bar.
@ -168,6 +186,12 @@ class ProgressBarBase(Callback):
def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
self._test_batch_idx += 1
def on_predict_start(self, trainer, pl_module):
self._predict_batch_idx = 0
def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
self._predict_batch_idx += 1
class ProgressBar(ProgressBarBase):
r"""
@ -282,6 +306,20 @@ class ProgressBar(ProgressBarBase):
)
return bar
def init_predict_tqdm(self) -> tqdm:
""" Override this to customize the tqdm bar for predicting. """
bar = tqdm(
desc='Predicting',
initial=self.train_batch_idx,
position=(2 * self.process_position),
disable=self.is_disabled,
leave=True,
dynamic_ncols=True,
file=sys.stdout,
smoothing=0,
)
return bar
def init_validation_tqdm(self) -> tqdm:
""" Override this to customize the tqdm bar for validation. """
bar = tqdm(
@ -294,12 +332,10 @@ class ProgressBar(ProgressBarBase):
)
return bar
def init_test_tqdm(self, trainer=None) -> tqdm:
def init_test_tqdm(self) -> tqdm:
""" Override this to customize the tqdm bar for testing. """
desc = "Testing"
desc = "Predicting" if trainer is not None and getattr(trainer, "is_predicting", False) else "Testing"
bar = tqdm(
desc=desc,
desc="Testing",
position=(2 * self.process_position),
disable=self.is_disabled,
leave=True,
@ -365,7 +401,7 @@ class ProgressBar(ProgressBarBase):
def on_test_start(self, trainer, pl_module):
super().on_test_start(trainer, pl_module)
self.test_progress_bar = self.init_test_tqdm(trainer=trainer)
self.test_progress_bar = self.init_test_tqdm()
self.test_progress_bar.total = convert_inf(self.total_test_batches)
def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
@ -377,6 +413,19 @@ class ProgressBar(ProgressBarBase):
super().on_test_end(trainer, pl_module)
self.test_progress_bar.close()
def on_predict_start(self, trainer, pl_module):
super().on_predict_start(trainer, pl_module)
self.predict_progress_bar = self.init_predict_tqdm()
self.predict_progress_bar.total = convert_inf(self.total_predict_batches)
def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
super().on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
if self._should_update(self.predict_batch_idx, self.total_predict_batches):
self._update_bar(self.predict_progress_bar)
def on_predict_end(self, trainer, pl_module):
self.predict_progress_bar.close()
def _should_update(self, current, total):
return self.is_enabled and (current % self.refresh_rate == 0 or current == total)

View File

@ -260,6 +260,10 @@ class LightningDataModule(DataHooks, CheckpointHooks, metaclass=_DataModuleWrapp
def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
pass
@abstractmethod
def predict_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
pass
@abstractmethod
def transfer_batch_to_device(self, batch: Any, device: torch.device) -> Any:
pass

View File

@ -204,17 +204,23 @@ class ModelHooks:
"""
# do something when the batch ends
def on_test_model_train(self) -> None:
"""
Sets the model to train during the test loop
"""
self.train()
def on_test_model_eval(self) -> None:
"""
Sets the model to eval during the test loop
"""
self.eval()
def on_test_model_train(self) -> None:
def on_predict_model_eval(self) -> None:
"""
Sets the model to train during the test loop
Sets the model to eval during the predict loop
"""
self.train()
self.eval()
def on_epoch_start(self) -> None:
"""
@ -518,6 +524,31 @@ class DataHooks:
will have an argument ``dataloader_idx`` which matches the order here.
"""
def predict_dataloader(self) -> Union[DataLoader, List[DataLoader]]:
r"""
Implement one or multiple PyTorch DataLoaders for prediction.
It's recommended that all data downloads and preparation happen in :meth:`prepare_data`.
- :meth:`~pytorch_lightning.trainer.Trainer.fit`
- ...
- :meth:`prepare_data`
- :meth:`train_dataloader`
- :meth:`val_dataloader`
- :meth:`test_dataloader`
Note:
Lightning adds the correct sampler for distributed and arbitrary hardware
There is no need to set it yourself.
Return:
Single or multiple PyTorch DataLoaders.
Note:
In the case where you return multiple prediction dataloaders, the :meth:`predict`
will have an argument ``dataloader_idx`` which matches the order here.
"""
def transfer_batch_to_device(self, batch: Any, device: Optional[torch.device] = None) -> Any:
"""
Override this hook if your :class:`~torch.utils.data.DataLoader` returns tensors

View File

@ -54,14 +54,22 @@ class _LightningModuleWrapperBase(torch.nn.Module):
if not self.module.automatic_optimization:
self.module.trainer.model.require_backward_grad_sync = False
warn_if_output_is_none(output, "training_step")
elif running_stage == RunningStage.TESTING:
output = self.module.test_step(*inputs, **kwargs)
warn_if_output_is_none(output, "test_step")
elif running_stage == RunningStage.EVALUATING:
output = self.module.validation_step(*inputs, **kwargs)
warn_if_output_is_none(output, "validation_step")
else:
elif running_stage == RunningStage.PREDICTING:
output = self.module.predict(*inputs, **kwargs)
warn_if_output_is_none(output, "predict")
else:
output = self.module(*inputs, **kwargs)
return output

View File

@ -33,11 +33,11 @@ class Plugin(ABC):
Will be called by the accelerator.
"""
def pre_training(self) -> None:
"""Hook to do something before the training starts."""
def pre_dispatch(self) -> None:
"""Hook to do something before the training/evaluation/prediction starts."""
def post_training(self) -> None:
"""Hook to do something after the training finishes."""
def post_dispatch(self) -> None:
"""Hook to do something after the training/evaluation/prediction finishes."""
@contextlib.contextmanager
def train_step_context(self) -> Generator:
@ -53,3 +53,8 @@ class Plugin(ABC):
def test_step_context(self) -> Generator:
"""A contextmanager for the teststep"""
yield
@contextlib.contextmanager
def predict_context(self) -> Generator:
"""A contextmanager for the predict step"""
yield

View File

@ -215,7 +215,7 @@ class DDPPlugin(ParallelPlugin):
log.info(f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}")
torch_distrib.init_process_group(torch_backend, rank=global_rank, world_size=world_size)
def pre_training(self):
def pre_dispatch(self):
# TODO: check if needed
seed = os.environ.get("PL_GLOBAL_SEED")
if seed is not None:
@ -232,7 +232,7 @@ class DDPPlugin(ParallelPlugin):
# where to store ip_table
self.init_ddp_connection(self.global_rank, self.world_size)
# TODO: we moved it to the trainer.fit after calling pre_training
# TODO: we moved it to the trainer.fit after calling pre_dispatch
# ... need to double check that it is the correct place
# self.trainer.call_setup_hook(self.model)
@ -257,7 +257,7 @@ class DDPPlugin(ParallelPlugin):
self.barrier()
def post_training(self):
def post_dispatch(self):
if "WORLD_SIZE" in os.environ:
del os.environ["WORLD_SIZE"]

View File

@ -110,6 +110,9 @@ class DDPSpawnPlugin(ParallelPlugin):
def start_testing(self, trainer):
mp.spawn(self.new_process, **self.mp_spawn_kwargs)
def start_predicting(self, trainer):
mp.spawn(self.new_process, **self.mp_spawn_kwargs)
def new_process(self, process_idx, trainer, mp_queue):
self.mp_queue = mp_queue
@ -128,7 +131,7 @@ class DDPSpawnPlugin(ParallelPlugin):
# where to store ip_table
self.init_ddp_connection(self.global_rank, self.world_size)
# TODO: we moved it to the trainer.fit after calling pre_training
# TODO: we moved it to the trainer.fit after calling pre_dispatch
# ... need to double check that it is the correct place
# self.trainer.call_setup_hook(self.model)
@ -153,15 +156,12 @@ class DDPSpawnPlugin(ParallelPlugin):
self.barrier()
if trainer.testing:
results = trainer.run_test()
else:
results = trainer.train()
results = trainer.train_or_test_or_predict()
# persist info in ddp_spawn
self.transfer_distrib_spawn_state_on_fit_end(results)
def post_training(self):
def post_dispatch(self):
# restore main state with best weights
best_path = self.mp_queue.get()
last_path = self.mp_queue.get()

View File

@ -50,7 +50,7 @@ class HorovodPlugin(ParallelPlugin):
self.model_to_device()
def pre_training(self):
def pre_dispatch(self):
def _unpack_lightning_optimizer(opt):
return opt._optimizer if isinstance(opt, LightningOptimizer) else opt
@ -95,20 +95,26 @@ class HorovodPlugin(ParallelPlugin):
stack.enter_context(optimizer.skip_synchronize())
# set up training routine
self._results = trainer.train()
self._results = trainer.run_train()
# Make sure all workers have finished training before returning to the user
hvd.join()
def start_testing(self, trainer):
with ExitStack() as stack:
# set up training routine
# self.trainer.train_loop.setup_training(self.trainer.model)
self._results = trainer.run_test()
# Make sure all workers have finished training before returning to the user
hvd.join()
def start_predicting(self, trainer):
with ExitStack() as stack:
# set up training routine
self._results = trainer.run_predict()
# Make sure all workers have finished training before returning to the user
hvd.join()
def barrier(self, *args, **kwargs):
hvd.join()

View File

@ -325,9 +325,9 @@ class RPCSequentialPlugin(RPCPlugin):
# Initialize optimizer step on main process
self.worker_optimizer_step(model=self.lightning_module, opt_idx=optimizer_idx, **kwargs)
def post_training(self):
def post_training_step(self):
if self.main_rpc_process:
super().post_training()
super().post_training_step()
def start_training(self, trainer: 'Trainer') -> None:
if self.main_rpc_process:

View File

@ -36,14 +36,14 @@ class SingleTPUPlugin(SingleDevicePlugin):
def model_to_device(self) -> None:
self._model.to(self.root_device)
def pre_training(self) -> None:
def pre_dispatch(self) -> None:
if isinstance(self.device, int):
self.device = xm.xla_device(self.device)
self.tpu_local_core_rank = xm.get_local_ordinal()
self.tpu_global_core_rank = xm.get_ordinal()
def post_training(self) -> None:
def post_dispatch(self) -> None:
model = self.lightning_module
if on_colab_kaggle():

View File

@ -95,10 +95,7 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
trainer.save_checkpoint = self.save_checkpoint
self.barrier()
if trainer.testing:
results = trainer.run_test()
else:
results = trainer.train()
results = trainer.train_or_test_or_predict()
self.__save_end_of_training_weights(self.lightning_module)
self.transfer_distrib_spawn_state_on_fit_end(results)
@ -182,7 +179,7 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
should_stop = int(stop.item()) == self.world_size
return should_stop
def post_training(self) -> None:
def post_dispatch(self) -> None:
# TODO: Check if trainer references can be resolved otherwise
model = self.lightning_module
@ -233,6 +230,9 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
def start_testing(self, trainer) -> None:
xmp.spawn(self.new_process, **self.xmp_spawn_kwargs)
def start_predicting(self, trainer) -> None:
xmp.spawn(self.new_process, **self.xmp_spawn_kwargs)
def training_step(self, *args, **kwargs):
return self.lightning_module.training_step(*args, **kwargs)

View File

@ -112,12 +112,16 @@ class TrainingTypePlugin(Plugin, ABC):
def start_training(self, trainer: 'Trainer') -> None:
# double dispatch to initiate the training loop
self._results = trainer.train()
self._results = trainer.run_train()
def start_testing(self, trainer: 'Trainer') -> None:
# double dispatch to initiate the test loop
self._results = trainer.run_test()
def start_predicting(self, trainer: 'Trainer') -> None:
# double dispatch to initiate the predicting loop
self._results = trainer.run_predict()
def training_step(self, *args, **kwargs):
return self.lightning_module.training_step(*args, **kwargs)

View File

@ -52,7 +52,7 @@ class ConfigValidator(object):
# verify model has a train dataloader
# -----------------------------------
has_train_dataloader = is_overridden('train_dataloader', model)
if not has_train_dataloader and not self.trainer._predicting:
if not has_train_dataloader:
raise MisconfigurationException(
'No `train_dataloader()` method defined. Lightning `Trainer` expects as minimum a'
' `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.'
@ -62,7 +62,7 @@ class ConfigValidator(object):
# verify model has optimizer
# -----------------------------------
has_optimizers = is_overridden('configure_optimizers', model)
if not has_optimizers and not self.trainer._predicting:
if not has_optimizers:
raise MisconfigurationException(
'No `configure_optimizers()` method defined. Lightning `Trainer` expects as minimum a'
' `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.'

View File

@ -90,7 +90,14 @@ class DataConnector(object):
'You cannot pass train_dataloader or val_dataloaders to trainer.fit if you supply a datamodule'
)
def attach_dataloaders(self, model, train_dataloader=None, val_dataloaders=None, test_dataloaders=None):
def attach_dataloaders(
self,
model,
train_dataloader=None,
val_dataloaders=None,
test_dataloaders=None,
predict_dataloaders=None
):
# when dataloader is passed via fit, patch the train_dataloader
# functions to overwrite with these implementations
if train_dataloader is not None:
@ -102,6 +109,9 @@ class DataConnector(object):
if test_dataloaders is not None:
model.test_dataloader = _PatchDataLoader(test_dataloaders)
if predict_dataloaders is not None:
model.predict_dataloader = _PatchDataLoader(predict_dataloaders)
def attach_datamodule(self, model, datamodule: Optional[LightningDataModule], stage: str) -> None:
# Todo: required argument `stage` is not used
@ -118,6 +128,8 @@ class DataConnector(object):
model.val_dataloader = datamodule.val_dataloader
if is_overridden('test_dataloader', datamodule):
model.test_dataloader = datamodule.test_dataloader
if is_overridden('predict_dataloader', datamodule):
model.predict_dataloader = datamodule.predict_dataloader
# Override transfer_batch_to_device if dataset-specific to_device logic has been defined in datamodule
if is_overridden('transfer_batch_to_device', datamodule):

View File

@ -29,6 +29,7 @@ class DebuggingConnector:
limit_train_batches,
limit_val_batches,
limit_test_batches,
limit_predict_batches,
val_check_interval,
overfit_batches,
fast_dev_run,
@ -56,6 +57,7 @@ class DebuggingConnector:
limit_train_batches = fast_dev_run
limit_val_batches = fast_dev_run
limit_test_batches = fast_dev_run
limit_predict_batches = fast_dev_run
self.trainer.max_steps = fast_dev_run
self.trainer.num_sanity_val_steps = 0
self.trainer.max_epochs = 1
@ -71,6 +73,7 @@ class DebuggingConnector:
self.trainer.limit_train_batches = _determine_batch_limits(limit_train_batches, 'limit_train_batches')
self.trainer.limit_val_batches = _determine_batch_limits(limit_val_batches, 'limit_val_batches')
self.trainer.limit_test_batches = _determine_batch_limits(limit_test_batches, 'limit_test_batches')
self.trainer.limit_predict_batches = _determine_batch_limits(limit_predict_batches, 'limit_predict_batches')
self.trainer.val_check_interval = _determine_batch_limits(val_check_interval, 'val_check_interval')
self.trainer.overfit_batches = _determine_batch_limits(overfit_batches, 'overfit_batches')
self.determine_data_use_amount(self.trainer.overfit_batches)

View File

@ -293,7 +293,8 @@ class TrainerDataLoadingMixin(ABC):
loader = dataloaders[loader_i]
# shuffling in val and test set is bad practice
if mode in ('val', 'test') and hasattr(loader, 'sampler') and isinstance(loader.sampler, RandomSampler):
modes = ('val', 'test', 'predict')
if mode in modes and hasattr(loader, 'sampler') and isinstance(loader.sampler, RandomSampler):
# when overfitting, the dataloader should not have sampler
if self.overfit_batches > 0:
@ -363,7 +364,7 @@ class TrainerDataLoadingMixin(ABC):
self.num_val_batches, self.val_dataloaders = self._reset_eval_dataloader(model, 'val')
def reset_test_dataloader(self, model) -> None:
"""Resets the validation dataloader and determines the number of batches.
"""Resets the test dataloader and determines the number of batches.
Args:
model: The current `LightningModule`
@ -374,6 +375,17 @@ class TrainerDataLoadingMixin(ABC):
self.num_test_batches, self.test_dataloaders =\
self._reset_eval_dataloader(model, 'test')
def reset_predict_dataloader(self, model) -> None:
"""Resets the predict dataloader and determines the number of batches.
Args:
model: The current `LightningModule`
"""
has_loader = is_overridden('predict_dataloader', model)
if has_loader:
self.num_predict_batches, self.predict_dataloaders =\
self._reset_eval_dataloader(model, 'predict')
def request_dataloader(self, dataloader_fx: Callable) -> DataLoader:
"""Handles downloading data in the GPU or TPU case.

View File

@ -154,16 +154,7 @@ class EvaluationLoop(object):
model_ref = self.trainer.get_model()
model_ref._results = Result()
if self.trainer._predicting:
model_ref._current_fx_name = "predict"
predictions = self.trainer.accelerator_backend.predict(args)
self._predictions[dataloader_idx].append(predictions)
self.trainer._progress_bar_callback.on_test_batch_end(
self.trainer, model_ref, predictions, batch, batch_idx, dataloader_idx
)
return
elif self.testing:
if self.testing:
model_ref._current_fx_name = "test_step"
with self.trainer.profiler.profile("test_step"):
output = self.trainer.accelerator_backend.test_step(args)

View File

@ -0,0 +1,97 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from pytorch_lightning.utilities.apply_func import apply_to_collection
class PredictLoop(object):
def __init__(self, trainer):
self.trainer = trainer
self.max_batches = None
self.num_dataloaders = None
def on_trainer_init(self):
self.trainer.num_predict_batches = []
def get_predict_dataloaders(self, max_batches):
# select dataloaders
model = self.trainer.get_model()
self.trainer.reset_predict_dataloader(model)
dataloaders = self.trainer.predict_dataloaders
if max_batches is None:
max_batches = self.trainer.num_predict_batches
return dataloaders, max_batches
def should_skip_predict(self, dataloaders, max_batches):
return dataloaders is None or not sum(max_batches)
def on_predict_model_eval(self, *_, **__):
model_ref = self.trainer.get_model()
model_ref.on_predict_model_eval()
def setup(self, model, max_batches, dataloaders):
# copy properties for forward overrides
self.trainer.model_connector.copy_trainer_model_properties(model)
# convert max_batches to list
if isinstance(max_batches, int):
max_batches = [max_batches] * len(dataloaders)
self.max_batches = max_batches
self.num_dataloaders = self._get_num_dataloaders(dataloaders)
self._predictions = [[] for _ in range(self.num_dataloaders)]
self.trainer._progress_bar_callback.on_predict_start(self.trainer, self.trainer.get_model())
def _get_num_dataloaders(self, dataloaders):
# case where user does:
# return dl1, dl2
length = len(dataloaders)
if len(dataloaders) > 0 and isinstance(dataloaders[0], (list, tuple)):
length = len(dataloaders[0])
return length
def predict(self, batch, batch_idx, dataloader_idx):
# configure args
args = [batch, batch_idx]
if self.num_dataloaders:
args.append(dataloader_idx)
model_ref = self.trainer.get_model()
model_ref._current_fx_name = "predict"
predictions = self.trainer.accelerator_backend.predict(args)
self._predictions[dataloader_idx].append(predictions)
self.trainer._progress_bar_callback.on_predict_batch_end(
self.trainer, model_ref, predictions, batch, batch_idx, dataloader_idx
)
return
def on_predict_epoch_end(self):
self.trainer._progress_bar_callback.on_predict_end(self.trainer, self.trainer.get_model())
results = self._predictions
def _convert_to_numpy(v):
return v.cpu().numpy()
results = apply_to_collection(results, torch.Tensor, _convert_to_numpy)
if len(results) == 1:
return results[0]
return results

View File

@ -49,6 +49,7 @@ from pytorch_lightning.trainer.evaluation_loop import EvaluationLoop
from pytorch_lightning.trainer.logging import TrainerLoggingMixin
from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin
from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin
from pytorch_lightning.trainer.predict_loop import PredictLoop
from pytorch_lightning.trainer.properties import TrainerProperties
from pytorch_lightning.trainer.states import RunningStage, TrainerState
from pytorch_lightning.trainer.training_loop import TrainLoop
@ -57,6 +58,7 @@ from pytorch_lightning.tuner.tuning import Tuner
from pytorch_lightning.utilities import DeviceType, rank_zero_warn
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.debugging import InternalDebugger
from pytorch_lightning.utilities.enums import LightningEnum
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.memory import recursive_detach
from pytorch_lightning.utilities.model_helpers import is_overridden
@ -107,6 +109,7 @@ class Trainer(
limit_train_batches: Union[int, float] = 1.0,
limit_val_batches: Union[int, float] = 1.0,
limit_test_batches: Union[int, float] = 1.0,
limit_predict_batches: Union[int, float] = 1.0,
val_check_interval: Union[int, float] = 1.0,
flush_logs_every_n_steps: int = 100,
log_every_n_steps: int = 50,
@ -296,7 +299,6 @@ class Trainer(
"""
super().__init__()
self._running_stage = None
self._predicting = False
distributed_backend = distributed_backend or accelerator
@ -319,8 +321,9 @@ class Trainer(
self.checkpoint_connector = CheckpointConnector(self)
self.slurm_connector = SLURMConnector(self)
self.tuner = Tuner(self)
self.evaluation_loop = EvaluationLoop(self)
self.train_loop = TrainLoop(self, multiple_trainloader_mode)
self.evaluation_loop = EvaluationLoop(self)
self.predict_loop = PredictLoop(self)
# training state
self.weights_summary = weights_summary
@ -393,6 +396,7 @@ class Trainer(
limit_train_batches,
limit_val_batches,
limit_test_batches,
limit_predict_batches,
val_check_interval,
overfit_batches,
fast_dev_run,
@ -440,7 +444,11 @@ class Trainer(
"""
# bookkeeping
self._state = TrainerState.RUNNING
self._set_wide_running_stage(RunningStage.TRAINING)
# bookkeeping
# we reuse fit in .test() and .predict(). When already set, it shouldn't be modified.
if self._running_stage is None:
self._set_running_stage(RunningStage.TRAINING, model)
# set local properties on the model
self.model_connector.copy_trainer_model_properties(model)
@ -463,27 +471,47 @@ class Trainer(
self.accelerator_backend.setup(self, model)
self.setup_trainer(model)
# ----------------------------
# INSPECT THE CORE LOOPS
# ----------------------------
# Lightning internal flow looks like this.
#
# trainer.fit(...) or trainer.test(...) or trainer.predict(...) ||
# | ||
# create accelerator ||
# | ||
# trainer.dispatch || LIGHTNING
# | ||
# start_training or start_testing or start_predicting call || FLOW
# from `accelerator.training_type_plugin` ||
# | || DIRECTION
# run_train or run_test or run_predict call ||
# from `trainer` ||
# | ||
# results \/
# This is used to guide readers to the core loops: train, test, predict.
# `run_predict` is the simplest to understand, use `Go to Definition` to read it :)
# Search for `start_training` or `start_testing` or `start_predicting` in
# `pytorch_lightning/plugins/training_type` folder to find accelerator dispatch functions.
self.accelerator.train_loop = self.run_train
self.accelerator.validation_loop = self.run_evaluation
self.accelerator.test_loop = self.run_evaluation
self.accelerator.predict_loop = self.run_predict
# ----------------------------
# TRAIN
# ----------------------------
# hook
self.call_hook("on_fit_start")
# plugin will setup training (e.g. ddp will launch child processes)
# TODO: the old setup is now called "pre_training", where should this hook be called now?
self.training_type_plugin.pre_training()
self.precision_plugin.pre_training()
# plugin will setup fitting (e.g. ddp will launch child processes)
self.pre_dispatch()
# double dispatch: let the plugin initiate the training/test loop.
if self.testing:
self.training_type_plugin.start_testing(self)
else:
self.training_type_plugin.start_training(self)
# dispath `start_training` or `start_testing` or `start_predicting`
self.dispatch()
self.precision_plugin.post_training()
self.training_type_plugin.post_training()
self.accelerator_backend.teardown()
results = self.training_type_plugin.results
# plugin will finalized fitting (e.g. ddp_spawn will load trained model)
self.post_dispatch()
# ----------------------------
# POST-Training CLEAN UP
@ -501,31 +529,47 @@ class Trainer(
if self._state != TrainerState.INTERRUPTED:
self._state = TrainerState.FINISHED
self._set_wide_running_stage(None)
self._set_running_stage(None, model)
return results or 1
return self.training_type_plugin.results or 1
def _set_wide_running_stage(self, stage):
model_ref = self.get_model()
def pre_dispatch(self):
self.training_type_plugin.pre_dispatch()
self.precision_plugin.pre_dispatch()
if stage is None:
self._running_stage = stage
model_ref.running_stage = stage
return
def post_dispatch(self):
self.training_type_plugin.post_dispatch()
self.precision_plugin.post_dispatch()
self.accelerator_backend.teardown()
# todo: clean up this routing mess.
if self._running_stage == RunningStage.TESTING:
stage = RunningStage.TESTING
def dispatch(self):
if self.testing:
self.training_type_plugin.start_testing(self)
# WARNING: With predicting,
# trainer _running_state should be RunningStage.TESTING
# however, the model running_stage should be RunningStage.PREDICTING or None
if model_ref is not None:
if self._predicting:
model_ref.running_stage = RunningStage.PREDICTING
else:
model_ref.running_stage = stage
elif self.predicting:
self.training_type_plugin.start_predicting(self)
else:
self.training_type_plugin.start_training(self)
def train_or_test_or_predict(self):
if self.testing:
results = self.run_test()
elif self.predicting:
results = self.run_predict()
else:
results = self.run_train()
return results
def _set_running_stage(self, stage: LightningEnum, model_ref: LightningModule):
"""
This function is used to set the running_state on both
the trainer and the model
"""
model_ref.running_stage = stage
self._running_stage = stage
def _pre_training_routine(self):
@ -560,7 +604,7 @@ class Trainer(
if self.is_function_implemented("on_pretrain_routine_end"):
ref_model.on_pretrain_routine_end()
def train(self):
def run_train(self):
self._pre_training_routine()
@ -570,7 +614,7 @@ class Trainer(
self.run_sanity_check(self.get_model())
# set stage for logging
self._set_wide_running_stage(RunningStage.TRAINING)
self._set_running_stage(RunningStage.TRAINING, self.get_model())
self.checkpoint_connector.has_trained = False
@ -634,7 +678,7 @@ class Trainer(
def run_evaluation(self, max_batches=None, on_epoch=False):
# used to know if we are logging for val, test + reset cached results
self._set_wide_running_stage(RunningStage.TESTING if self.testing else RunningStage.EVALUATING)
self._set_running_stage(RunningStage.TESTING if self.testing else RunningStage.EVALUATING, self.get_model())
self.logger_connector.reset()
# bookkeeping
@ -647,11 +691,10 @@ class Trainer(
if self.evaluation_loop.should_skip_evaluation(max_batches):
return [], []
# ref model
model = self.get_model()
# enable eval mode + no grads
self.evaluation_loop.on_evaluation_model_eval()
# ref model
model = self.get_model()
model.zero_grad()
torch.set_grad_enabled(False)
@ -685,8 +728,6 @@ class Trainer(
# lightning module methods
with self.profiler.profile("evaluation_step_and_end"):
output = self.evaluation_loop.evaluation_step(batch, batch_idx, dataloader_idx)
if self._predicting:
continue
output = self.evaluation_loop.evaluation_step_end(output)
# hook + store predictions
@ -701,9 +742,6 @@ class Trainer(
# store batch level output per dataloader
self.evaluation_loop.outputs.append(dl_outputs)
if self._predicting:
return self.evaluation_loop.on_predict_epoch_end()
# lightning module method
deprecated_eval_results = self.evaluation_loop.evaluation_epoch_end()
@ -764,6 +802,45 @@ class Trainer(
return eval_loop_results
def run_predict(self):
# prepare dataloaders
dataloaders, max_batches = self.predict_loop.get_predict_dataloaders(None)
# check if we want to skip this evaluation
if self.predict_loop.should_skip_predict(dataloaders, max_batches):
return []
# ref model
model = self.get_model()
# enable eval mode + no grads
self.predict_loop.on_predict_model_eval()
model.zero_grad()
torch.set_grad_enabled(False)
# set up the eval loop
self.predict_loop.setup(model, max_batches, dataloaders)
# run validation/testing
for dataloader_idx, dataloader in enumerate(dataloaders):
dataloader = self.accelerator_backend.process_dataloader(dataloader)
dl_max_batches = self.predict_loop.max_batches[dataloader_idx]
for batch_idx, batch in enumerate(dataloader):
if batch is None:
continue
# stop short when running on limited batches
if batch_idx >= dl_max_batches:
break
# lightning module methods
with self.profiler.profile("predict"):
self.predict_loop.predict(batch, batch_idx, dataloader_idx)
results = self.predict_loop.on_predict_epoch_end()
return results
def run_sanity_check(self, ref_model):
using_val_step = ref_model.val_dataloader is not None and is_overridden('validation_step', ref_model)
should_sanity_check = using_val_step and self.num_sanity_val_steps > 0 and self.limit_val_batches > 0
@ -828,7 +905,7 @@ class Trainer(
# --------------------
self.verbose_test = verbose
self._set_wide_running_stage(RunningStage.TESTING)
self._set_running_stage(RunningStage.TESTING, model or self.get_model())
# If you supply a datamodule you can't supply train_dataloader or val_dataloaders
if test_dataloaders and datamodule:
@ -845,9 +922,7 @@ class Trainer(
results = self.__test_using_best_weights(ckpt_path, test_dataloaders)
self.teardown('test')
self._set_wide_running_stage(None)
self._set_running_stage(None, model or self.get_model())
return results
def __test_using_best_weights(self, ckpt_path, test_dataloaders):
@ -935,35 +1010,28 @@ class Trainer(
# --------------------
# SETUP HOOK
# --------------------
self._set_wide_running_stage(RunningStage.TESTING)
# If you supply a datamodule you can't supply dataloaders
model = model or self.get_model()
self._set_running_stage(RunningStage.PREDICTING, model)
if dataloaders and datamodule:
raise MisconfigurationException(
'You cannot pass dataloaders to trainer.predict if you supply a datamodule.'
)
if model is None:
raise MisconfigurationException('You need to pass a model to `trainer.predict`.')
if datamodule is not None:
# Attach datamodule to get setup/prepare_data added to model before the call to it below
self.data_connector.attach_datamodule(model, datamodule, 'test')
self.data_connector.attach_datamodule(model, datamodule, 'predict')
# attach data
if dataloaders is not None:
self.data_connector.attach_dataloaders(model, test_dataloaders=dataloaders)
self.data_connector.attach_dataloaders(model, predict_dataloaders=dataloaders)
# set path variable
self._predicting = True
self.model = model
results = self.fit(model)
# unset path variable
self.teardown('test')
self._predicting = False
self._set_wide_running_stage(None)
self._set_running_stage(None, model)
return results
@ -1069,6 +1137,17 @@ class Trainer(
elif self.testing:
self._running_stage = None
@property
def predicting(self) -> bool:
return self._running_stage == RunningStage.PREDICTING
@predicting.setter
def predicting(self, val: bool) -> None:
if val:
self._running_stage = RunningStage.PREDICTING
elif self.predicting:
self._running_stage = None
@property
def tuning(self) -> bool:
return self._running_stage == RunningStage.TUNING

View File

@ -547,7 +547,7 @@ class TrainLoop:
self.trainer.run_evaluation()
# reset stage to train
self.trainer._set_wide_running_stage(RunningStage.TRAINING)
self.trainer._set_running_stage(RunningStage.TRAINING, self.trainer.lightning_module)
# -----------------------------------------
# SAVE LOGGERS (ie: Tensorboard, etc...)
@ -594,7 +594,7 @@ class TrainLoop:
self.trainer.run_evaluation(on_epoch=True)
# reset stage to train
self.trainer._set_wide_running_stage(RunningStage.TRAINING)
self.trainer._set_running_stage(RunningStage.TRAINING, self.trainer.lightning_module)
should_skip_eval = self.trainer.evaluation_loop.should_skip_evaluation(self.trainer.num_val_batches)
should_train_only = self.trainer.disable_validation or should_skip_eval

View File

@ -288,8 +288,8 @@ class MockedUpdateProgressBars(ProgressBar):
bar = super().init_validation_tqdm()
return self._mock_bar_update(bar)
def init_test_tqdm(self, trainer=None):
bar = super().init_test_tqdm(trainer=trainer)
def init_test_tqdm(self):
bar = super().init_test_tqdm()
return self._mock_bar_update(bar)

View File

@ -39,7 +39,7 @@ def test_lightning_wrapper_module_methods(wrapper_class):
wrapped_module(batch, batch_idx)
pl_module.validation_step.assert_called_with(batch, batch_idx)
pl_module.running_stage = None
pl_module.running_stage = RunningStage.PREDICTING
wrapped_module(batch)
pl_module.predict.assert_called_with(batch)

View File

@ -1498,6 +1498,9 @@ class TestLightningDataModule(LightningDataModule):
def test_dataloader(self):
return self._dataloaders
def predict_dataloader(self):
return self._dataloaders
def predict(tmpdir, accelerator, gpus, num_processes, plugins=None, datamodule=True):
@ -1515,7 +1518,6 @@ def predict(tmpdir, accelerator, gpus, num_processes, plugins=None, datamodule=T
gpus=gpus,
num_processes=num_processes,
plugins=plugins,
num_sanity_val_steps=0
)
if datamodule:
results = trainer.predict(model, datamodule=datamodule)
@ -1529,9 +1531,6 @@ def predict(tmpdir, accelerator, gpus, num_processes, plugins=None, datamodule=T
assert results[0][0].shape == torch.Size([1, 2])
@pytest.mark.skipif(
not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest"
)
@pytest.mark.parametrize('datamodule', [False, True])
def test_trainer_predict_cpu(tmpdir, datamodule):
predict(tmpdir, None, None, 1, datamodule=datamodule)