diff --git a/genienlp/tasks/dialogue_dataset.py b/genienlp/tasks/dialogue_dataset.py index 0fc0eeee..c88944cd 100644 --- a/genienlp/tasks/dialogue_dataset.py +++ b/genienlp/tasks/dialogue_dataset.py @@ -48,45 +48,6 @@ class E2EDialogueDataset(CQA): train=train_path, eval=validation_path, test=test_path ) -class AnnotationClassifierDataset(CQA): + +class ErrorClassificationDataset(E2EDialogueDataset): is_sequence_classification = True - def __init__(self, path, *, make_example, **kwargs): - subsample = kwargs.pop('subsample') - examples = [] - - with open(path) as fin: - data = ujson.load(fin)['data'] - for turn in data: - processed = make_example(turn, train_target=kwargs.get('train_target', False)) - if processed: - examples.append(processed) - - if subsample is not None and len(examples) >= subsample: - break - - super().__init__(examples, **kwargs) - - # do not sort eval/ test set so we can compute individual scores for each subtask (e2e_dialogue_score) - self.eval_sort_key_fn = None - - # in e2e evaluation use 1 batch at a time - if kwargs.get('e2e_evaluation', False): - self.eval_batch_size_fn = default_batch_fn - - @classmethod - def return_splits(cls, path='.data', train='train', validation='valid', test='test', **kwargs): - train_path, validation_path, test_path = None, None, None - if train: - train_path = os.path.join(path, f'{train}.json') - if validation: - validation_path = os.path.join(path, f'{validation}.json') - if test: - test_path = os.path.join(path, 'test.json') - - train_data = None if train is None else cls(train_path, **kwargs) - validation_data = None if validation is None else cls(validation_path, **kwargs) - test_data = None if test is None else cls(test_path, **kwargs) - - return Split(train=train_data, eval=validation_data, test=test_data), Split( - train=train_path, eval=validation_path, test=test_path - ) diff --git a/genienlp/tasks/dialogue_task.py b/genienlp/tasks/dialogue_task.py index c5d56ad5..52bc5dd5 100644 --- a/genienlp/tasks/dialogue_task.py +++ b/genienlp/tasks/dialogue_task.py @@ -1,6 +1,6 @@ from ..data_utils.example import Example from .base_task import BaseTask -from .dialogue_dataset import E2EDialogueDataset, AnnotationClassifierDataset +from .dialogue_dataset import E2EDialogueDataset, ErrorClassificationDataset from .registry import register_task @@ -34,38 +34,6 @@ class E2EDialogueTask(BaseTask): kwargs['e2e_evaluation'] = self.args.e2e_dialogue_evaluation return E2EDialogueDataset.return_splits(path=root, make_example=self._make_example, **kwargs) -@register_task('annotation_classifier') -class AnnotationClassiferTask(BaseTask): - def __init__(self, name, args): - self.id2label = ['positive', 'negative'] - self.num_labels = 2 - super().__init__(name, args) - self._metrics = ['sc_f1', 'sc_precision', 'sc_recall'] - - def utterance_field(self): - return 'context' - - def _make_example(self, turn, **kwargs): - if 'type' not in turn: return None - dial_id, turn_id, input_text, output_text, train_target, type = ( - turn['dial_id'], - turn['turn_id'], - turn['input_text'], - turn['output_text'], - turn['train_target'], - turn['type'] - ) - - example_id = '/'.join([dial_id, str(turn_id), train_target]) - - return Example.from_raw( - self.name + '/' + str(example_id), input_text + '_' + output_text, '', ['0', '1'][type == 'positive'], preprocess=self.preprocess_field, lower=False - ) - - def get_splits(self, root, **kwargs): - kwargs['e2e_evaluation'] = self.args.e2e_dialogue_evaluation - return AnnotationClassifierDataset.return_splits(path=root, make_example=self._make_example, **kwargs) - @register_task('risawoz') class RiSAWOZ(E2EDialogueTask): @@ -163,3 +131,48 @@ class BiTODDST(BiTOD): kwargs['train_target'] = 'dst' kwargs['e2e_evaluation'] = self.args.e2e_dialogue_evaluation return E2EDialogueDataset.return_splits(path=root, make_example=self._make_example, **kwargs) + + +@register_task('error_cls') +class ErrorClassificationTask(BiTOD): + def __init__(self, name, args): + super().__init__(name, args) + + self.label2id = {'negative': 0, 'positive': 1} + self.id2label = {v: k for k, v in self.label2id.items()} + self.num_labels = len(self.id2label) + + self.special_tokens.update(['##']) + + @property + def metrics(self): + return ['sc_f1', 'sc_precision', 'sc_recall'] + + def _make_example(self, turn, **kwargs): + if 'type' not in turn: + return None + dial_id, turn_id, input_text, output_text, train_target, type = ( + turn['dial_id'], + turn['turn_id'], + turn['input_text'], + turn['output_text'], + turn['train_target'], + turn['type'], + ) + + answer = str(self.label2id[type]) + + example_id = '/'.join([dial_id, str(turn_id), train_target]) + + return Example.from_raw( + self.name + '/' + str(example_id), + input_text + ' ## ' + output_text, + '', + answer, + preprocess=self.preprocess_field, + lower=False, + ) + + def get_splits(self, root, **kwargs): + kwargs['e2e_evaluation'] = self.args.e2e_dialogue_evaluation + return ErrorClassificationDataset.return_splits(path=root, make_example=self._make_example, **kwargs)