Skip to content

Commit 5166cf7

Browse files
authored
Added WMT22 data (closes #215) (#216)
* Added WMT22 data * allow langpair-specific overrides to select the default annotator * added both refs for wmt21/dev
1 parent e416ee2 commit 5166cf7

4 files changed

Lines changed: 145 additions & 25 deletions

File tree

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,7 @@ data/
33
build
44
dist
55
__pycache__
6+
sacrebleu.egg-info
7+
.sacrebleu
8+
*~
9+
.DS_Store

sacrebleu/dataset/__init__.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,51 @@
7474

7575
DATASETS = {
7676
# wmt
77+
"wmt22": WMTXMLDataset(
78+
"wmt22",
79+
data=["https://github.com/wmt-conference/wmt22-news-systems/archive/refs/tags/v1.1.tar.gz"],
80+
description="Official evaluation and system data for WMT22.",
81+
md5=["0840978b9b50b9ac3b2b081e37d620b9"],
82+
langpairs={
83+
"cs-en": {
84+
"path": "wmt22-news-systems-1.1/xml/wmttest2022.cs-en.all.xml",
85+
"refs": ["B"],
86+
},
87+
"cs-uk": ["wmt22-news-systems-1.1/xml/wmttest2022.cs-uk.all.xml"],
88+
"de-en": ["wmt22-news-systems-1.1/xml/wmttest2022.de-en.all.xml"],
89+
"de-fr": ["wmt22-news-systems-1.1/xml/wmttest2022.de-fr.all.xml"],
90+
"en-cs": {
91+
"path": "wmt22-news-systems-1.1/xml/wmttest2022.en-cs.all.xml",
92+
"refs": ["B"],
93+
},
94+
"en-de": ["wmt22-news-systems-1.1/xml/wmttest2022.en-de.all.xml"],
95+
"en-hr": ["wmt22-news-systems-1.1/xml/wmttest2022.en-hr.all.xml"],
96+
"en-ja": ["wmt22-news-systems-1.1/xml/wmttest2022.en-ja.all.xml"],
97+
"en-liv": ["wmt22-news-systems-1.1/xml/wmttest2022.en-liv.all.xml"],
98+
"en-ru": ["wmt22-news-systems-1.1/xml/wmttest2022.en-ru.all.xml"],
99+
"en-uk": ["wmt22-news-systems-1.1/xml/wmttest2022.en-uk.all.xml"],
100+
"en-zh": ["wmt22-news-systems-1.1/xml/wmttest2022.en-zh.all.xml"],
101+
"fr-de": ["wmt22-news-systems-1.1/xml/wmttest2022.fr-de.all.xml"],
102+
"ja-en": ["wmt22-news-systems-1.1/xml/wmttest2022.ja-en.all.xml"],
103+
"liv-en": {
104+
"path": "wmt22-news-systems-1.1/xml/wmttest2022.liv-en.all.xml",
105+
# no translator because data is English-original
106+
"refs": [""],
107+
},
108+
"ru-en": ["wmt22-news-systems-1.1/xml/wmttest2022.ru-en.all.xml"],
109+
"ru-sah": {
110+
"path": "wmt22-news-systems-1.1/xml/wmttest2022.ru-sah.all.xml",
111+
# no translator because data is Yakut-original
112+
"refs": [""],
113+
},
114+
"sah-ru": ["wmt22-news-systems-1.1/xml/wmttest2022.sah-ru.all.xml"],
115+
"uk-cs": ["wmt22-news-systems-1.1/xml/wmttest2022.uk-cs.all.xml"],
116+
"uk-en": ["wmt22-news-systems-1.1/xml/wmttest2022.uk-en.all.xml"],
117+
"zh-en": ["wmt22-news-systems-1.1/xml/wmttest2022.zh-en.all.xml"],
118+
},
119+
# the default reference to use with this dataset
120+
refs=["A"],
121+
),
77122
"wmt21/systems": WMTXMLDataset(
78123
"wmt21/systems",
79124
data=["https://github.com/wmt-conference/wmt21-news-systems/archive/refs/tags/v1.3.tar.gz"],
@@ -101,8 +146,9 @@
101146
"xh-zu": ["wmt21-news-systems-1.3/xml/florestest2021.xh-zu.all.xml"],
102147
"zu-xh": ["wmt21-news-systems-1.3/xml/florestest2021.zu-xh.all.xml"],
103148
},
149+
# the reference to use with this dataset
150+
refs=["A"],
104151
),
105-
106152
"wmt21": WMTXMLDataset(
107153
"wmt21",
108154
data=["http://data.statmt.org/wmt21/translation-task/test.tgz"],
@@ -210,6 +256,8 @@
210256
"en-is": ["dev/xml/newsdev2021.en-is.xml"],
211257
"is-en": ["dev/xml/newsdev2021.is-en.xml"],
212258
},
259+
# datasets are bidirectional in origin, so use both refs
260+
refs=["A", ""],
213261
),
214262
"wmt20/tworefs": FakeSGMLDataset(
215263
"wmt20/tworefs",

sacrebleu/dataset/wmt_xml.py

Lines changed: 55 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,19 @@
88
from collections import defaultdict
99

1010

11+
def _get_field_by_translator(translator):
12+
if not translator:
13+
return "ref"
14+
else:
15+
return f"ref:{translator}"
16+
1117
class WMTXMLDataset(Dataset):
1218
"""
1319
The 2021+ WMT dataset format. Everything is contained in a single file.
1420
Can be parsed with the lxml parser.
1521
"""
16-
1722
@staticmethod
18-
def _unwrap_wmt21_or_later(raw_file, allowed_refs=[]):
23+
def _unwrap_wmt21_or_later(raw_file):
1924
"""
2025
Unwraps the XML file from wmt21 or later.
2126
This script is adapted from https://github.com/wmt-conference/wmt-format-tools
@@ -37,27 +42,20 @@ def _unwrap_wmt21_or_later(raw_file, allowed_refs=[]):
3742
for ref_doc in tree.getroot().findall(".//ref"):
3843
ref_langs.add(ref_doc.get("lang"))
3944
translator = ref_doc.get("translator")
40-
if len(allowed_refs) == 0 or translator in allowed_refs:
41-
translators.add(translator)
45+
translators.add(translator)
4246

4347
assert (
4448
len(src_langs) == 1
4549
), f"Multiple source languages found in the file: {raw_file}"
4650
assert (
4751
len(ref_langs) == 1
48-
), f"Multiple reference languages found in the file: {raw_file}"
52+
), f"Found {len(ref_langs)} reference languages found in the file: {raw_file}"
4953

5054
src = []
5155
docids = []
5256
orig_langs = []
5357

54-
def get_field_by_translator(translator):
55-
if not translator:
56-
return "ref"
57-
else:
58-
return f"ref:{translator}"
59-
60-
refs = {get_field_by_translator(translator): [] for translator in translators}
58+
refs = { _get_field_by_translator(translator): [] for translator in translators }
6159

6260
systems = defaultdict(list)
6361

@@ -97,7 +95,7 @@ def get_sents(doc):
9795
if not any([value.get(seg_id, "") for value in trans_to_ref.values()]):
9896
continue
9997
for translator in translators:
100-
refs[get_field_by_translator(translator)].append(
98+
refs[_get_field_by_translator(translator)].append(
10199
trans_to_ref.get(translator, {translator: {}}).get(seg_id, "")
102100
)
103101
src.append(src_sents[seg_id])
@@ -109,22 +107,31 @@ def get_sents(doc):
109107

110108
return {"src": src, **refs, "docid": docids, "origlang": orig_langs, **systems}
111109

110+
def _get_langpair_path(self, langpair):
111+
"""
112+
Returns the path for this language pair.
113+
This is useful because in WMT22, the language-pair data structure can be a dict,
114+
in order to allow for overriding which test set to use.
115+
"""
116+
langpair_data = self._get_langpair_metadata(langpair)[langpair]
117+
rel_path = langpair_data["path"] if type(langpair_data) == dict else langpair_data[0]
118+
return os.path.join(self._rawdir, rel_path)
119+
112120
def process_to_text(self, langpair=None):
113121
"""Processes raw files to plain text files.
114122
115123
:param langpair: The language pair to process. e.g. "en-de". If None, all files will be processed.
116124
"""
117125
# ensure that the dataset is downloaded
118126
self.maybe_download()
119-
langpairs = self._get_langpair_metadata(langpair)
120127

121-
for langpair, files in langpairs.items():
122-
rawfile = os.path.join(
123-
self._rawdir, files[0]
124-
) # all source and reference data in one file for wmt21 and later
128+
for langpair in sorted(self._get_langpair_metadata(langpair).keys()):
129+
# The data type can be a list of paths, or a dict, containing the "path"
130+
# and an override on which labeled reference to use (key "refs")
131+
rawfile = self._get_langpair_path(langpair)
125132

126133
with smart_open(rawfile) as fin:
127-
fields = self._unwrap_wmt21_or_later(fin, allowed_refs=self.kwargs.get("refs", []))
134+
fields = self._unwrap_wmt21_or_later(fin)
128135

129136
for fieldname in fields:
130137
textfile = self._get_txt_file_path(langpair, fieldname)
@@ -137,11 +144,37 @@ def process_to_text(self, langpair=None):
137144
for line in fields[fieldname]:
138145
print(self._clean(line), file=fout)
139146

147+
def _get_langpair_allowed_refs(self, langpair):
148+
"""
149+
Returns the preferred references for this language pair.
150+
This can be set in the language pair block (as in WMT22), and backs off to the
151+
test-set-level default, or nothing.
152+
153+
There is one exception. In the metadata, sometimes there is no translator field
154+
listed (e.g., wmt22:liv-en). In this case, the reference is set to "", and the
155+
field "ref" is returned.
156+
"""
157+
defaults = self.kwargs.get("refs", [])
158+
langpair_data = self._get_langpair_metadata(langpair)[langpair]
159+
if type(langpair_data) == dict:
160+
allowed_refs = langpair_data.get("refs", defaults)
161+
else:
162+
allowed_refs = defaults
163+
allowed_refs = [_get_field_by_translator(ref) for ref in allowed_refs]
164+
165+
return allowed_refs
166+
140167
def get_reference_files(self, langpair):
168+
"""
169+
Returns the requested reference files.
170+
This is defined as a default at the test-set level, and can be overridden per language.
171+
"""
172+
# Iterate through the (label, file path) pairs, looking for permitted labels
173+
allowed_refs = self._get_langpair_allowed_refs(langpair)
141174
all_files = self.get_files(langpair)
142175
all_fields = self.fieldnames(langpair)
143176
ref_files = [
144-
f for f, field in zip(all_files, all_fields) if field.startswith("ref")
177+
f for f, field in zip(all_files, all_fields) if field in allowed_refs
145178
]
146179
return ref_files
147180

@@ -157,10 +190,9 @@ def fieldnames(self, langpair):
157190
:return: a list of field names
158191
"""
159192
self.maybe_download()
160-
meta = self._get_langpair_metadata(langpair)[langpair]
161-
rawfile = os.path.join(self._rawdir, meta[0])
193+
rawfile = self._get_langpair_path(langpair)
162194

163195
with smart_open(rawfile) as fin:
164-
fields = self._unwrap_wmt21_or_later(fin, allowed_refs=self.kwargs.get("refs", []))
196+
fields = self._unwrap_wmt21_or_later(fin)
165197

166198
return list(fields.keys())

test/test_dataset.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,4 +75,40 @@ def test_source_and_references():
7575
"""
7676
for ds in dataset.DATASETS.values():
7777
for pair in ds.langpairs:
78-
assert len(list(ds.source(pair))) == len(list(ds.references(pair)))
78+
src_len = len(list(ds.source(pair)))
79+
ref_len = len(list(ds.references(pair)))
80+
assert src_len == ref_len, f"source/reference failure for {ds.name}:{pair} len(source)={src_len} len(references)={ref_len}"
81+
82+
83+
def test_wmt22_references():
84+
"""
85+
WMT21 added the ability to specify which reference to use (among many in the XML).
86+
The default was "A" for everything.
87+
WMT22 added the ability to override this default on a per-langpair basis, by
88+
replacing the langpair list of paths with a dict that had the list of paths and
89+
the annotator override.
90+
"""
91+
wmt22 = dataset.DATASETS["wmt22"]
92+
93+
# make sure CS-EN returns all reference fields
94+
cs_en_fields = wmt22.fieldnames("cs-en")
95+
for ref in ["ref:B", "ref:C"]:
96+
assert ref in cs_en_fields
97+
assert "ref:A" not in cs_en_fields
98+
99+
# make sure ref:B is the one used by default
100+
assert wmt22._get_langpair_allowed_refs("cs-en") == ["ref:B"]
101+
102+
# similar check for another dataset: there should be no default ("A"),
103+
# and the only ref found should be the unannotated one
104+
assert "ref:A" not in wmt22.fieldnames("liv-en")
105+
assert "ref" in wmt22.fieldnames("liv-en")
106+
107+
# and that ref:A is the default for all languages where it wasn't overridden
108+
for langpair, langpair_data in wmt22.langpairs.items():
109+
if type(langpair_data) == dict:
110+
assert wmt22._get_langpair_allowed_refs(langpair) != ["ref:A"]
111+
else:
112+
assert wmt22._get_langpair_allowed_refs(langpair) == ["ref:A"]
113+
114+

0 commit comments

Comments
 (0)