Take care of cjk chars in run_generation and run_lm_finetuning
This commit is contained in:
parent
f0ff4d104f
commit
388c41deb8
|
@ -10,8 +10,7 @@ from ..data_utils.progbar import progress_bar
|
||||||
from ..util import detokenize, tokenize, lower_case, SpecialTokenMap, remove_thingtalk_quotes
|
from ..util import detokenize, tokenize, lower_case, SpecialTokenMap, remove_thingtalk_quotes
|
||||||
|
|
||||||
from genienlp.util import get_number_of_lines
|
from genienlp.util import get_number_of_lines
|
||||||
from ..tasks.almond.utils import is_entity, quoted_pattern_maybe_space, device_pattern
|
from ..tasks.almond.utils import is_entity, quoted_pattern_maybe_space, device_pattern, detokenize_cjk_chars
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -271,7 +270,8 @@ def create_features_from_tsv_file(file_path, tokenizer, input_column, gold_colum
|
||||||
# just make sure source language is used when tokenizing input sentence
|
# just make sure source language is used when tokenizing input sentence
|
||||||
# tokenizer takes care of adding language code at the end of the sentence
|
# tokenizer takes care of adding language code at the end of the sentence
|
||||||
tokenizer.cur_lang_code = tokenizer.lang_code_to_id[src_lang]
|
tokenizer.cur_lang_code = tokenizer.lang_code_to_id[src_lang]
|
||||||
|
|
||||||
|
input_sequence = detokenize_cjk_chars(input_sequence)
|
||||||
input_sequence_ids = tokenizer.encode(input_sequence, add_special_tokens=True)
|
input_sequence_ids = tokenizer.encode(input_sequence, add_special_tokens=True)
|
||||||
|
|
||||||
prompt_ids = [] # includes the first few tokens of the output
|
prompt_ids = [] # includes the first few tokens of the output
|
||||||
|
|
|
@ -9,6 +9,7 @@ from torch.nn.utils.rnn import pad_sequence
|
||||||
|
|
||||||
from ..util import get_number_of_lines
|
from ..util import get_number_of_lines
|
||||||
from ..data_utils.progbar import progress_bar
|
from ..data_utils.progbar import progress_bar
|
||||||
|
from ..tasks.almond.utils import detokenize_cjk_chars
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -88,6 +89,10 @@ class TextDataset(Dataset):
|
||||||
def _add_example(self, input_sequence, output_sequence, args):
|
def _add_example(self, input_sequence, output_sequence, args):
|
||||||
# TODO we should make use of tokenizer.build_inputs_with_special_tokens(sequence1, sequence2). Add special tokens manualy only if our model does not support two sequences (like GPT2).
|
# TODO we should make use of tokenizer.build_inputs_with_special_tokens(sequence1, sequence2). Add special tokens manualy only if our model does not support two sequences (like GPT2).
|
||||||
|
|
||||||
|
input_sequence = detokenize_cjk_chars(input_sequence)
|
||||||
|
if output_sequence is not None:
|
||||||
|
output_sequence = detokenize_cjk_chars(output_sequence)
|
||||||
|
|
||||||
input_token_ids = self.tokenizer.encode(input_sequence, add_special_tokens=False) + [self.tokenizer.convert_tokens_to_ids(args.start_special_token)]
|
input_token_ids = self.tokenizer.encode(input_sequence, add_special_tokens=False) + [self.tokenizer.convert_tokens_to_ids(args.start_special_token)]
|
||||||
if output_sequence is None:
|
if output_sequence is None:
|
||||||
output_token_ids = []
|
output_token_ids = []
|
||||||
|
@ -128,6 +133,10 @@ class TextDataset(Dataset):
|
||||||
|
|
||||||
def _add_seq2seq_example(self, input_sequence, output_sequence, args):
|
def _add_seq2seq_example(self, input_sequence, output_sequence, args):
|
||||||
|
|
||||||
|
input_sequence = detokenize_cjk_chars(input_sequence)
|
||||||
|
if output_sequence is not None:
|
||||||
|
output_sequence = detokenize_cjk_chars(output_sequence)
|
||||||
|
|
||||||
if args.model_type == 'mbart':
|
if args.model_type == 'mbart':
|
||||||
model_inputs = self.tokenizer.prepare_seq2seq_batch([input_sequence], args.src_lang, [output_sequence], args.tgt_lang)
|
model_inputs = self.tokenizer.prepare_seq2seq_batch([input_sequence], args.src_lang, [output_sequence], args.tgt_lang)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -63,7 +63,7 @@ def check_args(args):
|
||||||
'you have to specify the --src_lang flag.')
|
'you have to specify the --src_lang flag.')
|
||||||
elif args.src_lang not in MARIAN_GROUP_MEMBERS[args.model_name_or_path.rsplit('-', 2)[1]]:
|
elif args.src_lang not in MARIAN_GROUP_MEMBERS[args.model_name_or_path.rsplit('-', 2)[1]]:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'Dource language is not in the model group languages, please specify the correct source language.')
|
'Source language is not in the model group languages, please specify the correct source language.')
|
||||||
|
|
||||||
if args.model_type == 'marian' and args.model_name_or_path.rsplit('-', 1)[1] not in MARIAN_GROUP_MEMBERS and args.tgt_lang:
|
if args.model_type == 'marian' and args.model_name_or_path.rsplit('-', 1)[1] not in MARIAN_GROUP_MEMBERS and args.tgt_lang:
|
||||||
logger.warning('Target language should not be provided when using models with single language pairs,'
|
logger.warning('Target language should not be provided when using models with single language pairs,'
|
||||||
|
|
|
@ -34,6 +34,7 @@ from torch.multiprocessing import Process, set_start_method
|
||||||
|
|
||||||
from genienlp.paraphrase.data_utils import create_features_from_tsv_file, output_heuristics
|
from genienlp.paraphrase.data_utils import create_features_from_tsv_file, output_heuristics
|
||||||
from genienlp.paraphrase.model_utils import compute_metrics, compute_attention, replace_quoted_params, force_replace_quoted_params
|
from genienlp.paraphrase.model_utils import compute_metrics, compute_attention, replace_quoted_params, force_replace_quoted_params
|
||||||
|
from ..tasks.almond.utils import tokenize_cjk_chars
|
||||||
|
|
||||||
try:
|
try:
|
||||||
set_start_method('spawn')
|
set_start_method('spawn')
|
||||||
|
@ -496,6 +497,8 @@ def run_single_process_generation(args, config):
|
||||||
text = re.sub('\s\s+', ' ', text) # remove duplicate white spaces
|
text = re.sub('\s\s+', ' ', text) # remove duplicate white spaces
|
||||||
text = text.strip()
|
text = text.strip()
|
||||||
|
|
||||||
|
text = tokenize_cjk_chars(text)
|
||||||
|
|
||||||
if not args.skip_heuristics:
|
if not args.skip_heuristics:
|
||||||
text = output_heuristics(text, batch_reverse_maps[sample_index])
|
text = output_heuristics(text, batch_reverse_maps[sample_index])
|
||||||
batch_outputs[sample_index].append(text)
|
batch_outputs[sample_index].append(text)
|
||||||
|
|
|
@ -37,7 +37,7 @@ from ..registry import register_task
|
||||||
from ..generic_dataset import CQA, context_question_len, token_batch_fn, default_batch_fn
|
from ..generic_dataset import CQA, context_question_len, token_batch_fn, default_batch_fn
|
||||||
from ...data_utils.example import Example
|
from ...data_utils.example import Example
|
||||||
from ...data_utils.progbar import progress_bar
|
from ...data_utils.progbar import progress_bar
|
||||||
from .utils import ISO_to_LANG, is_device, is_entity, process_id, is_cjk_char
|
from .utils import ISO_to_LANG, is_device, is_entity, process_id, is_cjk_char, detokenize_cjk_chars
|
||||||
from ...util import multiwoz_specific_preprocess, multiwoz_specific_postprocess
|
from ...util import multiwoz_specific_preprocess, multiwoz_specific_postprocess
|
||||||
|
|
||||||
from ..base_dataset import Split
|
from ..base_dataset import Split
|
||||||
|
@ -132,20 +132,6 @@ class BaseAlmondTask(BaseTask):
|
||||||
|
|
||||||
def get_splits(self, root, **kwargs):
|
def get_splits(self, root, **kwargs):
|
||||||
return AlmondDataset.return_splits(path=os.path.join(root, 'almond'), make_example=self._make_example, **kwargs)
|
return AlmondDataset.return_splits(path=os.path.join(root, 'almond'), make_example=self._make_example, **kwargs)
|
||||||
|
|
||||||
def _detokenize_cjk_chars(self, sentence):
|
|
||||||
output = []
|
|
||||||
i = 0
|
|
||||||
while i < len(sentence):
|
|
||||||
output.append(sentence[i])
|
|
||||||
# skip space after cjk chars only if followed by another cjk char
|
|
||||||
if is_cjk_char(ord(sentence[i])) and \
|
|
||||||
i+1 < len(sentence) and sentence[i+1] == ' ' and \
|
|
||||||
i+2 < len(sentence) and is_cjk_char(ord(sentence[i+2])):
|
|
||||||
i += 2
|
|
||||||
else:
|
|
||||||
i += 1
|
|
||||||
return "".join(output)
|
|
||||||
|
|
||||||
def tokenize(self, sentence, field_name=None):
|
def tokenize(self, sentence, field_name=None):
|
||||||
if not sentence:
|
if not sentence:
|
||||||
|
@ -154,7 +140,7 @@ class BaseAlmondTask(BaseTask):
|
||||||
if self.force_subword_tokenize:
|
if self.force_subword_tokenize:
|
||||||
return sentence.split(' '), None
|
return sentence.split(' '), None
|
||||||
|
|
||||||
sentence = self._detokenize_cjk_chars(sentence)
|
sentence = detokenize_cjk_chars(sentence)
|
||||||
|
|
||||||
if self._dataset_specific_preprocess == 'multiwoz' and self._is_program_field(field_name):
|
if self._dataset_specific_preprocess == 'multiwoz' and self._is_program_field(field_name):
|
||||||
sentence = multiwoz_specific_preprocess(sentence)
|
sentence = multiwoz_specific_preprocess(sentence)
|
||||||
|
|
|
@ -52,5 +52,34 @@ def process_id(ex):
|
||||||
id_ = id_[1:]
|
id_ = id_[1:]
|
||||||
return id_
|
return id_
|
||||||
|
|
||||||
|
def detokenize_cjk_chars(sentence):
|
||||||
|
output = []
|
||||||
|
i = 0
|
||||||
|
while i < len(sentence):
|
||||||
|
output.append(sentence[i])
|
||||||
|
# skip space after cjk chars only if followed by another cjk char
|
||||||
|
if is_cjk_char(ord(sentence[i])) and \
|
||||||
|
i+1 < len(sentence) and sentence[i+1] == ' ' and \
|
||||||
|
i+2 < len(sentence) and is_cjk_char(ord(sentence[i+2])):
|
||||||
|
i += 2
|
||||||
|
else:
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
return "".join(output)
|
||||||
|
|
||||||
|
|
||||||
|
def tokenize_cjk_chars(sentence):
|
||||||
|
output = []
|
||||||
|
i = 0
|
||||||
|
while i < len(sentence):
|
||||||
|
output.append(sentence[i])
|
||||||
|
if is_cjk_char(ord(sentence[i])) and i+1 < len(sentence) and sentence[i+1] != ' ':
|
||||||
|
output.append(' ')
|
||||||
|
elif not is_cjk_char(ord(sentence[i])) and i + 1 < len(sentence) and is_cjk_char(ord(sentence[i + 1])):
|
||||||
|
output.append(' ')
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
output = "".join(output)
|
||||||
|
output = output.replace(' ', ' ')
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
Loading…
Reference in New Issue