genienlp/decanlp/text/torchtext/data/example.py

90 lines
2.8 KiB
Python

import csv
import sys
import json
import six
def intern_strings(x):
if isinstance(x, (list, tuple)):
r = []
for y in x:
if isinstance(y, str):
r.append(sys.intern(y))
else:
r.append(y)
return r
return x
class Example(object):
"""Defines a single training or test example.
Stores each column of the example as an attribute.
"""
@classmethod
def fromJSON(cls, data, fields, **kwargs):
return cls.fromdict(json.loads(data), fields, **kwargs)
@classmethod
def fromdict(cls, data, fields, **kwargs):
ex = cls()
for key, vals in fields.items():
if key not in data:
raise ValueError("Specified key {} was not found in "
"the input data".format(key))
if vals is not None:
if not isinstance(vals, list):
vals = [vals]
for val in vals:
name, field = val
setattr(ex, name, intern_strings(field.preprocess(data[key], **kwargs)))
return ex
@classmethod
def fromTSV(cls, data, fields, **kwargs):
return cls.fromlist(data.split('\t'), fields, **kwargs)
@classmethod
def fromCSV(cls, data, fields, **kwargs):
data = data.rstrip("\n")
# If Python 2, encode to utf-8 since CSV doesn't take unicode input
if six.PY2:
data = data.encode('utf-8')
# Use Python CSV module to parse the CSV line
parsed_csv_lines = csv.reader([data])
# If Python 2, decode back to unicode (the original input format).
if six.PY2:
for line in parsed_csv_lines:
parsed_csv_line = [six.text_type(col, 'utf-8') for col in line]
break
else:
parsed_csv_line = list(parsed_csv_lines)[0]
return cls.fromlist(parsed_csv_line, fields, **kwargs)
@classmethod
def fromlist(cls, data, fields, **kwargs):
ex = cls()
for (name, field), val in zip(fields, data):
if field is not None:
if isinstance(val, six.string_types):
val = val.rstrip('\n')
setattr(ex, name, intern_strings(field.preprocess(val, **kwargs)))
return ex
@classmethod
def fromtree(cls, data, fields, subtrees=False, **kwargs):
try:
from nltk.tree import Tree
except ImportError:
print("Please install NLTK. "
"See the docs at http://nltk.org for more information.")
raise
tree = Tree.fromstring(data)
if subtrees:
return [cls.fromlist(
[' '.join(t.leaves()), t.label()], fields) for t in tree.subtrees()]
return cls.fromlist([' '.join(tree.leaves()), tree.label()], fields, **kwargs)