Remove unused portions of Field
Everything related to preprocessing, tokenization, numericalization has been removed and is implemented elsewhere.
This commit is contained in:
parent
ea83b75089
commit
cde004b3e4
|
@ -1,10 +1,6 @@
|
|||
from .dataset import Dataset
|
||||
from .field import RawField, Field, ReversibleField, SubwordField
|
||||
from .pipeline import Pipeline
|
||||
from .field import Field, ReversibleField
|
||||
from .utils import get_tokenizer, interleave_keys
|
||||
|
||||
__all__ = ["Batch",
|
||||
"Dataset",
|
||||
"RawField", "Field", "ReversibleField", "SubwordField",
|
||||
"Pipeline",
|
||||
__all__ = ["Dataset", "Field", "ReversibleField",
|
||||
"get_tokenizer", "interleave_keys"]
|
||||
|
|
|
@ -6,59 +6,11 @@ import torch
|
|||
from tqdm import tqdm
|
||||
|
||||
from .dataset import Dataset
|
||||
from .pipeline import Pipeline
|
||||
from .utils import get_tokenizer
|
||||
from ..vocab import Vocab, SubwordVocab
|
||||
|
||||
|
||||
class RawField(object):
|
||||
""" Defines a general datatype.
|
||||
|
||||
Every dataset consists of one or more types of data. For instance, a text
|
||||
classification dataset contains sentences and their classes, while a
|
||||
machine translation dataset contains paired examples of text in two
|
||||
languages. Each of these types of data is represented by an RawField object.
|
||||
An RawField object does not assume any property of the data type and
|
||||
it holds parameters relating to how a datatype should be processed.
|
||||
|
||||
Attributes:
|
||||
preprocessing: The Pipeline that will be applied to examples
|
||||
using this field before creating an example.
|
||||
Default: None.
|
||||
postprocessing: A Pipeline that will be applied to a list of examples
|
||||
using this field before assigning to a batch.
|
||||
Function signature: (batch(list)) -> object
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self, preprocessing=None, postprocessing=None):
|
||||
self.preprocessing = preprocessing
|
||||
self.postprocessing = postprocessing
|
||||
|
||||
def preprocess(self, x, field_name=None):
|
||||
""" Preprocess an example if the `preprocessing` Pipeline is provided. """
|
||||
if self.preprocessing is not None:
|
||||
return self.preprocessing(x)
|
||||
else:
|
||||
return x
|
||||
|
||||
def process(self, batch, *args, **kargs):
|
||||
""" Process a list of examples to create a batch.
|
||||
|
||||
Postprocess the batch with user-provided Pipeline.
|
||||
|
||||
Args:
|
||||
batch (list(object)): A list of object from a batch of examples.
|
||||
Returns:
|
||||
data (object): Processed object given the input and custom
|
||||
postprocessing Pipeline.
|
||||
"""
|
||||
if self.postprocessing is not None:
|
||||
batch = self.postprocessing(batch)
|
||||
return batch
|
||||
|
||||
|
||||
class Field(RawField):
|
||||
class Field(object):
|
||||
"""Defines a datatype together with instructions for converting to Tensor.
|
||||
|
||||
Field class models common text processing datatypes that can be represented
|
||||
|
@ -133,7 +85,6 @@ class Field(RawField):
|
|||
def __init__(
|
||||
self, sequential=True, use_vocab=True, init_token=None,
|
||||
eos_token=None, fix_length=None, tensor_type=torch.LongTensor,
|
||||
preprocessing=None, postprocessing=None, lower=False,
|
||||
tokenize=(lambda s: s.split()), include_lengths=False,
|
||||
batch_first=False, pad_token="<pad>", unk_token="<unk>",
|
||||
pad_first=False, decap=False, numerical=False):
|
||||
|
@ -145,104 +96,12 @@ class Field(RawField):
|
|||
self.unk_token = unk_token
|
||||
self.fix_length = fix_length
|
||||
self.tensor_type = tensor_type
|
||||
self.preprocessing = preprocessing
|
||||
self.postprocessing = postprocessing
|
||||
self.lower = lower
|
||||
self.tokenize = get_tokenizer(tokenize)
|
||||
self.include_lengths = include_lengths
|
||||
self.batch_first = batch_first
|
||||
self.pad_token = pad_token if self.sequential else None
|
||||
self.pad_first = pad_first
|
||||
|
||||
def preprocess(self, x, tokenize=None, field_name=None):
|
||||
"""Load a single example using this field, tokenizing if necessary.
|
||||
|
||||
If the input is a Python 2 `str`, it will be converted to Unicode
|
||||
first. If `sequential=True`, it will be tokenized. Then the input
|
||||
will be optionally lowercased and passed to the user-provided
|
||||
`preprocessing` Pipeline."""
|
||||
if (six.PY2 and isinstance(x, six.string_types) and not
|
||||
isinstance(x, six.text_type)):
|
||||
x = Pipeline(lambda s: six.text_type(s, encoding='utf-8'))(x)
|
||||
if self.sequential and isinstance(x, six.text_type):
|
||||
if tokenize is None:
|
||||
x = self.tokenize(x.rstrip('\n'))
|
||||
else:
|
||||
x = tokenize(x.rstrip('\n'), field_name=field_name)
|
||||
if self.lower:
|
||||
x = Pipeline(six.text_type.lower)(x)
|
||||
if self.preprocessing is not None:
|
||||
return self.preprocessing(x)
|
||||
else:
|
||||
return x
|
||||
|
||||
def process(self, batch, device, train, **kwargs):
|
||||
""" Process a list of examples to create a torch.Tensor.
|
||||
|
||||
Pad, numericalize, and postprocess a batch and create a tensor.
|
||||
|
||||
Args:
|
||||
batch (list(object)): A list of object from a batch of examples.
|
||||
Returns:
|
||||
data (torch.autograd.Varaible): Processed object given the input
|
||||
and custom postprocessing Pipeline.
|
||||
"""
|
||||
if self.numerical:
|
||||
if isinstance(batch[0], list):
|
||||
pad_value = max([max(example) for example in batch]) + 1000
|
||||
batch = deepcopy(batch)
|
||||
for example in batch:
|
||||
if self.init_token is not None:
|
||||
for idx, ex in enumerate(example):
|
||||
example[idx] += 1
|
||||
|
||||
max_len = max([len(example) for example in batch])
|
||||
for example in batch:
|
||||
if len(example) < max_len:
|
||||
example += [pad_value] * (max_len - len(example))
|
||||
tensor = torch.LongTensor(batch)
|
||||
tensor = tensor.to(device)
|
||||
else:
|
||||
padded = self.pad(batch)
|
||||
tensor = self.numericalize(padded, device=device, train=train, **kwargs)
|
||||
return tensor
|
||||
|
||||
def pad(self, minibatch):
|
||||
"""Pad a batch of examples using this field.
|
||||
|
||||
Pads to self.fix_length if provided, otherwise pads to the length of
|
||||
the longest example in the batch. Prepends self.init_token and appends
|
||||
self.eos_token if those attributes are not None. Returns a tuple of the
|
||||
padded list and a list containing lengths of each example if
|
||||
`self.include_lengths` is `True` and `self.sequential` is `True`, else just
|
||||
returns the padded list. If `self.sequential` is `False`, no padding is applied.
|
||||
"""
|
||||
minibatch = list(minibatch)
|
||||
if not self.sequential:
|
||||
return minibatch
|
||||
if self.fix_length is None:
|
||||
max_len = max(len(x) for x in minibatch)
|
||||
else:
|
||||
max_len = self.fix_length + (
|
||||
self.init_token, self.eos_token).count(None) - 2
|
||||
padded, lengths = [], []
|
||||
for x in minibatch:
|
||||
if self.pad_first:
|
||||
padded.append(
|
||||
[self.pad_token] * max(0, max_len - len(x)) +
|
||||
([] if self.init_token is None else [self.init_token]) +
|
||||
list(x[:max_len]) +
|
||||
([] if self.eos_token is None else [self.eos_token]))
|
||||
else:
|
||||
padded.append(
|
||||
([] if self.init_token is None else [self.init_token]) +
|
||||
list(x[:max_len]) +
|
||||
([] if self.eos_token is None else [self.eos_token]) +
|
||||
[self.pad_token] * max(0, max_len - len(x)))
|
||||
lengths.append(len(padded[-1]) - max(0, max_len - len(x)))
|
||||
if self.include_lengths:
|
||||
return (padded, lengths)
|
||||
return padded
|
||||
|
||||
def build_vocab(self, field_names, *args, **kwargs):
|
||||
"""Construct the Vocab object for this field from one or more datasets.
|
||||
|
@ -276,98 +135,6 @@ class Field(RawField):
|
|||
self.vocab.itos.append(w)
|
||||
|
||||
|
||||
def vocab_from_counter(self, counter, **kwargs):
|
||||
specials = list(OrderedDict.fromkeys(
|
||||
tok for tok in [self.unk_token, self.pad_token, self.init_token,
|
||||
self.eos_token]
|
||||
if tok is not None))
|
||||
self.vocab = self.vocab_cls(counter, specials=specials, **kwargs)
|
||||
|
||||
|
||||
def numericalize(self, arr, device=None, train=True, limited=None, l2f=None, oov2l=None):
|
||||
"""Turn a batch of examples that use this field into a Variable.
|
||||
|
||||
If the field has include_lengths=True, a tensor of lengths will be
|
||||
included in the return value.
|
||||
|
||||
Arguments:
|
||||
arr (List[List[str]], or tuple of (List[List[str]], List[int])):
|
||||
List of tokenized and padded examples, or tuple of List of
|
||||
tokenized and padded examples and List of lengths of each
|
||||
example if self.include_lengths is True.
|
||||
device (-1 or None): Device to create the Variable's Tensor on.
|
||||
Use -1 for CPU and None for the currently active GPU device.
|
||||
Default: None.
|
||||
train (boolean): Whether the batch is for a training set.
|
||||
If False, the Variable will be created with volatile=True.
|
||||
Default: True.
|
||||
"""
|
||||
if limited is None:
|
||||
limited = self.vocab.stoi
|
||||
if self.include_lengths and not isinstance(arr, tuple):
|
||||
raise ValueError("Field has include_lengths set to True, but "
|
||||
"input data is not a tuple of "
|
||||
"(data batch, batch lengths).")
|
||||
if isinstance(arr, tuple):
|
||||
arr, lengths = arr
|
||||
# lengths = torch.LongTensor(lengths)
|
||||
|
||||
if self.use_vocab:
|
||||
if self.sequential:
|
||||
def limited_idx(x):
|
||||
if x in limited:
|
||||
lim_idx = limited[x]
|
||||
elif x in oov2l:
|
||||
lim_idx = oov2l[x]
|
||||
else:
|
||||
lim_idx = len(limited) + len(oov2l)
|
||||
oov2l[x] = lim_idx
|
||||
l2f[lim_idx] = self.vocab.stoi[x]
|
||||
return lim_idx
|
||||
|
||||
lim_arr = [[limited_idx(x) for x in ex] for ex in arr]
|
||||
num = [[self.vocab.stoi[x] for x in ex] for ex in arr]
|
||||
|
||||
# arr = [[self.vocab.stoi[x] for x in ex] for ex in arr]
|
||||
else:
|
||||
num = [self.vocab.stoi[x] for x in arr]
|
||||
|
||||
if self.postprocessing is not None:
|
||||
num = self.postprocessing(num, self.vocab, train)
|
||||
else:
|
||||
if self.tensor_type not in self.tensor_types:
|
||||
raise ValueError(
|
||||
"Specified Field tensor_type {} can not be used with "
|
||||
"use_vocab=False because we do not know how to numericalize it. "
|
||||
"Please raise an issue at "
|
||||
"https://github.com/pytorch/text/issues".format(self.tensor_type))
|
||||
numericalization_func = self.tensor_types[self.tensor_type]
|
||||
# It doesn't make sense to explictly coerce to a numeric type if
|
||||
# the data is sequential, since it's unclear how to coerce padding tokens
|
||||
# to a numeric type.
|
||||
if not self.sequential:
|
||||
num = [numericalization_func(x) if isinstance(x, six.string_types)
|
||||
else x for x in arr]
|
||||
if self.postprocessing is not None:
|
||||
num = self.postprocessing(num, None, train)
|
||||
|
||||
num = self.tensor_type(num)
|
||||
lim_arr = self.tensor_type(lim_arr)
|
||||
if self.sequential and not self.batch_first:
|
||||
num.t_()
|
||||
lim_arr.t_()
|
||||
if self.sequential:
|
||||
num = num.contiguous()
|
||||
lim_arr = lim_arr.contiguous()
|
||||
num = num.to(device)
|
||||
lim_arr = lim_arr.to(device)
|
||||
# if self.include_lengths:
|
||||
# lengths = lengths.cuda(device)
|
||||
if self.include_lengths:
|
||||
return num, lengths, lim_arr, arr
|
||||
return arr
|
||||
|
||||
|
||||
class ReversibleField(Field):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
|
@ -416,34 +183,3 @@ class ReversibleField(Field):
|
|||
return [self.detokenize(ex) for ex in batch]
|
||||
else:
|
||||
return [''.join(ex) for ex in batch]
|
||||
|
||||
|
||||
class SubwordField(ReversibleField):
|
||||
|
||||
vocab_cls = SubwordVocab
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
kwargs['tokenize'] = 'subword'
|
||||
if 'unk_token' not in kwargs:
|
||||
kwargs['unk_token'] = '<EFBFBD>'
|
||||
super(SubwordField, self).__init__(**kwargs)
|
||||
|
||||
def segment(self, *args):
|
||||
"""Segment one or more datasets with this subword field.
|
||||
|
||||
Arguments:
|
||||
Positional arguments: Dataset objects or other indexable
|
||||
mutable sequences to segment. If a Dataset object is provided,
|
||||
all columns corresponding to this field are used; individual
|
||||
columns can also be provided directly.
|
||||
"""
|
||||
sources = []
|
||||
for arg in args:
|
||||
if isinstance(arg, Dataset):
|
||||
sources += [getattr(arg, name) for name, field in
|
||||
arg.fields.items() if field is self]
|
||||
else:
|
||||
sources.append(arg)
|
||||
for data in sources:
|
||||
for x in tqdm(data, 'segmenting'):
|
||||
x[:] = self.vocab.segment(x)
|
||||
|
|
|
@ -1,85 +0,0 @@
|
|||
class Pipeline(object):
|
||||
"""Defines a pipeline for transforming sequence data.
|
||||
|
||||
The input is assumed to be utf-8 encoded `str` (Python 3) or
|
||||
`unicode` (Python 2).
|
||||
|
||||
Attributes:
|
||||
convert_token: The function to apply to input sequence data.
|
||||
pipes: The Pipelines that will be applid to input sequence
|
||||
data in order.
|
||||
"""
|
||||
def __init__(self, convert_token=None):
|
||||
"""Create a pipeline.
|
||||
|
||||
Arguments:
|
||||
convert_token: The function to apply to input sequence data.
|
||||
If None, the identity function is used. Default: None
|
||||
"""
|
||||
if convert_token is None:
|
||||
self.convert_token = Pipeline.identity
|
||||
elif callable(convert_token):
|
||||
self.convert_token = convert_token
|
||||
else:
|
||||
raise ValueError("Pipeline input convert_token {} is not None "
|
||||
"or callable".format(convert_token))
|
||||
self.pipes = [self]
|
||||
|
||||
def __call__(self, x, *args):
|
||||
"""Apply the the current Pipeline(s) to an input.
|
||||
|
||||
Arguments:
|
||||
x: The input to process with the Pipeline(s).
|
||||
Positional arguments: Forwarded to the `call` function
|
||||
of the Pipeline(s).
|
||||
"""
|
||||
for pipe in self.pipes:
|
||||
x = pipe.call(x, *args)
|
||||
return x
|
||||
|
||||
def call(self, x, *args):
|
||||
"""Apply _only_ the convert_token function of the current pipeline
|
||||
to the input. If the input is a list, a list with the results of
|
||||
applying the `convert_token` function to all input elements is
|
||||
returned.
|
||||
|
||||
Arguments:
|
||||
x: The input to apply the convert_token function to.
|
||||
Positional arguments: Forwarded to the `convert_token` function
|
||||
of the current Pipeline.
|
||||
"""
|
||||
if isinstance(x, list):
|
||||
return [self.convert_token(tok, *args) for tok in x]
|
||||
return self.convert_token(x, *args)
|
||||
|
||||
def add_before(self, pipeline):
|
||||
"""Add a Pipeline to be applied before this processing pipeline.
|
||||
|
||||
Arguments:
|
||||
pipeline: The Pipeline or callable to apply before this
|
||||
Pipeline.
|
||||
"""
|
||||
if not isinstance(pipeline, Pipeline):
|
||||
pipeline = Pipeline(pipeline)
|
||||
self.pipes = pipeline.pipes[:] + self.pipes[:]
|
||||
return self
|
||||
|
||||
def add_after(self, pipeline):
|
||||
"""Add a Pipeline to be applied after this processing pipeline.
|
||||
|
||||
Arguments:
|
||||
pipeline: The Pipeline or callable to apply after this
|
||||
Pipeline.
|
||||
"""
|
||||
if not isinstance(pipeline, Pipeline):
|
||||
pipeline = Pipeline(pipeline)
|
||||
self.pipes = self.pipes[:] + pipeline.pipes[:]
|
||||
return self
|
||||
|
||||
@staticmethod
|
||||
def identity(x):
|
||||
"""Return a copy of the input.
|
||||
|
||||
This is here for serialization compatibility with pickle.
|
||||
"""
|
||||
return x
|
Loading…
Reference in New Issue