210 lines
8.7 KiB
Python
210 lines
8.7 KiB
Python
import os
|
|
import xml.etree.ElementTree as ET
|
|
import glob
|
|
import io
|
|
|
|
from .. import data
|
|
|
|
|
|
class TranslationDataset(data.Dataset):
|
|
"""Defines a dataset for machine translation."""
|
|
|
|
@staticmethod
|
|
def sort_key(ex):
|
|
return data.interleave_keys(len(ex.src), len(ex.trg))
|
|
|
|
def __init__(self, path, exts, fields, **kwargs):
|
|
"""Create a TranslationDataset given paths and fields.
|
|
|
|
Arguments:
|
|
path: Common prefix of paths to the data files for both languages.
|
|
exts: A tuple containing the extension to path for each language.
|
|
fields: A tuple containing the fields that will be used for data
|
|
in each language.
|
|
Remaining keyword arguments: Passed to the constructor of
|
|
data.Dataset.
|
|
"""
|
|
if not isinstance(fields[0], (tuple, list)):
|
|
fields = [('src', fields[0]), ('trg', fields[1])]
|
|
|
|
src_path, trg_path = tuple(os.path.expanduser(path + x) for x in exts)
|
|
|
|
examples = []
|
|
with open(src_path) as src_file, open(trg_path) as trg_file:
|
|
for src_line, trg_line in zip(src_file, trg_file):
|
|
src_line, trg_line = src_line.strip(), trg_line.strip()
|
|
if src_line != '' and trg_line != '':
|
|
examples.append(data.Example.fromlist(
|
|
[src_line, trg_line], fields))
|
|
|
|
super(TranslationDataset, self).__init__(examples, fields, **kwargs)
|
|
|
|
@classmethod
|
|
def splits(cls, exts, fields, root='.data',
|
|
train='train', validation='val', test='test', **kwargs):
|
|
"""Create dataset objects for splits of a TranslationDataset.
|
|
|
|
Arguments:
|
|
|
|
root: Root dataset storage directory. Default is '.data'.
|
|
exts: A tuple containing the extension to path for each language.
|
|
fields: A tuple containing the fields that will be used for data
|
|
in each language.
|
|
train: The prefix of the train data. Default: 'train'.
|
|
validation: The prefix of the validation data. Default: 'val'.
|
|
test: The prefix of the test data. Default: 'test'.
|
|
Remaining keyword arguments: Passed to the splits method of
|
|
Dataset.
|
|
"""
|
|
path = cls.download(root)
|
|
|
|
train_data = None if train is None else cls(
|
|
os.path.join(path, train), exts, fields, **kwargs)
|
|
val_data = None if validation is None else cls(
|
|
os.path.join(path, validation), exts, fields, **kwargs)
|
|
test_data = None if test is None else cls(
|
|
os.path.join(path, test), exts, fields, **kwargs)
|
|
return tuple(d for d in (train_data, val_data, test_data)
|
|
if d is not None)
|
|
|
|
|
|
class Multi30k(TranslationDataset):
|
|
"""The small-dataset WMT 2016 multimodal task, also known as Flickr30k"""
|
|
|
|
urls = ['http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/training.tar.gz',
|
|
'http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/validation.tar.gz',
|
|
'http://www.quest.dcs.shef.ac.uk/'
|
|
'wmt17_files_mmt/mmt_task1_test2016.tar.gz']
|
|
name = 'multi30k'
|
|
dirname = ''
|
|
|
|
@classmethod
|
|
def splits(cls, exts, fields, root='.data',
|
|
train='train', validation='val', test='test2016', **kwargs):
|
|
"""Create dataset objects for splits of the Multi30k dataset.
|
|
|
|
Arguments:
|
|
|
|
root: Root dataset storage directory. Default is '.data'.
|
|
exts: A tuple containing the extension to path for each language.
|
|
fields: A tuple containing the fields that will be used for data
|
|
in each language.
|
|
train: The prefix of the train data. Default: 'train'.
|
|
validation: The prefix of the validation data. Default: 'val'.
|
|
test: The prefix of the test data. Default: 'test'.
|
|
Remaining keyword arguments: Passed to the splits method of
|
|
Dataset.
|
|
"""
|
|
return super(Multi30k, cls).splits(
|
|
exts, fields, root, train, validation, test, **kwargs)
|
|
|
|
|
|
class IWSLT(TranslationDataset):
|
|
"""The IWSLT 2016 TED talk translation task"""
|
|
|
|
base_url = 'https://wit3.fbk.eu/archive/2016-01//texts/{}/{}/{}.tgz'
|
|
name = 'iwslt'
|
|
base_dirname = '{}-{}'
|
|
|
|
@classmethod
|
|
def splits(cls, exts, fields, root='.data',
|
|
train='train', validation='IWSLT16.TED.tst2013',
|
|
test='IWSLT16.TED.tst2014', **kwargs):
|
|
"""Create dataset objects for splits of the IWSLT dataset.
|
|
|
|
Arguments:
|
|
|
|
root: Root dataset storage directory. Default is '.data'.
|
|
exts: A tuple containing the extension to path for each language.
|
|
fields: A tuple containing the fields that will be used for data
|
|
in each language.
|
|
train: The prefix of the train data. Default: 'train'.
|
|
validation: The prefix of the validation data. Default: 'val'.
|
|
test: The prefix of the test data. Default: 'test'.
|
|
Remaining keyword arguments: Passed to the splits method of
|
|
Dataset.
|
|
"""
|
|
cls.dirname = cls.base_dirname.format(exts[0][1:], exts[1][1:])
|
|
cls.urls = [cls.base_url.format(exts[0][1:], exts[1][1:], cls.dirname)]
|
|
check = os.path.join(root, cls.name, cls.dirname)
|
|
path = cls.download(root, check=check)
|
|
|
|
if train is not None:
|
|
train = '.'.join([train, cls.dirname])
|
|
if validation is not None:
|
|
validation = '.'.join([validation, cls.dirname])
|
|
if test is not None:
|
|
test = '.'.join([test, cls.dirname])
|
|
|
|
if not os.path.exists(os.path.join(path, '.'.join(['train', cls.dirname])) + exts[0]):
|
|
cls.clean(path)
|
|
|
|
train_data = None if train is None else cls(
|
|
os.path.join(path, train), exts, fields, **kwargs)
|
|
val_data = None if validation is None else cls(
|
|
os.path.join(path, validation), exts, fields, **kwargs)
|
|
test_data = None if test is None else cls(
|
|
os.path.join(path, test), exts, fields, **kwargs)
|
|
return tuple(d for d in (train_data, val_data, test_data)
|
|
if d is not None)
|
|
|
|
@staticmethod
|
|
def clean(path):
|
|
for f_xml in glob.iglob(os.path.join(path, '*.xml')):
|
|
print(f_xml)
|
|
f_txt = os.path.splitext(f_xml)[0]
|
|
with io.open(f_txt, mode='w', encoding='utf-8') as fd_txt:
|
|
root = ET.parse(f_xml).getroot()[0]
|
|
for doc in root.findall('doc'):
|
|
for e in doc.findall('seg'):
|
|
fd_txt.write(e.text.strip() + '\n')
|
|
|
|
xml_tags = ['<url', '<keywords', '<talkid', '<description',
|
|
'<reviewer', '<translator', '<title', '<speaker']
|
|
for f_orig in glob.iglob(os.path.join(path, 'train.tags*')):
|
|
print(f_orig)
|
|
f_txt = f_orig.replace('.tags', '')
|
|
with io.open(f_txt, mode='w', encoding='utf-8') as fd_txt, \
|
|
io.open(f_orig, mode='r', encoding='utf-8') as fd_orig:
|
|
for l in fd_orig:
|
|
if not any(tag in l for tag in xml_tags):
|
|
fd_txt.write(l.strip() + '\n')
|
|
|
|
|
|
class WMT14(TranslationDataset):
|
|
"""The WMT 2014 English-German dataset, as preprocessed by Google Brain.
|
|
|
|
Though this download contains test sets from 2015 and 2016, the train set
|
|
differs slightly from WMT 2015 and 2016 and significantly from WMT 2017."""
|
|
|
|
urls = [('https://drive.google.com/uc?export=download&'
|
|
'id=0B_bZck-ksdkpM25jRUN2X2UxMm8', 'wmt16_en_de.tar.gz')]
|
|
name = 'wmt14'
|
|
dirname = ''
|
|
|
|
@classmethod
|
|
def splits(cls, exts, fields, root='.data',
|
|
train='train.tok.clean.bpe.32000',
|
|
validation='newstest2013.tok.bpe.32000',
|
|
test='newstest2014.tok.bpe.32000', **kwargs):
|
|
"""Create dataset objects for splits of the WMT 2014 dataset.
|
|
|
|
Arguments:
|
|
|
|
root: Root dataset storage directory. Default is '.data'.
|
|
exts: A tuple containing the extensions for each language. Must be
|
|
either ('.en', '.de') or the reverse.
|
|
fields: A tuple containing the fields that will be used for data
|
|
in each language.
|
|
train: The prefix of the train data. Default:
|
|
'train.tok.clean.bpe.32000'.
|
|
validation: The prefix of the validation data. Default:
|
|
'newstest2013.tok.bpe.32000'.
|
|
test: The prefix of the test data. Default:
|
|
'newstest2014.tok.bpe.32000'.
|
|
Remaining keyword arguments: Passed to the splits method of
|
|
Dataset.
|
|
"""
|
|
return super(WMT14, cls).splits(
|
|
exts, fields, root, train, validation, test, **kwargs)
|