Skip to content

Commit 80f0980

Browse files
committed
doc2query-t5
1 parent b92c746 commit 80f0980

File tree

3 files changed

+274
-1
lines changed

3 files changed

+274
-1
lines changed

data/.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
collection.tsv
2+
doc2query.tsv
23
index/

pyproject.toml

+3-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@ requires = [
1515
"pyterrier-t5@git+https://github.com/terrierteam/pyterrier_t5.git",
1616
"pyterrier_doc2query@git+https://github.com/terrierteam/pyterrier_doc2query.git",
1717
"torch",
18-
"transformers[torch]"
18+
"transformers[torch]",
19+
"more-itertools",
20+
"tqdm",
1921
]
2022

2123
[build-system]

scripts/doc2query-t5.py

+270
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,270 @@
1+
import signal
2+
import sys
3+
from typing import List, Tuple, Dict
4+
import re
5+
from pathlib import Path
6+
7+
from tqdm import tqdm
8+
import torch
9+
from transformers import T5TokenizerFast, T5ForConditionalGeneration
10+
from more_itertools import chunked
11+
import pandas as pd
12+
13+
14+
MODEL_NAME: str = "castorini/doc2query-t5-large-msmarco"
15+
MAX_LENGTH: int = 512
16+
BATCH_SIZE: int = 64
17+
DEVICE: str = "cuda" if torch.cuda.is_available() else "cpu"
18+
NUM_SAMPLES: int = 3
19+
20+
INPUT_FILE: Path = Path("data/collection.tsv")
21+
OUTPUT_FILE: Path = Path("data/doc2query.tsv")
22+
23+
24+
class Doc2Query:
25+
"""
26+
A class for generating queries from documents using T5.
27+
28+
Attributes
29+
----------
30+
model : T5ForConditionalGeneration
31+
The T5 model to use
32+
tokenizer : T5TokenizerFast
33+
The T5 tokenizer to use
34+
max_length : int
35+
The maximum length of the input
36+
batch_size : int
37+
The batch size to use
38+
input_file : Path
39+
The input file
40+
output_file : Path
41+
The output file
42+
output_df : pd.DataFrame
43+
The output dataframe
44+
"""
45+
46+
model: T5ForConditionalGeneration
47+
tokenizer: T5TokenizerFast
48+
device: torch.device
49+
max_length: int
50+
num_samples: int
51+
batch_size: int
52+
input_file: Path
53+
output_file: Path
54+
output_df: pd.DataFrame
55+
pattern: re.Pattern
56+
57+
def __init__(
58+
self,
59+
model_name: str = MODEL_NAME,
60+
max_length: int = MAX_LENGTH,
61+
batch_size: int = BATCH_SIZE,
62+
device: str = DEVICE,
63+
num_samples: int = NUM_SAMPLES,
64+
input_file: Path = INPUT_FILE,
65+
output_file: Path = OUTPUT_FILE,
66+
):
67+
"""
68+
Constructor for the Doc2Query class.
69+
70+
Parameters
71+
----------
72+
model_name : str, optional
73+
The name of the model to use, by default MODEL_NAME
74+
max_length : int, optional
75+
The maximum length of the input, by default MAX_LENGTH
76+
batch_size : int, optional
77+
The batch size to use, by default BATCH_SIZE
78+
device : str, optional
79+
The device to use, by default DEVICE
80+
num_samples : int, optional
81+
The number of samples to generate, by default NUM_SAMPLES
82+
input_file : Path, optional
83+
The input file, by default INPUT_FILE
84+
output_file : Path, optional
85+
The output file, by default OUTPUT_FILE
86+
"""
87+
signal.signal(
88+
signal.SIGINT, lambda signo, _: self.__del__() and sys.exit(signo)
89+
)
90+
signal.signal(
91+
signal.SIGTERM, lambda signo, _: self.__del__() and sys.exit(signo)
92+
)
93+
94+
self.device = torch.device(device)
95+
self.model = (
96+
T5ForConditionalGeneration.from_pretrained(model_name)
97+
.to(self.device)
98+
.eval()
99+
)
100+
self.tokenizer = T5TokenizerFast.from_pretrained(
101+
model_name,
102+
legacy=False,
103+
model_max_length=max_length,
104+
)
105+
self.max_length = max_length
106+
self.num_samples = num_samples
107+
self.batch_size = batch_size
108+
self.output_file = output_file
109+
assert input_file.exists()
110+
self.input_file = input_file
111+
if self.output_file.exists():
112+
self.output_df = pd.read_table(
113+
self.output_file,
114+
names=["docid"] + [f"query_{i}" for i in range(num_samples)],
115+
header=None,
116+
)
117+
else:
118+
self.output_df = pd.DataFrame(
119+
columns=["docid"] + [f"query_{i}" for i in range(num_samples)]
120+
)
121+
self.pattern = re.compile("^\\s*http\\S+")
122+
123+
def add_new_queries(self, new_queries: List[Tuple[str, List[str]]]):
124+
"""
125+
Add new queries to the output dataframe.
126+
127+
Parameters
128+
----------
129+
new_queries : List[Tuple[str, List[str]]]
130+
The new queries to add: (docid, queries)
131+
"""
132+
new_data: Dict[str, List[str]] = {
133+
"docid": [],
134+
"query_0": [],
135+
"query_1": [],
136+
"query_2": [],
137+
}
138+
139+
for docid, queries in new_queries:
140+
assert 1 <= len(queries) <= self.num_samples
141+
142+
if len(queries) == self.num_samples:
143+
# We do not append the queries if they are already in the output dataframe
144+
# remove docid from output_df
145+
if docid in self.output_df["docid"].values:
146+
self.output_df = self.output_df[self.output_df["docid"] != docid]
147+
new_data["docid"].append(docid)
148+
for i, query in enumerate(queries):
149+
new_data[f"query_{i}"].append(query)
150+
else:
151+
assert docid in self.output_df["docid"].values
152+
# We append the queries if they are not in the output dataframe
153+
existing_queries: List[str] = []
154+
for i in range(self.num_samples):
155+
# fetch the existing queries, if they are not NaN / None / strip() == "" etc.
156+
query = self.output_df[self.output_df["docid"] == docid][
157+
f"query_{i}"
158+
].values[0]
159+
if query is not None and query.strip() != "":
160+
existing_queries.append(query)
161+
162+
assert len(existing_queries) + len(queries) == self.num_samples
163+
164+
# remove docid from output_df
165+
self.output_df = self.output_df[self.output_df["docid"] != docid]
166+
new_data["docid"].append(docid)
167+
for i, query in enumerate(existing_queries + queries):
168+
new_data[f"query_{i}"].append(query)
169+
170+
self.output_df = pd.concat(
171+
[self.output_df, pd.DataFrame(new_data)], ignore_index=True
172+
)
173+
174+
def write_output(self):
175+
self.output_df.to_csv(self.output_file, sep="\t", index=False, header=False)
176+
177+
def __del__(self):
178+
self.write_output()
179+
180+
def generate_queries(self):
181+
"""
182+
Generate queries for the input file.
183+
"""
184+
input_df = pd.read_table(
185+
self.input_file, names=["docid", "document"], header=None
186+
)
187+
# remove docids that are already in the output dataframe, that do not have any NaN/None/strip() == "" values
188+
skipping_ids: int = 0
189+
valid_docids = set(self.output_df["docid"])
190+
for _, row in self.output_df.iterrows():
191+
for i in range(self.num_samples):
192+
query = row[f"query_{i}"]
193+
if query is None or query.strip() == "":
194+
valid_docids.remove(row["docid"])
195+
skipping_ids += 1
196+
break
197+
input_df = input_df[~input_df["docid"].isin(valid_docids)]
198+
199+
print(
200+
f"Processing {len(input_df)} documents (skipping {skipping_ids}). Minimum ID: {input_df['docid'].min()}, maximum ID: {input_df['docid'].max()}"
201+
)
202+
203+
self._generate_queries(
204+
list(zip(input_df["docid"].values, input_df["document"].values))
205+
)
206+
207+
def _generate_queries(self, documents: List[Tuple[str, str]]):
208+
"""
209+
Generate queries for a list of documents.
210+
211+
Parameters
212+
----------
213+
documents : List[Tuple[str, str]]
214+
The list of documents: (docid, document)
215+
216+
Returns
217+
-------
218+
List[Tuple[str, List[str]]]
219+
The list of queries: (docid, queries)
220+
"""
221+
iterator = chunked(iter(documents), self.batch_size)
222+
for batch in tqdm(iterator, total=len(documents) // self.batch_size + 1):
223+
docids: List[str] = []
224+
docs: List[str] = []
225+
for docid, doc in batch:
226+
docids.append(docid)
227+
docs.append(doc)
228+
queries = self._doc2query(docs)
229+
new_queries: List[Tuple[str, List[str]]] = list(zip(docids, queries))
230+
self.add_new_queries(new_queries)
231+
self.write_output()
232+
233+
def _doc2query(self, texts: List[str]) -> List[List[str]]:
234+
"""
235+
Generate queries for a list of texts.
236+
237+
Parameters
238+
----------
239+
texts : List[str]
240+
The list of texts
241+
242+
Returns
243+
-------
244+
List[List[str]]
245+
The list of num_samples queries
246+
"""
247+
docs = [re.sub(self.pattern, "", doc) for doc in texts]
248+
249+
with torch.no_grad():
250+
input_ids = self.tokenizer(
251+
docs,
252+
max_length=self.max_length,
253+
return_tensors="pt",
254+
padding=True,
255+
truncation=True,
256+
).input_ids.to(self.device)
257+
outputs = self.model.generate(
258+
input_ids=input_ids,
259+
max_length=self.max_length,
260+
do_sample=True,
261+
top_k=10,
262+
num_return_sequences=self.num_samples,
263+
)
264+
outputs = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
265+
rtr = [gens for gens in chunked(outputs, self.num_samples)]
266+
return rtr
267+
268+
269+
if __name__ == "__main__":
270+
Doc2Query().generate_queries()

0 commit comments

Comments
 (0)