88from collections import defaultdict
99
1010
11+ def _get_field_by_translator (translator ):
12+ if not translator :
13+ return "ref"
14+ else :
15+ return f"ref:{ translator } "
16+
1117class WMTXMLDataset (Dataset ):
1218 """
1319 The 2021+ WMT dataset format. Everything is contained in a single file.
1420 Can be parsed with the lxml parser.
1521 """
16-
1722 @staticmethod
18- def _unwrap_wmt21_or_later (raw_file , allowed_refs = [] ):
23+ def _unwrap_wmt21_or_later (raw_file ):
1924 """
2025 Unwraps the XML file from wmt21 or later.
2126 This script is adapted from https://github.com/wmt-conference/wmt-format-tools
@@ -37,27 +42,20 @@ def _unwrap_wmt21_or_later(raw_file, allowed_refs=[]):
3742 for ref_doc in tree .getroot ().findall (".//ref" ):
3843 ref_langs .add (ref_doc .get ("lang" ))
3944 translator = ref_doc .get ("translator" )
40- if len (allowed_refs ) == 0 or translator in allowed_refs :
41- translators .add (translator )
45+ translators .add (translator )
4246
4347 assert (
4448 len (src_langs ) == 1
4549 ), f"Multiple source languages found in the file: { raw_file } "
4650 assert (
4751 len (ref_langs ) == 1
48- ), f"Multiple reference languages found in the file: { raw_file } "
52+ ), f"Found { len ( ref_langs ) } reference languages found in the file: { raw_file } "
4953
5054 src = []
5155 docids = []
5256 orig_langs = []
5357
54- def get_field_by_translator (translator ):
55- if not translator :
56- return "ref"
57- else :
58- return f"ref:{ translator } "
59-
60- refs = {get_field_by_translator (translator ): [] for translator in translators }
58+ refs = { _get_field_by_translator (translator ): [] for translator in translators }
6159
6260 systems = defaultdict (list )
6361
@@ -97,7 +95,7 @@ def get_sents(doc):
9795 if not any ([value .get (seg_id , "" ) for value in trans_to_ref .values ()]):
9896 continue
9997 for translator in translators :
100- refs [get_field_by_translator (translator )].append (
98+ refs [_get_field_by_translator (translator )].append (
10199 trans_to_ref .get (translator , {translator : {}}).get (seg_id , "" )
102100 )
103101 src .append (src_sents [seg_id ])
@@ -109,22 +107,31 @@ def get_sents(doc):
109107
110108 return {"src" : src , ** refs , "docid" : docids , "origlang" : orig_langs , ** systems }
111109
110+ def _get_langpair_path (self , langpair ):
111+ """
112+ Returns the path for this language pair.
113+ This is useful because in WMT22, the language-pair data structure can be a dict,
114+ in order to allow for overriding which test set to use.
115+ """
116+ langpair_data = self ._get_langpair_metadata (langpair )[langpair ]
117+ rel_path = langpair_data ["path" ] if type (langpair_data ) == dict else langpair_data [0 ]
118+ return os .path .join (self ._rawdir , rel_path )
119+
112120 def process_to_text (self , langpair = None ):
113121 """Processes raw files to plain text files.
114122
115123 :param langpair: The language pair to process. e.g. "en-de". If None, all files will be processed.
116124 """
117125 # ensure that the dataset is downloaded
118126 self .maybe_download ()
119- langpairs = self ._get_langpair_metadata (langpair )
120127
121- for langpair , files in langpairs . items ( ):
122- rawfile = os . path . join (
123- self . _rawdir , files [ 0 ]
124- ) # all source and reference data in one file for wmt21 and later
128+ for langpair in sorted ( self . _get_langpair_metadata ( langpair ). keys () ):
129+ # The data type can be a list of paths, or a dict, containing the " path"
130+ # and an override on which labeled reference to use (key "refs")
131+ rawfile = self . _get_langpair_path ( langpair )
125132
126133 with smart_open (rawfile ) as fin :
127- fields = self ._unwrap_wmt21_or_later (fin , allowed_refs = self . kwargs . get ( "refs" , []) )
134+ fields = self ._unwrap_wmt21_or_later (fin )
128135
129136 for fieldname in fields :
130137 textfile = self ._get_txt_file_path (langpair , fieldname )
@@ -137,11 +144,37 @@ def process_to_text(self, langpair=None):
137144 for line in fields [fieldname ]:
138145 print (self ._clean (line ), file = fout )
139146
147+ def _get_langpair_allowed_refs (self , langpair ):
148+ """
149+ Returns the preferred references for this language pair.
150+ This can be set in the language pair block (as in WMT22), and backs off to the
151+ test-set-level default, or nothing.
152+
153+ There is one exception. In the metadata, sometimes there is no translator field
154+ listed (e.g., wmt22:liv-en). In this case, the reference is set to "", and the
155+ field "ref" is returned.
156+ """
157+ defaults = self .kwargs .get ("refs" , [])
158+ langpair_data = self ._get_langpair_metadata (langpair )[langpair ]
159+ if type (langpair_data ) == dict :
160+ allowed_refs = langpair_data .get ("refs" , defaults )
161+ else :
162+ allowed_refs = defaults
163+ allowed_refs = [_get_field_by_translator (ref ) for ref in allowed_refs ]
164+
165+ return allowed_refs
166+
140167 def get_reference_files (self , langpair ):
168+ """
169+ Returns the requested reference files.
170+ This is defined as a default at the test-set level, and can be overridden per language.
171+ """
172+ # Iterate through the (label, file path) pairs, looking for permitted labels
173+ allowed_refs = self ._get_langpair_allowed_refs (langpair )
141174 all_files = self .get_files (langpair )
142175 all_fields = self .fieldnames (langpair )
143176 ref_files = [
144- f for f , field in zip (all_files , all_fields ) if field . startswith ( "ref" )
177+ f for f , field in zip (all_files , all_fields ) if field in allowed_refs
145178 ]
146179 return ref_files
147180
@@ -157,10 +190,9 @@ def fieldnames(self, langpair):
157190 :return: a list of field names
158191 """
159192 self .maybe_download ()
160- meta = self ._get_langpair_metadata (langpair )[langpair ]
161- rawfile = os .path .join (self ._rawdir , meta [0 ])
193+ rawfile = self ._get_langpair_path (langpair )
162194
163195 with smart_open (rawfile ) as fin :
164- fields = self ._unwrap_wmt21_or_later (fin , allowed_refs = self . kwargs . get ( "refs" , []) )
196+ fields = self ._unwrap_wmt21_or_later (fin )
165197
166198 return list (fields .keys ())
0 commit comments