From 0a2fb05aacb35218a85d0d719a8cd9f4f3f7ff02 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Mon, 15 Feb 2021 19:24:36 +0900 Subject: [PATCH] Document exceptions in callbacks (#5541) * Add Raises: section to docstring * Add Raises section to the docs * Add raises section to the docs * Apply suggestions from code review Co-authored-by: Jirka Borovec * fix * Remove unnecessary instance check Co-authored-by: Jirka Borovec --- pytorch_lightning/callbacks/early_stopping.py | 6 ++++++ pytorch_lightning/callbacks/finetuning.py | 8 ++++++-- pytorch_lightning/callbacks/gpu_stats_monitor.py | 4 ++++ .../callbacks/gradient_accumulation_scheduler.py | 7 +++++++ pytorch_lightning/callbacks/lr_monitor.py | 8 ++++++++ pytorch_lightning/callbacks/model_checkpoint.py | 8 ++++++++ pytorch_lightning/callbacks/pruning.py | 13 +++++++++++++ 7 files changed, 52 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 7f42af82c4..384ce9699f 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -54,6 +54,12 @@ class EarlyStopping(Callback): strict: whether to crash the training if `monitor` is not found in the validation metrics. Default: ``True``. + Raises: + MisconfigurationException: + If ``mode`` is none of ``"min"``, ``"max"``, and ``"auto"``. + RuntimeError: + If the metric ``monitor`` is not available. + Example:: >>> from pytorch_lightning import Trainer diff --git a/pytorch_lightning/callbacks/finetuning.py b/pytorch_lightning/callbacks/finetuning.py index 02e7180a47..9f2697a9f9 100644 --- a/pytorch_lightning/callbacks/finetuning.py +++ b/pytorch_lightning/callbacks/finetuning.py @@ -318,8 +318,12 @@ class BackboneFinetuning(BaseFinetuning): self.verbose = verbose def on_fit_start(self, trainer, pl_module): - if hasattr(pl_module, "backbone") and \ - (isinstance(pl_module.backbone, Module) or isinstance(pl_module.backbone, Sequential)): + """ + Raises: + MisconfigurationException: + If LightningModule has no nn.Module `backbone` attribute. + """ + if hasattr(pl_module, "backbone") and isinstance(pl_module.backbone, Module): return raise MisconfigurationException("The LightningModule should have a nn.Module `backbone` attribute") diff --git a/pytorch_lightning/callbacks/gpu_stats_monitor.py b/pytorch_lightning/callbacks/gpu_stats_monitor.py index 2c1c6df18f..ace69b0234 100644 --- a/pytorch_lightning/callbacks/gpu_stats_monitor.py +++ b/pytorch_lightning/callbacks/gpu_stats_monitor.py @@ -48,6 +48,10 @@ class GPUStatsMonitor(Callback): temperature: Set to ``True`` to monitor the memory and gpu temperature in degree Celsius. Default: ``False``. + Raises: + MisconfigurationException: + If NVIDIA driver is not installed, not running on GPUs, or ``Trainer`` has no logger. + Example:: >>> from pytorch_lightning import Trainer diff --git a/pytorch_lightning/callbacks/gradient_accumulation_scheduler.py b/pytorch_lightning/callbacks/gradient_accumulation_scheduler.py index ed935a67bf..0af7d61bf5 100644 --- a/pytorch_lightning/callbacks/gradient_accumulation_scheduler.py +++ b/pytorch_lightning/callbacks/gradient_accumulation_scheduler.py @@ -32,6 +32,13 @@ class GradientAccumulationScheduler(Callback): Args: scheduling: scheduling in format {epoch: accumulation_factor} + Raises: + TypeError: + If ``scheduling`` is an empty ``dict``, + or not all keys and values of ``scheduling`` are integers. + IndexError: + If ``minimal_epoch`` is less than 0. + Example:: >>> from pytorch_lightning import Trainer diff --git a/pytorch_lightning/callbacks/lr_monitor.py b/pytorch_lightning/callbacks/lr_monitor.py index 726286ed61..7530bfaa9d 100644 --- a/pytorch_lightning/callbacks/lr_monitor.py +++ b/pytorch_lightning/callbacks/lr_monitor.py @@ -38,6 +38,10 @@ class LearningRateMonitor(Callback): log_momentum: option to also log the momentum values of the optimizer, if the optimizer has the ``momentum`` or ``betas`` attribute. Defaults to ``False``. + Raises: + MisconfigurationException: + If ``logging_interval`` is none of ``"step"``, ``"epoch"``, or ``None``. + Example:: >>> from pytorch_lightning import Trainer @@ -77,6 +81,10 @@ class LearningRateMonitor(Callback): Called before training, determines unique names for all lr schedulers in the case of multiple of the same type or in the case of multiple parameter groups + + Raises: + MisconfigurationException: + If ``Trainer`` has no ``logger``. """ if not trainer.logger: raise MisconfigurationException( diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index e6de1737b3..a3f048c67c 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -115,6 +115,14 @@ class ModelCheckpoint(Callback): For example, you can change the default last checkpoint name by doing ``checkpoint_callback.CHECKPOINT_NAME_LAST = "{epoch}-last"`` + Raises: + MisconfigurationException: + If ``save_top_k`` is neither ``None`` nor more than or equal to ``-1``, + if ``monitor`` is ``None`` and ``save_top_k`` is none of ``None``, ``-1``, and ``0``, or + if ``mode`` is none of ``"min"``, ``"max"``, and ``"auto"``. + ValueError: + If ``trainer.save_checkpoint`` is ``None``. + Example:: >>> from pytorch_lightning import Trainer diff --git a/pytorch_lightning/callbacks/pruning.py b/pytorch_lightning/callbacks/pruning.py index 253cd0bbc4..d3d280dbaa 100644 --- a/pytorch_lightning/callbacks/pruning.py +++ b/pytorch_lightning/callbacks/pruning.py @@ -135,6 +135,14 @@ class ModelPruning(Callback): verbose: Verbosity level. 0 to disable, 1 to log overall sparsity, 2 to log per-layer sparsity + Raises: + MisconfigurationException: + If ``parameter_names`` is neither ``"weight"`` nor ``"bias"``, + if the provided ``pruning_fn`` is not supported, + if ``pruning_dim`` is not provided when ``"unstructured"``, + if ``pruning_norm`` is not provided when ``"ln_structured"``, + if ``pruning_fn`` is neither ``str`` nor :class:`torch.nn.utils.prune.BasePruningMethod`, or + if ``amount`` is none of ``int``, ``float`` and ``Callable``. """ self._use_global_unstructured = use_global_unstructured @@ -382,6 +390,11 @@ class ModelPruning(Callback): """ This function is responsible of sanitizing ``parameters_to_prune`` and ``parameter_names``. If ``parameters_to_prune is None``, it will be generated with all parameters of the model. + + Raises: + MisconfigurationException: + If ``parameters_to_prune`` doesn't exist in the model, or + if ``parameters_to_prune`` is neither a list of tuple nor ``None``. """ parameters = parameter_names or ModelPruning.PARAMETER_NAMES