genienlp/text/test/data/test_dataset.py

104 lines
4.9 KiB
Python

# -*- coding: utf-8 -*-
from __future__ import unicode_literals
import torchtext.data as data
from ..common.torchtext_test_case import TorchtextTestCase
class TestDataset(TorchtextTestCase):
def test_tabular_simple_data(self):
for data_format in ["csv", "tsv", "json"]:
self.write_test_ppid_dataset(data_format=data_format)
if data_format == "json":
question_field = data.Field(sequential=True)
label_field = data.Field(sequential=False)
fields = {"question1": ("q1", question_field),
"question2": ("q2", question_field),
"label": ("label", label_field)}
else:
question_field = data.Field(sequential=True)
label_field = data.Field(sequential=False)
fields = [("id", None), ("q1", question_field),
("q2", question_field), ("label", label_field)]
dataset = data.TabularDataset(
path=self.test_ppid_dataset_path, format=data_format, fields=fields)
assert len(dataset) == 3
expected_examples = [
(["When", "do", "you", "use", "", "instead", "of", "し?"],
["When", "do", "you", "use", "\"&\"",
"instead", "of", "\"and\"?"], "0"),
(["Where", "was", "Lincoln", "born?"],
["Which", "location", "was", "Abraham", "Lincoln", "born?"], "1"),
(["What", "is", "2+2"], ["2+2=?"], "1")]
# Ensure examples have correct contents / test __getitem__
for i in range(len(dataset)):
self.assertEqual(dataset[i].q1, expected_examples[i][0])
self.assertEqual(dataset[i].q2, expected_examples[i][1])
self.assertEqual(dataset[i].label, expected_examples[i][2])
# Test __getattr__
for i, (q1, q2, label) in enumerate(zip(dataset.q1, dataset.q2,
dataset.label)):
self.assertEqual(q1, expected_examples[i][0])
self.assertEqual(q2, expected_examples[i][1])
self.assertEqual(label, expected_examples[i][2])
# Test __iter__
for i, example in enumerate(dataset):
self.assertEqual(example.q1, expected_examples[i][0])
self.assertEqual(example.q2, expected_examples[i][1])
self.assertEqual(example.label, expected_examples[i][2])
def test_json_dataset_one_key_multiple_fields(self):
self.write_test_ppid_dataset(data_format="json")
question_field = data.Field(sequential=True)
spacy_tok_question_field = data.Field(sequential=True, tokenize="spacy")
label_field = data.Field(sequential=False)
fields = {"question1": [("q1", question_field),
("q1_spacy", spacy_tok_question_field)],
"question2": [("q2", question_field),
("q2_spacy", spacy_tok_question_field)],
"label": ("label", label_field)}
dataset = data.TabularDataset(
path=self.test_ppid_dataset_path, format="json", fields=fields)
expected_examples = [
(["When", "do", "you", "use", "", "instead", "of", "し?"],
["When", "do", "you", "use", "", "instead", "of", "", "?"],
["When", "do", "you", "use", "\"&\"",
"instead", "of", "\"and\"?"],
["When", "do", "you", "use", "\"", "&", "\"",
"instead", "of", "\"", "and", "\"", "?"], "0"),
(["Where", "was", "Lincoln", "born?"],
["Where", "was", "Lincoln", "born", "?"],
["Which", "location", "was", "Abraham", "Lincoln", "born?"],
["Which", "location", "was", "Abraham", "Lincoln", "born", "?"],
"1"),
(["What", "is", "2+2"], ["What", "is", "2", "+", "2"],
["2+2=?"], ["2", "+", "2=", "?"], "1")]
for i, example in enumerate(dataset):
self.assertEqual(example.q1, expected_examples[i][0])
self.assertEqual(example.q1_spacy, expected_examples[i][1])
self.assertEqual(example.q2, expected_examples[i][2])
self.assertEqual(example.q2_spacy, expected_examples[i][3])
self.assertEqual(example.label, expected_examples[i][4])
def test_errors(self):
# Ensure that trying to retrieve a key not in JSON data errors
self.write_test_ppid_dataset(data_format="json")
question_field = data.Field(sequential=True)
label_field = data.Field(sequential=False)
fields = {"qeustion1": ("q1", question_field),
"question2": ("q2", question_field),
"label": ("label", label_field)}
with self.assertRaises(ValueError):
data.TabularDataset(
path=self.test_ppid_dataset_path, format="json", fields=fields)