Take care of cjk chars in run_generation and run_lm_finetuning

This commit is contained in:
mehrad 2020-12-16 15:21:57 -08:00
parent f0ff4d104f
commit 388c41deb8
6 changed files with 47 additions and 20 deletions

View File

@ -10,8 +10,7 @@ from ..data_utils.progbar import progress_bar
from ..util import detokenize, tokenize, lower_case, SpecialTokenMap, remove_thingtalk_quotes
from genienlp.util import get_number_of_lines
from ..tasks.almond.utils import is_entity, quoted_pattern_maybe_space, device_pattern
from ..tasks.almond.utils import is_entity, quoted_pattern_maybe_space, device_pattern, detokenize_cjk_chars
logger = logging.getLogger(__name__)
@ -271,7 +270,8 @@ def create_features_from_tsv_file(file_path, tokenizer, input_column, gold_colum
# just make sure source language is used when tokenizing input sentence
# tokenizer takes care of adding language code at the end of the sentence
tokenizer.cur_lang_code = tokenizer.lang_code_to_id[src_lang]
input_sequence = detokenize_cjk_chars(input_sequence)
input_sequence_ids = tokenizer.encode(input_sequence, add_special_tokens=True)
prompt_ids = [] # includes the first few tokens of the output

View File

@ -9,6 +9,7 @@ from torch.nn.utils.rnn import pad_sequence
from ..util import get_number_of_lines
from ..data_utils.progbar import progress_bar
from ..tasks.almond.utils import detokenize_cjk_chars
logger = logging.getLogger(__name__)
@ -88,6 +89,10 @@ class TextDataset(Dataset):
def _add_example(self, input_sequence, output_sequence, args):
# TODO we should make use of tokenizer.build_inputs_with_special_tokens(sequence1, sequence2). Add special tokens manualy only if our model does not support two sequences (like GPT2).
input_sequence = detokenize_cjk_chars(input_sequence)
if output_sequence is not None:
output_sequence = detokenize_cjk_chars(output_sequence)
input_token_ids = self.tokenizer.encode(input_sequence, add_special_tokens=False) + [self.tokenizer.convert_tokens_to_ids(args.start_special_token)]
if output_sequence is None:
output_token_ids = []
@ -128,6 +133,10 @@ class TextDataset(Dataset):
def _add_seq2seq_example(self, input_sequence, output_sequence, args):
input_sequence = detokenize_cjk_chars(input_sequence)
if output_sequence is not None:
output_sequence = detokenize_cjk_chars(output_sequence)
if args.model_type == 'mbart':
model_inputs = self.tokenizer.prepare_seq2seq_batch([input_sequence], args.src_lang, [output_sequence], args.tgt_lang)
else:

View File

@ -63,7 +63,7 @@ def check_args(args):
'you have to specify the --src_lang flag.')
elif args.src_lang not in MARIAN_GROUP_MEMBERS[args.model_name_or_path.rsplit('-', 2)[1]]:
raise ValueError(
'Dource language is not in the model group languages, please specify the correct source language.')
'Source language is not in the model group languages, please specify the correct source language.')
if args.model_type == 'marian' and args.model_name_or_path.rsplit('-', 1)[1] not in MARIAN_GROUP_MEMBERS and args.tgt_lang:
logger.warning('Target language should not be provided when using models with single language pairs,'

View File

@ -34,6 +34,7 @@ from torch.multiprocessing import Process, set_start_method
from genienlp.paraphrase.data_utils import create_features_from_tsv_file, output_heuristics
from genienlp.paraphrase.model_utils import compute_metrics, compute_attention, replace_quoted_params, force_replace_quoted_params
from ..tasks.almond.utils import tokenize_cjk_chars
try:
set_start_method('spawn')
@ -496,6 +497,8 @@ def run_single_process_generation(args, config):
text = re.sub('\s\s+', ' ', text) # remove duplicate white spaces
text = text.strip()
text = tokenize_cjk_chars(text)
if not args.skip_heuristics:
text = output_heuristics(text, batch_reverse_maps[sample_index])
batch_outputs[sample_index].append(text)

View File

@ -37,7 +37,7 @@ from ..registry import register_task
from ..generic_dataset import CQA, context_question_len, token_batch_fn, default_batch_fn
from ...data_utils.example import Example
from ...data_utils.progbar import progress_bar
from .utils import ISO_to_LANG, is_device, is_entity, process_id, is_cjk_char
from .utils import ISO_to_LANG, is_device, is_entity, process_id, is_cjk_char, detokenize_cjk_chars
from ...util import multiwoz_specific_preprocess, multiwoz_specific_postprocess
from ..base_dataset import Split
@ -132,20 +132,6 @@ class BaseAlmondTask(BaseTask):
def get_splits(self, root, **kwargs):
return AlmondDataset.return_splits(path=os.path.join(root, 'almond'), make_example=self._make_example, **kwargs)
def _detokenize_cjk_chars(self, sentence):
output = []
i = 0
while i < len(sentence):
output.append(sentence[i])
# skip space after cjk chars only if followed by another cjk char
if is_cjk_char(ord(sentence[i])) and \
i+1 < len(sentence) and sentence[i+1] == ' ' and \
i+2 < len(sentence) and is_cjk_char(ord(sentence[i+2])):
i += 2
else:
i += 1
return "".join(output)
def tokenize(self, sentence, field_name=None):
if not sentence:
@ -154,7 +140,7 @@ class BaseAlmondTask(BaseTask):
if self.force_subword_tokenize:
return sentence.split(' '), None
sentence = self._detokenize_cjk_chars(sentence)
sentence = detokenize_cjk_chars(sentence)
if self._dataset_specific_preprocess == 'multiwoz' and self._is_program_field(field_name):
sentence = multiwoz_specific_preprocess(sentence)

View File

@ -52,5 +52,34 @@ def process_id(ex):
id_ = id_[1:]
return id_
def detokenize_cjk_chars(sentence):
output = []
i = 0
while i < len(sentence):
output.append(sentence[i])
# skip space after cjk chars only if followed by another cjk char
if is_cjk_char(ord(sentence[i])) and \
i+1 < len(sentence) and sentence[i+1] == ' ' and \
i+2 < len(sentence) and is_cjk_char(ord(sentence[i+2])):
i += 2
else:
i += 1
return "".join(output)
def tokenize_cjk_chars(sentence):
output = []
i = 0
while i < len(sentence):
output.append(sentence[i])
if is_cjk_char(ord(sentence[i])) and i+1 < len(sentence) and sentence[i+1] != ' ':
output.append(' ')
elif not is_cjk_char(ord(sentence[i])) and i + 1 < len(sentence) and is_cjk_char(ord(sentence[i + 1])):
output.append(' ')
i += 1
output = "".join(output)
output = output.replace(' ', ' ')
return output