batch_postprocess_prediction_ids should return target_ids

This commit is contained in:
mehrad 2021-03-11 11:33:43 -08:00
parent 096af96e3f
commit 5f645e5c9c
3 changed files with 7 additions and 8 deletions

View File

@ -265,8 +265,8 @@ class BaseAlmondTask(BaseTask):
token_freqs = [[default_val] * default_size] * len(tokens)
return token_freqs
def batch_postprocess_prediction_ids(self, batch_example_ids, batch_src_ids, **kwargs):
return batch_src_ids
def batch_postprocess_prediction_ids(self, batch_example_ids, batch_src_ids, batch_tgt_ids, **kwargs):
return batch_tgt_ids
def postprocess_prediction(self, example_id, prediction):
@ -566,9 +566,8 @@ class Translate(NaturalSeq2Seq):
def postprocess_prediction(self, example_id, prediction):
return super().postprocess_prediction(example_id, prediction)
def batch_postprocess_prediction_ids(self, batch_example_ids, batch_src_ids, **kwargs):
def batch_postprocess_prediction_ids(self, batch_example_ids, batch_src_ids, batch_tgt_ids, **kwargs):
batch_tgt_ids = kwargs.pop('batch_tgt_ids')
numericalizer = kwargs.pop('numericalizer')
cross_attentions = kwargs.pop('cross_attentions')

View File

@ -69,8 +69,8 @@ class BaseTask:
"""
return generic_dataset.JSON.splits(root=root, name=self.name, **kwargs)
def batch_postprocess_prediction_ids(self, batch_example_ids, batch_src_ids, **kwargs):
return batch_src_ids
def batch_postprocess_prediction_ids(self, batch_example_ids, batch_src_ids, batch_tgt_ids, **kwargs):
return batch_tgt_ids
def postprocess_prediction(self, example_id, prediction):
return prediction

View File

@ -111,8 +111,8 @@ def generate_with_model(model, data_iterator, numericalizer, task, args,
cross_attentions = cross_attentions[-1, ...]
# postprocess prediction ids
kwargs = {'batch_tgt_ids': partial_batch_prediction_ids, 'numericalizer': numericalizer, 'cross_attentions': cross_attentions}
partial_batch_prediction_ids = task.batch_postprocess_prediction_ids(batch_example_ids, batch.context.value.data, **kwargs)
kwargs = {'numericalizer': numericalizer, 'cross_attentions': cross_attentions}
partial_batch_prediction_ids = task.batch_postprocess_prediction_ids(batch_example_ids, batch.context.value.data, partial_batch_prediction_ids, **kwargs)
if output_confidence_features or output_confidence_scores:
partial_batch_confidence_features = model.confidence_features(batch=batch, predictions=partial_batch_prediction_ids, mc_dropout_num=args.mc_dropout_num)