# # Copyright (c) 2018, Salesforce, Inc. # The Board of Trustees of the Leland Stanford Junior University # All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are met: # # * Redistributions of source code must retain the above copyright notice, this # list of conditions and the following disclaimer. # # * Redistributions in binary form must reproduce the above copyright notice, # this list of conditions and the following disclaimer in the documentation # and/or other materials provided with the distribution. # # * Neither the name of the copyright holder nor the names of its # contributors may be used to endorse or promote products derived from # this software without specific prior written permission. # # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ''' Created on April 2, 2019 author: mehrad ''' import itertools from ..util import set_seed from decanlp.utils.model_utils import init_model from .common import * import logging _logger = logging.getLogger(__name__) class MultiLingualTranslationModel(nn.Module): def __init__(self, field, args): super().__init__() self.field = field self.args = args self.device = set_seed(args) if args.use_google_translate: model_names = ['semantic_parser', 'thingtalk_machine_translation'] else: model_names = ['machine_translation', 'semantic_parser', 'thingtalk_machine_translation'] for name in model_names: model_dict = torch.load(os.path.join(args.saved_models, f'{name}/best.pth'), map_location=self.device) model_field = model_dict['field'] model = init_model(args, model_field, _logger, args.world_size, self.device, model_name='MultitaskQuestionAnsweringNetwork') model.load_state_dict(model_dict['model_state_dict']) # model_opt_dict = torch.load(os.path.join(args.saved_models, f'{model}/best_optim.pth'), map_location=self.device) # model_opt = init_opt(args, model) # model_opt.load_state_dict(model_opt_dict) setattr(self, f'{name}_field', model_field) setattr(self, f'{name}_model', model) # setattr(self, f'{name}_opt', model_opt) def forward(self, batch, iteration): context, context_lengths, context_limited, context_tokens = batch.context, batch.context_lengths, batch.context_limited, batch.context_tokens question, question_lengths, question_limited, question_tokens = batch.question, batch.question_lengths, batch.question_limited, batch.question_tokens answer, answer_lengths, answer_limited, answer_tokens = batch.answer, batch.answer_lengths, batch.answer_limited, batch.answer_tokens oov_to_limited_idx, limited_idx_to_full_idx = batch.oov_to_limited_idx, batch.limited_idx_to_full_idx setattr(batch, 'context_limited', None) setattr(batch, 'question_limited', None) setattr(batch, 'answer_limited', None) original_context = context original_question = question batch_size = len(question_tokens) # running through Farsi-English MT #### TODO fix this later if not self.args.use_google_translate: # question = 'Translate from Farsi to English'.split() question = 'Translate from English to ThingTalk'.split() question_tokens_single = [self.machine_translation_field.init_token] + question + [self.machine_translation_field.eos_token] question_tokens = list(itertools.repeat(question_tokens_single, batch_size)) context_numericalized = [[self.machine_translation_field.decoder_stoi.get(sentence[time], self.machine_translation_field.decoder_stoi['']) for sentence in context_tokens] for time in range(len(context_tokens[0]))] question_numericalized = [[self.machine_translation_field.decoder_stoi[sentence[time]] for sentence in question_tokens] for time in range(len(question_tokens[0]))] context = torch.tensor(context_numericalized, dtype=torch.long, device=self.device).t_() question = torch.tensor(question_numericalized, dtype=torch.long, device=self.device).t_() setattr(batch, 'context', context) setattr(batch, 'question', question) self.machine_translation_model.eval() _, greedy_output_from_MT = self.machine_translation_model(batch, iteration) task = self.machine_translation_model.args.train_tasks[0] if hasattr(self.machine_translation_model.args, 'train_tasks') else self.machine_translation_model.args.tasks[0] greedy_output_from_MT_tokens = self.machine_translation_field.reverse(greedy_output_from_MT, detokenize=task.detokenize, field_name='answer') # numericalize input to semantic parser question = 'Translate from English to ThingTalk'.split() question_tokens_single = [self.semantic_parser_field.init_token] + question + [self.semantic_parser_field.eos_token] question_tokens = list(itertools.repeat(question_tokens_single, batch_size)) if not self.args.use_google_translate: context_tokens = [[self.semantic_parser_field.init_token] + tokens.split(' ') + [self.semantic_parser_field.eos_token] for tokens in greedy_output_from_MT_tokens] context_lengths = [len(sublist) for sublist in context_tokens] max_length = max(context_lengths) context_tokens = [tokens + [self.semantic_parser_field.pad_token] * (max_length - len(tokens)) for tokens in context_tokens] context_numericalized = [[self.semantic_parser_field.decoder_stoi.get(sentence[time], self.semantic_parser_field.decoder_stoi['']) for sentence in context_tokens] for time in range(len(context_tokens[0]))] question_numericalized = [[self.semantic_parser_field.decoder_stoi[sentence[time]] for sentence in question_tokens] for time in range(len(question_tokens[0]))] context = torch.tensor(context_numericalized, dtype=torch.long, device=self.device).t_() question = torch.tensor(question_numericalized, dtype=torch.long, device=self.device).t_() setattr(batch, 'context_tokens', context_tokens) setattr(batch, 'context', context) setattr(batch, 'context_lengths', context_lengths) setattr(batch, 'question', question) self.semantic_parser_model.eval() _, greedy_output_from_SP = self.semantic_parser_model(batch, iteration) task = self.semantic_parser_model.args.train_tasks[0] if hasattr(self.semantic_parser_model.args, 'train_tasks') else self.semantic_parser_model.args.tasks[0] greedy_output_from_SP_tokens = self.semantic_parser_field.reverse(greedy_output_from_SP, detokenize=task.detokenize, field_name='answer') # numericalize input to thingtalk machine translation #### TODO fix this later # question = 'Translate from English to Farsi'.split() question = 'Translate from English to ThingTalk'.split() question_tokens_single = [self.semantic_parser_field.init_token] + question + [self.semantic_parser_field.eos_token] question_tokens = list(itertools.repeat(question_tokens_single, batch_size)) context_tokens = [[self.semantic_parser_field.init_token] + tokens.split(' ') + [self.semantic_parser_field.eos_token] for tokens in greedy_output_from_SP_tokens] context_lengths = [len(sublist) for sublist in context_tokens] max_length = max(context_lengths) context_tokens = [tokens + [self.thingtalk_machine_translation_field.pad_token] * (max_length - len(tokens)) for tokens in context_tokens] context_numericalized = [[self.thingtalk_machine_translation_field.decoder_stoi.get(sentence[time], self.thingtalk_machine_translation_field.decoder_stoi['']) for sentence in context_tokens] for time in range(len(context_tokens[0]))] question_numericalized = [[self.thingtalk_machine_translation_field.decoder_stoi[sentence[time]] for sentence in question_tokens] for time in range(len(question_tokens[0]))] ## TODO fix this answer_numericalized = [[self.field.decoder_stoi.get(sentence[time], self.field.decoder_stoi['']) for sentence in answer_tokens] for time in range(len(answer_tokens[0]))] context = torch.tensor(context_numericalized, dtype=torch.long, device=self.device).t_() question = torch.tensor(question_numericalized, dtype=torch.long, device=self.device).t_() answer = torch.tensor(answer_numericalized, dtype=torch.long, device=self.device).t_() setattr(batch, 'context_tokens', context_tokens) setattr(batch, 'context', context) setattr(batch, 'context_lengths', context_lengths) setattr(batch, 'question', question) setattr(batch, 'answer', answer) if self.training: self.thingtalk_machine_translation_model.train() loss, _ = self.thingtalk_machine_translation_model(batch, iteration) return loss, None else: self.thingtalk_machine_translation_model.eval() _, greedy_output_from_thingtalk_MT = self.thingtalk_machine_translation_model(batch, iteration) task = self.thingtalk_machine_translation_model.args.train_tasks[0] if hasattr(self.thingtalk_machine_translation_model.args, 'train_tasks') else self.thingtalk_machine_translation_model.args.tasks[0] greedy_output_from_thingtalk_MT_tokens = self.thingtalk_machine_translation_field.reverse(greedy_output_from_thingtalk_MT, detokenize=task.detokenize, field_name='answer') print(f'**** greedy_output_from_thingtalk_MT_tokens: {greedy_output_from_thingtalk_MT_tokens[0]} ***') context_tokens = [[self.thingtalk_machine_translation_field.init_token] + tokens.split(' ') + [self.thingtalk_machine_translation_field.eos_token] for tokens in greedy_output_from_thingtalk_MT_tokens] context_lengths = [len(sublist) for sublist in context_tokens] max_length = max(context_lengths) context_tokens = [tokens + [self.thingtalk_machine_translation_field.pad_token] * (max_length - len(tokens)) for tokens in context_tokens] prediction_tokens = context_tokens prediction_numericalized = [[self.field.decoder_stoi.get(sentence[time], self.field.decoder_stoi['']) for sentence in prediction_tokens] for time in range(len(prediction_tokens[0]))] prediction = torch.tensor(prediction_numericalized, dtype=torch.long, device=self.device).t_() setattr(batch, 'context', original_context) setattr(batch, 'question', original_question) return None, prediction