Update ErrorClassificationTask
This commit is contained in:
parent
0071321c12
commit
73075497f9
|
@ -48,45 +48,6 @@ class E2EDialogueDataset(CQA):
|
|||
train=train_path, eval=validation_path, test=test_path
|
||||
)
|
||||
|
||||
class AnnotationClassifierDataset(CQA):
|
||||
|
||||
class ErrorClassificationDataset(E2EDialogueDataset):
|
||||
is_sequence_classification = True
|
||||
def __init__(self, path, *, make_example, **kwargs):
|
||||
subsample = kwargs.pop('subsample')
|
||||
examples = []
|
||||
|
||||
with open(path) as fin:
|
||||
data = ujson.load(fin)['data']
|
||||
for turn in data:
|
||||
processed = make_example(turn, train_target=kwargs.get('train_target', False))
|
||||
if processed:
|
||||
examples.append(processed)
|
||||
|
||||
if subsample is not None and len(examples) >= subsample:
|
||||
break
|
||||
|
||||
super().__init__(examples, **kwargs)
|
||||
|
||||
# do not sort eval/ test set so we can compute individual scores for each subtask (e2e_dialogue_score)
|
||||
self.eval_sort_key_fn = None
|
||||
|
||||
# in e2e evaluation use 1 batch at a time
|
||||
if kwargs.get('e2e_evaluation', False):
|
||||
self.eval_batch_size_fn = default_batch_fn
|
||||
|
||||
@classmethod
|
||||
def return_splits(cls, path='.data', train='train', validation='valid', test='test', **kwargs):
|
||||
train_path, validation_path, test_path = None, None, None
|
||||
if train:
|
||||
train_path = os.path.join(path, f'{train}.json')
|
||||
if validation:
|
||||
validation_path = os.path.join(path, f'{validation}.json')
|
||||
if test:
|
||||
test_path = os.path.join(path, 'test.json')
|
||||
|
||||
train_data = None if train is None else cls(train_path, **kwargs)
|
||||
validation_data = None if validation is None else cls(validation_path, **kwargs)
|
||||
test_data = None if test is None else cls(test_path, **kwargs)
|
||||
|
||||
return Split(train=train_data, eval=validation_data, test=test_data), Split(
|
||||
train=train_path, eval=validation_path, test=test_path
|
||||
)
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from ..data_utils.example import Example
|
||||
from .base_task import BaseTask
|
||||
from .dialogue_dataset import E2EDialogueDataset, AnnotationClassifierDataset
|
||||
from .dialogue_dataset import E2EDialogueDataset, ErrorClassificationDataset
|
||||
from .registry import register_task
|
||||
|
||||
|
||||
|
@ -34,38 +34,6 @@ class E2EDialogueTask(BaseTask):
|
|||
kwargs['e2e_evaluation'] = self.args.e2e_dialogue_evaluation
|
||||
return E2EDialogueDataset.return_splits(path=root, make_example=self._make_example, **kwargs)
|
||||
|
||||
@register_task('annotation_classifier')
|
||||
class AnnotationClassiferTask(BaseTask):
|
||||
def __init__(self, name, args):
|
||||
self.id2label = ['positive', 'negative']
|
||||
self.num_labels = 2
|
||||
super().__init__(name, args)
|
||||
self._metrics = ['sc_f1', 'sc_precision', 'sc_recall']
|
||||
|
||||
def utterance_field(self):
|
||||
return 'context'
|
||||
|
||||
def _make_example(self, turn, **kwargs):
|
||||
if 'type' not in turn: return None
|
||||
dial_id, turn_id, input_text, output_text, train_target, type = (
|
||||
turn['dial_id'],
|
||||
turn['turn_id'],
|
||||
turn['input_text'],
|
||||
turn['output_text'],
|
||||
turn['train_target'],
|
||||
turn['type']
|
||||
)
|
||||
|
||||
example_id = '/'.join([dial_id, str(turn_id), train_target])
|
||||
|
||||
return Example.from_raw(
|
||||
self.name + '/' + str(example_id), input_text + '_' + output_text, '', ['0', '1'][type == 'positive'], preprocess=self.preprocess_field, lower=False
|
||||
)
|
||||
|
||||
def get_splits(self, root, **kwargs):
|
||||
kwargs['e2e_evaluation'] = self.args.e2e_dialogue_evaluation
|
||||
return AnnotationClassifierDataset.return_splits(path=root, make_example=self._make_example, **kwargs)
|
||||
|
||||
|
||||
@register_task('risawoz')
|
||||
class RiSAWOZ(E2EDialogueTask):
|
||||
|
@ -163,3 +131,48 @@ class BiTODDST(BiTOD):
|
|||
kwargs['train_target'] = 'dst'
|
||||
kwargs['e2e_evaluation'] = self.args.e2e_dialogue_evaluation
|
||||
return E2EDialogueDataset.return_splits(path=root, make_example=self._make_example, **kwargs)
|
||||
|
||||
|
||||
@register_task('error_cls')
|
||||
class ErrorClassificationTask(BiTOD):
|
||||
def __init__(self, name, args):
|
||||
super().__init__(name, args)
|
||||
|
||||
self.label2id = {'negative': 0, 'positive': 1}
|
||||
self.id2label = {v: k for k, v in self.label2id.items()}
|
||||
self.num_labels = len(self.id2label)
|
||||
|
||||
self.special_tokens.update(['##'])
|
||||
|
||||
@property
|
||||
def metrics(self):
|
||||
return ['sc_f1', 'sc_precision', 'sc_recall']
|
||||
|
||||
def _make_example(self, turn, **kwargs):
|
||||
if 'type' not in turn:
|
||||
return None
|
||||
dial_id, turn_id, input_text, output_text, train_target, type = (
|
||||
turn['dial_id'],
|
||||
turn['turn_id'],
|
||||
turn['input_text'],
|
||||
turn['output_text'],
|
||||
turn['train_target'],
|
||||
turn['type'],
|
||||
)
|
||||
|
||||
answer = str(self.label2id[type])
|
||||
|
||||
example_id = '/'.join([dial_id, str(turn_id), train_target])
|
||||
|
||||
return Example.from_raw(
|
||||
self.name + '/' + str(example_id),
|
||||
input_text + ' ## ' + output_text,
|
||||
'',
|
||||
answer,
|
||||
preprocess=self.preprocess_field,
|
||||
lower=False,
|
||||
)
|
||||
|
||||
def get_splits(self, root, **kwargs):
|
||||
kwargs['e2e_evaluation'] = self.args.e2e_dialogue_evaluation
|
||||
return ErrorClassificationDataset.return_splits(path=root, make_example=self._make_example, **kwargs)
|
||||
|
|
Loading…
Reference in New Issue