From bbcdb1a3e22962dbfa0c8ee82b84beed172ee809 Mon Sep 17 00:00:00 2001 From: Giovanni Campagna Date: Wed, 15 Jan 2020 22:27:58 -0800 Subject: [PATCH] Add a BERT-based numericalizer Compatible with wordpiece tokenization used by BERT, and also with Almond's unique requirements around how tokens are split --- Pipfile | 1 + Pipfile.lock | 149 ++++++++++++-- decanlp/data/example.py | 32 +-- decanlp/data/numericalizer/__init__.py | 33 ++++ decanlp/data/numericalizer/bert.py | 185 ++++++++++++++++++ decanlp/data/numericalizer/decoder_vocab.py | 67 +++++++ .../numericalizer/masked_bert_tokenizer.py | 122 ++++++++++++ .../data/numericalizer/sequential_field.py | 38 ++++ .../simple.py} | 81 ++------ decanlp/data/{ => numericalizer}/vocab.py | 0 .../multitask_question_answering_network.py | 4 +- decanlp/tasks/almond/__init__.py | 41 +++- decanlp/tasks/base.py | 7 +- decanlp/tasks/generic.py | 4 +- decanlp/train.py | 4 +- 15 files changed, 662 insertions(+), 106 deletions(-) create mode 100644 decanlp/data/numericalizer/__init__.py create mode 100644 decanlp/data/numericalizer/bert.py create mode 100644 decanlp/data/numericalizer/decoder_vocab.py create mode 100644 decanlp/data/numericalizer/masked_bert_tokenizer.py create mode 100644 decanlp/data/numericalizer/sequential_field.py rename decanlp/data/{numericalizer.py => numericalizer/simple.py} (73%) rename decanlp/data/{ => numericalizer}/vocab.py (100%) diff --git a/Pipfile b/Pipfile index c8660806..a124400d 100644 --- a/Pipfile +++ b/Pipfile @@ -20,6 +20,7 @@ tabulate = "*" tensorboardX = "*" Babel = "*" requests = "*" +transformers = "*" [requires] python_version = "3.7" diff --git a/Pipfile.lock b/Pipfile.lock index 736674fb..cdcafb58 100644 --- a/Pipfile.lock +++ b/Pipfile.lock @@ -1,7 +1,7 @@ { "_meta": { "hash": { - "sha256": "6bc781177e1bd4b396a93b6dbab0bf503d2d2d37b9dcecdbf6a667a47e4c8cc8" + "sha256": "a49da22eb026763b9c0581905f6624ca97e443e40bb691e3d5531f58dbb199f7" }, "pipfile-spec": 6, "requires": { @@ -32,6 +32,20 @@ ], "version": "==4.8.2" }, + "boto3": { + "hashes": [ + "sha256:916efdecda945daba08dca9032373440db681e53e6b46e7d0cae104ffd5ac7ca", + "sha256:925899699b81977b75d4036fc478209916f4547f8683e5c7b1ad678317a6652e" + ], + "version": "==1.11.3" + }, + "botocore": { + "hashes": [ + "sha256:799562d2af023f49518676c6b72ee07325da027513406b4b9d8e5b74ecea8257", + "sha256:80e6522c4bb8d98dab9fd77708c6722bee775c7b85a77326ef44543062a17462" + ], + "version": "==1.14.3" + }, "certifi": { "hashes": [ "sha256:017c25db2a153ce562900032d5bc68e9f191e44e9a0f762f373977de9df1fbb3", @@ -46,12 +60,27 @@ ], "version": "==3.0.4" }, + "click": { + "hashes": [ + "sha256:2335065e6395b9e67ca716de5f7526736bfa6ceead690adf616d925bdc622b13", + "sha256:5b94b49521f6456670fdb30cd82a4eca9412788a93fa6dd6df72c94d5a8ff2d7" + ], + "version": "==7.0" + }, "docopt": { "hashes": [ "sha256:49b3a825280bd66b3aa83585ef59c4a8c82f2c8a522dbe754a8bc8d08c85c491" ], "version": "==0.6.2" }, + "docutils": { + "hashes": [ + "sha256:6c4f696463b79f1fb8ba0c594b63840ebd41f059e92b31957c46b74a4599b6d0", + "sha256:9e4d7ecfc600058e07ba661411a2b7de2fd0fafa17d1a7f7361cd47b1175c827", + "sha256:a2aeea129088da402665e92e0b25b04b073c04b2dce4ab65caaa38b7ce2e1a99" + ], + "version": "==0.15.2" + }, "et-xmlfile": { "hashes": [ "sha256:614d9722d572f6246302c4491846d2c393c199cfa4edc9af593437691683335b" @@ -72,6 +101,20 @@ ], "version": "==1.4.1" }, + "jmespath": { + "hashes": [ + "sha256:3720a4b1bd659dd2eecad0666459b9788813e032b83e7ba58578e48254e0a0e6", + "sha256:bde2aef6f44302dfb30320115b17d030798de8c4110e28d5cf6cf91a7a31074c" + ], + "version": "==0.9.4" + }, + "joblib": { + "hashes": [ + "sha256:0630eea4f5664c463f23fbf5dcfc54a2bc6168902719fa8e19daf033022786c8", + "sha256:bdb4fd9b72915ffb49fde2229ce482dd7ae79d842ed8c2b4c932441495af1403" + ], + "version": "==0.14.1" + }, "nltk": { "hashes": [ "sha256:2661f9971d983db314bbebd51ba770811a362c6597fd0f303bb1d3beadcb4834" @@ -194,6 +237,32 @@ "index": "pypi", "version": "==0.5.3" }, + "regex": { + "hashes": [ + "sha256:07b39bf943d3d2fe63d46281d8504f8df0ff3fe4c57e13d1656737950e53e525", + "sha256:0932941cdfb3afcbc26cc3bcf7c3f3d73d5a9b9c56955d432dbf8bbc147d4c5b", + "sha256:0e182d2f097ea8549a249040922fa2b92ae28be4be4895933e369a525ba36576", + "sha256:10671601ee06cf4dc1bc0b4805309040bb34c9af423c12c379c83d7895622bb5", + "sha256:23e2c2c0ff50f44877f64780b815b8fd2e003cda9ce817a7fd00dea5600c84a0", + "sha256:26ff99c980f53b3191d8931b199b29d6787c059f2e029b2b0c694343b1708c35", + "sha256:27429b8d74ba683484a06b260b7bb00f312e7c757792628ea251afdbf1434003", + "sha256:3e77409b678b21a056415da3a56abfd7c3ad03da71f3051bbcdb68cf44d3c34d", + "sha256:4e8f02d3d72ca94efc8396f8036c0d3bcc812aefc28ec70f35bb888c74a25161", + "sha256:4eae742636aec40cf7ab98171ab9400393360b97e8f9da67b1867a9ee0889b26", + "sha256:6a6ae17bf8f2d82d1e8858a47757ce389b880083c4ff2498dba17c56e6c103b9", + "sha256:6a6ba91b94427cd49cd27764679024b14a96874e0dc638ae6bdd4b1a3ce97be1", + "sha256:7bcd322935377abcc79bfe5b63c44abd0b29387f267791d566bbb566edfdd146", + "sha256:98b8ed7bb2155e2cbb8b76f627b2fd12cf4b22ab6e14873e8641f266e0fb6d8f", + "sha256:bd25bb7980917e4e70ccccd7e3b5740614f1c408a642c245019cff9d7d1b6149", + "sha256:d0f424328f9822b0323b3b6f2e4b9c90960b24743d220763c7f07071e0778351", + "sha256:d58e4606da2a41659c84baeb3cfa2e4c87a74cec89a1e7c56bee4b956f9d7461", + "sha256:e3cd21cc2840ca67de0bbe4071f79f031c81418deb544ceda93ad75ca1ee9f7b", + "sha256:e6c02171d62ed6972ca8631f6f34fa3281d51db8b326ee397b9c83093a6b7242", + "sha256:e7c7661f7276507bce416eaae22040fd91ca471b5b33c13f8ff21137ed6f248c", + "sha256:ecc6de77df3ef68fee966bb8cb4e067e84d4d1f397d0ef6fce46913663540d77" + ], + "version": "==2020.1.8" + }, "requests": { "hashes": [ "sha256:11e007a8a2aa0323f5a921e9e6a2d7e4e67d9877e85773fba9ba6419025cbeb4", @@ -207,6 +276,13 @@ "git": "https://github.com/jekbradbury/revtok.git", "ref": "f1998b72a941d1e5f9578a66dc1c20b01913caab" }, + "s3transfer": { + "hashes": [ + "sha256:2157640a47c8b8fa2071bdd7b0d57378ec8957eede3bd083949c2dcc4d9b0dd4", + "sha256:e3343ae0f371781c17590cf06cb818a54484fbac9a65a5be7603a39b0a6d7b31" + ], + "version": "==0.3.0" + }, "sacrebleu": { "hashes": [ "sha256:0a4b9e53b742d95fcd2f32e4aaa42aadcf94121d998ca19c66c05e7037d1eeee", @@ -215,12 +291,48 @@ "index": "pypi", "version": "==1.4.3" }, + "sacremoses": { + "hashes": [ + "sha256:34dcfaacf9fa34a6353424431f0e4fcc60e8ebb27ffee320d57396690b712a3b" + ], + "version": "==0.0.38" + }, + "sentencepiece": { + "hashes": [ + "sha256:0a98ec863e541304df23a37787033001b62cb089f4ed9307911791d7e210c0b1", + "sha256:0ad221ea7914d65f57d3e3af7ae48852b5035166493312b5025367585b43ac41", + "sha256:0f72c4151791de7242e7184a9b7ef12503cef42e9a5a0c1b3510f2c68874e810", + "sha256:22fe7d92203fadbb6a0dc7d767430d37cdf3a9da4a0f2c5302c7bf294f7bfd8f", + "sha256:2a72d4c3d0dbb1e099ddd2dc6b724376d3d7ff77ba494756b894254485bec4b4", + "sha256:30791ce80a557339e17f1290c68dccd3f661612fdc6b689b4e4f21d805b64952", + "sha256:39904713b81869db10de53fe8b3719f35acf77f49351f28ceaad0d360f2f6305", + "sha256:3d5a2163deea95271ce8e38dfd0c3c924bea92aaf63bdda69b5458628dacc8bd", + "sha256:3f3dee204635c33ca2e450e17ee9e0e92f114a47f853c2e44e7f0f0ab444d8d0", + "sha256:4dcea889af53f669dc39d1ca870c37c52bb3110fcd96a2e7330d288400958281", + "sha256:4e36a92558ad9e2f91b311c5bcea90b7a63c567c0e7e20da44d6a6f01031b57e", + "sha256:576bf820eb963e6f275d4005ed5334fbed59eb54bed508e5cae6d16c7179710f", + "sha256:6d2bbdbf296d96304c6345675749981bb17dcf2a7163d2fec38f70a704b75669", + "sha256:76fdce3e7e614e24b35167c22c9c388e0c843be53d99afb5e1f25f6bfe04e228", + "sha256:97b8ee26892d236b2620af8ddae11713fbbb2dae9adf4ad5e988e5a82ce50a90", + "sha256:b3b6fe02af7ea4823c19e0d8efddc10ff59b8449bc1ae9921f9dd8ad33802c33", + "sha256:b416f514fff8785a1113e6c07f696e52967fc979d6cd946e454a8660cca72ef8", + "sha256:bf0bad6ba01ace3e938ffdf05c42b24d8fd3740487ba865504795a0bb9b1f2b3", + "sha256:c00387970360ec0369b5e7c75f3977fb14330df75465200c13bafb7a632d2e6b", + "sha256:c23fb7bb949934998375d41dbe54d4df1778a3b9dcb24bc2ddaaa595819ed1da", + "sha256:dfdcf48678656592b11d11e2102c52c38122e309f7a1a5272305d397cfe21ce0", + "sha256:fb69c5ba325b900cf2b91f517b46eec8ce3c50995955e293b46681d832021c0e", + "sha256:fba83bef6c7a7899cd811d9b1195e748722eb2a9737c3f3890160f0e01e3ad08", + "sha256:fe115aee209197839b2a357e34523e23768d553e8a69eac2b558499ccda56f80", + "sha256:ffdf51218a3d7e0dad79bdffd21ad15a23cbb9c572d2300c3295c6efc6c2357e" + ], + "version": "==0.1.85" + }, "six": { "hashes": [ - "sha256:1f1b7d42e254082a9db6279deae68afb421ceba6158efa6131de7b3003ee93fd", - "sha256:30f610279e8b2578cab6db20741130331735c781b56053c59c4076da27f06b66" + "sha256:236bdbdce46e6e6a3d61a337c0f8b763ca1e8717c03b369e87a7ec7ce1319c0a", + "sha256:8f3cd2e254d8f793e7f3d6d9df77b92252b52637291d0f0da013c76ea2724b6c" ], - "version": "==1.13.0" + "version": "==1.14.0" }, "soupsieve": { "hashes": [ @@ -260,18 +372,19 @@ }, "torch": { "hashes": [ - "sha256:0cec2e13a2e95c24c34f17d437f354ee2a40902e8d515a524556b350e12555dd", - "sha256:134e8291a97151b1ffeea09cb9ddde5238beb4e6d9dfb66657143d6990bfb865", - "sha256:31062923ac2e60eac676f6a0ae14702b051c158bbcf7f440eaba266b0defa197", - "sha256:3b05233481b51bb636cee63dc761bb7f602e198178782ff4159d385d1759608b", - "sha256:458f1d87e5b7064b2c39e36675d84e163be3143dd2fc806057b7878880c461bc", - "sha256:72a1c85bffd2154f085bc0a1d378d8a54e55a57d49664b874fe7c949022bf071", - "sha256:77fd8866c0bf529861ffd850a5dada2190a8d9c5167719fb0cfa89163e23b143", - "sha256:b6f01d851d1c5989d4a99b50ae0187762b15b7718dcd1a33704b665daa2402f9", - "sha256:d8e1d904a6193ed14a4fed220b00503b2baa576e71471286d1ebba899c851fae" + "sha256:271d4d1e44df6ed57c530f8849b028447c62b8a19b8e8740dd9baa56e7f682c1", + "sha256:30ce089475b287a37d6fbb8d71853e672edaf66699e3dd2eb19be6ce6296732a", + "sha256:405b9eb40e44037d2525b3ddb5bc4c66b519cd742bff249d4207d23f83e88ea5", + "sha256:504915c6bc6051ba6a4c2a43c446463dff04411e352f1e26fe13debeae431778", + "sha256:54d06a0e8ee85e5a437c24f4af9f4196c819294c23ffb5914e177756f55f1829", + "sha256:6f2fd9eb8c7eaf38a982ab266dbbfba0f29fb643bc74e677d045d6f2595e4692", + "sha256:8856f334aa9ecb742c1504bd2563d0ffb8dceb97149c8d72a04afa357f667dbc", + "sha256:8fff03bf7b474c16e4b50da65ea14200cc64553b67b9b2307f9dc7e8c69b9d28", + "sha256:9a1b1db73d8dcfd94b2eee24b939368742aa85f1217c55b8f5681e76c581e99a", + "sha256:bb1e87063661414e1149bef2e3a2499ce0b5060290799d7e26bc5578037075ba" ], "index": "pypi", - "version": "==1.3.1" + "version": "==1.4.0" }, "tqdm": { "hashes": [ @@ -281,6 +394,14 @@ "index": "pypi", "version": "==4.41.1" }, + "transformers": { + "hashes": [ + "sha256:2c237b06d60bb7f17f6b9e1ab9c5d4530508a287bc16ec64b5f7bb11d99df717", + "sha256:d881aca9ff1d0d9cf500bda47d1cbe1b87d4297af75f2e1b9cf7ac0293dd9c38" + ], + "index": "pypi", + "version": "==2.3.0" + }, "typing": { "hashes": [ "sha256:91dfe6f3f706ee8cc32d38edbbf304e9b7583fb37108fef38229617f8b3eba23", diff --git a/decanlp/data/example.py b/decanlp/data/example.py index d9709204..555bf871 100644 --- a/decanlp/data/example.py +++ b/decanlp/data/example.py @@ -1,6 +1,5 @@ # -# Copyright (c) 2018, Salesforce, Inc. -# The Board of Trustees of the Leland Stanford Junior University +# Copyright (c) 2019-2020 The Board of Trustees of the Leland Stanford Junior University # All rights reserved. # # Redistribution and use in source and binary forms, with or without @@ -31,19 +30,21 @@ import torch from typing import NamedTuple, List - -class SequentialField(NamedTuple): - value : torch.tensor - length : torch.tensor - limited : torch.tensor - tokens : List[str] +from .numericalizer import SequentialField class Example(NamedTuple): example_id : str + + # for each field in the example, we store the tokenized sentence, and a boolean mask + # indicating whether the token is a real word (subject to word-piece tokenization) + # or it should be treated as an opaque symbol context : List[str] + context_word_mask : List[bool] question : List[str] + question_word_mask : List[bool] answer : List[str] + answer_word_mask : List[bool] vocab_fields = ['context', 'question', 'answer'] @@ -51,10 +52,13 @@ class Example(NamedTuple): def from_raw(example_id : str, context : str, question : str, answer : str, tokenize, lower=False): args = [example_id] for argname, arg in (('context', context), ('question', question), ('answer', answer)): - new_arg = tokenize(arg.rstrip('\n'), field_name=argname) + words, mask = tokenize(arg.rstrip('\n'), field_name=argname) + if mask is None: + mask = [True for _ in words] if lower: - new_arg = [word.lower() for word in new_arg] - args.append(new_arg) + words = [word.lower() for word in words] + args.append(words) + args.append(mask) return Example(*args) @@ -69,9 +73,9 @@ class Batch(NamedTuple): def from_examples(examples, numericalizer, device=None): assert all(isinstance(ex.example_id, str) for ex in examples) example_ids = [ex.example_id for ex in examples] - context_input = [ex.context for ex in examples] - question_input = [ex.question for ex in examples] - answer_input = [ex.answer for ex in examples] + context_input = [(ex.context, ex.context_word_mask) for ex in examples] + question_input = [(ex.question, ex.question_word_mask) for ex in examples] + answer_input = [(ex.answer, ex.answer_word_mask) for ex in examples] decoder_vocab = numericalizer.decoder_vocab.clone() return Batch(example_ids, diff --git a/decanlp/data/numericalizer/__init__.py b/decanlp/data/numericalizer/__init__.py new file mode 100644 index 00000000..0415154b --- /dev/null +++ b/decanlp/data/numericalizer/__init__.py @@ -0,0 +1,33 @@ +# +# Copyright (c) 2020 The Board of Trustees of the Leland Stanford Junior University +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from .sequential_field import SequentialField + +from .simple import SimpleNumericalizer +from .bert import BertNumericalizer \ No newline at end of file diff --git a/decanlp/data/numericalizer/bert.py b/decanlp/data/numericalizer/bert.py new file mode 100644 index 00000000..aeab8e4d --- /dev/null +++ b/decanlp/data/numericalizer/bert.py @@ -0,0 +1,185 @@ +# +# Copyright (c) 2019-2020 The Board of Trustees of the Leland Stanford Junior University +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import collections +import os +import torch +from transformers import BertConfig + +from .sequential_field import SequentialField +from .decoder_vocab import DecoderVocabulary +from .masked_bert_tokenizer import MaskedBertTokenizer + + +class BertNumericalizer(object): + """ + Numericalizer that uses BertTokenizer from huggingface's transformers library. + """ + + def __init__(self, pretrained_tokenizer, max_generative_vocab, cache=None, fix_length=None): + self._pretrained_name = pretrained_tokenizer + self.max_generative_vocab = max_generative_vocab + self._cache = cache + self.config = None + self._tokenizer = None + + self.fix_length = fix_length + + @property + def num_tokens(self): + return self._tokenizer.vocab_size + + def load(self, save_dir): + self.config = BertConfig.from_pretrained(os.path.join(save_dir, 'bert-config.json'), cache_dir=self._cache) + self._tokenizer = MaskedBertTokenizer.from_pretrained(save_dir, config=self.config, cache_dir=self._cache) + + with open(os.path.join(save_dir, 'decoder-vocab.txt'), 'r') as fp: + self._decoder_words = [line.rstrip('\n') for line in fp] + + self._init() + + def save(self, save_dir): + self.config.save_pretrained(os.path.join(save_dir, 'bert-config.json')) + self._tokenizer.save_pretrained(os.path.join(save_dir)) + with open(os.path.join(save_dir, 'decoder-vocab.txt'), 'w') as fp: + for word in self._decoder_words: + fp.write(word + '\n') + + def build_vocab(self, vectors, vocab_fields, vocab_sets): + self.config = BertConfig.from_pretrained(self._pretrained_name, cache_dir=self._cache) + self._tokenizer = MaskedBertTokenizer.from_pretrained(self._pretrained_name, config=self.config, cache_dir=self._cache) + + # ensure that init, eos, unk and pad are set + # this method has no effect if the tokens are already set according to the tokenizer class + self._tokenizer.add_special_tokens({ + 'bos_token': '[CLS]', + 'eos_token': '[SEP]', + 'unk_token': '[UNK]', + 'pad_token': '[PAD]' + }) + + # do a pass over all the answers in the dataset, and construct a counter of wordpieces + decoder_words = collections.Counter() + for dataset in vocab_sets: + for example in dataset: + tokens = self._tokenizer.tokenize(example.answer, example.answer_word_mask) + decoder_words.update(tokens) + + self._decoder_words = decoder_words.most_common(self.max_generative_vocab) + + self._init() + + def grow_vocab(self, examples, vectors): + # TODO + return [] + + def _init(self): + self.pad_first = self._tokenizer.padding_side == 'left' + + self.init_token = self._tokenizer.bos_token + self.eos_token = self._tokenizer.eos_token + self.unk_token = self._tokenizer.unk_token + self.pad_token = self._tokenizer.pad_token + + self.init_id = self._tokenizer.bos_token_id + self.eos_id = self._tokenizer.eos_token_id + self.unk_id = self._tokenizer.unk_token_id + self.pad_id = self._tokenizer.pad_token_id + self.generative_vocab_size = len(self._decoder_words) + + assert self.init_id < self.generative_vocab_size + assert self.eos_id < self.generative_vocab_size + assert self.unk_id < self.generative_vocab_size + assert self.pad_id < self.generative_vocab_size + + self.decoder_vocab = DecoderVocabulary(self._decoder_words, self._tokenizer) + + def encode(self, minibatch, decoder_vocab, device=None): + assert isinstance(minibatch, list) + + # apply word-piece tokenization to everything first + wp_tokenized = [] + for tokens, mask in minibatch: + wp_tokenized.append(self._tokenizer.tokenize(tokens, mask)) + + if self.fix_length is None: + max_len = max(len(x) for x in minibatch) + 2 + else: + max_len = self.fix_length + 2 + padded = [] + lengths = [] + numerical = [] + decoder_numerical = [] + for wp_tokens in wp_tokenized: + if self.pad_first: + padded_example = [self.pad_token] * max(0, max_len - len(wp_tokens)) + \ + [self.init_token] + \ + list(wp_tokens[:max_len]) + \ + [self.eos_token] + else: + padded_example = [self.init_token] + \ + list(wp_tokens[:max_len]) + \ + [self.eos_token] + \ + [self.pad_token] * max(0, max_len - len(wp_tokens)) + + padded.append(padded_example) + lengths.append(len(padded_example) - max(0, max_len - len(wp_tokens))) + + numerical.append(self._tokenizer.convert_tokens_to_ids(padded_example)) + decoder_numerical.append([decoder_vocab.encode(word) for word in padded_example]) + + length = torch.tensor(lengths, dtype=torch.int32, device=device) + numerical = torch.tensor(numerical, dtype=torch.int64, device=device) + decoder_numerical = torch.tensor(decoder_numerical, dtype=torch.int64, device=device) + + return SequentialField(tokens=padded, length=length, value=numerical, limited=decoder_numerical) + + def decode(self, tensor): + return self._tokenizer.convert_ids_to_tokens(tensor) + + def reverse(self, batch, detokenize, field_name=None): + with torch.cuda.device_of(batch): + batch = batch.tolist() + + def reverse_one(tensor): + tokens = [] + + # trim up to EOS, remove other special stuff, and undo wordpiece tokenization + for token in self.decode(tensor): + if token == self.eos_token: + break + if token in (self.init_token, self.pad_token): + continue + if token.startswith('##'): + tokens[-1] += token[2:] + else: + tokens.append(token) + + return detokenize(tokens, field_name=field_name) + + return [reverse_one(tensor) for tensor in batch] \ No newline at end of file diff --git a/decanlp/data/numericalizer/decoder_vocab.py b/decanlp/data/numericalizer/decoder_vocab.py new file mode 100644 index 00000000..43a9a63e --- /dev/null +++ b/decanlp/data/numericalizer/decoder_vocab.py @@ -0,0 +1,67 @@ +# +# Copyright (c) 2019-2020 The Board of Trustees of the Leland Stanford Junior University +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +class DecoderVocabulary(object): + def __init__(self, words, full_vocab): + self.full_vocab = full_vocab + if words is not None: + self.itos = words + self.stoi = { word: idx for idx, word in enumerate(words) } + else: + self.itos = [] + self.stoi = dict() + self.oov_itos = [] + self.oov_stoi = dict() + + def clone(self): + new_subset = DecoderVocabulary(None, self.full_vocab) + new_subset.itos = self.itos + new_subset.stoi = self.stoi + return new_subset + + def __len__(self): + return len(self.itos) + len(self.oov_itos) + + def encode(self, word): + if word in self.stoi: + lim_idx = self.stoi[word] + elif word in self.oov_stoi: + lim_idx = self.oov_stoi[word] + else: + lim_idx = len(self) + self.oov_itos.append(word) + self.oov_stoi[word] = lim_idx + return lim_idx + + def decode(self, lim_idx): + if lim_idx < len(self.itos): + return self.full_vocab.stoi[self.itos[lim_idx]] + else: + return self.full_vocab.stoi[self.oov_itos[lim_idx-len(self.itos)]] \ No newline at end of file diff --git a/decanlp/data/numericalizer/masked_bert_tokenizer.py b/decanlp/data/numericalizer/masked_bert_tokenizer.py new file mode 100644 index 00000000..c85559c1 --- /dev/null +++ b/decanlp/data/numericalizer/masked_bert_tokenizer.py @@ -0,0 +1,122 @@ +# +# Copyright (c) 2019-2020 The Board of Trustees of the Leland Stanford Junior University +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# This file was partially copied from huggingface's tokenizers library +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from transformers import BertTokenizer + +class MaskedWordPieceTokenizer: + def __init__(self, vocab, added_tokens_encoder, added_tokens_decoder, unk_token, max_input_chars_per_word=100): + self.vocab = vocab + self.unk_token = unk_token + self.max_input_chars_per_word = max_input_chars_per_word + self.added_tokens_encoder = added_tokens_encoder + self.added_tokens_decoder = added_tokens_decoder + + def tokenize(self, tokens, mask): + output_tokens = [] + for token, should_word_split in tokens, mask: + if not should_word_split: + if token not in self.vocab and token not in self.added_tokens_encoder: + token_id = len(self.added_tokens_encoder) + self.added_tokens_encoder[token] = token_id + self.added_tokens_decoder[token_id] = token + output_tokens.append(token) + continue + + chars = list(token) + if len(chars) > self.max_input_chars_per_word: + output_tokens.append(self.unk_token) + continue + + is_bad = False + start = 0 + sub_tokens = [] + while start < len(chars): + end = len(chars) + cur_substr = None + while start < end: + substr = "".join(chars[start:end]) + if start > 0: + substr = "##" + substr + if substr in self.vocab: + cur_substr = substr + break + end -= 1 + if cur_substr is None: + is_bad = True + break + sub_tokens.append(cur_substr) + start = end + + if is_bad: + output_tokens.append(self.unk_token) + else: + output_tokens.extend(sub_tokens) + return output_tokens + + +class MaskedBertTokenizer(BertTokenizer): + """ + A modified BertTokenizer that respects a mask deciding whether a token should be split or not. + """ + def __init__(self, *args, do_lower_case, do_basic_tokenize, **kwargs): + # override do_lower_case and do_basic_tokenize unconditionally + super().__init__(*args, do_lower_case=False, do_basic_tokenize=False, **kwargs) + + # replace the word piece tokenizer with ours + self.wordpiece_tokenizer = MaskedWordPieceTokenizer(vocab=self.vocab, + added_tokens_encoder=self.added_tokens_encoder, + added_tokens_decoder=self.added_tokens_decoder, + unk_token=self.unk_token) + + def tokenize(self, tokens, mask=None): + return self.wordpiece_tokenizer.tokenize(tokens, mask) + + # provide an interface that DecoderVocabulary can like + @property + def stoi(self): + return self.vocab + + @property + def itos(self): + return self.ids_to_tokens \ No newline at end of file diff --git a/decanlp/data/numericalizer/sequential_field.py b/decanlp/data/numericalizer/sequential_field.py new file mode 100644 index 00000000..bf81219b --- /dev/null +++ b/decanlp/data/numericalizer/sequential_field.py @@ -0,0 +1,38 @@ +# +# Copyright (c) 2019-2020 The Board of Trustees of the Leland Stanford Junior University +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import torch +from typing import NamedTuple, List + + +class SequentialField(NamedTuple): + value : torch.tensor + length : torch.tensor + limited : torch.tensor + tokens : List[List[str]] \ No newline at end of file diff --git a/decanlp/data/numericalizer.py b/decanlp/data/numericalizer/simple.py similarity index 73% rename from decanlp/data/numericalizer.py rename to decanlp/data/numericalizer/simple.py index 313c8dbd..44b41d0a 100644 --- a/decanlp/data/numericalizer.py +++ b/decanlp/data/numericalizer/simple.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2018, The Board of Trustees of the Leland Stanford Junior University +# Copyright (c) 2019-2020 The Board of Trustees of the Leland Stanford Junior University # All rights reserved. # # Redistribution and use in source and binary forms, with or without @@ -31,51 +31,8 @@ import os import torch from .vocab import Vocab -from .example import Example, SequentialField - - -class DecoderVocabulary(object): - def __init__(self, words, full_vocab): - self.full_vocab = full_vocab - if words is not None: - self.itos = words - self.stoi = { word: idx for idx, word in enumerate(words) } - else: - self.itos = [] - self.stoi = dict() - self.oov_itos = [] - self.oov_stoi = dict() - - @property - def max_generative_vocab(self): - return len(self.itos) - - def clone(self): - new_subset = DecoderVocabulary(None, self.full_vocab) - new_subset.itos = self.itos - new_subset.stoi = self.stoi - return new_subset - - def __len__(self): - return len(self.itos) + len(self.oov_itos) - - def encode(self, word): - if word in self.stoi: - lim_idx = self.stoi[word] - elif word in self.oov_stoi: - lim_idx = self.oov_stoi[word] - else: - lim_idx = len(self) - self.oov_itos.append(word) - self.oov_stoi[word] = lim_idx - return lim_idx - - def decode(self, lim_idx): - if lim_idx < len(self.itos): - return self.full_vocab.stoi[self.itos[lim_idx]] - else: - return self.full_vocab.stoi[self.oov_itos[lim_idx-len(self.itos)]] - +from .sequential_field import SequentialField +from .decoder_vocab import DecoderVocabulary class SimpleNumericalizer(object): def __init__(self, max_effective_vocab, max_generative_vocab, fix_length=None, pad_first=False): @@ -101,8 +58,8 @@ class SimpleNumericalizer(object): def save(self, save_dir): torch.save(self.vocab, os.path.join(save_dir, 'vocab.pth')) - def build_vocab(self, vectors, vocab_sets): - self.vocab = Vocab.build_from_data(Example.vocab_fields, *vocab_sets, + def build_vocab(self, vectors, vocab_fields, vocab_sets): + self.vocab = Vocab.build_from_data(vocab_fields, *vocab_sets, unk_token=self.unk_token, init_token=self.init_token, eos_token=self.eos_token, @@ -152,31 +109,29 @@ class SimpleNumericalizer(object): self.decoder_vocab = DecoderVocabulary(self.vocab.itos[:self.max_generative_vocab], self.vocab) def encode(self, minibatch, decoder_vocab, device=None): - if not isinstance(minibatch, list): - minibatch = list(minibatch) + assert isinstance(minibatch, list) if self.fix_length is None: - max_len = max(len(x) for x in minibatch) + max_len = max(len(x[0]) for x in minibatch) else: - max_len = self.fix_length + ( - self.init_token, self.eos_token).count(None) - 2 + max_len = self.fix_length + 2 padded = [] lengths = [] numerical = [] decoder_numerical = [] - for example in minibatch: + for tokens, _mask in minibatch: if self.pad_first: - padded_example = [self.pad_token] * max(0, max_len - len(example)) + \ - ([] if self.init_token is None else [self.init_token]) + \ - list(example[:max_len]) + \ - ([] if self.eos_token is None else [self.eos_token]) + padded_example = [self.pad_token] * max(0, max_len - len(tokens)) + \ + [self.init_token] + \ + list(tokens[:max_len]) + \ + [self.eos_token] else: - padded_example = ([] if self.init_token is None else [self.init_token]) + \ - list(example[:max_len]) + \ - ([] if self.eos_token is None else [self.eos_token]) + \ - [self.pad_token] * max(0, max_len - len(example)) + padded_example = [self.init_token] + \ + list(tokens[:max_len]) + \ + [self.eos_token] + \ + [self.pad_token] * max(0, max_len - len(tokens)) padded.append(padded_example) - lengths.append(len(padded_example) - max(0, max_len - len(example))) + lengths.append(len(padded_example) - max(0, max_len - len(tokens))) numerical.append([self.vocab.stoi[word] for word in padded_example]) decoder_numerical.append([decoder_vocab.encode(word) for word in padded_example]) diff --git a/decanlp/data/vocab.py b/decanlp/data/numericalizer/vocab.py similarity index 100% rename from decanlp/data/vocab.py rename to decanlp/data/numericalizer/vocab.py diff --git a/decanlp/models/multitask_question_answering_network.py b/decanlp/models/multitask_question_answering_network.py index 59ef5b01..70549039 100644 --- a/decanlp/models/multitask_question_answering_network.py +++ b/decanlp/models/multitask_question_answering_network.py @@ -79,8 +79,8 @@ class MQANEncoder(nn.Module): self.encoder_embeddings.set_embeddings(embeddings) def forward(self, batch): - context, context_lengths, context_limited, context_tokens = batch.context.value, batch.context.length, batch.context.limited, batch.context.tokens - question, question_lengths, question_limited, question_tokens = batch.question.value, batch.question.length, batch.question.limited, batch.question.tokens + context, context_lengths = batch.context.value, batch.context.length + question, question_lengths = batch.question.value, batch.question.length context_embedded = self.encoder_embeddings(context) question_embedded = self.encoder_embeddings(question) diff --git a/decanlp/tasks/almond/__init__.py b/decanlp/tasks/almond/__init__.py index 253d6a76..8d527189 100644 --- a/decanlp/tasks/almond/__init__.py +++ b/decanlp/tasks/almond/__init__.py @@ -90,11 +90,11 @@ class AlmondDataset(generic_dataset.CQA): # the question is irrelevant, so the question says English and ThingTalk even if we're doing # a different language (like Chinese) if reverse_task: - question = 'Translate from ThingTalk to English' + question = 'translate from thingtalk to english' context = target_code answer = sentence else: - question = 'Translate from English to ThingTalk' + question = 'translate from english to thingtalk' context = sentence answer = target_code @@ -147,6 +147,10 @@ class AlmondDataset(generic_dataset.CQA): pass +def is_entity(token): + return token[0].isupper() + + class BaseAlmondTask(BaseTask): """Base class for the Almond semantic parsing task i.e. natural language to formal language (ThingTalk) mapping""" @@ -158,10 +162,31 @@ class BaseAlmondTask(BaseTask): def metrics(self): return ['em', 'nem', 'nf1', 'fm', 'dm', 'bleu'] + def _is_program_field(self, field_name): + raise NotImplementedError() + def tokenize(self, sentence, field_name=None): if not sentence: - return [] - return sentence.split(' ') + return [], [] + + if self._is_program_field(field_name): + mask = [] + in_string = False + tokens = sentence.split(' ') + for token in tokens: + if token == '"': + in_string = not in_string + mask.append(False) + else: + mask.append(in_string) + + assert len(tokens) == len(mask) + return tokens, mask + + else: + tokens = sentence.split(' ') + mask = [not is_entity(token) for token in tokens] + return tokens, mask def detokenize(self, tokenized, field_name=None): return ' '.join(tokenized) @@ -172,12 +197,17 @@ class Almond(BaseAlmondTask): """The Almond semantic parsing task i.e. natural language to formal language (ThingTalk) mapping""" + def _is_program_field(self, field_name): + return field_name == 'answer' + def get_splits(self, root, **kwargs): return AlmondDataset.splits(root=root, tokenize=self.tokenize, **kwargs) @register_task('contextual_almond') class ContextualAlmond(BaseAlmondTask): + def _is_program_field(self, field_name): + return field_name in ('answer', 'context') def get_splits(self, root, **kwargs): return AlmondDataset.splits(root=root, tokenize=self.tokenize, contextual=True, **kwargs) @@ -192,5 +222,8 @@ class ReverseAlmond(BaseTask): def metrics(self): return ['bleu', 'em', 'nem', 'nf1'] + def _is_program_field(self, field_name): + return field_name == 'context' + def get_splits(self, root, **kwargs): return AlmondDataset.splits(root=root, reverse_task=True, tokenize=self.tokenize, **kwargs) \ No newline at end of file diff --git a/decanlp/tasks/base.py b/decanlp/tasks/base.py index a8505e04..4fd19d01 100644 --- a/decanlp/tasks/base.py +++ b/decanlp/tasks/base.py @@ -53,8 +53,8 @@ class BaseTask: def tokenize(self, sentence, field_name=None): if not sentence: - return [] - return revtok.tokenize(sentence) + return [], None + return revtok.tokenize(sentence), None def detokenize(self, tokenized, field_name=None): return revtok.detokenize(tokenized) @@ -92,6 +92,3 @@ class BaseTask: :return: a list of metric names """ return ['em', 'nem', 'nf1'] - - tokenize = None - detokenize = None diff --git a/decanlp/tasks/generic.py b/decanlp/tasks/generic.py index 007f1d2a..08651854 100644 --- a/decanlp/tasks/generic.py +++ b/decanlp/tasks/generic.py @@ -68,8 +68,8 @@ class SQuAD(BaseTask): def tokenize(self, sentence, field_name=None): if not sentence: - return [] - return sentence.split() + return [], None + return sentence.split(), None def detokenize(self, tokenized, field_name=None): return ' '.join(tokenized) diff --git a/decanlp/train.py b/decanlp/train.py index 3e0d53fc..82fbec4f 100644 --- a/decanlp/train.py +++ b/decanlp/train.py @@ -41,6 +41,7 @@ import numpy as np import torch from tensorboardX import SummaryWriter +from .data.example import Example from .utils.parallel_utils import NamedTupleCompatibleDataParallel from . import arguments from . import models @@ -49,7 +50,6 @@ from .util import elapsed_time, set_seed, preprocess_examples, get_trainable_par init_devices, make_numericalizer from .utils.saver import Saver from .utils.embeddings import load_embeddings -from .data.numericalizer import SimpleNumericalizer def initialize_logger(args): @@ -118,7 +118,7 @@ def prepare_data(args, logger): vectors = load_embeddings(args, logger) vocab_sets = (train_sets + val_sets) if len(vocab_sets) == 0 else vocab_sets logger.info(f'Building vocabulary') - numericalizer.build_vocab(vectors, vocab_sets) + numericalizer.build_vocab(vectors, Example.vocab_fields, vocab_sets) numericalizer.save(args.save) logger.info(f'Vocabulary has {numericalizer.num_tokens} tokens')