metrics: add BLEU (#2535)

* metrics: added bleu score and test bleu

* metrics: fixed type hints in bleu

* bleu score moved to metrics/functional/nlp.py

* refactor with torch.Tensor

* Update test_sequence.py

* refactor as Borda requests and nltk==3.2

* locked nltk==3.3

* nltk>=3.3, parametrized smooth argument for test

* fix bleu_score example

* added class BLEUScore metrics and test

* added class BLEUScore metrics and test

* update CHANGELOG

* refactor with torchtext

* torchtext changed to optional import

* fix E501 line too long

* add else: in optional import

* remove pragma: no-cover

* constants changed to CAPITALS

* remove class in tests

* List -> Sequence, conda -> pip, cast with tensor

* add torchtext in test.txt

* remove torchtext from test.txt

* bump torchtext to 0.5.0

* bump torchtext to 0.5.0

* Apply suggestions from code review

* ignore bleu score in doctest, renamed to nlp.py

* back to implementation with torch

* remove --ignore in CI test, proper reference format

* apply justus comment

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
This commit is contained in:
Jeff Yang 2020-07-22 20:28:24 +06:30 committed by GitHub
parent 5025be7860
commit 0a65826462
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 287 additions and 36 deletions

View File

@ -9,6 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added
- Added BLEU metrics ([#2535](https://github.com/PyTorchLightning/pytorch-lightning/pull/2535))
### Changed

View File

@ -28,7 +28,7 @@ Example::
.. warning::
The metrics package is still in development! If we're missing a metric or you find a mistake, please send a PR!
to a few metrics. Please feel free to create an issue/PR if you have a proposed metric or have found a bug.
to a few metrics. Please feel free to create an issue/PR if you have a proposed metric or have found a bug.
----------------
@ -73,7 +73,7 @@ Here's an example showing how to implement a NumpyMetric
class RMSE(NumpyMetric):
def forward(self, x, y):
return np.sqrt(np.mean(np.power(x-y, 2.0)))
.. autoclass:: pytorch_lightning.metrics.metric.NumpyMetric
:noindex:
@ -138,6 +138,12 @@ AUROC
.. autoclass:: pytorch_lightning.metrics.classification.AUROC
:noindex:
BLEUScore
^^^^^^^^^
.. autoclass:: pytorch_lightning.metrics.nlp.BLEUScore
:noindex:
ConfusionMatrix
^^^^^^^^^^^^^^^
@ -283,6 +289,12 @@ average_precision (F)
.. autofunction:: pytorch_lightning.metrics.functional.average_precision
:noindex:
bleu_score (F)
^^^^^^^^^^^^^^
.. autofunction:: pytorch_lightning.metrics.functional.bleu_score
:noindex:
confusion_matrix (F)
^^^^^^^^^^^^^^^^^^^^
@ -418,22 +430,22 @@ to_onehot (F)
Sklearn interface
-----------------
Lightning supports `sklearns metrics module <https://scikit-learn.org/stable/modules/classes.html#module-sklearn.metrics>`_
as a backend for calculating metrics. Sklearns metrics are well tested and robust,
Lightning supports `sklearns metrics module <https://scikit-learn.org/stable/modules/classes.html#module-sklearn.metrics>`_
as a backend for calculating metrics. Sklearns metrics are well tested and robust,
but requires conversion between pytorch and numpy thus may slow down your computations.
To use the sklearn backend of metrics simply import as
.. code-block:: python
import pytorch_lightning.metrics.sklearns import plm
metric = plm.Accuracy(normalize=True)
val = metric(pred, target)
Each converted sklearn metric comes has the same interface as its
original counterpart (e.g. accuracy takes the additional `normalize` keyword).
Like the native Lightning metrics, these converted sklearn metrics also come
Each converted sklearn metric comes has the same interface as its
original counterpart (e.g. accuracy takes the additional `normalize` keyword).
Like the native Lightning metrics, these converted sklearn metrics also come
with built-in distributed (ddp) support.
SklearnMetric (sk)
@ -460,7 +472,7 @@ AveragePrecision (sk)
.. autofunction:: pytorch_lightning.metrics.sklearns.AveragePrecision
:noindex:
ConfusionMatrix (sk)
^^^^^^^^^^^^^^^^^^^^

View File

@ -30,6 +30,7 @@ dependencies:
- twine==1.13.0
- pillow<7.0.0
- scikit-image
- nltk>=3.3
# Optional
- scipy>=0.13.3

View File

@ -5,7 +5,7 @@ from pytorch_lightning.metrics.regression import (
MSE,
PSNR,
RMSE,
RMSLE
RMSLE,
)
from pytorch_lightning.metrics.classification import (
Accuracy,
@ -28,30 +28,32 @@ from pytorch_lightning.metrics.sklearns import (
PrecisionRecallCurve,
SklearnMetric,
)
from pytorch_lightning.metrics.nlp import BLEUScore
__classification_metrics = [
'AUC',
'AUROC',
'Accuracy',
'AveragePrecision',
'ConfusionMatrix',
'DiceCoefficient',
'F1',
'FBeta',
'MulticlassPrecisionRecall',
'MulticlassROC',
'Precision',
'PrecisionRecall',
'PrecisionRecallCurve',
'ROC',
'Recall',
'IoU',
"AUC",
"AUROC",
"Accuracy",
"AveragePrecision",
"ConfusionMatrix",
"DiceCoefficient",
"F1",
"FBeta",
"MulticlassPrecisionRecall",
"MulticlassROC",
"Precision",
"PrecisionRecall",
"PrecisionRecallCurve",
"ROC",
"Recall",
"IoU",
]
__regression_metrics = [
'MAE',
'MSE',
'PSNR',
'RMSE',
'RMSLE'
"MAE",
"MSE",
"PSNR",
"RMSE",
"RMSLE",
]
__all__ = __regression_metrics + __classification_metrics + ['SklearnMetric']
__sequence_metrics = ["BLEUScore"]
__all__ = __regression_metrics + __classification_metrics + ["SklearnMetric"] + __sequence_metrics

View File

@ -25,5 +25,6 @@ from pytorch_lightning.metrics.functional.regression import (
mse,
psnr,
rmse,
rmsle
rmsle,
)
from pytorch_lightning.metrics.functional.nlp import bleu_score

View File

@ -0,0 +1,92 @@
# referenced from
# Library Name: torchtext
# Authors: torchtext authors and @sluks
# Date: 2020-07-18
# Link: https://pytorch.org/text/_modules/torchtext/data/metrics.html#bleu_score
from typing import Sequence, List
from collections import Counter
import torch
def _count_ngram(ngram_input_list: List[str], n_gram: int) -> Counter:
"""Counting how many times each word appears in a given text with ngram
Args:
ngram_input_list: A list of translated text or reference texts
n_gram: gram value ranged 1 to 4
Return:
ngram_counter: a collections.Counter object of ngram
"""
ngram_counter = Counter()
for i in range(1, n_gram + 1):
for j in range(len(ngram_input_list) - i + 1):
ngram_key = tuple(ngram_input_list[j : i + j])
ngram_counter[ngram_key] += 1
return ngram_counter
def bleu_score(
translate_corpus: Sequence[str], reference_corpus: Sequence[str], n_gram: int = 4, smooth: bool = False
) -> torch.Tensor:
"""Calculate BLEU score of machine translated text with one or more references.
Args:
translate_corpus: An iterable of machine translated corpus
reference_corpus: An iterable of iterables of reference corpus
n_gram: Gram value ranged from 1 to 4 (Default 4)
smooth: Whether or not to apply smoothing Lin et al. 2004
Return:
A Tensor with BLEU Score
Example:
>>> translate_corpus = ['the cat is on the mat'.split()]
>>> reference_corpus = [['there is a cat on the mat'.split(), 'a cat is on the mat'.split()]]
>>> bleu_score(translate_corpus, reference_corpus)
tensor(0.7598)
"""
assert len(translate_corpus) == len(reference_corpus)
numerator = torch.zeros(n_gram)
denominator = torch.zeros(n_gram)
precision_scores = torch.zeros(n_gram)
c = 0.0
r = 0.0
for (translation, references) in zip(translate_corpus, reference_corpus):
c += len(translation)
ref_len_list = [len(ref) for ref in references]
ref_len_diff = [abs(len(translation) - x) for x in ref_len_list]
r += ref_len_list[ref_len_diff.index(min(ref_len_diff))]
translation_counter = _count_ngram(translation, n_gram)
reference_counter = Counter()
for ref in references:
reference_counter |= _count_ngram(ref, n_gram)
ngram_counter_clip = translation_counter & reference_counter
for counter_clip in ngram_counter_clip:
numerator[len(counter_clip) - 1] += ngram_counter_clip[counter_clip]
for counter in translation_counter:
denominator[len(counter) - 1] += translation_counter[counter]
trans_len = torch.tensor(c)
ref_len = torch.tensor(r)
if min(numerator) == 0.0:
return torch.tensor(0.0)
if smooth:
precision_scores = torch.add(numerator, torch.ones(n_gram)) / torch.add(denominator, torch.ones(n_gram))
else:
precision_scores = numerator / denominator
log_precision_scores = torch.tensor([1.0 / n_gram] * n_gram) * torch.log(precision_scores)
geometric_mean = torch.exp(torch.sum(log_precision_scores))
brevity_penalty = torch.tensor(1.0) if c > r else torch.exp(1 - (ref_len / trans_len))
bleu = brevity_penalty * geometric_mean
return bleu

View File

@ -0,0 +1,46 @@
import torch
from pytorch_lightning.metrics.functional.nlp import bleu_score
from pytorch_lightning.metrics.metric import Metric
class BLEUScore(Metric):
"""
Calculate BLEU score of machine translated text with one or more references.
Example:
>>> translate_corpus = ['the cat is on the mat'.split()]
>>> reference_corpus = [['there is a cat on the mat'.split(), 'a cat is on the mat'.split()]]
>>> metric = BLEUScore()
>>> metric(translate_corpus, reference_corpus)
tensor(0.7598)
"""
def __init__(self, n_gram: int = 4, smooth: bool = False):
"""
Args:
n_gram: Gram value ranged from 1 to 4 (Default 4)
smooth: Whether or not to apply smoothing Lin et al. 2004
"""
super().__init__(name="bleu")
self.n_gram = n_gram
self.smooth = smooth
def forward(self, translate_corpus: list, reference_corpus: list) -> torch.Tensor:
"""
Actual metric computation
Args:
translate_corpus: An iterable of machine translated corpus
reference_corpus: An iterable of iterables of reference corpus
Return:
torch.Tensor: BLEU Score
"""
return bleu_score(
translate_corpus=translate_corpus,
reference_corpus=reference_corpus,
n_gram=self.n_gram,
smooth=self.smooth,
).to(self.device, self.dtype)

View File

@ -11,4 +11,4 @@ horovod>=0.19.1
omegaconf>=2.0.0
# scipy>=0.13.3
scikit-learn>=0.20.0
torchtext>=0.3.1
torchtext>=0.3.1

View File

@ -12,3 +12,4 @@ black==19.10b0
pre-commit>=1.0
cloudpickle>=1.2
nltk>=3.3

View File

@ -0,0 +1,66 @@
import pytest
import torch
from nltk.translate.bleu_score import SmoothingFunction, corpus_bleu, sentence_bleu
from pytorch_lightning.metrics.functional.nlp import bleu_score
# example taken from
# https://www.nltk.org/api/nltk.translate.html?highlight=bleu%20score#nltk.translate.bleu_score.sentence_bleu
HYPOTHESIS1 = tuple(
"It is a guide to action which ensures that the military always obeys the commands of the party".split()
)
REFERENCE1 = tuple("It is a guide to action that ensures that the military will forever heed Party commands".split())
REFERENCE2 = tuple(
"It is a guiding principle which makes the military forces always being under the command of the Party".split()
)
REFERENCE3 = tuple("It is the practical guide for the army always to heed the directions of the party".split())
# example taken from
# https://www.nltk.org/api/nltk.translate.html?highlight=bleu%20score#nltk.translate.bleu_score.corpus_bleu
HYP1 = "It is a guide to action which ensures that the military always obeys the commands of the party".split()
HYP2 = "he read the book because he was interested in world history".split()
REF1A = "It is a guide to action that ensures that the military will forever heed Party commands".split()
REF1B = "It is a guiding principle which makes the military force always being under the command of the Party".split()
REF1C = "It is the practical guide for the army always to heed the directions of the party".split()
REF2A = "he was interested in world history because he read the book".split()
LIST_OF_REFERENCES = [[REF1A, REF1B, REF1C], [REF2A]]
HYPOTHESES = [HYP1, HYP2]
# https://www.nltk.org/api/nltk.translate.html?highlight=bleu%20score#nltk.translate.bleu_score.SmoothingFunction
smooth_func = SmoothingFunction().method2
@pytest.mark.parametrize(
["weights", "n_gram", "smooth_func", "smooth"],
[
pytest.param([1], 1, None, False),
pytest.param([0.5, 0.5], 2, smooth_func, True),
pytest.param([0.333333, 0.333333, 0.333333], 3, None, False),
pytest.param([0.25, 0.25, 0.25, 0.25], 4, smooth_func, True),
],
)
def test_bleu_score(weights, n_gram, smooth_func, smooth):
nltk_output = sentence_bleu(
[REFERENCE1, REFERENCE2, REFERENCE3], HYPOTHESIS1, weights=weights, smoothing_function=smooth_func
)
pl_output = bleu_score([HYPOTHESIS1], [[REFERENCE1, REFERENCE2, REFERENCE3]], n_gram=n_gram, smooth=smooth)
assert torch.allclose(pl_output, torch.tensor(nltk_output))
nltk_output = corpus_bleu(LIST_OF_REFERENCES, HYPOTHESES, weights=weights, smoothing_function=smooth_func)
pl_output = bleu_score(HYPOTHESES, LIST_OF_REFERENCES, n_gram=n_gram, smooth=smooth)
assert torch.allclose(pl_output, torch.tensor(nltk_output))
def test_bleu_empty():
hyp = [[]]
ref = [[[]]]
assert bleu_score(hyp, ref) == torch.tensor(0.0)
def test_no_4_gram():
hyps = [["My", "full", "pytorch-lightning"]]
refs = [[["My", "full", "pytorch-lightning", "test"], ["Completely", "Different"]]]
assert bleu_score(hyps, refs) == torch.tensor(0.0)

29
tests/metrics/test_nlp.py Normal file
View File

@ -0,0 +1,29 @@
import pytest
import torch
from pytorch_lightning.metrics.nlp import BLEUScore
# example taken from
# https://www.nltk.org/api/nltk.translate.html?highlight=bleu%20score#nltk.translate.bleu_score.corpus_bleu
HYP1 = "It is a guide to action which ensures that the military always obeys the commands of the party".split()
HYP2 = "he read the book because he was interested in world history".split()
REF1A = "It is a guide to action that ensures that the military will forever heed Party commands".split()
REF1B = "It is a guiding principle which makes the military forces always being under the command of the Party".split()
REF1C = "It is the practical guide for the army always to heed the directions of the party".split()
REF2A = "he was interested in world history because he read the book".split()
LIST_OF_REFERENCES = [[REF1A, REF1B, REF1C], [REF2A]]
HYPOTHESES = [HYP1, HYP2]
@pytest.mark.parametrize(
["n_gram", "smooth"],
[pytest.param(1, True), pytest.param(2, False), pytest.param(3, True), pytest.param(4, False),],
)
def test_bleu(smooth, n_gram):
bleu = BLEUScore(n_gram=n_gram, smooth=smooth)
assert bleu.name == "bleu"
pl_output = bleu(HYPOTHESES, LIST_OF_REFERENCES)
assert isinstance(pl_output, torch.Tensor)