diff --git a/decanlp/tasks/almond/__init__.py b/decanlp/tasks/almond/__init__.py index d48f3a88..cb0e9883 100644 --- a/decanlp/tasks/almond/__init__.py +++ b/decanlp/tasks/almond/__init__.py @@ -152,7 +152,7 @@ class AlmondDataset(generic_dataset.CQA): aux_data = cls(os.path.join(path, 'aux' + '.tsv'), fields, contextual=contextual, **kwargs) train_data = None if train is None else cls( - os.path.join(path, train + '.tsv'), fields, **kwargs) + os.path.join(path, train + '.tsv'), fields, contextual=contextual, **kwargs) val_data = None if validation is None else cls( os.path.join(path, validation + '.tsv'), fields, contextual=contextual, **kwargs) test_data = None if test is None else cls(