Take care of cjk chars in run_generation and run_lm_finetuning
This commit is contained in:
parent
f0ff4d104f
commit
388c41deb8
|
@ -10,8 +10,7 @@ from ..data_utils.progbar import progress_bar
|
|||
from ..util import detokenize, tokenize, lower_case, SpecialTokenMap, remove_thingtalk_quotes
|
||||
|
||||
from genienlp.util import get_number_of_lines
|
||||
from ..tasks.almond.utils import is_entity, quoted_pattern_maybe_space, device_pattern
|
||||
|
||||
from ..tasks.almond.utils import is_entity, quoted_pattern_maybe_space, device_pattern, detokenize_cjk_chars
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -271,7 +270,8 @@ def create_features_from_tsv_file(file_path, tokenizer, input_column, gold_colum
|
|||
# just make sure source language is used when tokenizing input sentence
|
||||
# tokenizer takes care of adding language code at the end of the sentence
|
||||
tokenizer.cur_lang_code = tokenizer.lang_code_to_id[src_lang]
|
||||
|
||||
|
||||
input_sequence = detokenize_cjk_chars(input_sequence)
|
||||
input_sequence_ids = tokenizer.encode(input_sequence, add_special_tokens=True)
|
||||
|
||||
prompt_ids = [] # includes the first few tokens of the output
|
||||
|
|
|
@ -9,6 +9,7 @@ from torch.nn.utils.rnn import pad_sequence
|
|||
|
||||
from ..util import get_number_of_lines
|
||||
from ..data_utils.progbar import progress_bar
|
||||
from ..tasks.almond.utils import detokenize_cjk_chars
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -88,6 +89,10 @@ class TextDataset(Dataset):
|
|||
def _add_example(self, input_sequence, output_sequence, args):
|
||||
# TODO we should make use of tokenizer.build_inputs_with_special_tokens(sequence1, sequence2). Add special tokens manualy only if our model does not support two sequences (like GPT2).
|
||||
|
||||
input_sequence = detokenize_cjk_chars(input_sequence)
|
||||
if output_sequence is not None:
|
||||
output_sequence = detokenize_cjk_chars(output_sequence)
|
||||
|
||||
input_token_ids = self.tokenizer.encode(input_sequence, add_special_tokens=False) + [self.tokenizer.convert_tokens_to_ids(args.start_special_token)]
|
||||
if output_sequence is None:
|
||||
output_token_ids = []
|
||||
|
@ -128,6 +133,10 @@ class TextDataset(Dataset):
|
|||
|
||||
def _add_seq2seq_example(self, input_sequence, output_sequence, args):
|
||||
|
||||
input_sequence = detokenize_cjk_chars(input_sequence)
|
||||
if output_sequence is not None:
|
||||
output_sequence = detokenize_cjk_chars(output_sequence)
|
||||
|
||||
if args.model_type == 'mbart':
|
||||
model_inputs = self.tokenizer.prepare_seq2seq_batch([input_sequence], args.src_lang, [output_sequence], args.tgt_lang)
|
||||
else:
|
||||
|
|
|
@ -63,7 +63,7 @@ def check_args(args):
|
|||
'you have to specify the --src_lang flag.')
|
||||
elif args.src_lang not in MARIAN_GROUP_MEMBERS[args.model_name_or_path.rsplit('-', 2)[1]]:
|
||||
raise ValueError(
|
||||
'Dource language is not in the model group languages, please specify the correct source language.')
|
||||
'Source language is not in the model group languages, please specify the correct source language.')
|
||||
|
||||
if args.model_type == 'marian' and args.model_name_or_path.rsplit('-', 1)[1] not in MARIAN_GROUP_MEMBERS and args.tgt_lang:
|
||||
logger.warning('Target language should not be provided when using models with single language pairs,'
|
||||
|
|
|
@ -34,6 +34,7 @@ from torch.multiprocessing import Process, set_start_method
|
|||
|
||||
from genienlp.paraphrase.data_utils import create_features_from_tsv_file, output_heuristics
|
||||
from genienlp.paraphrase.model_utils import compute_metrics, compute_attention, replace_quoted_params, force_replace_quoted_params
|
||||
from ..tasks.almond.utils import tokenize_cjk_chars
|
||||
|
||||
try:
|
||||
set_start_method('spawn')
|
||||
|
@ -496,6 +497,8 @@ def run_single_process_generation(args, config):
|
|||
text = re.sub('\s\s+', ' ', text) # remove duplicate white spaces
|
||||
text = text.strip()
|
||||
|
||||
text = tokenize_cjk_chars(text)
|
||||
|
||||
if not args.skip_heuristics:
|
||||
text = output_heuristics(text, batch_reverse_maps[sample_index])
|
||||
batch_outputs[sample_index].append(text)
|
||||
|
|
|
@ -37,7 +37,7 @@ from ..registry import register_task
|
|||
from ..generic_dataset import CQA, context_question_len, token_batch_fn, default_batch_fn
|
||||
from ...data_utils.example import Example
|
||||
from ...data_utils.progbar import progress_bar
|
||||
from .utils import ISO_to_LANG, is_device, is_entity, process_id, is_cjk_char
|
||||
from .utils import ISO_to_LANG, is_device, is_entity, process_id, is_cjk_char, detokenize_cjk_chars
|
||||
from ...util import multiwoz_specific_preprocess, multiwoz_specific_postprocess
|
||||
|
||||
from ..base_dataset import Split
|
||||
|
@ -132,20 +132,6 @@ class BaseAlmondTask(BaseTask):
|
|||
|
||||
def get_splits(self, root, **kwargs):
|
||||
return AlmondDataset.return_splits(path=os.path.join(root, 'almond'), make_example=self._make_example, **kwargs)
|
||||
|
||||
def _detokenize_cjk_chars(self, sentence):
|
||||
output = []
|
||||
i = 0
|
||||
while i < len(sentence):
|
||||
output.append(sentence[i])
|
||||
# skip space after cjk chars only if followed by another cjk char
|
||||
if is_cjk_char(ord(sentence[i])) and \
|
||||
i+1 < len(sentence) and sentence[i+1] == ' ' and \
|
||||
i+2 < len(sentence) and is_cjk_char(ord(sentence[i+2])):
|
||||
i += 2
|
||||
else:
|
||||
i += 1
|
||||
return "".join(output)
|
||||
|
||||
def tokenize(self, sentence, field_name=None):
|
||||
if not sentence:
|
||||
|
@ -154,7 +140,7 @@ class BaseAlmondTask(BaseTask):
|
|||
if self.force_subword_tokenize:
|
||||
return sentence.split(' '), None
|
||||
|
||||
sentence = self._detokenize_cjk_chars(sentence)
|
||||
sentence = detokenize_cjk_chars(sentence)
|
||||
|
||||
if self._dataset_specific_preprocess == 'multiwoz' and self._is_program_field(field_name):
|
||||
sentence = multiwoz_specific_preprocess(sentence)
|
||||
|
|
|
@ -52,5 +52,34 @@ def process_id(ex):
|
|||
id_ = id_[1:]
|
||||
return id_
|
||||
|
||||
def detokenize_cjk_chars(sentence):
|
||||
output = []
|
||||
i = 0
|
||||
while i < len(sentence):
|
||||
output.append(sentence[i])
|
||||
# skip space after cjk chars only if followed by another cjk char
|
||||
if is_cjk_char(ord(sentence[i])) and \
|
||||
i+1 < len(sentence) and sentence[i+1] == ' ' and \
|
||||
i+2 < len(sentence) and is_cjk_char(ord(sentence[i+2])):
|
||||
i += 2
|
||||
else:
|
||||
i += 1
|
||||
|
||||
return "".join(output)
|
||||
|
||||
|
||||
def tokenize_cjk_chars(sentence):
|
||||
output = []
|
||||
i = 0
|
||||
while i < len(sentence):
|
||||
output.append(sentence[i])
|
||||
if is_cjk_char(ord(sentence[i])) and i+1 < len(sentence) and sentence[i+1] != ' ':
|
||||
output.append(' ')
|
||||
elif not is_cjk_char(ord(sentence[i])) and i + 1 < len(sentence) and is_cjk_char(ord(sentence[i + 1])):
|
||||
output.append(' ')
|
||||
i += 1
|
||||
|
||||
output = "".join(output)
|
||||
output = output.replace(' ', ' ')
|
||||
|
||||
return output
|
||||
|
|
Loading…
Reference in New Issue