210 lines
8.7 KiB
210 lines
8.7 KiB
import os
import xml.etree.ElementTree as ET
import glob
import io
from .. import data
class TranslationDataset(data.Dataset):
"""Defines a dataset for machine translation."""
def sort_key(ex):
return data.interleave_keys(len(ex.src), len(ex.trg))
def __init__(self, path, exts, fields, **kwargs):
"""Create a TranslationDataset given paths and fields.
path: Common prefix of paths to the data files for both languages.
exts: A tuple containing the extension to path for each language.
fields: A tuple containing the fields that will be used for data
in each language.
Remaining keyword arguments: Passed to the constructor of
if not isinstance(fields[0], (tuple, list)):
fields = [('src', fields[0]), ('trg', fields[1])]
src_path, trg_path = tuple(os.path.expanduser(path + x) for x in exts)
examples = []
with open(src_path) as src_file, open(trg_path) as trg_file:
for src_line, trg_line in zip(src_file, trg_file):
src_line, trg_line = src_line.strip(), trg_line.strip()
if src_line != '' and trg_line != '':
[src_line, trg_line], fields))
super(TranslationDataset, self).__init__(examples, fields, **kwargs)
def splits(cls, exts, fields, root='.data',
train='train', validation='val', test='test', **kwargs):
"""Create dataset objects for splits of a TranslationDataset.
root: Root dataset storage directory. Default is '.data'.
exts: A tuple containing the extension to path for each language.
fields: A tuple containing the fields that will be used for data
in each language.
train: The prefix of the train data. Default: 'train'.
validation: The prefix of the validation data. Default: 'val'.
test: The prefix of the test data. Default: 'test'.
Remaining keyword arguments: Passed to the splits method of
path = cls.download(root)
train_data = None if train is None else cls(
os.path.join(path, train), exts, fields, **kwargs)
val_data = None if validation is None else cls(
os.path.join(path, validation), exts, fields, **kwargs)
test_data = None if test is None else cls(
os.path.join(path, test), exts, fields, **kwargs)
return tuple(d for d in (train_data, val_data, test_data)
if d is not None)
class Multi30k(TranslationDataset):
"""The small-dataset WMT 2016 multimodal task, also known as Flickr30k"""
urls = ['http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/training.tar.gz',
name = 'multi30k'
dirname = ''
def splits(cls, exts, fields, root='.data',
train='train', validation='val', test='test2016', **kwargs):
"""Create dataset objects for splits of the Multi30k dataset.
root: Root dataset storage directory. Default is '.data'.
exts: A tuple containing the extension to path for each language.
fields: A tuple containing the fields that will be used for data
in each language.
train: The prefix of the train data. Default: 'train'.
validation: The prefix of the validation data. Default: 'val'.
test: The prefix of the test data. Default: 'test'.
Remaining keyword arguments: Passed to the splits method of
return super(Multi30k, cls).splits(
exts, fields, root, train, validation, test, **kwargs)
class IWSLT(TranslationDataset):
"""The IWSLT 2016 TED talk translation task"""
base_url = 'https://wit3.fbk.eu/archive/2016-01//texts/{}/{}/{}.tgz'
name = 'iwslt'
base_dirname = '{}-{}'
def splits(cls, exts, fields, root='.data',
train='train', validation='IWSLT16.TED.tst2013',
test='IWSLT16.TED.tst2014', **kwargs):
"""Create dataset objects for splits of the IWSLT dataset.
root: Root dataset storage directory. Default is '.data'.
exts: A tuple containing the extension to path for each language.
fields: A tuple containing the fields that will be used for data
in each language.
train: The prefix of the train data. Default: 'train'.
validation: The prefix of the validation data. Default: 'val'.
test: The prefix of the test data. Default: 'test'.
Remaining keyword arguments: Passed to the splits method of
cls.dirname = cls.base_dirname.format(exts[0][1:], exts[1][1:])
cls.urls = [cls.base_url.format(exts[0][1:], exts[1][1:], cls.dirname)]
check = os.path.join(root, cls.name, cls.dirname)
path = cls.download(root, check=check)
if train is not None:
train = '.'.join([train, cls.dirname])
if validation is not None:
validation = '.'.join([validation, cls.dirname])
if test is not None:
test = '.'.join([test, cls.dirname])
if not os.path.exists(os.path.join(path, '.'.join(['train', cls.dirname])) + exts[0]):
train_data = None if train is None else cls(
os.path.join(path, train), exts, fields, **kwargs)
val_data = None if validation is None else cls(
os.path.join(path, validation), exts, fields, **kwargs)
test_data = None if test is None else cls(
os.path.join(path, test), exts, fields, **kwargs)
return tuple(d for d in (train_data, val_data, test_data)
if d is not None)
def clean(path):
for f_xml in glob.iglob(os.path.join(path, '*.xml')):
f_txt = os.path.splitext(f_xml)[0]
with io.open(f_txt, mode='w', encoding='utf-8') as fd_txt:
root = ET.parse(f_xml).getroot()[0]
for doc in root.findall('doc'):
for e in doc.findall('seg'):
fd_txt.write(e.text.strip() + '\n')
xml_tags = ['<url', '<keywords', '<talkid', '<description',
'<reviewer', '<translator', '<title', '<speaker']
for f_orig in glob.iglob(os.path.join(path, 'train.tags*')):
f_txt = f_orig.replace('.tags', '')
with io.open(f_txt, mode='w', encoding='utf-8') as fd_txt, \
io.open(f_orig, mode='r', encoding='utf-8') as fd_orig:
for l in fd_orig:
if not any(tag in l for tag in xml_tags):
fd_txt.write(l.strip() + '\n')
class WMT14(TranslationDataset):
"""The WMT 2014 English-German dataset, as preprocessed by Google Brain.
Though this download contains test sets from 2015 and 2016, the train set
differs slightly from WMT 2015 and 2016 and significantly from WMT 2017."""
urls = [('https://drive.google.com/uc?export=download&'
'id=0B_bZck-ksdkpM25jRUN2X2UxMm8', 'wmt16_en_de.tar.gz')]
name = 'wmt14'
dirname = ''
def splits(cls, exts, fields, root='.data',
test='newstest2014.tok.bpe.32000', **kwargs):
"""Create dataset objects for splits of the WMT 2014 dataset.
root: Root dataset storage directory. Default is '.data'.
exts: A tuple containing the extensions for each language. Must be
either ('.en', '.de') or the reverse.
fields: A tuple containing the fields that will be used for data
in each language.
train: The prefix of the train data. Default:
validation: The prefix of the validation data. Default:
test: The prefix of the test data. Default:
Remaining keyword arguments: Passed to the splits method of
return super(WMT14, cls).splits(
exts, fields, root, train, validation, test, **kwargs)