genienlp/text/test/translation.py

103 lines
2.3 KiB
Python

from torchtext import data
from torchtext import datasets
import re
import spacy
spacy_de = spacy.load('de')
spacy_en = spacy.load('en')
url = re.compile('(<url>.*</url>)')
def tokenize_de(text):
return [tok.text for tok in spacy_de.tokenizer(url.sub('@URL@', text))]
def tokenize_en(text):
return [tok.text for tok in spacy_en.tokenizer(url.sub('@URL@', text))]
# Testing IWSLT
DE = data.Field(tokenize=tokenize_de)
EN = data.Field(tokenize=tokenize_en)
train, val, test = datasets.IWSLT.splits(exts=('.de', '.en'), fields=(DE, EN))
print(train.fields)
print(len(train))
print(vars(train[0]))
print(vars(train[100]))
DE.build_vocab(train.src, min_freq=3)
EN.build_vocab(train.trg, max_size=50000)
train_iter, val_iter = data.BucketIterator.splits(
(train, val), batch_size=3, device=0)
print(DE.vocab.freqs.most_common(10))
print(len(DE.vocab))
print(EN.vocab.freqs.most_common(10))
print(len(EN.vocab))
batch = next(iter(train_iter))
print(batch.src)
print(batch.trg)
# Testing Multi30k
DE = data.Field(tokenize=tokenize_de)
EN = data.Field(tokenize=tokenize_en)
train, val, test = datasets.Multi30k.splits(exts=('.de', '.en'), fields=(DE, EN))
print(train.fields)
print(len(train))
print(vars(train[0]))
print(vars(train[100]))
DE.build_vocab(train.src, min_freq=3)
EN.build_vocab(train.trg, max_size=50000)
train_iter, val_iter = data.BucketIterator.splits(
(train, val), batch_size=3, device=0)
print(DE.vocab.freqs.most_common(10))
print(len(DE.vocab))
print(EN.vocab.freqs.most_common(10))
print(len(EN.vocab))
batch = next(iter(train_iter))
print(batch.src)
print(batch.trg)
# Testing custom paths
DE = data.Field(tokenize=tokenize_de)
EN = data.Field(tokenize=tokenize_en)
train, val = datasets.TranslationDataset.splits(
path='.data/multi30k/', train='train',
validation='val', exts=('.de', '.en'),
fields=(DE, EN))
print(train.fields)
print(len(train))
print(vars(train[0]))
print(vars(train[100]))
DE.build_vocab(train.src, min_freq=3)
EN.build_vocab(train.trg, max_size=50000)
train_iter, val_iter = data.BucketIterator.splits(
(train, val), batch_size=3, device=0)
print(DE.vocab.freqs.most_common(10))
print(len(DE.vocab))
print(EN.vocab.freqs.most_common(10))
print(len(EN.vocab))
batch = next(iter(train_iter))
print(batch.src)
print(batch.trg)