28 lines
694 B
Python
28 lines
694 B
Python
from torchtext import data
|
|
|
|
|
|
TEXT = data.Field()
|
|
LABELS = data.Field()
|
|
|
|
train, val, test = data.TabularDataset.splits(
|
|
path='~/chainer-research/jmt-data/pos_wsj/pos_wsj', train='.train',
|
|
validation='.dev', test='.test', format='tsv',
|
|
fields=[('text', TEXT), ('labels', LABELS)])
|
|
|
|
print(train.fields)
|
|
print(len(train))
|
|
print(vars(train[0]))
|
|
|
|
train_iter, val_iter, test_iter = data.BucketIterator.splits(
|
|
(train, val, test), batch_size=3, sort_key=lambda x: len(x.text), device=0)
|
|
|
|
LABELS.build_vocab(train.labels)
|
|
TEXT.build_vocab(train.text)
|
|
|
|
print(TEXT.vocab.freqs.most_common(10))
|
|
print(LABELS.vocab.itos)
|
|
|
|
batch = next(iter(train_iter))
|
|
print(batch.text)
|
|
print(batch.labels)
|