Add a BERT-based numericalizer
Compatible with wordpiece tokenization used by BERT, and also with Almond's unique requirements around how tokens are split
This commit is contained in:
parent
ea8f7c5577
commit
bbcdb1a3e2
1
Pipfile
1
Pipfile
|
@ -20,6 +20,7 @@ tabulate = "*"
|
|||
tensorboardX = "*"
|
||||
Babel = "*"
|
||||
requests = "*"
|
||||
transformers = "*"
|
||||
|
||||
[requires]
|
||||
python_version = "3.7"
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
{
|
||||
"_meta": {
|
||||
"hash": {
|
||||
"sha256": "6bc781177e1bd4b396a93b6dbab0bf503d2d2d37b9dcecdbf6a667a47e4c8cc8"
|
||||
"sha256": "a49da22eb026763b9c0581905f6624ca97e443e40bb691e3d5531f58dbb199f7"
|
||||
},
|
||||
"pipfile-spec": 6,
|
||||
"requires": {
|
||||
|
@ -32,6 +32,20 @@
|
|||
],
|
||||
"version": "==4.8.2"
|
||||
},
|
||||
"boto3": {
|
||||
"hashes": [
|
||||
"sha256:916efdecda945daba08dca9032373440db681e53e6b46e7d0cae104ffd5ac7ca",
|
||||
"sha256:925899699b81977b75d4036fc478209916f4547f8683e5c7b1ad678317a6652e"
|
||||
],
|
||||
"version": "==1.11.3"
|
||||
},
|
||||
"botocore": {
|
||||
"hashes": [
|
||||
"sha256:799562d2af023f49518676c6b72ee07325da027513406b4b9d8e5b74ecea8257",
|
||||
"sha256:80e6522c4bb8d98dab9fd77708c6722bee775c7b85a77326ef44543062a17462"
|
||||
],
|
||||
"version": "==1.14.3"
|
||||
},
|
||||
"certifi": {
|
||||
"hashes": [
|
||||
"sha256:017c25db2a153ce562900032d5bc68e9f191e44e9a0f762f373977de9df1fbb3",
|
||||
|
@ -46,12 +60,27 @@
|
|||
],
|
||||
"version": "==3.0.4"
|
||||
},
|
||||
"click": {
|
||||
"hashes": [
|
||||
"sha256:2335065e6395b9e67ca716de5f7526736bfa6ceead690adf616d925bdc622b13",
|
||||
"sha256:5b94b49521f6456670fdb30cd82a4eca9412788a93fa6dd6df72c94d5a8ff2d7"
|
||||
],
|
||||
"version": "==7.0"
|
||||
},
|
||||
"docopt": {
|
||||
"hashes": [
|
||||
"sha256:49b3a825280bd66b3aa83585ef59c4a8c82f2c8a522dbe754a8bc8d08c85c491"
|
||||
],
|
||||
"version": "==0.6.2"
|
||||
},
|
||||
"docutils": {
|
||||
"hashes": [
|
||||
"sha256:6c4f696463b79f1fb8ba0c594b63840ebd41f059e92b31957c46b74a4599b6d0",
|
||||
"sha256:9e4d7ecfc600058e07ba661411a2b7de2fd0fafa17d1a7f7361cd47b1175c827",
|
||||
"sha256:a2aeea129088da402665e92e0b25b04b073c04b2dce4ab65caaa38b7ce2e1a99"
|
||||
],
|
||||
"version": "==0.15.2"
|
||||
},
|
||||
"et-xmlfile": {
|
||||
"hashes": [
|
||||
"sha256:614d9722d572f6246302c4491846d2c393c199cfa4edc9af593437691683335b"
|
||||
|
@ -72,6 +101,20 @@
|
|||
],
|
||||
"version": "==1.4.1"
|
||||
},
|
||||
"jmespath": {
|
||||
"hashes": [
|
||||
"sha256:3720a4b1bd659dd2eecad0666459b9788813e032b83e7ba58578e48254e0a0e6",
|
||||
"sha256:bde2aef6f44302dfb30320115b17d030798de8c4110e28d5cf6cf91a7a31074c"
|
||||
],
|
||||
"version": "==0.9.4"
|
||||
},
|
||||
"joblib": {
|
||||
"hashes": [
|
||||
"sha256:0630eea4f5664c463f23fbf5dcfc54a2bc6168902719fa8e19daf033022786c8",
|
||||
"sha256:bdb4fd9b72915ffb49fde2229ce482dd7ae79d842ed8c2b4c932441495af1403"
|
||||
],
|
||||
"version": "==0.14.1"
|
||||
},
|
||||
"nltk": {
|
||||
"hashes": [
|
||||
"sha256:2661f9971d983db314bbebd51ba770811a362c6597fd0f303bb1d3beadcb4834"
|
||||
|
@ -194,6 +237,32 @@
|
|||
"index": "pypi",
|
||||
"version": "==0.5.3"
|
||||
},
|
||||
"regex": {
|
||||
"hashes": [
|
||||
"sha256:07b39bf943d3d2fe63d46281d8504f8df0ff3fe4c57e13d1656737950e53e525",
|
||||
"sha256:0932941cdfb3afcbc26cc3bcf7c3f3d73d5a9b9c56955d432dbf8bbc147d4c5b",
|
||||
"sha256:0e182d2f097ea8549a249040922fa2b92ae28be4be4895933e369a525ba36576",
|
||||
"sha256:10671601ee06cf4dc1bc0b4805309040bb34c9af423c12c379c83d7895622bb5",
|
||||
"sha256:23e2c2c0ff50f44877f64780b815b8fd2e003cda9ce817a7fd00dea5600c84a0",
|
||||
"sha256:26ff99c980f53b3191d8931b199b29d6787c059f2e029b2b0c694343b1708c35",
|
||||
"sha256:27429b8d74ba683484a06b260b7bb00f312e7c757792628ea251afdbf1434003",
|
||||
"sha256:3e77409b678b21a056415da3a56abfd7c3ad03da71f3051bbcdb68cf44d3c34d",
|
||||
"sha256:4e8f02d3d72ca94efc8396f8036c0d3bcc812aefc28ec70f35bb888c74a25161",
|
||||
"sha256:4eae742636aec40cf7ab98171ab9400393360b97e8f9da67b1867a9ee0889b26",
|
||||
"sha256:6a6ae17bf8f2d82d1e8858a47757ce389b880083c4ff2498dba17c56e6c103b9",
|
||||
"sha256:6a6ba91b94427cd49cd27764679024b14a96874e0dc638ae6bdd4b1a3ce97be1",
|
||||
"sha256:7bcd322935377abcc79bfe5b63c44abd0b29387f267791d566bbb566edfdd146",
|
||||
"sha256:98b8ed7bb2155e2cbb8b76f627b2fd12cf4b22ab6e14873e8641f266e0fb6d8f",
|
||||
"sha256:bd25bb7980917e4e70ccccd7e3b5740614f1c408a642c245019cff9d7d1b6149",
|
||||
"sha256:d0f424328f9822b0323b3b6f2e4b9c90960b24743d220763c7f07071e0778351",
|
||||
"sha256:d58e4606da2a41659c84baeb3cfa2e4c87a74cec89a1e7c56bee4b956f9d7461",
|
||||
"sha256:e3cd21cc2840ca67de0bbe4071f79f031c81418deb544ceda93ad75ca1ee9f7b",
|
||||
"sha256:e6c02171d62ed6972ca8631f6f34fa3281d51db8b326ee397b9c83093a6b7242",
|
||||
"sha256:e7c7661f7276507bce416eaae22040fd91ca471b5b33c13f8ff21137ed6f248c",
|
||||
"sha256:ecc6de77df3ef68fee966bb8cb4e067e84d4d1f397d0ef6fce46913663540d77"
|
||||
],
|
||||
"version": "==2020.1.8"
|
||||
},
|
||||
"requests": {
|
||||
"hashes": [
|
||||
"sha256:11e007a8a2aa0323f5a921e9e6a2d7e4e67d9877e85773fba9ba6419025cbeb4",
|
||||
|
@ -207,6 +276,13 @@
|
|||
"git": "https://github.com/jekbradbury/revtok.git",
|
||||
"ref": "f1998b72a941d1e5f9578a66dc1c20b01913caab"
|
||||
},
|
||||
"s3transfer": {
|
||||
"hashes": [
|
||||
"sha256:2157640a47c8b8fa2071bdd7b0d57378ec8957eede3bd083949c2dcc4d9b0dd4",
|
||||
"sha256:e3343ae0f371781c17590cf06cb818a54484fbac9a65a5be7603a39b0a6d7b31"
|
||||
],
|
||||
"version": "==0.3.0"
|
||||
},
|
||||
"sacrebleu": {
|
||||
"hashes": [
|
||||
"sha256:0a4b9e53b742d95fcd2f32e4aaa42aadcf94121d998ca19c66c05e7037d1eeee",
|
||||
|
@ -215,12 +291,48 @@
|
|||
"index": "pypi",
|
||||
"version": "==1.4.3"
|
||||
},
|
||||
"sacremoses": {
|
||||
"hashes": [
|
||||
"sha256:34dcfaacf9fa34a6353424431f0e4fcc60e8ebb27ffee320d57396690b712a3b"
|
||||
],
|
||||
"version": "==0.0.38"
|
||||
},
|
||||
"sentencepiece": {
|
||||
"hashes": [
|
||||
"sha256:0a98ec863e541304df23a37787033001b62cb089f4ed9307911791d7e210c0b1",
|
||||
"sha256:0ad221ea7914d65f57d3e3af7ae48852b5035166493312b5025367585b43ac41",
|
||||
"sha256:0f72c4151791de7242e7184a9b7ef12503cef42e9a5a0c1b3510f2c68874e810",
|
||||
"sha256:22fe7d92203fadbb6a0dc7d767430d37cdf3a9da4a0f2c5302c7bf294f7bfd8f",
|
||||
"sha256:2a72d4c3d0dbb1e099ddd2dc6b724376d3d7ff77ba494756b894254485bec4b4",
|
||||
"sha256:30791ce80a557339e17f1290c68dccd3f661612fdc6b689b4e4f21d805b64952",
|
||||
"sha256:39904713b81869db10de53fe8b3719f35acf77f49351f28ceaad0d360f2f6305",
|
||||
"sha256:3d5a2163deea95271ce8e38dfd0c3c924bea92aaf63bdda69b5458628dacc8bd",
|
||||
"sha256:3f3dee204635c33ca2e450e17ee9e0e92f114a47f853c2e44e7f0f0ab444d8d0",
|
||||
"sha256:4dcea889af53f669dc39d1ca870c37c52bb3110fcd96a2e7330d288400958281",
|
||||
"sha256:4e36a92558ad9e2f91b311c5bcea90b7a63c567c0e7e20da44d6a6f01031b57e",
|
||||
"sha256:576bf820eb963e6f275d4005ed5334fbed59eb54bed508e5cae6d16c7179710f",
|
||||
"sha256:6d2bbdbf296d96304c6345675749981bb17dcf2a7163d2fec38f70a704b75669",
|
||||
"sha256:76fdce3e7e614e24b35167c22c9c388e0c843be53d99afb5e1f25f6bfe04e228",
|
||||
"sha256:97b8ee26892d236b2620af8ddae11713fbbb2dae9adf4ad5e988e5a82ce50a90",
|
||||
"sha256:b3b6fe02af7ea4823c19e0d8efddc10ff59b8449bc1ae9921f9dd8ad33802c33",
|
||||
"sha256:b416f514fff8785a1113e6c07f696e52967fc979d6cd946e454a8660cca72ef8",
|
||||
"sha256:bf0bad6ba01ace3e938ffdf05c42b24d8fd3740487ba865504795a0bb9b1f2b3",
|
||||
"sha256:c00387970360ec0369b5e7c75f3977fb14330df75465200c13bafb7a632d2e6b",
|
||||
"sha256:c23fb7bb949934998375d41dbe54d4df1778a3b9dcb24bc2ddaaa595819ed1da",
|
||||
"sha256:dfdcf48678656592b11d11e2102c52c38122e309f7a1a5272305d397cfe21ce0",
|
||||
"sha256:fb69c5ba325b900cf2b91f517b46eec8ce3c50995955e293b46681d832021c0e",
|
||||
"sha256:fba83bef6c7a7899cd811d9b1195e748722eb2a9737c3f3890160f0e01e3ad08",
|
||||
"sha256:fe115aee209197839b2a357e34523e23768d553e8a69eac2b558499ccda56f80",
|
||||
"sha256:ffdf51218a3d7e0dad79bdffd21ad15a23cbb9c572d2300c3295c6efc6c2357e"
|
||||
],
|
||||
"version": "==0.1.85"
|
||||
},
|
||||
"six": {
|
||||
"hashes": [
|
||||
"sha256:1f1b7d42e254082a9db6279deae68afb421ceba6158efa6131de7b3003ee93fd",
|
||||
"sha256:30f610279e8b2578cab6db20741130331735c781b56053c59c4076da27f06b66"
|
||||
"sha256:236bdbdce46e6e6a3d61a337c0f8b763ca1e8717c03b369e87a7ec7ce1319c0a",
|
||||
"sha256:8f3cd2e254d8f793e7f3d6d9df77b92252b52637291d0f0da013c76ea2724b6c"
|
||||
],
|
||||
"version": "==1.13.0"
|
||||
"version": "==1.14.0"
|
||||
},
|
||||
"soupsieve": {
|
||||
"hashes": [
|
||||
|
@ -260,18 +372,19 @@
|
|||
},
|
||||
"torch": {
|
||||
"hashes": [
|
||||
"sha256:0cec2e13a2e95c24c34f17d437f354ee2a40902e8d515a524556b350e12555dd",
|
||||
"sha256:134e8291a97151b1ffeea09cb9ddde5238beb4e6d9dfb66657143d6990bfb865",
|
||||
"sha256:31062923ac2e60eac676f6a0ae14702b051c158bbcf7f440eaba266b0defa197",
|
||||
"sha256:3b05233481b51bb636cee63dc761bb7f602e198178782ff4159d385d1759608b",
|
||||
"sha256:458f1d87e5b7064b2c39e36675d84e163be3143dd2fc806057b7878880c461bc",
|
||||
"sha256:72a1c85bffd2154f085bc0a1d378d8a54e55a57d49664b874fe7c949022bf071",
|
||||
"sha256:77fd8866c0bf529861ffd850a5dada2190a8d9c5167719fb0cfa89163e23b143",
|
||||
"sha256:b6f01d851d1c5989d4a99b50ae0187762b15b7718dcd1a33704b665daa2402f9",
|
||||
"sha256:d8e1d904a6193ed14a4fed220b00503b2baa576e71471286d1ebba899c851fae"
|
||||
"sha256:271d4d1e44df6ed57c530f8849b028447c62b8a19b8e8740dd9baa56e7f682c1",
|
||||
"sha256:30ce089475b287a37d6fbb8d71853e672edaf66699e3dd2eb19be6ce6296732a",
|
||||
"sha256:405b9eb40e44037d2525b3ddb5bc4c66b519cd742bff249d4207d23f83e88ea5",
|
||||
"sha256:504915c6bc6051ba6a4c2a43c446463dff04411e352f1e26fe13debeae431778",
|
||||
"sha256:54d06a0e8ee85e5a437c24f4af9f4196c819294c23ffb5914e177756f55f1829",
|
||||
"sha256:6f2fd9eb8c7eaf38a982ab266dbbfba0f29fb643bc74e677d045d6f2595e4692",
|
||||
"sha256:8856f334aa9ecb742c1504bd2563d0ffb8dceb97149c8d72a04afa357f667dbc",
|
||||
"sha256:8fff03bf7b474c16e4b50da65ea14200cc64553b67b9b2307f9dc7e8c69b9d28",
|
||||
"sha256:9a1b1db73d8dcfd94b2eee24b939368742aa85f1217c55b8f5681e76c581e99a",
|
||||
"sha256:bb1e87063661414e1149bef2e3a2499ce0b5060290799d7e26bc5578037075ba"
|
||||
],
|
||||
"index": "pypi",
|
||||
"version": "==1.3.1"
|
||||
"version": "==1.4.0"
|
||||
},
|
||||
"tqdm": {
|
||||
"hashes": [
|
||||
|
@ -281,6 +394,14 @@
|
|||
"index": "pypi",
|
||||
"version": "==4.41.1"
|
||||
},
|
||||
"transformers": {
|
||||
"hashes": [
|
||||
"sha256:2c237b06d60bb7f17f6b9e1ab9c5d4530508a287bc16ec64b5f7bb11d99df717",
|
||||
"sha256:d881aca9ff1d0d9cf500bda47d1cbe1b87d4297af75f2e1b9cf7ac0293dd9c38"
|
||||
],
|
||||
"index": "pypi",
|
||||
"version": "==2.3.0"
|
||||
},
|
||||
"typing": {
|
||||
"hashes": [
|
||||
"sha256:91dfe6f3f706ee8cc32d38edbbf304e9b7583fb37108fef38229617f8b3eba23",
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
#
|
||||
# Copyright (c) 2018, Salesforce, Inc.
|
||||
# The Board of Trustees of the Leland Stanford Junior University
|
||||
# Copyright (c) 2019-2020 The Board of Trustees of the Leland Stanford Junior University
|
||||
# All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
|
@ -31,19 +30,21 @@
|
|||
import torch
|
||||
from typing import NamedTuple, List
|
||||
|
||||
|
||||
class SequentialField(NamedTuple):
|
||||
value : torch.tensor
|
||||
length : torch.tensor
|
||||
limited : torch.tensor
|
||||
tokens : List[str]
|
||||
from .numericalizer import SequentialField
|
||||
|
||||
|
||||
class Example(NamedTuple):
|
||||
example_id : str
|
||||
|
||||
# for each field in the example, we store the tokenized sentence, and a boolean mask
|
||||
# indicating whether the token is a real word (subject to word-piece tokenization)
|
||||
# or it should be treated as an opaque symbol
|
||||
context : List[str]
|
||||
context_word_mask : List[bool]
|
||||
question : List[str]
|
||||
question_word_mask : List[bool]
|
||||
answer : List[str]
|
||||
answer_word_mask : List[bool]
|
||||
|
||||
vocab_fields = ['context', 'question', 'answer']
|
||||
|
||||
|
@ -51,10 +52,13 @@ class Example(NamedTuple):
|
|||
def from_raw(example_id : str, context : str, question : str, answer : str, tokenize, lower=False):
|
||||
args = [example_id]
|
||||
for argname, arg in (('context', context), ('question', question), ('answer', answer)):
|
||||
new_arg = tokenize(arg.rstrip('\n'), field_name=argname)
|
||||
words, mask = tokenize(arg.rstrip('\n'), field_name=argname)
|
||||
if mask is None:
|
||||
mask = [True for _ in words]
|
||||
if lower:
|
||||
new_arg = [word.lower() for word in new_arg]
|
||||
args.append(new_arg)
|
||||
words = [word.lower() for word in words]
|
||||
args.append(words)
|
||||
args.append(mask)
|
||||
return Example(*args)
|
||||
|
||||
|
||||
|
@ -69,9 +73,9 @@ class Batch(NamedTuple):
|
|||
def from_examples(examples, numericalizer, device=None):
|
||||
assert all(isinstance(ex.example_id, str) for ex in examples)
|
||||
example_ids = [ex.example_id for ex in examples]
|
||||
context_input = [ex.context for ex in examples]
|
||||
question_input = [ex.question for ex in examples]
|
||||
answer_input = [ex.answer for ex in examples]
|
||||
context_input = [(ex.context, ex.context_word_mask) for ex in examples]
|
||||
question_input = [(ex.question, ex.question_word_mask) for ex in examples]
|
||||
answer_input = [(ex.answer, ex.answer_word_mask) for ex in examples]
|
||||
decoder_vocab = numericalizer.decoder_vocab.clone()
|
||||
|
||||
return Batch(example_ids,
|
||||
|
|
|
@ -0,0 +1,33 @@
|
|||
#
|
||||
# Copyright (c) 2020 The Board of Trustees of the Leland Stanford Junior University
|
||||
# All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# * Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# * Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# * Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
from .sequential_field import SequentialField
|
||||
|
||||
from .simple import SimpleNumericalizer
|
||||
from .bert import BertNumericalizer
|
|
@ -0,0 +1,185 @@
|
|||
#
|
||||
# Copyright (c) 2019-2020 The Board of Trustees of the Leland Stanford Junior University
|
||||
# All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# * Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# * Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# * Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
import collections
|
||||
import os
|
||||
import torch
|
||||
from transformers import BertConfig
|
||||
|
||||
from .sequential_field import SequentialField
|
||||
from .decoder_vocab import DecoderVocabulary
|
||||
from .masked_bert_tokenizer import MaskedBertTokenizer
|
||||
|
||||
|
||||
class BertNumericalizer(object):
|
||||
"""
|
||||
Numericalizer that uses BertTokenizer from huggingface's transformers library.
|
||||
"""
|
||||
|
||||
def __init__(self, pretrained_tokenizer, max_generative_vocab, cache=None, fix_length=None):
|
||||
self._pretrained_name = pretrained_tokenizer
|
||||
self.max_generative_vocab = max_generative_vocab
|
||||
self._cache = cache
|
||||
self.config = None
|
||||
self._tokenizer = None
|
||||
|
||||
self.fix_length = fix_length
|
||||
|
||||
@property
|
||||
def num_tokens(self):
|
||||
return self._tokenizer.vocab_size
|
||||
|
||||
def load(self, save_dir):
|
||||
self.config = BertConfig.from_pretrained(os.path.join(save_dir, 'bert-config.json'), cache_dir=self._cache)
|
||||
self._tokenizer = MaskedBertTokenizer.from_pretrained(save_dir, config=self.config, cache_dir=self._cache)
|
||||
|
||||
with open(os.path.join(save_dir, 'decoder-vocab.txt'), 'r') as fp:
|
||||
self._decoder_words = [line.rstrip('\n') for line in fp]
|
||||
|
||||
self._init()
|
||||
|
||||
def save(self, save_dir):
|
||||
self.config.save_pretrained(os.path.join(save_dir, 'bert-config.json'))
|
||||
self._tokenizer.save_pretrained(os.path.join(save_dir))
|
||||
with open(os.path.join(save_dir, 'decoder-vocab.txt'), 'w') as fp:
|
||||
for word in self._decoder_words:
|
||||
fp.write(word + '\n')
|
||||
|
||||
def build_vocab(self, vectors, vocab_fields, vocab_sets):
|
||||
self.config = BertConfig.from_pretrained(self._pretrained_name, cache_dir=self._cache)
|
||||
self._tokenizer = MaskedBertTokenizer.from_pretrained(self._pretrained_name, config=self.config, cache_dir=self._cache)
|
||||
|
||||
# ensure that init, eos, unk and pad are set
|
||||
# this method has no effect if the tokens are already set according to the tokenizer class
|
||||
self._tokenizer.add_special_tokens({
|
||||
'bos_token': '[CLS]',
|
||||
'eos_token': '[SEP]',
|
||||
'unk_token': '[UNK]',
|
||||
'pad_token': '[PAD]'
|
||||
})
|
||||
|
||||
# do a pass over all the answers in the dataset, and construct a counter of wordpieces
|
||||
decoder_words = collections.Counter()
|
||||
for dataset in vocab_sets:
|
||||
for example in dataset:
|
||||
tokens = self._tokenizer.tokenize(example.answer, example.answer_word_mask)
|
||||
decoder_words.update(tokens)
|
||||
|
||||
self._decoder_words = decoder_words.most_common(self.max_generative_vocab)
|
||||
|
||||
self._init()
|
||||
|
||||
def grow_vocab(self, examples, vectors):
|
||||
# TODO
|
||||
return []
|
||||
|
||||
def _init(self):
|
||||
self.pad_first = self._tokenizer.padding_side == 'left'
|
||||
|
||||
self.init_token = self._tokenizer.bos_token
|
||||
self.eos_token = self._tokenizer.eos_token
|
||||
self.unk_token = self._tokenizer.unk_token
|
||||
self.pad_token = self._tokenizer.pad_token
|
||||
|
||||
self.init_id = self._tokenizer.bos_token_id
|
||||
self.eos_id = self._tokenizer.eos_token_id
|
||||
self.unk_id = self._tokenizer.unk_token_id
|
||||
self.pad_id = self._tokenizer.pad_token_id
|
||||
self.generative_vocab_size = len(self._decoder_words)
|
||||
|
||||
assert self.init_id < self.generative_vocab_size
|
||||
assert self.eos_id < self.generative_vocab_size
|
||||
assert self.unk_id < self.generative_vocab_size
|
||||
assert self.pad_id < self.generative_vocab_size
|
||||
|
||||
self.decoder_vocab = DecoderVocabulary(self._decoder_words, self._tokenizer)
|
||||
|
||||
def encode(self, minibatch, decoder_vocab, device=None):
|
||||
assert isinstance(minibatch, list)
|
||||
|
||||
# apply word-piece tokenization to everything first
|
||||
wp_tokenized = []
|
||||
for tokens, mask in minibatch:
|
||||
wp_tokenized.append(self._tokenizer.tokenize(tokens, mask))
|
||||
|
||||
if self.fix_length is None:
|
||||
max_len = max(len(x) for x in minibatch) + 2
|
||||
else:
|
||||
max_len = self.fix_length + 2
|
||||
padded = []
|
||||
lengths = []
|
||||
numerical = []
|
||||
decoder_numerical = []
|
||||
for wp_tokens in wp_tokenized:
|
||||
if self.pad_first:
|
||||
padded_example = [self.pad_token] * max(0, max_len - len(wp_tokens)) + \
|
||||
[self.init_token] + \
|
||||
list(wp_tokens[:max_len]) + \
|
||||
[self.eos_token]
|
||||
else:
|
||||
padded_example = [self.init_token] + \
|
||||
list(wp_tokens[:max_len]) + \
|
||||
[self.eos_token] + \
|
||||
[self.pad_token] * max(0, max_len - len(wp_tokens))
|
||||
|
||||
padded.append(padded_example)
|
||||
lengths.append(len(padded_example) - max(0, max_len - len(wp_tokens)))
|
||||
|
||||
numerical.append(self._tokenizer.convert_tokens_to_ids(padded_example))
|
||||
decoder_numerical.append([decoder_vocab.encode(word) for word in padded_example])
|
||||
|
||||
length = torch.tensor(lengths, dtype=torch.int32, device=device)
|
||||
numerical = torch.tensor(numerical, dtype=torch.int64, device=device)
|
||||
decoder_numerical = torch.tensor(decoder_numerical, dtype=torch.int64, device=device)
|
||||
|
||||
return SequentialField(tokens=padded, length=length, value=numerical, limited=decoder_numerical)
|
||||
|
||||
def decode(self, tensor):
|
||||
return self._tokenizer.convert_ids_to_tokens(tensor)
|
||||
|
||||
def reverse(self, batch, detokenize, field_name=None):
|
||||
with torch.cuda.device_of(batch):
|
||||
batch = batch.tolist()
|
||||
|
||||
def reverse_one(tensor):
|
||||
tokens = []
|
||||
|
||||
# trim up to EOS, remove other special stuff, and undo wordpiece tokenization
|
||||
for token in self.decode(tensor):
|
||||
if token == self.eos_token:
|
||||
break
|
||||
if token in (self.init_token, self.pad_token):
|
||||
continue
|
||||
if token.startswith('##'):
|
||||
tokens[-1] += token[2:]
|
||||
else:
|
||||
tokens.append(token)
|
||||
|
||||
return detokenize(tokens, field_name=field_name)
|
||||
|
||||
return [reverse_one(tensor) for tensor in batch]
|
|
@ -0,0 +1,67 @@
|
|||
#
|
||||
# Copyright (c) 2019-2020 The Board of Trustees of the Leland Stanford Junior University
|
||||
# All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# * Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# * Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# * Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
class DecoderVocabulary(object):
|
||||
def __init__(self, words, full_vocab):
|
||||
self.full_vocab = full_vocab
|
||||
if words is not None:
|
||||
self.itos = words
|
||||
self.stoi = { word: idx for idx, word in enumerate(words) }
|
||||
else:
|
||||
self.itos = []
|
||||
self.stoi = dict()
|
||||
self.oov_itos = []
|
||||
self.oov_stoi = dict()
|
||||
|
||||
def clone(self):
|
||||
new_subset = DecoderVocabulary(None, self.full_vocab)
|
||||
new_subset.itos = self.itos
|
||||
new_subset.stoi = self.stoi
|
||||
return new_subset
|
||||
|
||||
def __len__(self):
|
||||
return len(self.itos) + len(self.oov_itos)
|
||||
|
||||
def encode(self, word):
|
||||
if word in self.stoi:
|
||||
lim_idx = self.stoi[word]
|
||||
elif word in self.oov_stoi:
|
||||
lim_idx = self.oov_stoi[word]
|
||||
else:
|
||||
lim_idx = len(self)
|
||||
self.oov_itos.append(word)
|
||||
self.oov_stoi[word] = lim_idx
|
||||
return lim_idx
|
||||
|
||||
def decode(self, lim_idx):
|
||||
if lim_idx < len(self.itos):
|
||||
return self.full_vocab.stoi[self.itos[lim_idx]]
|
||||
else:
|
||||
return self.full_vocab.stoi[self.oov_itos[lim_idx-len(self.itos)]]
|
|
@ -0,0 +1,122 @@
|
|||
#
|
||||
# Copyright (c) 2019-2020 The Board of Trustees of the Leland Stanford Junior University
|
||||
# All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# * Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# * Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# * Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
# This file was partially copied from huggingface's tokenizers library
|
||||
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from transformers import BertTokenizer
|
||||
|
||||
class MaskedWordPieceTokenizer:
|
||||
def __init__(self, vocab, added_tokens_encoder, added_tokens_decoder, unk_token, max_input_chars_per_word=100):
|
||||
self.vocab = vocab
|
||||
self.unk_token = unk_token
|
||||
self.max_input_chars_per_word = max_input_chars_per_word
|
||||
self.added_tokens_encoder = added_tokens_encoder
|
||||
self.added_tokens_decoder = added_tokens_decoder
|
||||
|
||||
def tokenize(self, tokens, mask):
|
||||
output_tokens = []
|
||||
for token, should_word_split in tokens, mask:
|
||||
if not should_word_split:
|
||||
if token not in self.vocab and token not in self.added_tokens_encoder:
|
||||
token_id = len(self.added_tokens_encoder)
|
||||
self.added_tokens_encoder[token] = token_id
|
||||
self.added_tokens_decoder[token_id] = token
|
||||
output_tokens.append(token)
|
||||
continue
|
||||
|
||||
chars = list(token)
|
||||
if len(chars) > self.max_input_chars_per_word:
|
||||
output_tokens.append(self.unk_token)
|
||||
continue
|
||||
|
||||
is_bad = False
|
||||
start = 0
|
||||
sub_tokens = []
|
||||
while start < len(chars):
|
||||
end = len(chars)
|
||||
cur_substr = None
|
||||
while start < end:
|
||||
substr = "".join(chars[start:end])
|
||||
if start > 0:
|
||||
substr = "##" + substr
|
||||
if substr in self.vocab:
|
||||
cur_substr = substr
|
||||
break
|
||||
end -= 1
|
||||
if cur_substr is None:
|
||||
is_bad = True
|
||||
break
|
||||
sub_tokens.append(cur_substr)
|
||||
start = end
|
||||
|
||||
if is_bad:
|
||||
output_tokens.append(self.unk_token)
|
||||
else:
|
||||
output_tokens.extend(sub_tokens)
|
||||
return output_tokens
|
||||
|
||||
|
||||
class MaskedBertTokenizer(BertTokenizer):
|
||||
"""
|
||||
A modified BertTokenizer that respects a mask deciding whether a token should be split or not.
|
||||
"""
|
||||
def __init__(self, *args, do_lower_case, do_basic_tokenize, **kwargs):
|
||||
# override do_lower_case and do_basic_tokenize unconditionally
|
||||
super().__init__(*args, do_lower_case=False, do_basic_tokenize=False, **kwargs)
|
||||
|
||||
# replace the word piece tokenizer with ours
|
||||
self.wordpiece_tokenizer = MaskedWordPieceTokenizer(vocab=self.vocab,
|
||||
added_tokens_encoder=self.added_tokens_encoder,
|
||||
added_tokens_decoder=self.added_tokens_decoder,
|
||||
unk_token=self.unk_token)
|
||||
|
||||
def tokenize(self, tokens, mask=None):
|
||||
return self.wordpiece_tokenizer.tokenize(tokens, mask)
|
||||
|
||||
# provide an interface that DecoderVocabulary can like
|
||||
@property
|
||||
def stoi(self):
|
||||
return self.vocab
|
||||
|
||||
@property
|
||||
def itos(self):
|
||||
return self.ids_to_tokens
|
|
@ -0,0 +1,38 @@
|
|||
#
|
||||
# Copyright (c) 2019-2020 The Board of Trustees of the Leland Stanford Junior University
|
||||
# All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# * Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# * Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# * Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
import torch
|
||||
from typing import NamedTuple, List
|
||||
|
||||
|
||||
class SequentialField(NamedTuple):
|
||||
value : torch.tensor
|
||||
length : torch.tensor
|
||||
limited : torch.tensor
|
||||
tokens : List[List[str]]
|
|
@ -1,5 +1,5 @@
|
|||
#
|
||||
# Copyright (c) 2018, The Board of Trustees of the Leland Stanford Junior University
|
||||
# Copyright (c) 2019-2020 The Board of Trustees of the Leland Stanford Junior University
|
||||
# All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
|
@ -31,51 +31,8 @@ import os
|
|||
import torch
|
||||
|
||||
from .vocab import Vocab
|
||||
from .example import Example, SequentialField
|
||||
|
||||
|
||||
class DecoderVocabulary(object):
|
||||
def __init__(self, words, full_vocab):
|
||||
self.full_vocab = full_vocab
|
||||
if words is not None:
|
||||
self.itos = words
|
||||
self.stoi = { word: idx for idx, word in enumerate(words) }
|
||||
else:
|
||||
self.itos = []
|
||||
self.stoi = dict()
|
||||
self.oov_itos = []
|
||||
self.oov_stoi = dict()
|
||||
|
||||
@property
|
||||
def max_generative_vocab(self):
|
||||
return len(self.itos)
|
||||
|
||||
def clone(self):
|
||||
new_subset = DecoderVocabulary(None, self.full_vocab)
|
||||
new_subset.itos = self.itos
|
||||
new_subset.stoi = self.stoi
|
||||
return new_subset
|
||||
|
||||
def __len__(self):
|
||||
return len(self.itos) + len(self.oov_itos)
|
||||
|
||||
def encode(self, word):
|
||||
if word in self.stoi:
|
||||
lim_idx = self.stoi[word]
|
||||
elif word in self.oov_stoi:
|
||||
lim_idx = self.oov_stoi[word]
|
||||
else:
|
||||
lim_idx = len(self)
|
||||
self.oov_itos.append(word)
|
||||
self.oov_stoi[word] = lim_idx
|
||||
return lim_idx
|
||||
|
||||
def decode(self, lim_idx):
|
||||
if lim_idx < len(self.itos):
|
||||
return self.full_vocab.stoi[self.itos[lim_idx]]
|
||||
else:
|
||||
return self.full_vocab.stoi[self.oov_itos[lim_idx-len(self.itos)]]
|
||||
|
||||
from .sequential_field import SequentialField
|
||||
from .decoder_vocab import DecoderVocabulary
|
||||
|
||||
class SimpleNumericalizer(object):
|
||||
def __init__(self, max_effective_vocab, max_generative_vocab, fix_length=None, pad_first=False):
|
||||
|
@ -101,8 +58,8 @@ class SimpleNumericalizer(object):
|
|||
def save(self, save_dir):
|
||||
torch.save(self.vocab, os.path.join(save_dir, 'vocab.pth'))
|
||||
|
||||
def build_vocab(self, vectors, vocab_sets):
|
||||
self.vocab = Vocab.build_from_data(Example.vocab_fields, *vocab_sets,
|
||||
def build_vocab(self, vectors, vocab_fields, vocab_sets):
|
||||
self.vocab = Vocab.build_from_data(vocab_fields, *vocab_sets,
|
||||
unk_token=self.unk_token,
|
||||
init_token=self.init_token,
|
||||
eos_token=self.eos_token,
|
||||
|
@ -152,31 +109,29 @@ class SimpleNumericalizer(object):
|
|||
self.decoder_vocab = DecoderVocabulary(self.vocab.itos[:self.max_generative_vocab], self.vocab)
|
||||
|
||||
def encode(self, minibatch, decoder_vocab, device=None):
|
||||
if not isinstance(minibatch, list):
|
||||
minibatch = list(minibatch)
|
||||
assert isinstance(minibatch, list)
|
||||
if self.fix_length is None:
|
||||
max_len = max(len(x) for x in minibatch)
|
||||
max_len = max(len(x[0]) for x in minibatch)
|
||||
else:
|
||||
max_len = self.fix_length + (
|
||||
self.init_token, self.eos_token).count(None) - 2
|
||||
max_len = self.fix_length + 2
|
||||
padded = []
|
||||
lengths = []
|
||||
numerical = []
|
||||
decoder_numerical = []
|
||||
for example in minibatch:
|
||||
for tokens, _mask in minibatch:
|
||||
if self.pad_first:
|
||||
padded_example = [self.pad_token] * max(0, max_len - len(example)) + \
|
||||
([] if self.init_token is None else [self.init_token]) + \
|
||||
list(example[:max_len]) + \
|
||||
([] if self.eos_token is None else [self.eos_token])
|
||||
padded_example = [self.pad_token] * max(0, max_len - len(tokens)) + \
|
||||
[self.init_token] + \
|
||||
list(tokens[:max_len]) + \
|
||||
[self.eos_token]
|
||||
else:
|
||||
padded_example = ([] if self.init_token is None else [self.init_token]) + \
|
||||
list(example[:max_len]) + \
|
||||
([] if self.eos_token is None else [self.eos_token]) + \
|
||||
[self.pad_token] * max(0, max_len - len(example))
|
||||
padded_example = [self.init_token] + \
|
||||
list(tokens[:max_len]) + \
|
||||
[self.eos_token] + \
|
||||
[self.pad_token] * max(0, max_len - len(tokens))
|
||||
|
||||
padded.append(padded_example)
|
||||
lengths.append(len(padded_example) - max(0, max_len - len(example)))
|
||||
lengths.append(len(padded_example) - max(0, max_len - len(tokens)))
|
||||
|
||||
numerical.append([self.vocab.stoi[word] for word in padded_example])
|
||||
decoder_numerical.append([decoder_vocab.encode(word) for word in padded_example])
|
|
@ -79,8 +79,8 @@ class MQANEncoder(nn.Module):
|
|||
self.encoder_embeddings.set_embeddings(embeddings)
|
||||
|
||||
def forward(self, batch):
|
||||
context, context_lengths, context_limited, context_tokens = batch.context.value, batch.context.length, batch.context.limited, batch.context.tokens
|
||||
question, question_lengths, question_limited, question_tokens = batch.question.value, batch.question.length, batch.question.limited, batch.question.tokens
|
||||
context, context_lengths = batch.context.value, batch.context.length
|
||||
question, question_lengths = batch.question.value, batch.question.length
|
||||
|
||||
context_embedded = self.encoder_embeddings(context)
|
||||
question_embedded = self.encoder_embeddings(question)
|
||||
|
|
|
@ -90,11 +90,11 @@ class AlmondDataset(generic_dataset.CQA):
|
|||
# the question is irrelevant, so the question says English and ThingTalk even if we're doing
|
||||
# a different language (like Chinese)
|
||||
if reverse_task:
|
||||
question = 'Translate from ThingTalk to English'
|
||||
question = 'translate from thingtalk to english'
|
||||
context = target_code
|
||||
answer = sentence
|
||||
else:
|
||||
question = 'Translate from English to ThingTalk'
|
||||
question = 'translate from english to thingtalk'
|
||||
context = sentence
|
||||
answer = target_code
|
||||
|
||||
|
@ -147,6 +147,10 @@ class AlmondDataset(generic_dataset.CQA):
|
|||
pass
|
||||
|
||||
|
||||
def is_entity(token):
|
||||
return token[0].isupper()
|
||||
|
||||
|
||||
class BaseAlmondTask(BaseTask):
|
||||
"""Base class for the Almond semantic parsing task
|
||||
i.e. natural language to formal language (ThingTalk) mapping"""
|
||||
|
@ -158,10 +162,31 @@ class BaseAlmondTask(BaseTask):
|
|||
def metrics(self):
|
||||
return ['em', 'nem', 'nf1', 'fm', 'dm', 'bleu']
|
||||
|
||||
def _is_program_field(self, field_name):
|
||||
raise NotImplementedError()
|
||||
|
||||
def tokenize(self, sentence, field_name=None):
|
||||
if not sentence:
|
||||
return []
|
||||
return sentence.split(' ')
|
||||
return [], []
|
||||
|
||||
if self._is_program_field(field_name):
|
||||
mask = []
|
||||
in_string = False
|
||||
tokens = sentence.split(' ')
|
||||
for token in tokens:
|
||||
if token == '"':
|
||||
in_string = not in_string
|
||||
mask.append(False)
|
||||
else:
|
||||
mask.append(in_string)
|
||||
|
||||
assert len(tokens) == len(mask)
|
||||
return tokens, mask
|
||||
|
||||
else:
|
||||
tokens = sentence.split(' ')
|
||||
mask = [not is_entity(token) for token in tokens]
|
||||
return tokens, mask
|
||||
|
||||
def detokenize(self, tokenized, field_name=None):
|
||||
return ' '.join(tokenized)
|
||||
|
@ -172,12 +197,17 @@ class Almond(BaseAlmondTask):
|
|||
"""The Almond semantic parsing task
|
||||
i.e. natural language to formal language (ThingTalk) mapping"""
|
||||
|
||||
def _is_program_field(self, field_name):
|
||||
return field_name == 'answer'
|
||||
|
||||
def get_splits(self, root, **kwargs):
|
||||
return AlmondDataset.splits(root=root, tokenize=self.tokenize, **kwargs)
|
||||
|
||||
|
||||
@register_task('contextual_almond')
|
||||
class ContextualAlmond(BaseAlmondTask):
|
||||
def _is_program_field(self, field_name):
|
||||
return field_name in ('answer', 'context')
|
||||
|
||||
def get_splits(self, root, **kwargs):
|
||||
return AlmondDataset.splits(root=root, tokenize=self.tokenize, contextual=True, **kwargs)
|
||||
|
@ -192,5 +222,8 @@ class ReverseAlmond(BaseTask):
|
|||
def metrics(self):
|
||||
return ['bleu', 'em', 'nem', 'nf1']
|
||||
|
||||
def _is_program_field(self, field_name):
|
||||
return field_name == 'context'
|
||||
|
||||
def get_splits(self, root, **kwargs):
|
||||
return AlmondDataset.splits(root=root, reverse_task=True, tokenize=self.tokenize, **kwargs)
|
|
@ -53,8 +53,8 @@ class BaseTask:
|
|||
|
||||
def tokenize(self, sentence, field_name=None):
|
||||
if not sentence:
|
||||
return []
|
||||
return revtok.tokenize(sentence)
|
||||
return [], None
|
||||
return revtok.tokenize(sentence), None
|
||||
|
||||
def detokenize(self, tokenized, field_name=None):
|
||||
return revtok.detokenize(tokenized)
|
||||
|
@ -92,6 +92,3 @@ class BaseTask:
|
|||
:return: a list of metric names
|
||||
"""
|
||||
return ['em', 'nem', 'nf1']
|
||||
|
||||
tokenize = None
|
||||
detokenize = None
|
||||
|
|
|
@ -68,8 +68,8 @@ class SQuAD(BaseTask):
|
|||
|
||||
def tokenize(self, sentence, field_name=None):
|
||||
if not sentence:
|
||||
return []
|
||||
return sentence.split()
|
||||
return [], None
|
||||
return sentence.split(), None
|
||||
|
||||
def detokenize(self, tokenized, field_name=None):
|
||||
return ' '.join(tokenized)
|
||||
|
|
|
@ -41,6 +41,7 @@ import numpy as np
|
|||
import torch
|
||||
from tensorboardX import SummaryWriter
|
||||
|
||||
from .data.example import Example
|
||||
from .utils.parallel_utils import NamedTupleCompatibleDataParallel
|
||||
from . import arguments
|
||||
from . import models
|
||||
|
@ -49,7 +50,6 @@ from .util import elapsed_time, set_seed, preprocess_examples, get_trainable_par
|
|||
init_devices, make_numericalizer
|
||||
from .utils.saver import Saver
|
||||
from .utils.embeddings import load_embeddings
|
||||
from .data.numericalizer import SimpleNumericalizer
|
||||
|
||||
|
||||
def initialize_logger(args):
|
||||
|
@ -118,7 +118,7 @@ def prepare_data(args, logger):
|
|||
vectors = load_embeddings(args, logger)
|
||||
vocab_sets = (train_sets + val_sets) if len(vocab_sets) == 0 else vocab_sets
|
||||
logger.info(f'Building vocabulary')
|
||||
numericalizer.build_vocab(vectors, vocab_sets)
|
||||
numericalizer.build_vocab(vectors, Example.vocab_fields, vocab_sets)
|
||||
numericalizer.save(args.save)
|
||||
|
||||
logger.info(f'Vocabulary has {numericalizer.num_tokens} tokens')
|
||||
|
|
Loading…
Reference in New Issue