From 388c41deb88780927197d3104f2369a84341565a Mon Sep 17 00:00:00 2001 From: mehrad Date: Wed, 16 Dec 2020 15:21:57 -0800 Subject: [PATCH] Take care of cjk chars in run_generation and run_lm_finetuning --- genienlp/paraphrase/data_utils.py | 6 +++--- genienlp/paraphrase/dataset.py | 9 +++++++++ genienlp/paraphrase/model_utils.py | 2 +- genienlp/paraphrase/run_generation.py | 3 +++ genienlp/tasks/almond/__init__.py | 18 ++--------------- genienlp/tasks/almond/utils.py | 29 +++++++++++++++++++++++++++ 6 files changed, 47 insertions(+), 20 deletions(-) diff --git a/genienlp/paraphrase/data_utils.py b/genienlp/paraphrase/data_utils.py index 4478a534..7c52937f 100644 --- a/genienlp/paraphrase/data_utils.py +++ b/genienlp/paraphrase/data_utils.py @@ -10,8 +10,7 @@ from ..data_utils.progbar import progress_bar from ..util import detokenize, tokenize, lower_case, SpecialTokenMap, remove_thingtalk_quotes from genienlp.util import get_number_of_lines -from ..tasks.almond.utils import is_entity, quoted_pattern_maybe_space, device_pattern - +from ..tasks.almond.utils import is_entity, quoted_pattern_maybe_space, device_pattern, detokenize_cjk_chars logger = logging.getLogger(__name__) @@ -271,7 +270,8 @@ def create_features_from_tsv_file(file_path, tokenizer, input_column, gold_colum # just make sure source language is used when tokenizing input sentence # tokenizer takes care of adding language code at the end of the sentence tokenizer.cur_lang_code = tokenizer.lang_code_to_id[src_lang] - + + input_sequence = detokenize_cjk_chars(input_sequence) input_sequence_ids = tokenizer.encode(input_sequence, add_special_tokens=True) prompt_ids = [] # includes the first few tokens of the output diff --git a/genienlp/paraphrase/dataset.py b/genienlp/paraphrase/dataset.py index 0eb9ab76..045bf7a6 100644 --- a/genienlp/paraphrase/dataset.py +++ b/genienlp/paraphrase/dataset.py @@ -9,6 +9,7 @@ from torch.nn.utils.rnn import pad_sequence from ..util import get_number_of_lines from ..data_utils.progbar import progress_bar +from ..tasks.almond.utils import detokenize_cjk_chars logger = logging.getLogger(__name__) @@ -88,6 +89,10 @@ class TextDataset(Dataset): def _add_example(self, input_sequence, output_sequence, args): # TODO we should make use of tokenizer.build_inputs_with_special_tokens(sequence1, sequence2). Add special tokens manualy only if our model does not support two sequences (like GPT2). + input_sequence = detokenize_cjk_chars(input_sequence) + if output_sequence is not None: + output_sequence = detokenize_cjk_chars(output_sequence) + input_token_ids = self.tokenizer.encode(input_sequence, add_special_tokens=False) + [self.tokenizer.convert_tokens_to_ids(args.start_special_token)] if output_sequence is None: output_token_ids = [] @@ -128,6 +133,10 @@ class TextDataset(Dataset): def _add_seq2seq_example(self, input_sequence, output_sequence, args): + input_sequence = detokenize_cjk_chars(input_sequence) + if output_sequence is not None: + output_sequence = detokenize_cjk_chars(output_sequence) + if args.model_type == 'mbart': model_inputs = self.tokenizer.prepare_seq2seq_batch([input_sequence], args.src_lang, [output_sequence], args.tgt_lang) else: diff --git a/genienlp/paraphrase/model_utils.py b/genienlp/paraphrase/model_utils.py index a6c60c0b..faefc168 100644 --- a/genienlp/paraphrase/model_utils.py +++ b/genienlp/paraphrase/model_utils.py @@ -63,7 +63,7 @@ def check_args(args): 'you have to specify the --src_lang flag.') elif args.src_lang not in MARIAN_GROUP_MEMBERS[args.model_name_or_path.rsplit('-', 2)[1]]: raise ValueError( - 'Dource language is not in the model group languages, please specify the correct source language.') + 'Source language is not in the model group languages, please specify the correct source language.') if args.model_type == 'marian' and args.model_name_or_path.rsplit('-', 1)[1] not in MARIAN_GROUP_MEMBERS and args.tgt_lang: logger.warning('Target language should not be provided when using models with single language pairs,' diff --git a/genienlp/paraphrase/run_generation.py b/genienlp/paraphrase/run_generation.py index 04f89800..cafea0dc 100644 --- a/genienlp/paraphrase/run_generation.py +++ b/genienlp/paraphrase/run_generation.py @@ -34,6 +34,7 @@ from torch.multiprocessing import Process, set_start_method from genienlp.paraphrase.data_utils import create_features_from_tsv_file, output_heuristics from genienlp.paraphrase.model_utils import compute_metrics, compute_attention, replace_quoted_params, force_replace_quoted_params +from ..tasks.almond.utils import tokenize_cjk_chars try: set_start_method('spawn') @@ -496,6 +497,8 @@ def run_single_process_generation(args, config): text = re.sub('\s\s+', ' ', text) # remove duplicate white spaces text = text.strip() + text = tokenize_cjk_chars(text) + if not args.skip_heuristics: text = output_heuristics(text, batch_reverse_maps[sample_index]) batch_outputs[sample_index].append(text) diff --git a/genienlp/tasks/almond/__init__.py b/genienlp/tasks/almond/__init__.py index ed9d1215..f0e64c9d 100644 --- a/genienlp/tasks/almond/__init__.py +++ b/genienlp/tasks/almond/__init__.py @@ -37,7 +37,7 @@ from ..registry import register_task from ..generic_dataset import CQA, context_question_len, token_batch_fn, default_batch_fn from ...data_utils.example import Example from ...data_utils.progbar import progress_bar -from .utils import ISO_to_LANG, is_device, is_entity, process_id, is_cjk_char +from .utils import ISO_to_LANG, is_device, is_entity, process_id, is_cjk_char, detokenize_cjk_chars from ...util import multiwoz_specific_preprocess, multiwoz_specific_postprocess from ..base_dataset import Split @@ -132,20 +132,6 @@ class BaseAlmondTask(BaseTask): def get_splits(self, root, **kwargs): return AlmondDataset.return_splits(path=os.path.join(root, 'almond'), make_example=self._make_example, **kwargs) - - def _detokenize_cjk_chars(self, sentence): - output = [] - i = 0 - while i < len(sentence): - output.append(sentence[i]) - # skip space after cjk chars only if followed by another cjk char - if is_cjk_char(ord(sentence[i])) and \ - i+1 < len(sentence) and sentence[i+1] == ' ' and \ - i+2 < len(sentence) and is_cjk_char(ord(sentence[i+2])): - i += 2 - else: - i += 1 - return "".join(output) def tokenize(self, sentence, field_name=None): if not sentence: @@ -154,7 +140,7 @@ class BaseAlmondTask(BaseTask): if self.force_subword_tokenize: return sentence.split(' '), None - sentence = self._detokenize_cjk_chars(sentence) + sentence = detokenize_cjk_chars(sentence) if self._dataset_specific_preprocess == 'multiwoz' and self._is_program_field(field_name): sentence = multiwoz_specific_preprocess(sentence) diff --git a/genienlp/tasks/almond/utils.py b/genienlp/tasks/almond/utils.py index 9da09f6e..73709833 100644 --- a/genienlp/tasks/almond/utils.py +++ b/genienlp/tasks/almond/utils.py @@ -52,5 +52,34 @@ def process_id(ex): id_ = id_[1:] return id_ +def detokenize_cjk_chars(sentence): + output = [] + i = 0 + while i < len(sentence): + output.append(sentence[i]) + # skip space after cjk chars only if followed by another cjk char + if is_cjk_char(ord(sentence[i])) and \ + i+1 < len(sentence) and sentence[i+1] == ' ' and \ + i+2 < len(sentence) and is_cjk_char(ord(sentence[i+2])): + i += 2 + else: + i += 1 + + return "".join(output) +def tokenize_cjk_chars(sentence): + output = [] + i = 0 + while i < len(sentence): + output.append(sentence[i]) + if is_cjk_char(ord(sentence[i])) and i+1 < len(sentence) and sentence[i+1] != ' ': + output.append(' ') + elif not is_cjk_char(ord(sentence[i])) and i + 1 < len(sentence) and is_cjk_char(ord(sentence[i + 1])): + output.append(' ') + i += 1 + + output = "".join(output) + output = output.replace(' ', ' ') + + return output