document exceptions for metrics/regression (#6202)

Co-authored-by: Akihiro Nitta <nitta@akihironitta.com>
Co-authored-by: Prajakta Phadke <pphadke@iu.edu>
Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
This commit is contained in:
prajakta0111 2021-02-28 09:52:26 -05:00 committed by GitHub
parent 111d9c7267
commit 15c477e9fc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 15 additions and 1 deletions

View File

@ -59,6 +59,10 @@ class ExplainedVariance(Metric):
process_group: process_group:
Specify the process group on which synchronization is called. default: None (which selects the entire world) Specify the process group on which synchronization is called. default: None (which selects the entire world)
Raises:
ValueError:
If ``multioutput`` is not one of ``"raw_values"``, ``"uniform_average"`` or ``"variance_weighted"``.
Example: Example:
>>> from pytorch_lightning.metrics import ExplainedVariance >>> from pytorch_lightning.metrics import ExplainedVariance

View File

@ -51,6 +51,10 @@ class PSNR(Metric):
process_group: process_group:
Specify the process group on which synchronization is called. default: None (which selects the entire world) Specify the process group on which synchronization is called. default: None (which selects the entire world)
Raises:
ValueError:
If ``dim`` is not ``None`` and ``data_range`` is not given.
Example: Example:
>>> from pytorch_lightning.metrics import PSNR >>> from pytorch_lightning.metrics import PSNR

View File

@ -66,6 +66,12 @@ class R2Score(Metric):
process_group: process_group:
Specify the process group on which synchronization is called. default: None (which selects the entire world) Specify the process group on which synchronization is called. default: None (which selects the entire world)
Raises:
ValueError:
If ``adjusted`` parameter is not an integer larger or equal to 0.
ValueError:
If ``multioutput`` is not one of ``"raw_values"``, ``"uniform_average"`` or ``"variance_weighted"``.
Example: Example:
>>> from pytorch_lightning.metrics import R2Score >>> from pytorch_lightning.metrics import R2Score
@ -102,7 +108,7 @@ class R2Score(Metric):
self.num_outputs = num_outputs self.num_outputs = num_outputs
if adjusted < 0 or not isinstance(adjusted, int): if adjusted < 0 or not isinstance(adjusted, int):
raise ValueError('`adjusted` parameter should be an integer larger or' ' equal to 0.') raise ValueError('`adjusted` parameter should be an integer larger or equal to 0.')
self.adjusted = adjusted self.adjusted = adjusted
allowed_multioutput = ('raw_values', 'uniform_average', 'variance_weighted') allowed_multioutput = ('raw_values', 'uniform_average', 'variance_weighted')