Update ErrorClassificationTask

This commit is contained in:
mehrad 2022-07-18 15:32:51 -07:00
parent 0071321c12
commit 73075497f9
No known key found for this signature in database
GPG Key ID: AAF81F778210AE42
2 changed files with 48 additions and 74 deletions

View File

@ -48,45 +48,6 @@ class E2EDialogueDataset(CQA):
train=train_path, eval=validation_path, test=test_path
)
class AnnotationClassifierDataset(CQA):
class ErrorClassificationDataset(E2EDialogueDataset):
is_sequence_classification = True
def __init__(self, path, *, make_example, **kwargs):
subsample = kwargs.pop('subsample')
examples = []
with open(path) as fin:
data = ujson.load(fin)['data']
for turn in data:
processed = make_example(turn, train_target=kwargs.get('train_target', False))
if processed:
examples.append(processed)
if subsample is not None and len(examples) >= subsample:
break
super().__init__(examples, **kwargs)
# do not sort eval/ test set so we can compute individual scores for each subtask (e2e_dialogue_score)
self.eval_sort_key_fn = None
# in e2e evaluation use 1 batch at a time
if kwargs.get('e2e_evaluation', False):
self.eval_batch_size_fn = default_batch_fn
@classmethod
def return_splits(cls, path='.data', train='train', validation='valid', test='test', **kwargs):
train_path, validation_path, test_path = None, None, None
if train:
train_path = os.path.join(path, f'{train}.json')
if validation:
validation_path = os.path.join(path, f'{validation}.json')
if test:
test_path = os.path.join(path, 'test.json')
train_data = None if train is None else cls(train_path, **kwargs)
validation_data = None if validation is None else cls(validation_path, **kwargs)
test_data = None if test is None else cls(test_path, **kwargs)
return Split(train=train_data, eval=validation_data, test=test_data), Split(
train=train_path, eval=validation_path, test=test_path
)

View File

@ -1,6 +1,6 @@
from ..data_utils.example import Example
from .base_task import BaseTask
from .dialogue_dataset import E2EDialogueDataset, AnnotationClassifierDataset
from .dialogue_dataset import E2EDialogueDataset, ErrorClassificationDataset
from .registry import register_task
@ -34,38 +34,6 @@ class E2EDialogueTask(BaseTask):
kwargs['e2e_evaluation'] = self.args.e2e_dialogue_evaluation
return E2EDialogueDataset.return_splits(path=root, make_example=self._make_example, **kwargs)
@register_task('annotation_classifier')
class AnnotationClassiferTask(BaseTask):
def __init__(self, name, args):
self.id2label = ['positive', 'negative']
self.num_labels = 2
super().__init__(name, args)
self._metrics = ['sc_f1', 'sc_precision', 'sc_recall']
def utterance_field(self):
return 'context'
def _make_example(self, turn, **kwargs):
if 'type' not in turn: return None
dial_id, turn_id, input_text, output_text, train_target, type = (
turn['dial_id'],
turn['turn_id'],
turn['input_text'],
turn['output_text'],
turn['train_target'],
turn['type']
)
example_id = '/'.join([dial_id, str(turn_id), train_target])
return Example.from_raw(
self.name + '/' + str(example_id), input_text + '_' + output_text, '', ['0', '1'][type == 'positive'], preprocess=self.preprocess_field, lower=False
)
def get_splits(self, root, **kwargs):
kwargs['e2e_evaluation'] = self.args.e2e_dialogue_evaluation
return AnnotationClassifierDataset.return_splits(path=root, make_example=self._make_example, **kwargs)
@register_task('risawoz')
class RiSAWOZ(E2EDialogueTask):
@ -163,3 +131,48 @@ class BiTODDST(BiTOD):
kwargs['train_target'] = 'dst'
kwargs['e2e_evaluation'] = self.args.e2e_dialogue_evaluation
return E2EDialogueDataset.return_splits(path=root, make_example=self._make_example, **kwargs)
@register_task('error_cls')
class ErrorClassificationTask(BiTOD):
def __init__(self, name, args):
super().__init__(name, args)
self.label2id = {'negative': 0, 'positive': 1}
self.id2label = {v: k for k, v in self.label2id.items()}
self.num_labels = len(self.id2label)
self.special_tokens.update(['##'])
@property
def metrics(self):
return ['sc_f1', 'sc_precision', 'sc_recall']
def _make_example(self, turn, **kwargs):
if 'type' not in turn:
return None
dial_id, turn_id, input_text, output_text, train_target, type = (
turn['dial_id'],
turn['turn_id'],
turn['input_text'],
turn['output_text'],
turn['train_target'],
turn['type'],
)
answer = str(self.label2id[type])
example_id = '/'.join([dial_id, str(turn_id), train_target])
return Example.from_raw(
self.name + '/' + str(example_id),
input_text + ' ## ' + output_text,
'',
answer,
preprocess=self.preprocess_field,
lower=False,
)
def get_splits(self, root, **kwargs):
kwargs['e2e_evaluation'] = self.args.e2e_dialogue_evaluation
return ErrorClassificationDataset.return_splits(path=root, make_example=self._make_example, **kwargs)