Skip to content

Commit f9d28a9

Browse files
committed
Kaggle command
1 parent b7015aa commit f9d28a9

File tree

2 files changed

+101
-1
lines changed

2 files changed

+101
-1
lines changed

py_css/interface/kaggle.py

+92
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
import logging
2+
from typing import Dict, Tuple, List
3+
import csv
4+
import os
5+
6+
import pandas as pd
7+
8+
import indexer.index as index_module
9+
import models.base as base_model
10+
import models.baseline as baseline_module
11+
12+
index = None
13+
pipeline: base_model.Pipeline
14+
15+
16+
def to_kaggle_format(df: pd.DataFrame) -> str:
17+
"""
18+
Convert a dataframe to the Kaggle submission format.
19+
20+
Parameters
21+
----------
22+
df : pd.DataFrame
23+
The dataframe to convert.
24+
25+
Returns
26+
-------
27+
str
28+
The dataframe in the Kaggle submission format.
29+
"""
30+
# for each query, only keep the best 3 docnos
31+
df = df.groupby("qid").head(3)
32+
33+
output = "qid,docid\n"
34+
for _, row in df.iterrows():
35+
output += f"{row['qid']},{row['docno']}\n"
36+
return output
37+
38+
39+
def main(
40+
*,
41+
recreate: bool,
42+
queries_file_path: str,
43+
output_file_path: str,
44+
baseline_params: Tuple[int, int, int],
45+
) -> None:
46+
"""
47+
The main function of the eval interface.
48+
49+
Parameters
50+
----------
51+
recreate : bool
52+
Whether to recreate the index.
53+
queries_file_path : str
54+
The path to the queries file.
55+
qrels_file_path : str
56+
The path to the qrels file.
57+
baseline_params : Tuple[int, int, int]
58+
The parameters for the baseline model.
59+
"""
60+
global index
61+
global pipeline
62+
63+
index = index_module.get_index(recreate=recreate)
64+
pipeline = baseline_module.Baseline(
65+
index, baseline_params[0], baseline_params[1], baseline_params[2]
66+
)
67+
68+
logging.info("Loading queries...")
69+
queries: Dict[int, Dict[int, base_model.Query]] = {} # topic_id -> (turn_id, query)
70+
with open(queries_file_path, "r") as queries_file:
71+
# Skip the header
72+
queries_file.readline()
73+
csv_reader = csv.reader(queries_file)
74+
for line in csv_reader:
75+
query_id, query, topic_id, turn_id = tuple(line[0:4])
76+
queries.setdefault(int(topic_id), {})[int(turn_id)] = base_model.Query(
77+
query_id=query_id, query=query
78+
)
79+
inputs: List[Tuple[List[base_model.Query], base_model.Context]] = []
80+
for topic_id, qs in queries.items():
81+
inputs.append(
82+
([query for _, query in sorted(qs.items(), key=lambda x: x[0])], [])
83+
)
84+
85+
logging.info("Running queries...")
86+
_, results = pipeline.batch_search_conversation(inputs)
87+
88+
logging.info("Writing results...")
89+
# create file and parent directories if not exist
90+
os.makedirs(os.path.dirname(output_file_path), exist_ok=True)
91+
with open(output_file_path, "w") as output_file:
92+
output_file.write(to_kaggle_format(results))

py_css/main.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import interface.cli as cli_module
77
import interface.run_queries as run_queries_module
88
import interface.eval as eval_module
9+
import interface.kaggle as kaggle_module
910

1011
import models.base as base_module
1112
import models.baseline as baseline_module
@@ -60,7 +61,7 @@ def main():
6061
parser.add_argument(
6162
"command",
6263
type=str,
63-
choices=["cli", "run_file", "eval"],
64+
choices=["cli", "run_file", "eval", "kaggle"],
6465
help='Command to run (e.g., "cli" for command line interface)',
6566
)
6667

@@ -127,6 +128,13 @@ def main():
127128
qrels_file_path=args.qrels,
128129
baseline_params=args.baseline_params,
129130
)
131+
elif args.command == "kaggle":
132+
kaggle_module.main(
133+
recreate=args.recreate,
134+
queries_file_path=args.queries,
135+
output_file_path=args.output,
136+
baseline_params=args.baseline_params,
137+
)
130138

131139

132140
if __name__ == "__main__":

0 commit comments

Comments
 (0)