Add PredictLoop (#5752)
* integrate distrib_type * sync changes * sync * fixes * add forgotten generators * add missing logic * update * import * missed imports * import fixes * isort * mv f * changelog * format * move helper to parallel plugin * d * add world size * clean up * duplicate * activate ddp_sharded and tpu * set nvidia flags * remove unused colab var * use_tpu <-> on_tpu attrs * make some ddp_cpu and clusterplugin tests pass * Ref/accelerator connector (#5742) * final cleanup Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * connector cleanup Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * trainer cleanup Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * accelerator cleanup + missing logic in accelerator connector Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * add missing changes to callbacks Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * reflect accelerator changes to lightning module Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * clean cluster envs Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * cleanup plugins Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * add broadcasting Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * yapf * remove plugin connector Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * plugins * add predict_loop * manual optimization * clean predictloop * update optimizer routing * add predict loop on new accelerator * resolve a bug * add rank to torchelastic * add predict_loop * add predict loop on new accelerator * resolve a bug * fix memory mixed precision * update * setstate on trainer for pickling in ddp spawn * add predict_loop * clean predictloop * add predict loop on new accelerator * resolve a bug * add predict_loop * add predict loop on new accelerator * resolve a bug * add predict_loop * add predict loop on new accelerator * resolve a bug * add predict_loop * add predict loop on new accelerator * resolve a bug * add predict_loop * clean predictloop * add predict loop on new accelerator * resolve a bug * add predict_loop * add predict loop on new accelerator * resolve a bug * resolve tests * add predict method * add back commented accelerator code * adapt test for sync_batch_norm to new plugin * fix deprecated tests * fix ddp cpu choice when no num_processes are given * yapf format * skip a memory test that cannot pass anymore * remove sanetize * rename train to run_train * remove useless hooks * add misconfigurationException * remove wrong naming * resolve some legacy * udpate docstring * fix pickle error in spawn plugin * x * avoid * x * fix cyclic import in docs build * add support for sharded * update typing * add sharded and sharded_spawn to distributed types * make unwrap model default * refactor LightningShardedDataParallel similar to LightningDistributedDataParallel * update sharded spawn to reflect changes * update sharded to reflect changes * Merge 1.1.5 changes * fix merge * fix merge * yapf isort * fix merge * yapf isort * fix indentation in test * copy over reinit scheduler implementation from dev1.2 * fix apex tracking calls with dev_debugger * reduce diff to dev1.2, clean up * fix trainer config test when gpus>0 and num_processes >0 and ddp_cpu * sort plugin tests legacy/new * fix error handling for amp on cpu * fix merge fix merge fix merge * [Feat] Resolve manual_backward (#5837) * resolve manual_backward * resolve flake8 * update * resolve for ddp_spawn * resolve flake8 * resolve flake8 * resolve flake8 Co-authored-by: Ubuntu <ubuntu@ip-172-31-88-60.ec2.internal> * fix tests/accelerator tests on cpu * [BugFix] Resolve manual optimization (#5852) * resolve manual_optimization * update * update Co-authored-by: Ubuntu <ubuntu@ip-172-31-88-60.ec2.internal> * Remove copy trainer parameters to happen earlier within the loop and add safe guard to get ref model (#5856) * resovle a bug * Accelerator refactor sharded rpc (#5854) * rpc branch * merge * update handling of rpc * make devices etc. Optional in RPC * set devices etc. later if necessary * remove devices from sequential * make devices optional in rpc * fix import * uncomment everything * fix cluster selection Co-authored-by: Ubuntu <ubuntu@ip-172-31-88-60.ec2.internal> * resolve bug * fix assert in rpc test * resolve a test * fix docs compilation * accelerator refactor - fix for sharded parity test (#5866) * fix memory issue with ddp_spawn * x x x x x x x x x * x * Remove DDP2 as this does not apply * Add missing pre optimizer hook to ensure lambda closure is called * fix apex docstring * [accelerator][BugFix] Resolve some test for 1 gpu (#5863) * update * revert init * resolve a bug * update * resolve flake8 * update * update * update * revert init * resolve a bug * update * resolve flake8 * update * update * update * update * update * revert init * resolve a bug * update * resolve flake8 * update * update * update * revert init * update * resolve flake8 * update * update * update * update * update * all_gather * update * make plugins work, add misconfig for RPC * update * update * remove breaking test * resolve some tests * resolve flake8 * revert to ddp_spawn Co-authored-by: root <root@ip-172-31-88-60.ec2.internal> Co-authored-by: Ubuntu <ubuntu@ip-172-31-88-60.ec2.internal> Co-authored-by: Justus Schock <justus.schock@rwth-aachen.de> * yapf isort * resolve flake8 * fix apex doctests * fix apex doctests 2 * resolve docs * update drone * clean env * update * update * update * update * merge * Fix RPC related tests, clean out old API, update for new accelerator API [skip ci] (#5881) * Fix RPC related tests, clean out old API, update for new accelerator API * Move tests out of legacy folder, update paths and names * Update test_remove_1-4.py * Expose properties for tpu cores/gpus/num_gpus * Add root GPU property * Move properties to properties.py * move tests that were previously in drone * Fix root GPU property (#5908) * Move root GPU to property, remove horovod set as this is handled in horovod plugin, ensure we mock correctly to set GPU accelerator * Add missing tests back * fix best model path transfer when no checkpoint callback available * Fix setup hook order [wip] (#5858) * Call trainer setup hook before accelerator setup * Add test case * add new test * typo * fix callback order in test Co-authored-by: tchaton <thomas@grid.ai> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * rename ddp sequential -> rpc sequential for special test * revert * fix stupid merge problem * Use property in connector for sampler (#5913) * merge the import conflicts * fix spawning of processes in slurm * [wip] Fix some bugs for TPU [skip ci] (#5878) * fixed for single tpu * fixed spawn * fixed spawn * update * update * wip * resolve bugs * resolve bug * update on comment * removed decorator * resolve comments * set to 4 * update * update * need cleaning * update * update * update * resolve flake8 * resolve bugs * exclude broadcast * resolve bugs * change test * update * update * skip if meet fails * properly raise trace * update * add catch * wrap test * resolve typo * update * typo Co-authored-by: Lezwon Castelino <lezwon@gmail.com> Co-authored-by: Your Name <you@example.com> * resolve some tests * update * fix imports * update * resolve flake8 * update azure pipeline * skip a sharded test on cpu that requires a gpu * resolve tpus * resolve bug * resolve flake8 * update * updat utils * revert permission change on files * suggestions from carlos Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * remove unrelated formatting changes * remove incomplete comment * Update pytorch_lightning/accelerators/__init__.py Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * remove unrelated formatting change * add types * warn 1.7 ddp manual backward only if ddp kwarg unset * yapf + isort * pep8 unused imports * fix cyclic import in docs * Apply suggestions from code review * typer in accelerator.py * typo * resolve flake8 * update code * update * Update pytorch_lightning/trainer/predict_loop.py Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * Update pytorch_lightning/trainer/predict_loop.py Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * fix merge * fix merge * reset legacy accelerator * add missing rename dispatch * rename post traning * update code * resolved comments * typo * typo * add flow description * resolve comments * update on comments * update flow * add backticks * resolve tpu Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> Co-authored-by: justusschock <justus.schock@posteo.de> Co-authored-by: Justus Schock <justus.schock@rwth-aachen.de> Co-authored-by: Ubuntu <ubuntu@ip-172-31-88-60.ec2.internal> Co-authored-by: Sean Naren <sean.narenthiran@gmail.com> Co-authored-by: SeanNaren <sean@grid.ai> Co-authored-by: root <root@ip-172-31-88-60.ec2.internal> Co-authored-by: Lezwon Castelino <lezwon@gmail.com> Co-authored-by: Your Name <you@example.com> Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
This commit is contained in:
parent
a52be5bb07
commit
e982800b81
|
@ -78,6 +78,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Added AUC/AUROC class interface ([#5479](https://github.com/PyTorchLightning/pytorch-lightning/pull/5479))
|
||||
|
||||
|
||||
- Added `PredictLoop` object ([#5752](https://github.com/PyTorchLightning/pytorch-lightning/pull/5752))
|
||||
|
||||
|
||||
- Added `QuantizationAwareTraining` callback ([#5706](https://github.com/PyTorchLightning/pytorch-lightning/pull/5706))
|
||||
|
||||
|
||||
|
|
|
@ -139,9 +139,8 @@ class Accelerator(object):
|
|||
|
||||
args[0] = batch
|
||||
|
||||
with self.precision_plugin.train_step_context():
|
||||
with self.training_type_plugin.train_step_context():
|
||||
return self.training_type_plugin.training_step(*args)
|
||||
with self.precision_plugin.train_step_context(), self.training_type_plugin.train_step_context():
|
||||
return self.training_type_plugin.training_step(*args)
|
||||
|
||||
def post_training_step(self):
|
||||
self.training_type_plugin.post_training_step()
|
||||
|
@ -161,9 +160,8 @@ class Accelerator(object):
|
|||
|
||||
args[0] = batch
|
||||
|
||||
with self.precision_plugin.val_step_context():
|
||||
with self.training_type_plugin.val_step_context():
|
||||
return self.training_type_plugin.validation_step(*args)
|
||||
with self.precision_plugin.val_step_context(), self.training_type_plugin.val_step_context():
|
||||
return self.training_type_plugin.validation_step(*args)
|
||||
|
||||
def test_step(self, args):
|
||||
"""The actual test step.
|
||||
|
@ -180,9 +178,26 @@ class Accelerator(object):
|
|||
|
||||
args[0] = batch
|
||||
|
||||
with self.precision_plugin.test_step_context():
|
||||
with self.training_type_plugin.test_step_context():
|
||||
return self.training_type_plugin.test_step(*args)
|
||||
with self.precision_plugin.test_step_context(), self.training_type_plugin.test_step_context():
|
||||
return self.training_type_plugin.test_step(*args)
|
||||
|
||||
def predict(self, args):
|
||||
"""The actual predict step.
|
||||
|
||||
Args:
|
||||
args: the arguments for the models predict step. Can consist of the following:
|
||||
batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]):
|
||||
The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
|
||||
batch_idx (int): The index of this batch.
|
||||
dataloader_idx (int): The index of the dataloader that produced this batch
|
||||
(only if multiple predict dataloaders used).
|
||||
"""
|
||||
batch = self.to_device(args[0])
|
||||
|
||||
args[0] = batch
|
||||
|
||||
with self.precision_plugin.predict_context(), self.training_type_plugin.predict_context():
|
||||
return self.training_type_plugin.predict(*args)
|
||||
|
||||
def training_step_end(self, output):
|
||||
"""A hook to do something at the end of the training step
|
||||
|
|
|
@ -67,6 +67,7 @@ class ProgressBarBase(Callback):
|
|||
self._train_batch_idx = 0
|
||||
self._val_batch_idx = 0
|
||||
self._test_batch_idx = 0
|
||||
self._predict_batch_idx = 0
|
||||
|
||||
@property
|
||||
def trainer(self):
|
||||
|
@ -96,6 +97,14 @@ class ProgressBarBase(Callback):
|
|||
"""
|
||||
return self._test_batch_idx
|
||||
|
||||
@property
|
||||
def predict_batch_idx(self) -> int:
|
||||
"""
|
||||
The current batch index being processed during predicting.
|
||||
Use this to update your progress bar.
|
||||
"""
|
||||
return self._predict_batch_idx
|
||||
|
||||
@property
|
||||
def total_train_batches(self) -> int:
|
||||
"""
|
||||
|
@ -108,7 +117,7 @@ class ProgressBarBase(Callback):
|
|||
@property
|
||||
def total_val_batches(self) -> int:
|
||||
"""
|
||||
The total number of training batches during validation, which may change from epoch to epoch.
|
||||
The total number of validation batches during validation, which may change from epoch to epoch.
|
||||
Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the
|
||||
validation dataloader is of infinite size.
|
||||
"""
|
||||
|
@ -121,12 +130,21 @@ class ProgressBarBase(Callback):
|
|||
@property
|
||||
def total_test_batches(self) -> int:
|
||||
"""
|
||||
The total number of training batches during testing, which may change from epoch to epoch.
|
||||
The total number of testing batches during testing, which may change from epoch to epoch.
|
||||
Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the
|
||||
test dataloader is of infinite size.
|
||||
"""
|
||||
return sum(self.trainer.num_test_batches)
|
||||
|
||||
@property
|
||||
def total_predict_batches(self) -> int:
|
||||
"""
|
||||
The total number of predicting batches during testing, which may change from epoch to epoch.
|
||||
Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the
|
||||
predict dataloader is of infinite size.
|
||||
"""
|
||||
return sum(self.trainer.num_predict_batches)
|
||||
|
||||
def disable(self):
|
||||
"""
|
||||
You should provide a way to disable the progress bar.
|
||||
|
@ -168,6 +186,12 @@ class ProgressBarBase(Callback):
|
|||
def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
|
||||
self._test_batch_idx += 1
|
||||
|
||||
def on_predict_start(self, trainer, pl_module):
|
||||
self._predict_batch_idx = 0
|
||||
|
||||
def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
|
||||
self._predict_batch_idx += 1
|
||||
|
||||
|
||||
class ProgressBar(ProgressBarBase):
|
||||
r"""
|
||||
|
@ -282,6 +306,20 @@ class ProgressBar(ProgressBarBase):
|
|||
)
|
||||
return bar
|
||||
|
||||
def init_predict_tqdm(self) -> tqdm:
|
||||
""" Override this to customize the tqdm bar for predicting. """
|
||||
bar = tqdm(
|
||||
desc='Predicting',
|
||||
initial=self.train_batch_idx,
|
||||
position=(2 * self.process_position),
|
||||
disable=self.is_disabled,
|
||||
leave=True,
|
||||
dynamic_ncols=True,
|
||||
file=sys.stdout,
|
||||
smoothing=0,
|
||||
)
|
||||
return bar
|
||||
|
||||
def init_validation_tqdm(self) -> tqdm:
|
||||
""" Override this to customize the tqdm bar for validation. """
|
||||
bar = tqdm(
|
||||
|
@ -294,12 +332,10 @@ class ProgressBar(ProgressBarBase):
|
|||
)
|
||||
return bar
|
||||
|
||||
def init_test_tqdm(self, trainer=None) -> tqdm:
|
||||
def init_test_tqdm(self) -> tqdm:
|
||||
""" Override this to customize the tqdm bar for testing. """
|
||||
desc = "Testing"
|
||||
desc = "Predicting" if trainer is not None and getattr(trainer, "is_predicting", False) else "Testing"
|
||||
bar = tqdm(
|
||||
desc=desc,
|
||||
desc="Testing",
|
||||
position=(2 * self.process_position),
|
||||
disable=self.is_disabled,
|
||||
leave=True,
|
||||
|
@ -365,7 +401,7 @@ class ProgressBar(ProgressBarBase):
|
|||
|
||||
def on_test_start(self, trainer, pl_module):
|
||||
super().on_test_start(trainer, pl_module)
|
||||
self.test_progress_bar = self.init_test_tqdm(trainer=trainer)
|
||||
self.test_progress_bar = self.init_test_tqdm()
|
||||
self.test_progress_bar.total = convert_inf(self.total_test_batches)
|
||||
|
||||
def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
|
||||
|
@ -377,6 +413,19 @@ class ProgressBar(ProgressBarBase):
|
|||
super().on_test_end(trainer, pl_module)
|
||||
self.test_progress_bar.close()
|
||||
|
||||
def on_predict_start(self, trainer, pl_module):
|
||||
super().on_predict_start(trainer, pl_module)
|
||||
self.predict_progress_bar = self.init_predict_tqdm()
|
||||
self.predict_progress_bar.total = convert_inf(self.total_predict_batches)
|
||||
|
||||
def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
|
||||
super().on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
|
||||
if self._should_update(self.predict_batch_idx, self.total_predict_batches):
|
||||
self._update_bar(self.predict_progress_bar)
|
||||
|
||||
def on_predict_end(self, trainer, pl_module):
|
||||
self.predict_progress_bar.close()
|
||||
|
||||
def _should_update(self, current, total):
|
||||
return self.is_enabled and (current % self.refresh_rate == 0 or current == total)
|
||||
|
||||
|
|
|
@ -260,6 +260,10 @@ class LightningDataModule(DataHooks, CheckpointHooks, metaclass=_DataModuleWrapp
|
|||
def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def predict_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transfer_batch_to_device(self, batch: Any, device: torch.device) -> Any:
|
||||
pass
|
||||
|
|
|
@ -204,17 +204,23 @@ class ModelHooks:
|
|||
"""
|
||||
# do something when the batch ends
|
||||
|
||||
def on_test_model_train(self) -> None:
|
||||
"""
|
||||
Sets the model to train during the test loop
|
||||
"""
|
||||
self.train()
|
||||
|
||||
def on_test_model_eval(self) -> None:
|
||||
"""
|
||||
Sets the model to eval during the test loop
|
||||
"""
|
||||
self.eval()
|
||||
|
||||
def on_test_model_train(self) -> None:
|
||||
def on_predict_model_eval(self) -> None:
|
||||
"""
|
||||
Sets the model to train during the test loop
|
||||
Sets the model to eval during the predict loop
|
||||
"""
|
||||
self.train()
|
||||
self.eval()
|
||||
|
||||
def on_epoch_start(self) -> None:
|
||||
"""
|
||||
|
@ -518,6 +524,31 @@ class DataHooks:
|
|||
will have an argument ``dataloader_idx`` which matches the order here.
|
||||
"""
|
||||
|
||||
def predict_dataloader(self) -> Union[DataLoader, List[DataLoader]]:
|
||||
r"""
|
||||
Implement one or multiple PyTorch DataLoaders for prediction.
|
||||
|
||||
It's recommended that all data downloads and preparation happen in :meth:`prepare_data`.
|
||||
|
||||
- :meth:`~pytorch_lightning.trainer.Trainer.fit`
|
||||
- ...
|
||||
- :meth:`prepare_data`
|
||||
- :meth:`train_dataloader`
|
||||
- :meth:`val_dataloader`
|
||||
- :meth:`test_dataloader`
|
||||
|
||||
Note:
|
||||
Lightning adds the correct sampler for distributed and arbitrary hardware
|
||||
There is no need to set it yourself.
|
||||
|
||||
Return:
|
||||
Single or multiple PyTorch DataLoaders.
|
||||
|
||||
Note:
|
||||
In the case where you return multiple prediction dataloaders, the :meth:`predict`
|
||||
will have an argument ``dataloader_idx`` which matches the order here.
|
||||
"""
|
||||
|
||||
def transfer_batch_to_device(self, batch: Any, device: Optional[torch.device] = None) -> Any:
|
||||
"""
|
||||
Override this hook if your :class:`~torch.utils.data.DataLoader` returns tensors
|
||||
|
|
|
@ -54,14 +54,22 @@ class _LightningModuleWrapperBase(torch.nn.Module):
|
|||
if not self.module.automatic_optimization:
|
||||
self.module.trainer.model.require_backward_grad_sync = False
|
||||
warn_if_output_is_none(output, "training_step")
|
||||
|
||||
elif running_stage == RunningStage.TESTING:
|
||||
output = self.module.test_step(*inputs, **kwargs)
|
||||
warn_if_output_is_none(output, "test_step")
|
||||
|
||||
elif running_stage == RunningStage.EVALUATING:
|
||||
output = self.module.validation_step(*inputs, **kwargs)
|
||||
warn_if_output_is_none(output, "validation_step")
|
||||
else:
|
||||
|
||||
elif running_stage == RunningStage.PREDICTING:
|
||||
output = self.module.predict(*inputs, **kwargs)
|
||||
warn_if_output_is_none(output, "predict")
|
||||
|
||||
else:
|
||||
output = self.module(*inputs, **kwargs)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
|
|
|
@ -33,11 +33,11 @@ class Plugin(ABC):
|
|||
Will be called by the accelerator.
|
||||
"""
|
||||
|
||||
def pre_training(self) -> None:
|
||||
"""Hook to do something before the training starts."""
|
||||
def pre_dispatch(self) -> None:
|
||||
"""Hook to do something before the training/evaluation/prediction starts."""
|
||||
|
||||
def post_training(self) -> None:
|
||||
"""Hook to do something after the training finishes."""
|
||||
def post_dispatch(self) -> None:
|
||||
"""Hook to do something after the training/evaluation/prediction finishes."""
|
||||
|
||||
@contextlib.contextmanager
|
||||
def train_step_context(self) -> Generator:
|
||||
|
@ -53,3 +53,8 @@ class Plugin(ABC):
|
|||
def test_step_context(self) -> Generator:
|
||||
"""A contextmanager for the teststep"""
|
||||
yield
|
||||
|
||||
@contextlib.contextmanager
|
||||
def predict_context(self) -> Generator:
|
||||
"""A contextmanager for the predict step"""
|
||||
yield
|
||||
|
|
|
@ -215,7 +215,7 @@ class DDPPlugin(ParallelPlugin):
|
|||
log.info(f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}")
|
||||
torch_distrib.init_process_group(torch_backend, rank=global_rank, world_size=world_size)
|
||||
|
||||
def pre_training(self):
|
||||
def pre_dispatch(self):
|
||||
# TODO: check if needed
|
||||
seed = os.environ.get("PL_GLOBAL_SEED")
|
||||
if seed is not None:
|
||||
|
@ -232,7 +232,7 @@ class DDPPlugin(ParallelPlugin):
|
|||
# where to store ip_table
|
||||
self.init_ddp_connection(self.global_rank, self.world_size)
|
||||
|
||||
# TODO: we moved it to the trainer.fit after calling pre_training
|
||||
# TODO: we moved it to the trainer.fit after calling pre_dispatch
|
||||
# ... need to double check that it is the correct place
|
||||
# self.trainer.call_setup_hook(self.model)
|
||||
|
||||
|
@ -257,7 +257,7 @@ class DDPPlugin(ParallelPlugin):
|
|||
|
||||
self.barrier()
|
||||
|
||||
def post_training(self):
|
||||
def post_dispatch(self):
|
||||
if "WORLD_SIZE" in os.environ:
|
||||
del os.environ["WORLD_SIZE"]
|
||||
|
||||
|
|
|
@ -110,6 +110,9 @@ class DDPSpawnPlugin(ParallelPlugin):
|
|||
def start_testing(self, trainer):
|
||||
mp.spawn(self.new_process, **self.mp_spawn_kwargs)
|
||||
|
||||
def start_predicting(self, trainer):
|
||||
mp.spawn(self.new_process, **self.mp_spawn_kwargs)
|
||||
|
||||
def new_process(self, process_idx, trainer, mp_queue):
|
||||
self.mp_queue = mp_queue
|
||||
|
||||
|
@ -128,7 +131,7 @@ class DDPSpawnPlugin(ParallelPlugin):
|
|||
# where to store ip_table
|
||||
self.init_ddp_connection(self.global_rank, self.world_size)
|
||||
|
||||
# TODO: we moved it to the trainer.fit after calling pre_training
|
||||
# TODO: we moved it to the trainer.fit after calling pre_dispatch
|
||||
# ... need to double check that it is the correct place
|
||||
# self.trainer.call_setup_hook(self.model)
|
||||
|
||||
|
@ -153,15 +156,12 @@ class DDPSpawnPlugin(ParallelPlugin):
|
|||
|
||||
self.barrier()
|
||||
|
||||
if trainer.testing:
|
||||
results = trainer.run_test()
|
||||
else:
|
||||
results = trainer.train()
|
||||
results = trainer.train_or_test_or_predict()
|
||||
|
||||
# persist info in ddp_spawn
|
||||
self.transfer_distrib_spawn_state_on_fit_end(results)
|
||||
|
||||
def post_training(self):
|
||||
def post_dispatch(self):
|
||||
# restore main state with best weights
|
||||
best_path = self.mp_queue.get()
|
||||
last_path = self.mp_queue.get()
|
||||
|
|
|
@ -50,7 +50,7 @@ class HorovodPlugin(ParallelPlugin):
|
|||
|
||||
self.model_to_device()
|
||||
|
||||
def pre_training(self):
|
||||
def pre_dispatch(self):
|
||||
|
||||
def _unpack_lightning_optimizer(opt):
|
||||
return opt._optimizer if isinstance(opt, LightningOptimizer) else opt
|
||||
|
@ -95,20 +95,26 @@ class HorovodPlugin(ParallelPlugin):
|
|||
stack.enter_context(optimizer.skip_synchronize())
|
||||
|
||||
# set up training routine
|
||||
self._results = trainer.train()
|
||||
self._results = trainer.run_train()
|
||||
|
||||
# Make sure all workers have finished training before returning to the user
|
||||
hvd.join()
|
||||
|
||||
def start_testing(self, trainer):
|
||||
with ExitStack() as stack:
|
||||
# set up training routine
|
||||
# self.trainer.train_loop.setup_training(self.trainer.model)
|
||||
self._results = trainer.run_test()
|
||||
|
||||
# Make sure all workers have finished training before returning to the user
|
||||
hvd.join()
|
||||
|
||||
def start_predicting(self, trainer):
|
||||
with ExitStack() as stack:
|
||||
# set up training routine
|
||||
self._results = trainer.run_predict()
|
||||
|
||||
# Make sure all workers have finished training before returning to the user
|
||||
hvd.join()
|
||||
|
||||
def barrier(self, *args, **kwargs):
|
||||
hvd.join()
|
||||
|
||||
|
|
|
@ -325,9 +325,9 @@ class RPCSequentialPlugin(RPCPlugin):
|
|||
# Initialize optimizer step on main process
|
||||
self.worker_optimizer_step(model=self.lightning_module, opt_idx=optimizer_idx, **kwargs)
|
||||
|
||||
def post_training(self):
|
||||
def post_training_step(self):
|
||||
if self.main_rpc_process:
|
||||
super().post_training()
|
||||
super().post_training_step()
|
||||
|
||||
def start_training(self, trainer: 'Trainer') -> None:
|
||||
if self.main_rpc_process:
|
||||
|
|
|
@ -36,14 +36,14 @@ class SingleTPUPlugin(SingleDevicePlugin):
|
|||
def model_to_device(self) -> None:
|
||||
self._model.to(self.root_device)
|
||||
|
||||
def pre_training(self) -> None:
|
||||
def pre_dispatch(self) -> None:
|
||||
if isinstance(self.device, int):
|
||||
self.device = xm.xla_device(self.device)
|
||||
|
||||
self.tpu_local_core_rank = xm.get_local_ordinal()
|
||||
self.tpu_global_core_rank = xm.get_ordinal()
|
||||
|
||||
def post_training(self) -> None:
|
||||
def post_dispatch(self) -> None:
|
||||
model = self.lightning_module
|
||||
|
||||
if on_colab_kaggle():
|
||||
|
|
|
@ -95,10 +95,7 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
|
|||
trainer.save_checkpoint = self.save_checkpoint
|
||||
self.barrier()
|
||||
|
||||
if trainer.testing:
|
||||
results = trainer.run_test()
|
||||
else:
|
||||
results = trainer.train()
|
||||
results = trainer.train_or_test_or_predict()
|
||||
|
||||
self.__save_end_of_training_weights(self.lightning_module)
|
||||
self.transfer_distrib_spawn_state_on_fit_end(results)
|
||||
|
@ -182,7 +179,7 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
|
|||
should_stop = int(stop.item()) == self.world_size
|
||||
return should_stop
|
||||
|
||||
def post_training(self) -> None:
|
||||
def post_dispatch(self) -> None:
|
||||
# TODO: Check if trainer references can be resolved otherwise
|
||||
model = self.lightning_module
|
||||
|
||||
|
@ -233,6 +230,9 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
|
|||
def start_testing(self, trainer) -> None:
|
||||
xmp.spawn(self.new_process, **self.xmp_spawn_kwargs)
|
||||
|
||||
def start_predicting(self, trainer) -> None:
|
||||
xmp.spawn(self.new_process, **self.xmp_spawn_kwargs)
|
||||
|
||||
def training_step(self, *args, **kwargs):
|
||||
return self.lightning_module.training_step(*args, **kwargs)
|
||||
|
||||
|
|
|
@ -112,12 +112,16 @@ class TrainingTypePlugin(Plugin, ABC):
|
|||
|
||||
def start_training(self, trainer: 'Trainer') -> None:
|
||||
# double dispatch to initiate the training loop
|
||||
self._results = trainer.train()
|
||||
self._results = trainer.run_train()
|
||||
|
||||
def start_testing(self, trainer: 'Trainer') -> None:
|
||||
# double dispatch to initiate the test loop
|
||||
self._results = trainer.run_test()
|
||||
|
||||
def start_predicting(self, trainer: 'Trainer') -> None:
|
||||
# double dispatch to initiate the predicting loop
|
||||
self._results = trainer.run_predict()
|
||||
|
||||
def training_step(self, *args, **kwargs):
|
||||
return self.lightning_module.training_step(*args, **kwargs)
|
||||
|
||||
|
|
|
@ -52,7 +52,7 @@ class ConfigValidator(object):
|
|||
# verify model has a train dataloader
|
||||
# -----------------------------------
|
||||
has_train_dataloader = is_overridden('train_dataloader', model)
|
||||
if not has_train_dataloader and not self.trainer._predicting:
|
||||
if not has_train_dataloader:
|
||||
raise MisconfigurationException(
|
||||
'No `train_dataloader()` method defined. Lightning `Trainer` expects as minimum a'
|
||||
' `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.'
|
||||
|
@ -62,7 +62,7 @@ class ConfigValidator(object):
|
|||
# verify model has optimizer
|
||||
# -----------------------------------
|
||||
has_optimizers = is_overridden('configure_optimizers', model)
|
||||
if not has_optimizers and not self.trainer._predicting:
|
||||
if not has_optimizers:
|
||||
raise MisconfigurationException(
|
||||
'No `configure_optimizers()` method defined. Lightning `Trainer` expects as minimum a'
|
||||
' `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.'
|
||||
|
|
|
@ -90,7 +90,14 @@ class DataConnector(object):
|
|||
'You cannot pass train_dataloader or val_dataloaders to trainer.fit if you supply a datamodule'
|
||||
)
|
||||
|
||||
def attach_dataloaders(self, model, train_dataloader=None, val_dataloaders=None, test_dataloaders=None):
|
||||
def attach_dataloaders(
|
||||
self,
|
||||
model,
|
||||
train_dataloader=None,
|
||||
val_dataloaders=None,
|
||||
test_dataloaders=None,
|
||||
predict_dataloaders=None
|
||||
):
|
||||
# when dataloader is passed via fit, patch the train_dataloader
|
||||
# functions to overwrite with these implementations
|
||||
if train_dataloader is not None:
|
||||
|
@ -102,6 +109,9 @@ class DataConnector(object):
|
|||
if test_dataloaders is not None:
|
||||
model.test_dataloader = _PatchDataLoader(test_dataloaders)
|
||||
|
||||
if predict_dataloaders is not None:
|
||||
model.predict_dataloader = _PatchDataLoader(predict_dataloaders)
|
||||
|
||||
def attach_datamodule(self, model, datamodule: Optional[LightningDataModule], stage: str) -> None:
|
||||
# Todo: required argument `stage` is not used
|
||||
|
||||
|
@ -118,6 +128,8 @@ class DataConnector(object):
|
|||
model.val_dataloader = datamodule.val_dataloader
|
||||
if is_overridden('test_dataloader', datamodule):
|
||||
model.test_dataloader = datamodule.test_dataloader
|
||||
if is_overridden('predict_dataloader', datamodule):
|
||||
model.predict_dataloader = datamodule.predict_dataloader
|
||||
|
||||
# Override transfer_batch_to_device if dataset-specific to_device logic has been defined in datamodule
|
||||
if is_overridden('transfer_batch_to_device', datamodule):
|
||||
|
|
|
@ -29,6 +29,7 @@ class DebuggingConnector:
|
|||
limit_train_batches,
|
||||
limit_val_batches,
|
||||
limit_test_batches,
|
||||
limit_predict_batches,
|
||||
val_check_interval,
|
||||
overfit_batches,
|
||||
fast_dev_run,
|
||||
|
@ -56,6 +57,7 @@ class DebuggingConnector:
|
|||
limit_train_batches = fast_dev_run
|
||||
limit_val_batches = fast_dev_run
|
||||
limit_test_batches = fast_dev_run
|
||||
limit_predict_batches = fast_dev_run
|
||||
self.trainer.max_steps = fast_dev_run
|
||||
self.trainer.num_sanity_val_steps = 0
|
||||
self.trainer.max_epochs = 1
|
||||
|
@ -71,6 +73,7 @@ class DebuggingConnector:
|
|||
self.trainer.limit_train_batches = _determine_batch_limits(limit_train_batches, 'limit_train_batches')
|
||||
self.trainer.limit_val_batches = _determine_batch_limits(limit_val_batches, 'limit_val_batches')
|
||||
self.trainer.limit_test_batches = _determine_batch_limits(limit_test_batches, 'limit_test_batches')
|
||||
self.trainer.limit_predict_batches = _determine_batch_limits(limit_predict_batches, 'limit_predict_batches')
|
||||
self.trainer.val_check_interval = _determine_batch_limits(val_check_interval, 'val_check_interval')
|
||||
self.trainer.overfit_batches = _determine_batch_limits(overfit_batches, 'overfit_batches')
|
||||
self.determine_data_use_amount(self.trainer.overfit_batches)
|
||||
|
|
|
@ -293,7 +293,8 @@ class TrainerDataLoadingMixin(ABC):
|
|||
loader = dataloaders[loader_i]
|
||||
|
||||
# shuffling in val and test set is bad practice
|
||||
if mode in ('val', 'test') and hasattr(loader, 'sampler') and isinstance(loader.sampler, RandomSampler):
|
||||
modes = ('val', 'test', 'predict')
|
||||
if mode in modes and hasattr(loader, 'sampler') and isinstance(loader.sampler, RandomSampler):
|
||||
|
||||
# when overfitting, the dataloader should not have sampler
|
||||
if self.overfit_batches > 0:
|
||||
|
@ -363,7 +364,7 @@ class TrainerDataLoadingMixin(ABC):
|
|||
self.num_val_batches, self.val_dataloaders = self._reset_eval_dataloader(model, 'val')
|
||||
|
||||
def reset_test_dataloader(self, model) -> None:
|
||||
"""Resets the validation dataloader and determines the number of batches.
|
||||
"""Resets the test dataloader and determines the number of batches.
|
||||
|
||||
Args:
|
||||
model: The current `LightningModule`
|
||||
|
@ -374,6 +375,17 @@ class TrainerDataLoadingMixin(ABC):
|
|||
self.num_test_batches, self.test_dataloaders =\
|
||||
self._reset_eval_dataloader(model, 'test')
|
||||
|
||||
def reset_predict_dataloader(self, model) -> None:
|
||||
"""Resets the predict dataloader and determines the number of batches.
|
||||
|
||||
Args:
|
||||
model: The current `LightningModule`
|
||||
"""
|
||||
has_loader = is_overridden('predict_dataloader', model)
|
||||
if has_loader:
|
||||
self.num_predict_batches, self.predict_dataloaders =\
|
||||
self._reset_eval_dataloader(model, 'predict')
|
||||
|
||||
def request_dataloader(self, dataloader_fx: Callable) -> DataLoader:
|
||||
"""Handles downloading data in the GPU or TPU case.
|
||||
|
||||
|
|
|
@ -154,16 +154,7 @@ class EvaluationLoop(object):
|
|||
model_ref = self.trainer.get_model()
|
||||
model_ref._results = Result()
|
||||
|
||||
if self.trainer._predicting:
|
||||
model_ref._current_fx_name = "predict"
|
||||
predictions = self.trainer.accelerator_backend.predict(args)
|
||||
self._predictions[dataloader_idx].append(predictions)
|
||||
self.trainer._progress_bar_callback.on_test_batch_end(
|
||||
self.trainer, model_ref, predictions, batch, batch_idx, dataloader_idx
|
||||
)
|
||||
return
|
||||
|
||||
elif self.testing:
|
||||
if self.testing:
|
||||
model_ref._current_fx_name = "test_step"
|
||||
with self.trainer.profiler.profile("test_step"):
|
||||
output = self.trainer.accelerator_backend.test_step(args)
|
||||
|
|
|
@ -0,0 +1,97 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
||||
|
||||
|
||||
class PredictLoop(object):
|
||||
|
||||
def __init__(self, trainer):
|
||||
self.trainer = trainer
|
||||
self.max_batches = None
|
||||
self.num_dataloaders = None
|
||||
|
||||
def on_trainer_init(self):
|
||||
self.trainer.num_predict_batches = []
|
||||
|
||||
def get_predict_dataloaders(self, max_batches):
|
||||
# select dataloaders
|
||||
model = self.trainer.get_model()
|
||||
self.trainer.reset_predict_dataloader(model)
|
||||
dataloaders = self.trainer.predict_dataloaders
|
||||
if max_batches is None:
|
||||
max_batches = self.trainer.num_predict_batches
|
||||
|
||||
return dataloaders, max_batches
|
||||
|
||||
def should_skip_predict(self, dataloaders, max_batches):
|
||||
return dataloaders is None or not sum(max_batches)
|
||||
|
||||
def on_predict_model_eval(self, *_, **__):
|
||||
model_ref = self.trainer.get_model()
|
||||
model_ref.on_predict_model_eval()
|
||||
|
||||
def setup(self, model, max_batches, dataloaders):
|
||||
# copy properties for forward overrides
|
||||
self.trainer.model_connector.copy_trainer_model_properties(model)
|
||||
|
||||
# convert max_batches to list
|
||||
if isinstance(max_batches, int):
|
||||
max_batches = [max_batches] * len(dataloaders)
|
||||
|
||||
self.max_batches = max_batches
|
||||
self.num_dataloaders = self._get_num_dataloaders(dataloaders)
|
||||
self._predictions = [[] for _ in range(self.num_dataloaders)]
|
||||
|
||||
self.trainer._progress_bar_callback.on_predict_start(self.trainer, self.trainer.get_model())
|
||||
|
||||
def _get_num_dataloaders(self, dataloaders):
|
||||
# case where user does:
|
||||
# return dl1, dl2
|
||||
length = len(dataloaders)
|
||||
if len(dataloaders) > 0 and isinstance(dataloaders[0], (list, tuple)):
|
||||
length = len(dataloaders[0])
|
||||
return length
|
||||
|
||||
def predict(self, batch, batch_idx, dataloader_idx):
|
||||
# configure args
|
||||
args = [batch, batch_idx]
|
||||
if self.num_dataloaders:
|
||||
args.append(dataloader_idx)
|
||||
|
||||
model_ref = self.trainer.get_model()
|
||||
|
||||
model_ref._current_fx_name = "predict"
|
||||
predictions = self.trainer.accelerator_backend.predict(args)
|
||||
self._predictions[dataloader_idx].append(predictions)
|
||||
self.trainer._progress_bar_callback.on_predict_batch_end(
|
||||
self.trainer, model_ref, predictions, batch, batch_idx, dataloader_idx
|
||||
)
|
||||
return
|
||||
|
||||
def on_predict_epoch_end(self):
|
||||
self.trainer._progress_bar_callback.on_predict_end(self.trainer, self.trainer.get_model())
|
||||
|
||||
results = self._predictions
|
||||
|
||||
def _convert_to_numpy(v):
|
||||
return v.cpu().numpy()
|
||||
|
||||
results = apply_to_collection(results, torch.Tensor, _convert_to_numpy)
|
||||
|
||||
if len(results) == 1:
|
||||
return results[0]
|
||||
|
||||
return results
|
|
@ -49,6 +49,7 @@ from pytorch_lightning.trainer.evaluation_loop import EvaluationLoop
|
|||
from pytorch_lightning.trainer.logging import TrainerLoggingMixin
|
||||
from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin
|
||||
from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin
|
||||
from pytorch_lightning.trainer.predict_loop import PredictLoop
|
||||
from pytorch_lightning.trainer.properties import TrainerProperties
|
||||
from pytorch_lightning.trainer.states import RunningStage, TrainerState
|
||||
from pytorch_lightning.trainer.training_loop import TrainLoop
|
||||
|
@ -57,6 +58,7 @@ from pytorch_lightning.tuner.tuning import Tuner
|
|||
from pytorch_lightning.utilities import DeviceType, rank_zero_warn
|
||||
from pytorch_lightning.utilities.cloud_io import load as pl_load
|
||||
from pytorch_lightning.utilities.debugging import InternalDebugger
|
||||
from pytorch_lightning.utilities.enums import LightningEnum
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.memory import recursive_detach
|
||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||
|
@ -107,6 +109,7 @@ class Trainer(
|
|||
limit_train_batches: Union[int, float] = 1.0,
|
||||
limit_val_batches: Union[int, float] = 1.0,
|
||||
limit_test_batches: Union[int, float] = 1.0,
|
||||
limit_predict_batches: Union[int, float] = 1.0,
|
||||
val_check_interval: Union[int, float] = 1.0,
|
||||
flush_logs_every_n_steps: int = 100,
|
||||
log_every_n_steps: int = 50,
|
||||
|
@ -296,7 +299,6 @@ class Trainer(
|
|||
"""
|
||||
super().__init__()
|
||||
self._running_stage = None
|
||||
self._predicting = False
|
||||
|
||||
distributed_backend = distributed_backend or accelerator
|
||||
|
||||
|
@ -319,8 +321,9 @@ class Trainer(
|
|||
self.checkpoint_connector = CheckpointConnector(self)
|
||||
self.slurm_connector = SLURMConnector(self)
|
||||
self.tuner = Tuner(self)
|
||||
self.evaluation_loop = EvaluationLoop(self)
|
||||
self.train_loop = TrainLoop(self, multiple_trainloader_mode)
|
||||
self.evaluation_loop = EvaluationLoop(self)
|
||||
self.predict_loop = PredictLoop(self)
|
||||
|
||||
# training state
|
||||
self.weights_summary = weights_summary
|
||||
|
@ -393,6 +396,7 @@ class Trainer(
|
|||
limit_train_batches,
|
||||
limit_val_batches,
|
||||
limit_test_batches,
|
||||
limit_predict_batches,
|
||||
val_check_interval,
|
||||
overfit_batches,
|
||||
fast_dev_run,
|
||||
|
@ -440,7 +444,11 @@ class Trainer(
|
|||
"""
|
||||
# bookkeeping
|
||||
self._state = TrainerState.RUNNING
|
||||
self._set_wide_running_stage(RunningStage.TRAINING)
|
||||
|
||||
# bookkeeping
|
||||
# we reuse fit in .test() and .predict(). When already set, it shouldn't be modified.
|
||||
if self._running_stage is None:
|
||||
self._set_running_stage(RunningStage.TRAINING, model)
|
||||
|
||||
# set local properties on the model
|
||||
self.model_connector.copy_trainer_model_properties(model)
|
||||
|
@ -463,27 +471,47 @@ class Trainer(
|
|||
self.accelerator_backend.setup(self, model)
|
||||
self.setup_trainer(model)
|
||||
|
||||
# ----------------------------
|
||||
# INSPECT THE CORE LOOPS
|
||||
# ----------------------------
|
||||
# Lightning internal flow looks like this.
|
||||
#
|
||||
# trainer.fit(...) or trainer.test(...) or trainer.predict(...) ||
|
||||
# | ||
|
||||
# create accelerator ||
|
||||
# | ||
|
||||
# trainer.dispatch || LIGHTNING
|
||||
# | ||
|
||||
# start_training or start_testing or start_predicting call || FLOW
|
||||
# from `accelerator.training_type_plugin` ||
|
||||
# | || DIRECTION
|
||||
# run_train or run_test or run_predict call ||
|
||||
# from `trainer` ||
|
||||
# | ||
|
||||
# results \/
|
||||
# This is used to guide readers to the core loops: train, test, predict.
|
||||
# `run_predict` is the simplest to understand, use `Go to Definition` to read it :)
|
||||
# Search for `start_training` or `start_testing` or `start_predicting` in
|
||||
# `pytorch_lightning/plugins/training_type` folder to find accelerator dispatch functions.
|
||||
self.accelerator.train_loop = self.run_train
|
||||
self.accelerator.validation_loop = self.run_evaluation
|
||||
self.accelerator.test_loop = self.run_evaluation
|
||||
self.accelerator.predict_loop = self.run_predict
|
||||
|
||||
# ----------------------------
|
||||
# TRAIN
|
||||
# ----------------------------
|
||||
# hook
|
||||
self.call_hook("on_fit_start")
|
||||
|
||||
# plugin will setup training (e.g. ddp will launch child processes)
|
||||
# TODO: the old setup is now called "pre_training", where should this hook be called now?
|
||||
self.training_type_plugin.pre_training()
|
||||
self.precision_plugin.pre_training()
|
||||
# plugin will setup fitting (e.g. ddp will launch child processes)
|
||||
self.pre_dispatch()
|
||||
|
||||
# double dispatch: let the plugin initiate the training/test loop.
|
||||
if self.testing:
|
||||
self.training_type_plugin.start_testing(self)
|
||||
else:
|
||||
self.training_type_plugin.start_training(self)
|
||||
# dispath `start_training` or `start_testing` or `start_predicting`
|
||||
self.dispatch()
|
||||
|
||||
self.precision_plugin.post_training()
|
||||
self.training_type_plugin.post_training()
|
||||
self.accelerator_backend.teardown()
|
||||
results = self.training_type_plugin.results
|
||||
# plugin will finalized fitting (e.g. ddp_spawn will load trained model)
|
||||
self.post_dispatch()
|
||||
|
||||
# ----------------------------
|
||||
# POST-Training CLEAN UP
|
||||
|
@ -501,31 +529,47 @@ class Trainer(
|
|||
if self._state != TrainerState.INTERRUPTED:
|
||||
self._state = TrainerState.FINISHED
|
||||
|
||||
self._set_wide_running_stage(None)
|
||||
self._set_running_stage(None, model)
|
||||
|
||||
return results or 1
|
||||
return self.training_type_plugin.results or 1
|
||||
|
||||
def _set_wide_running_stage(self, stage):
|
||||
model_ref = self.get_model()
|
||||
def pre_dispatch(self):
|
||||
self.training_type_plugin.pre_dispatch()
|
||||
self.precision_plugin.pre_dispatch()
|
||||
|
||||
if stage is None:
|
||||
self._running_stage = stage
|
||||
model_ref.running_stage = stage
|
||||
return
|
||||
def post_dispatch(self):
|
||||
self.training_type_plugin.post_dispatch()
|
||||
self.precision_plugin.post_dispatch()
|
||||
self.accelerator_backend.teardown()
|
||||
|
||||
# todo: clean up this routing mess.
|
||||
if self._running_stage == RunningStage.TESTING:
|
||||
stage = RunningStage.TESTING
|
||||
def dispatch(self):
|
||||
if self.testing:
|
||||
self.training_type_plugin.start_testing(self)
|
||||
|
||||
# WARNING: With predicting,
|
||||
# trainer _running_state should be RunningStage.TESTING
|
||||
# however, the model running_stage should be RunningStage.PREDICTING or None
|
||||
if model_ref is not None:
|
||||
if self._predicting:
|
||||
model_ref.running_stage = RunningStage.PREDICTING
|
||||
else:
|
||||
model_ref.running_stage = stage
|
||||
elif self.predicting:
|
||||
self.training_type_plugin.start_predicting(self)
|
||||
|
||||
else:
|
||||
self.training_type_plugin.start_training(self)
|
||||
|
||||
def train_or_test_or_predict(self):
|
||||
if self.testing:
|
||||
results = self.run_test()
|
||||
|
||||
elif self.predicting:
|
||||
results = self.run_predict()
|
||||
|
||||
else:
|
||||
results = self.run_train()
|
||||
|
||||
return results
|
||||
|
||||
def _set_running_stage(self, stage: LightningEnum, model_ref: LightningModule):
|
||||
"""
|
||||
This function is used to set the running_state on both
|
||||
the trainer and the model
|
||||
"""
|
||||
model_ref.running_stage = stage
|
||||
self._running_stage = stage
|
||||
|
||||
def _pre_training_routine(self):
|
||||
|
@ -560,7 +604,7 @@ class Trainer(
|
|||
if self.is_function_implemented("on_pretrain_routine_end"):
|
||||
ref_model.on_pretrain_routine_end()
|
||||
|
||||
def train(self):
|
||||
def run_train(self):
|
||||
|
||||
self._pre_training_routine()
|
||||
|
||||
|
@ -570,7 +614,7 @@ class Trainer(
|
|||
self.run_sanity_check(self.get_model())
|
||||
|
||||
# set stage for logging
|
||||
self._set_wide_running_stage(RunningStage.TRAINING)
|
||||
self._set_running_stage(RunningStage.TRAINING, self.get_model())
|
||||
|
||||
self.checkpoint_connector.has_trained = False
|
||||
|
||||
|
@ -634,7 +678,7 @@ class Trainer(
|
|||
def run_evaluation(self, max_batches=None, on_epoch=False):
|
||||
|
||||
# used to know if we are logging for val, test + reset cached results
|
||||
self._set_wide_running_stage(RunningStage.TESTING if self.testing else RunningStage.EVALUATING)
|
||||
self._set_running_stage(RunningStage.TESTING if self.testing else RunningStage.EVALUATING, self.get_model())
|
||||
self.logger_connector.reset()
|
||||
|
||||
# bookkeeping
|
||||
|
@ -647,11 +691,10 @@ class Trainer(
|
|||
if self.evaluation_loop.should_skip_evaluation(max_batches):
|
||||
return [], []
|
||||
|
||||
# ref model
|
||||
model = self.get_model()
|
||||
|
||||
# enable eval mode + no grads
|
||||
self.evaluation_loop.on_evaluation_model_eval()
|
||||
# ref model
|
||||
model = self.get_model()
|
||||
model.zero_grad()
|
||||
torch.set_grad_enabled(False)
|
||||
|
||||
|
@ -685,8 +728,6 @@ class Trainer(
|
|||
# lightning module methods
|
||||
with self.profiler.profile("evaluation_step_and_end"):
|
||||
output = self.evaluation_loop.evaluation_step(batch, batch_idx, dataloader_idx)
|
||||
if self._predicting:
|
||||
continue
|
||||
output = self.evaluation_loop.evaluation_step_end(output)
|
||||
|
||||
# hook + store predictions
|
||||
|
@ -701,9 +742,6 @@ class Trainer(
|
|||
# store batch level output per dataloader
|
||||
self.evaluation_loop.outputs.append(dl_outputs)
|
||||
|
||||
if self._predicting:
|
||||
return self.evaluation_loop.on_predict_epoch_end()
|
||||
|
||||
# lightning module method
|
||||
deprecated_eval_results = self.evaluation_loop.evaluation_epoch_end()
|
||||
|
||||
|
@ -764,6 +802,45 @@ class Trainer(
|
|||
|
||||
return eval_loop_results
|
||||
|
||||
def run_predict(self):
|
||||
# prepare dataloaders
|
||||
dataloaders, max_batches = self.predict_loop.get_predict_dataloaders(None)
|
||||
|
||||
# check if we want to skip this evaluation
|
||||
if self.predict_loop.should_skip_predict(dataloaders, max_batches):
|
||||
return []
|
||||
|
||||
# ref model
|
||||
model = self.get_model()
|
||||
|
||||
# enable eval mode + no grads
|
||||
self.predict_loop.on_predict_model_eval()
|
||||
model.zero_grad()
|
||||
torch.set_grad_enabled(False)
|
||||
|
||||
# set up the eval loop
|
||||
self.predict_loop.setup(model, max_batches, dataloaders)
|
||||
|
||||
# run validation/testing
|
||||
for dataloader_idx, dataloader in enumerate(dataloaders):
|
||||
dataloader = self.accelerator_backend.process_dataloader(dataloader)
|
||||
dl_max_batches = self.predict_loop.max_batches[dataloader_idx]
|
||||
|
||||
for batch_idx, batch in enumerate(dataloader):
|
||||
if batch is None:
|
||||
continue
|
||||
|
||||
# stop short when running on limited batches
|
||||
if batch_idx >= dl_max_batches:
|
||||
break
|
||||
|
||||
# lightning module methods
|
||||
with self.profiler.profile("predict"):
|
||||
self.predict_loop.predict(batch, batch_idx, dataloader_idx)
|
||||
|
||||
results = self.predict_loop.on_predict_epoch_end()
|
||||
return results
|
||||
|
||||
def run_sanity_check(self, ref_model):
|
||||
using_val_step = ref_model.val_dataloader is not None and is_overridden('validation_step', ref_model)
|
||||
should_sanity_check = using_val_step and self.num_sanity_val_steps > 0 and self.limit_val_batches > 0
|
||||
|
@ -828,7 +905,7 @@ class Trainer(
|
|||
# --------------------
|
||||
self.verbose_test = verbose
|
||||
|
||||
self._set_wide_running_stage(RunningStage.TESTING)
|
||||
self._set_running_stage(RunningStage.TESTING, model or self.get_model())
|
||||
|
||||
# If you supply a datamodule you can't supply train_dataloader or val_dataloaders
|
||||
if test_dataloaders and datamodule:
|
||||
|
@ -845,9 +922,7 @@ class Trainer(
|
|||
results = self.__test_using_best_weights(ckpt_path, test_dataloaders)
|
||||
|
||||
self.teardown('test')
|
||||
|
||||
self._set_wide_running_stage(None)
|
||||
|
||||
self._set_running_stage(None, model or self.get_model())
|
||||
return results
|
||||
|
||||
def __test_using_best_weights(self, ckpt_path, test_dataloaders):
|
||||
|
@ -935,35 +1010,28 @@ class Trainer(
|
|||
# --------------------
|
||||
# SETUP HOOK
|
||||
# --------------------
|
||||
self._set_wide_running_stage(RunningStage.TESTING)
|
||||
|
||||
# If you supply a datamodule you can't supply dataloaders
|
||||
|
||||
model = model or self.get_model()
|
||||
|
||||
self._set_running_stage(RunningStage.PREDICTING, model)
|
||||
|
||||
if dataloaders and datamodule:
|
||||
raise MisconfigurationException(
|
||||
'You cannot pass dataloaders to trainer.predict if you supply a datamodule.'
|
||||
)
|
||||
|
||||
if model is None:
|
||||
raise MisconfigurationException('You need to pass a model to `trainer.predict`.')
|
||||
|
||||
if datamodule is not None:
|
||||
# Attach datamodule to get setup/prepare_data added to model before the call to it below
|
||||
self.data_connector.attach_datamodule(model, datamodule, 'test')
|
||||
self.data_connector.attach_datamodule(model, datamodule, 'predict')
|
||||
|
||||
# attach data
|
||||
if dataloaders is not None:
|
||||
self.data_connector.attach_dataloaders(model, test_dataloaders=dataloaders)
|
||||
self.data_connector.attach_dataloaders(model, predict_dataloaders=dataloaders)
|
||||
|
||||
# set path variable
|
||||
self._predicting = True
|
||||
self.model = model
|
||||
|
||||
results = self.fit(model)
|
||||
|
||||
# unset path variable
|
||||
self.teardown('test')
|
||||
self._predicting = False
|
||||
self._set_wide_running_stage(None)
|
||||
self._set_running_stage(None, model)
|
||||
|
||||
return results
|
||||
|
||||
|
@ -1069,6 +1137,17 @@ class Trainer(
|
|||
elif self.testing:
|
||||
self._running_stage = None
|
||||
|
||||
@property
|
||||
def predicting(self) -> bool:
|
||||
return self._running_stage == RunningStage.PREDICTING
|
||||
|
||||
@predicting.setter
|
||||
def predicting(self, val: bool) -> None:
|
||||
if val:
|
||||
self._running_stage = RunningStage.PREDICTING
|
||||
elif self.predicting:
|
||||
self._running_stage = None
|
||||
|
||||
@property
|
||||
def tuning(self) -> bool:
|
||||
return self._running_stage == RunningStage.TUNING
|
||||
|
|
|
@ -547,7 +547,7 @@ class TrainLoop:
|
|||
self.trainer.run_evaluation()
|
||||
|
||||
# reset stage to train
|
||||
self.trainer._set_wide_running_stage(RunningStage.TRAINING)
|
||||
self.trainer._set_running_stage(RunningStage.TRAINING, self.trainer.lightning_module)
|
||||
|
||||
# -----------------------------------------
|
||||
# SAVE LOGGERS (ie: Tensorboard, etc...)
|
||||
|
@ -594,7 +594,7 @@ class TrainLoop:
|
|||
self.trainer.run_evaluation(on_epoch=True)
|
||||
|
||||
# reset stage to train
|
||||
self.trainer._set_wide_running_stage(RunningStage.TRAINING)
|
||||
self.trainer._set_running_stage(RunningStage.TRAINING, self.trainer.lightning_module)
|
||||
|
||||
should_skip_eval = self.trainer.evaluation_loop.should_skip_evaluation(self.trainer.num_val_batches)
|
||||
should_train_only = self.trainer.disable_validation or should_skip_eval
|
||||
|
|
|
@ -288,8 +288,8 @@ class MockedUpdateProgressBars(ProgressBar):
|
|||
bar = super().init_validation_tqdm()
|
||||
return self._mock_bar_update(bar)
|
||||
|
||||
def init_test_tqdm(self, trainer=None):
|
||||
bar = super().init_test_tqdm(trainer=trainer)
|
||||
def init_test_tqdm(self):
|
||||
bar = super().init_test_tqdm()
|
||||
return self._mock_bar_update(bar)
|
||||
|
||||
|
||||
|
|
|
@ -39,7 +39,7 @@ def test_lightning_wrapper_module_methods(wrapper_class):
|
|||
wrapped_module(batch, batch_idx)
|
||||
pl_module.validation_step.assert_called_with(batch, batch_idx)
|
||||
|
||||
pl_module.running_stage = None
|
||||
pl_module.running_stage = RunningStage.PREDICTING
|
||||
wrapped_module(batch)
|
||||
pl_module.predict.assert_called_with(batch)
|
||||
|
||||
|
|
|
@ -1498,6 +1498,9 @@ class TestLightningDataModule(LightningDataModule):
|
|||
def test_dataloader(self):
|
||||
return self._dataloaders
|
||||
|
||||
def predict_dataloader(self):
|
||||
return self._dataloaders
|
||||
|
||||
|
||||
def predict(tmpdir, accelerator, gpus, num_processes, plugins=None, datamodule=True):
|
||||
|
||||
|
@ -1515,7 +1518,6 @@ def predict(tmpdir, accelerator, gpus, num_processes, plugins=None, datamodule=T
|
|||
gpus=gpus,
|
||||
num_processes=num_processes,
|
||||
plugins=plugins,
|
||||
num_sanity_val_steps=0
|
||||
)
|
||||
if datamodule:
|
||||
results = trainer.predict(model, datamodule=datamodule)
|
||||
|
@ -1529,9 +1531,6 @@ def predict(tmpdir, accelerator, gpus, num_processes, plugins=None, datamodule=T
|
|||
assert results[0][0].shape == torch.Size([1, 2])
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest"
|
||||
)
|
||||
@pytest.mark.parametrize('datamodule', [False, True])
|
||||
def test_trainer_predict_cpu(tmpdir, datamodule):
|
||||
predict(tmpdir, None, None, 1, datamodule=datamodule)
|
||||
|
|
Loading…
Reference in New Issue