batch_postprocess_prediction_ids should return target_ids
This commit is contained in:
parent
096af96e3f
commit
5f645e5c9c
|
@ -265,8 +265,8 @@ class BaseAlmondTask(BaseTask):
|
|||
token_freqs = [[default_val] * default_size] * len(tokens)
|
||||
return token_freqs
|
||||
|
||||
def batch_postprocess_prediction_ids(self, batch_example_ids, batch_src_ids, **kwargs):
|
||||
return batch_src_ids
|
||||
def batch_postprocess_prediction_ids(self, batch_example_ids, batch_src_ids, batch_tgt_ids, **kwargs):
|
||||
return batch_tgt_ids
|
||||
|
||||
def postprocess_prediction(self, example_id, prediction):
|
||||
|
||||
|
@ -566,9 +566,8 @@ class Translate(NaturalSeq2Seq):
|
|||
def postprocess_prediction(self, example_id, prediction):
|
||||
return super().postprocess_prediction(example_id, prediction)
|
||||
|
||||
def batch_postprocess_prediction_ids(self, batch_example_ids, batch_src_ids, **kwargs):
|
||||
def batch_postprocess_prediction_ids(self, batch_example_ids, batch_src_ids, batch_tgt_ids, **kwargs):
|
||||
|
||||
batch_tgt_ids = kwargs.pop('batch_tgt_ids')
|
||||
numericalizer = kwargs.pop('numericalizer')
|
||||
cross_attentions = kwargs.pop('cross_attentions')
|
||||
|
||||
|
|
|
@ -69,8 +69,8 @@ class BaseTask:
|
|||
"""
|
||||
return generic_dataset.JSON.splits(root=root, name=self.name, **kwargs)
|
||||
|
||||
def batch_postprocess_prediction_ids(self, batch_example_ids, batch_src_ids, **kwargs):
|
||||
return batch_src_ids
|
||||
def batch_postprocess_prediction_ids(self, batch_example_ids, batch_src_ids, batch_tgt_ids, **kwargs):
|
||||
return batch_tgt_ids
|
||||
|
||||
def postprocess_prediction(self, example_id, prediction):
|
||||
return prediction
|
||||
|
|
|
@ -111,8 +111,8 @@ def generate_with_model(model, data_iterator, numericalizer, task, args,
|
|||
cross_attentions = cross_attentions[-1, ...]
|
||||
|
||||
# postprocess prediction ids
|
||||
kwargs = {'batch_tgt_ids': partial_batch_prediction_ids, 'numericalizer': numericalizer, 'cross_attentions': cross_attentions}
|
||||
partial_batch_prediction_ids = task.batch_postprocess_prediction_ids(batch_example_ids, batch.context.value.data, **kwargs)
|
||||
kwargs = {'numericalizer': numericalizer, 'cross_attentions': cross_attentions}
|
||||
partial_batch_prediction_ids = task.batch_postprocess_prediction_ids(batch_example_ids, batch.context.value.data, partial_batch_prediction_ids, **kwargs)
|
||||
|
||||
if output_confidence_features or output_confidence_scores:
|
||||
partial_batch_confidence_features = model.confidence_features(batch=batch, predictions=partial_batch_prediction_ids, mc_dropout_num=args.mc_dropout_num)
|
||||
|
|
Loading…
Reference in New Issue