import os import xml.etree.ElementTree as ET import glob import io from .. import data class TranslationDataset(data.Dataset): """Defines a dataset for machine translation.""" @staticmethod def sort_key(ex): return data.interleave_keys(len(ex.src), len(ex.trg)) def __init__(self, path, exts, fields, **kwargs): """Create a TranslationDataset given paths and fields. Arguments: path: Common prefix of paths to the data files for both languages. exts: A tuple containing the extension to path for each language. fields: A tuple containing the fields that will be used for data in each language. Remaining keyword arguments: Passed to the constructor of data.Dataset. """ if not isinstance(fields[0], (tuple, list)): fields = [('src', fields[0]), ('trg', fields[1])] src_path, trg_path = tuple(os.path.expanduser(path + x) for x in exts) examples = [] with open(src_path) as src_file, open(trg_path) as trg_file: for src_line, trg_line in zip(src_file, trg_file): src_line, trg_line = src_line.strip(), trg_line.strip() if src_line != '' and trg_line != '': examples.append(data.Example.fromlist( [src_line, trg_line], fields)) super(TranslationDataset, self).__init__(examples, fields, **kwargs) @classmethod def splits(cls, exts, fields, root='.data', train='train', validation='val', test='test', **kwargs): """Create dataset objects for splits of a TranslationDataset. Arguments: root: Root dataset storage directory. Default is '.data'. exts: A tuple containing the extension to path for each language. fields: A tuple containing the fields that will be used for data in each language. train: The prefix of the train data. Default: 'train'. validation: The prefix of the validation data. Default: 'val'. test: The prefix of the test data. Default: 'test'. Remaining keyword arguments: Passed to the splits method of Dataset. """ path = cls.download(root) train_data = None if train is None else cls( os.path.join(path, train), exts, fields, **kwargs) val_data = None if validation is None else cls( os.path.join(path, validation), exts, fields, **kwargs) test_data = None if test is None else cls( os.path.join(path, test), exts, fields, **kwargs) return tuple(d for d in (train_data, val_data, test_data) if d is not None) class Multi30k(TranslationDataset): """The small-dataset WMT 2016 multimodal task, also known as Flickr30k""" urls = ['http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/training.tar.gz', 'http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/validation.tar.gz', 'http://www.quest.dcs.shef.ac.uk/' 'wmt17_files_mmt/mmt_task1_test2016.tar.gz'] name = 'multi30k' dirname = '' @classmethod def splits(cls, exts, fields, root='.data', train='train', validation='val', test='test2016', **kwargs): """Create dataset objects for splits of the Multi30k dataset. Arguments: root: Root dataset storage directory. Default is '.data'. exts: A tuple containing the extension to path for each language. fields: A tuple containing the fields that will be used for data in each language. train: The prefix of the train data. Default: 'train'. validation: The prefix of the validation data. Default: 'val'. test: The prefix of the test data. Default: 'test'. Remaining keyword arguments: Passed to the splits method of Dataset. """ return super(Multi30k, cls).splits( exts, fields, root, train, validation, test, **kwargs) class IWSLT(TranslationDataset): """The IWSLT 2016 TED talk translation task""" base_url = 'https://wit3.fbk.eu/archive/2016-01//texts/{}/{}/{}.tgz' name = 'iwslt' base_dirname = '{}-{}' @classmethod def splits(cls, exts, fields, root='.data', train='train', validation='IWSLT16.TED.tst2013', test='IWSLT16.TED.tst2014', **kwargs): """Create dataset objects for splits of the IWSLT dataset. Arguments: root: Root dataset storage directory. Default is '.data'. exts: A tuple containing the extension to path for each language. fields: A tuple containing the fields that will be used for data in each language. train: The prefix of the train data. Default: 'train'. validation: The prefix of the validation data. Default: 'val'. test: The prefix of the test data. Default: 'test'. Remaining keyword arguments: Passed to the splits method of Dataset. """ cls.dirname = cls.base_dirname.format(exts[0][1:], exts[1][1:]) cls.urls = [cls.base_url.format(exts[0][1:], exts[1][1:], cls.dirname)] check = os.path.join(root, cls.name, cls.dirname) path = cls.download(root, check=check) if train is not None: train = '.'.join([train, cls.dirname]) if validation is not None: validation = '.'.join([validation, cls.dirname]) if test is not None: test = '.'.join([test, cls.dirname]) if not os.path.exists(os.path.join(path, '.'.join(['train', cls.dirname])) + exts[0]): cls.clean(path) train_data = None if train is None else cls( os.path.join(path, train), exts, fields, **kwargs) val_data = None if validation is None else cls( os.path.join(path, validation), exts, fields, **kwargs) test_data = None if test is None else cls( os.path.join(path, test), exts, fields, **kwargs) return tuple(d for d in (train_data, val_data, test_data) if d is not None) @staticmethod def clean(path): for f_xml in glob.iglob(os.path.join(path, '*.xml')): print(f_xml) f_txt = os.path.splitext(f_xml)[0] with io.open(f_txt, mode='w', encoding='utf-8') as fd_txt: root = ET.parse(f_xml).getroot()[0] for doc in root.findall('doc'): for e in doc.findall('seg'): fd_txt.write(e.text.strip() + '\n') xml_tags = ['