From 0a6582646280415cf2bc2ac8c79b1939fd700413 Mon Sep 17 00:00:00 2001 From: Jeff Yang Date: Wed, 22 Jul 2020 20:28:24 +0630 Subject: [PATCH] metrics: add BLEU (#2535) * metrics: added bleu score and test bleu * metrics: fixed type hints in bleu * bleu score moved to metrics/functional/nlp.py * refactor with torch.Tensor * Update test_sequence.py * refactor as Borda requests and nltk==3.2 * locked nltk==3.3 * nltk>=3.3, parametrized smooth argument for test * fix bleu_score example * added class BLEUScore metrics and test * added class BLEUScore metrics and test * update CHANGELOG * refactor with torchtext * torchtext changed to optional import * fix E501 line too long * add else: in optional import * remove pragma: no-cover * constants changed to CAPITALS * remove class in tests * List -> Sequence, conda -> pip, cast with tensor * add torchtext in test.txt * remove torchtext from test.txt * bump torchtext to 0.5.0 * bump torchtext to 0.5.0 * Apply suggestions from code review * ignore bleu score in doctest, renamed to nlp.py * back to implementation with torch * remove --ignore in CI test, proper reference format * apply justus comment Co-authored-by: Jirka Borovec --- CHANGELOG.md | 1 + docs/source/metrics.rst | 34 ++++--- environment.yml | 1 + pytorch_lightning/metrics/__init__.py | 48 +++++----- .../metrics/functional/__init__.py | 3 +- pytorch_lightning/metrics/functional/nlp.py | 92 +++++++++++++++++++ pytorch_lightning/metrics/nlp.py | 46 ++++++++++ requirements/extra.txt | 2 +- requirements/test.txt | 1 + tests/metrics/functional/test_nlp.py | 66 +++++++++++++ tests/metrics/test_nlp.py | 29 ++++++ 11 files changed, 287 insertions(+), 36 deletions(-) create mode 100644 pytorch_lightning/metrics/functional/nlp.py create mode 100644 pytorch_lightning/metrics/nlp.py create mode 100644 tests/metrics/functional/test_nlp.py create mode 100644 tests/metrics/test_nlp.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 6a1e23adaa..2b4c0f0dc1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added BLEU metrics ([#2535](https://github.com/PyTorchLightning/pytorch-lightning/pull/2535)) ### Changed diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index a62ba72297..860d2fd5c7 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -28,7 +28,7 @@ Example:: .. warning:: The metrics package is still in development! If we're missing a metric or you find a mistake, please send a PR! - to a few metrics. Please feel free to create an issue/PR if you have a proposed metric or have found a bug. + to a few metrics. Please feel free to create an issue/PR if you have a proposed metric or have found a bug. ---------------- @@ -73,7 +73,7 @@ Here's an example showing how to implement a NumpyMetric class RMSE(NumpyMetric): def forward(self, x, y): return np.sqrt(np.mean(np.power(x-y, 2.0))) - + .. autoclass:: pytorch_lightning.metrics.metric.NumpyMetric :noindex: @@ -138,6 +138,12 @@ AUROC .. autoclass:: pytorch_lightning.metrics.classification.AUROC :noindex: +BLEUScore +^^^^^^^^^ + +.. autoclass:: pytorch_lightning.metrics.nlp.BLEUScore + :noindex: + ConfusionMatrix ^^^^^^^^^^^^^^^ @@ -283,6 +289,12 @@ average_precision (F) .. autofunction:: pytorch_lightning.metrics.functional.average_precision :noindex: +bleu_score (F) +^^^^^^^^^^^^^^ + +.. autofunction:: pytorch_lightning.metrics.functional.bleu_score + :noindex: + confusion_matrix (F) ^^^^^^^^^^^^^^^^^^^^ @@ -418,22 +430,22 @@ to_onehot (F) Sklearn interface ----------------- - -Lightning supports `sklearns metrics module `_ -as a backend for calculating metrics. Sklearns metrics are well tested and robust, + +Lightning supports `sklearns metrics module `_ +as a backend for calculating metrics. Sklearns metrics are well tested and robust, but requires conversion between pytorch and numpy thus may slow down your computations. To use the sklearn backend of metrics simply import as .. code-block:: python - + import pytorch_lightning.metrics.sklearns import plm metric = plm.Accuracy(normalize=True) val = metric(pred, target) - -Each converted sklearn metric comes has the same interface as its -original counterpart (e.g. accuracy takes the additional `normalize` keyword). -Like the native Lightning metrics, these converted sklearn metrics also come + +Each converted sklearn metric comes has the same interface as its +original counterpart (e.g. accuracy takes the additional `normalize` keyword). +Like the native Lightning metrics, these converted sklearn metrics also come with built-in distributed (ddp) support. SklearnMetric (sk) @@ -460,7 +472,7 @@ AveragePrecision (sk) .. autofunction:: pytorch_lightning.metrics.sklearns.AveragePrecision :noindex: - + ConfusionMatrix (sk) ^^^^^^^^^^^^^^^^^^^^ diff --git a/environment.yml b/environment.yml index 25d1a0d6bd..817102453b 100644 --- a/environment.yml +++ b/environment.yml @@ -30,6 +30,7 @@ dependencies: - twine==1.13.0 - pillow<7.0.0 - scikit-image + - nltk>=3.3 # Optional - scipy>=0.13.3 diff --git a/pytorch_lightning/metrics/__init__.py b/pytorch_lightning/metrics/__init__.py index 7a65077ebe..2cfbbfa01f 100644 --- a/pytorch_lightning/metrics/__init__.py +++ b/pytorch_lightning/metrics/__init__.py @@ -5,7 +5,7 @@ from pytorch_lightning.metrics.regression import ( MSE, PSNR, RMSE, - RMSLE + RMSLE, ) from pytorch_lightning.metrics.classification import ( Accuracy, @@ -28,30 +28,32 @@ from pytorch_lightning.metrics.sklearns import ( PrecisionRecallCurve, SklearnMetric, ) +from pytorch_lightning.metrics.nlp import BLEUScore __classification_metrics = [ - 'AUC', - 'AUROC', - 'Accuracy', - 'AveragePrecision', - 'ConfusionMatrix', - 'DiceCoefficient', - 'F1', - 'FBeta', - 'MulticlassPrecisionRecall', - 'MulticlassROC', - 'Precision', - 'PrecisionRecall', - 'PrecisionRecallCurve', - 'ROC', - 'Recall', - 'IoU', + "AUC", + "AUROC", + "Accuracy", + "AveragePrecision", + "ConfusionMatrix", + "DiceCoefficient", + "F1", + "FBeta", + "MulticlassPrecisionRecall", + "MulticlassROC", + "Precision", + "PrecisionRecall", + "PrecisionRecallCurve", + "ROC", + "Recall", + "IoU", ] __regression_metrics = [ - 'MAE', - 'MSE', - 'PSNR', - 'RMSE', - 'RMSLE' + "MAE", + "MSE", + "PSNR", + "RMSE", + "RMSLE", ] -__all__ = __regression_metrics + __classification_metrics + ['SklearnMetric'] +__sequence_metrics = ["BLEUScore"] +__all__ = __regression_metrics + __classification_metrics + ["SklearnMetric"] + __sequence_metrics diff --git a/pytorch_lightning/metrics/functional/__init__.py b/pytorch_lightning/metrics/functional/__init__.py index 35cc286b5a..eb92cabf8e 100644 --- a/pytorch_lightning/metrics/functional/__init__.py +++ b/pytorch_lightning/metrics/functional/__init__.py @@ -25,5 +25,6 @@ from pytorch_lightning.metrics.functional.regression import ( mse, psnr, rmse, - rmsle + rmsle, ) +from pytorch_lightning.metrics.functional.nlp import bleu_score diff --git a/pytorch_lightning/metrics/functional/nlp.py b/pytorch_lightning/metrics/functional/nlp.py new file mode 100644 index 0000000000..e1bb86ab19 --- /dev/null +++ b/pytorch_lightning/metrics/functional/nlp.py @@ -0,0 +1,92 @@ +# referenced from +# Library Name: torchtext +# Authors: torchtext authors and @sluks +# Date: 2020-07-18 +# Link: https://pytorch.org/text/_modules/torchtext/data/metrics.html#bleu_score +from typing import Sequence, List +from collections import Counter + +import torch + + +def _count_ngram(ngram_input_list: List[str], n_gram: int) -> Counter: + """Counting how many times each word appears in a given text with ngram + + Args: + ngram_input_list: A list of translated text or reference texts + n_gram: gram value ranged 1 to 4 + + Return: + ngram_counter: a collections.Counter object of ngram + """ + + ngram_counter = Counter() + + for i in range(1, n_gram + 1): + for j in range(len(ngram_input_list) - i + 1): + ngram_key = tuple(ngram_input_list[j : i + j]) + ngram_counter[ngram_key] += 1 + + return ngram_counter + + +def bleu_score( + translate_corpus: Sequence[str], reference_corpus: Sequence[str], n_gram: int = 4, smooth: bool = False +) -> torch.Tensor: + """Calculate BLEU score of machine translated text with one or more references. + + Args: + translate_corpus: An iterable of machine translated corpus + reference_corpus: An iterable of iterables of reference corpus + n_gram: Gram value ranged from 1 to 4 (Default 4) + smooth: Whether or not to apply smoothing – Lin et al. 2004 + + Return: + A Tensor with BLEU Score + + Example: + + >>> translate_corpus = ['the cat is on the mat'.split()] + >>> reference_corpus = [['there is a cat on the mat'.split(), 'a cat is on the mat'.split()]] + >>> bleu_score(translate_corpus, reference_corpus) + tensor(0.7598) + """ + + assert len(translate_corpus) == len(reference_corpus) + numerator = torch.zeros(n_gram) + denominator = torch.zeros(n_gram) + precision_scores = torch.zeros(n_gram) + c = 0.0 + r = 0.0 + for (translation, references) in zip(translate_corpus, reference_corpus): + c += len(translation) + ref_len_list = [len(ref) for ref in references] + ref_len_diff = [abs(len(translation) - x) for x in ref_len_list] + r += ref_len_list[ref_len_diff.index(min(ref_len_diff))] + translation_counter = _count_ngram(translation, n_gram) + reference_counter = Counter() + for ref in references: + reference_counter |= _count_ngram(ref, n_gram) + + ngram_counter_clip = translation_counter & reference_counter + for counter_clip in ngram_counter_clip: + numerator[len(counter_clip) - 1] += ngram_counter_clip[counter_clip] + + for counter in translation_counter: + denominator[len(counter) - 1] += translation_counter[counter] + + trans_len = torch.tensor(c) + ref_len = torch.tensor(r) + if min(numerator) == 0.0: + return torch.tensor(0.0) + + if smooth: + precision_scores = torch.add(numerator, torch.ones(n_gram)) / torch.add(denominator, torch.ones(n_gram)) + else: + precision_scores = numerator / denominator + log_precision_scores = torch.tensor([1.0 / n_gram] * n_gram) * torch.log(precision_scores) + geometric_mean = torch.exp(torch.sum(log_precision_scores)) + brevity_penalty = torch.tensor(1.0) if c > r else torch.exp(1 - (ref_len / trans_len)) + bleu = brevity_penalty * geometric_mean + + return bleu diff --git a/pytorch_lightning/metrics/nlp.py b/pytorch_lightning/metrics/nlp.py new file mode 100644 index 0000000000..a4284ada3f --- /dev/null +++ b/pytorch_lightning/metrics/nlp.py @@ -0,0 +1,46 @@ +import torch + +from pytorch_lightning.metrics.functional.nlp import bleu_score +from pytorch_lightning.metrics.metric import Metric + + +class BLEUScore(Metric): + """ + Calculate BLEU score of machine translated text with one or more references. + + Example: + + >>> translate_corpus = ['the cat is on the mat'.split()] + >>> reference_corpus = [['there is a cat on the mat'.split(), 'a cat is on the mat'.split()]] + >>> metric = BLEUScore() + >>> metric(translate_corpus, reference_corpus) + tensor(0.7598) + """ + + def __init__(self, n_gram: int = 4, smooth: bool = False): + """ + Args: + n_gram: Gram value ranged from 1 to 4 (Default 4) + smooth: Whether or not to apply smoothing – Lin et al. 2004 + """ + super().__init__(name="bleu") + self.n_gram = n_gram + self.smooth = smooth + + def forward(self, translate_corpus: list, reference_corpus: list) -> torch.Tensor: + """ + Actual metric computation + + Args: + translate_corpus: An iterable of machine translated corpus + reference_corpus: An iterable of iterables of reference corpus + + Return: + torch.Tensor: BLEU Score + """ + return bleu_score( + translate_corpus=translate_corpus, + reference_corpus=reference_corpus, + n_gram=self.n_gram, + smooth=self.smooth, + ).to(self.device, self.dtype) diff --git a/requirements/extra.txt b/requirements/extra.txt index 71a9ea9ccc..e245a9512d 100644 --- a/requirements/extra.txt +++ b/requirements/extra.txt @@ -11,4 +11,4 @@ horovod>=0.19.1 omegaconf>=2.0.0 # scipy>=0.13.3 scikit-learn>=0.20.0 -torchtext>=0.3.1 \ No newline at end of file +torchtext>=0.3.1 diff --git a/requirements/test.txt b/requirements/test.txt index bd31596c21..8491cf8c83 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -12,3 +12,4 @@ black==19.10b0 pre-commit>=1.0 cloudpickle>=1.2 +nltk>=3.3 diff --git a/tests/metrics/functional/test_nlp.py b/tests/metrics/functional/test_nlp.py new file mode 100644 index 0000000000..2f1647270e --- /dev/null +++ b/tests/metrics/functional/test_nlp.py @@ -0,0 +1,66 @@ +import pytest +import torch +from nltk.translate.bleu_score import SmoothingFunction, corpus_bleu, sentence_bleu + +from pytorch_lightning.metrics.functional.nlp import bleu_score + +# example taken from +# https://www.nltk.org/api/nltk.translate.html?highlight=bleu%20score#nltk.translate.bleu_score.sentence_bleu +HYPOTHESIS1 = tuple( + "It is a guide to action which ensures that the military always obeys the commands of the party".split() +) +REFERENCE1 = tuple("It is a guide to action that ensures that the military will forever heed Party commands".split()) +REFERENCE2 = tuple( + "It is a guiding principle which makes the military forces always being under the command of the Party".split() +) +REFERENCE3 = tuple("It is the practical guide for the army always to heed the directions of the party".split()) + + +# example taken from +# https://www.nltk.org/api/nltk.translate.html?highlight=bleu%20score#nltk.translate.bleu_score.corpus_bleu +HYP1 = "It is a guide to action which ensures that the military always obeys the commands of the party".split() +HYP2 = "he read the book because he was interested in world history".split() + +REF1A = "It is a guide to action that ensures that the military will forever heed Party commands".split() +REF1B = "It is a guiding principle which makes the military force always being under the command of the Party".split() +REF1C = "It is the practical guide for the army always to heed the directions of the party".split() +REF2A = "he was interested in world history because he read the book".split() + +LIST_OF_REFERENCES = [[REF1A, REF1B, REF1C], [REF2A]] +HYPOTHESES = [HYP1, HYP2] + +# https://www.nltk.org/api/nltk.translate.html?highlight=bleu%20score#nltk.translate.bleu_score.SmoothingFunction +smooth_func = SmoothingFunction().method2 + + +@pytest.mark.parametrize( + ["weights", "n_gram", "smooth_func", "smooth"], + [ + pytest.param([1], 1, None, False), + pytest.param([0.5, 0.5], 2, smooth_func, True), + pytest.param([0.333333, 0.333333, 0.333333], 3, None, False), + pytest.param([0.25, 0.25, 0.25, 0.25], 4, smooth_func, True), + ], +) +def test_bleu_score(weights, n_gram, smooth_func, smooth): + nltk_output = sentence_bleu( + [REFERENCE1, REFERENCE2, REFERENCE3], HYPOTHESIS1, weights=weights, smoothing_function=smooth_func + ) + pl_output = bleu_score([HYPOTHESIS1], [[REFERENCE1, REFERENCE2, REFERENCE3]], n_gram=n_gram, smooth=smooth) + assert torch.allclose(pl_output, torch.tensor(nltk_output)) + + nltk_output = corpus_bleu(LIST_OF_REFERENCES, HYPOTHESES, weights=weights, smoothing_function=smooth_func) + pl_output = bleu_score(HYPOTHESES, LIST_OF_REFERENCES, n_gram=n_gram, smooth=smooth) + assert torch.allclose(pl_output, torch.tensor(nltk_output)) + + +def test_bleu_empty(): + hyp = [[]] + ref = [[[]]] + assert bleu_score(hyp, ref) == torch.tensor(0.0) + + +def test_no_4_gram(): + hyps = [["My", "full", "pytorch-lightning"]] + refs = [[["My", "full", "pytorch-lightning", "test"], ["Completely", "Different"]]] + assert bleu_score(hyps, refs) == torch.tensor(0.0) diff --git a/tests/metrics/test_nlp.py b/tests/metrics/test_nlp.py new file mode 100644 index 0000000000..e58b1f3398 --- /dev/null +++ b/tests/metrics/test_nlp.py @@ -0,0 +1,29 @@ +import pytest +import torch + +from pytorch_lightning.metrics.nlp import BLEUScore + +# example taken from +# https://www.nltk.org/api/nltk.translate.html?highlight=bleu%20score#nltk.translate.bleu_score.corpus_bleu +HYP1 = "It is a guide to action which ensures that the military always obeys the commands of the party".split() +HYP2 = "he read the book because he was interested in world history".split() + +REF1A = "It is a guide to action that ensures that the military will forever heed Party commands".split() +REF1B = "It is a guiding principle which makes the military forces always being under the command of the Party".split() +REF1C = "It is the practical guide for the army always to heed the directions of the party".split() +REF2A = "he was interested in world history because he read the book".split() + +LIST_OF_REFERENCES = [[REF1A, REF1B, REF1C], [REF2A]] +HYPOTHESES = [HYP1, HYP2] + + +@pytest.mark.parametrize( + ["n_gram", "smooth"], + [pytest.param(1, True), pytest.param(2, False), pytest.param(3, True), pytest.param(4, False),], +) +def test_bleu(smooth, n_gram): + bleu = BLEUScore(n_gram=n_gram, smooth=smooth) + assert bleu.name == "bleu" + + pl_output = bleu(HYPOTHESES, LIST_OF_REFERENCES) + assert isinstance(pl_output, torch.Tensor)