Skip to content

Commit ebf7cf6

Browse files
committed
Enable summarization by subsets and groups
Signed-off-by: Jonathan Bnayahu <[email protected]>
1 parent 92367e6 commit ebf7cf6

File tree

1 file changed

+96
-32
lines changed

1 file changed

+96
-32
lines changed

src/unitxt/evaluate_cli.py

Lines changed: 96 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import platform
88
import subprocess
99
import sys
10-
from datetime import datetime
10+
from datetime import datetime, timezone
1111
from functools import partial
1212
from typing import Any, Dict, List, Optional, Tuple, Union
1313

@@ -691,9 +691,8 @@ def _save_results_to_disk(
691691
"results": global_scores,
692692
}
693693

694-
# prepend to the results_path name the time in a wat like this: 2025-04-04T11:37:32
695-
696-
timestamp = datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
694+
# prepend the timestamp in UTC (e.g., 2025-01-18T11-37-32) to the file names
695+
timestamp = datetime.now().astimezone(timezone.utc).strftime("%Y-%m-%dT%H-%M-%S")
697696

698697
results_path = prepend_timestamp_to_path(results_path, timestamp)
699698
samples_path = prepend_timestamp_to_path(samples_path, timestamp)
@@ -836,47 +835,112 @@ def main():
836835
logger.info("Unitxt Evaluation CLI finished successfully.")
837836

838837

839-
def extract_scores(directory): # pragma: no cover
838+
def extract_scores(folder: str, subset: str, group: str): # pragma: no cover
840839
import pandas as pd
841840

842-
data = []
841+
def safe_score(d: dict, key="score"):
842+
na = "N/A"
843+
return d.get(key, na) if isinstance(d, dict) else na
843844

844-
for filename in sorted(os.listdir(directory)):
845-
if filename.endswith("evaluation_results.json"):
846-
file_path = os.path.join(directory, filename)
847-
try:
848-
with open(file_path, encoding="utf-8") as f:
849-
content = json.load(f)
845+
def extract_subset(results: dict, subset: str, group: str):
846+
subset_results = results.get(subset, {})
847+
row = {subset: safe_score(subset_results)}
848+
849+
groups = subset_results.get("groups", {})
850+
851+
if not groups:
852+
return row
850853

851-
env_info = content.get("environment_info", {})
852-
timestamp = env_info.get("timestamp_utc", "N/A")
853-
model = env_info.get("parsed_arguments", {}).get("model", "N/A")
854-
results = content.get("results", {})
854+
group_results = groups.get(group) if group else next(iter(groups.values()), {})
855855

856-
row = {}
857-
row["Model"] = model
858-
row["Timestamp"] = timestamp
859-
row["Average"] = results.get("score", "N/A")
856+
if not isinstance(group_results, dict):
857+
return row
860858

861-
for key in results.keys():
862-
if isinstance(results[key], dict):
863-
score = results[key].get("score", "N/A")
864-
row[key] = score
859+
row.update(
860+
{k: safe_score(v) for k, v in group_results.items() if isinstance(v, dict)}
861+
)
862+
return row
863+
864+
def extract_all(results: dict):
865+
row = {"Average": safe_score(results)}
866+
row.update(
867+
{k: safe_score(v) for k, v in results.items() if isinstance(v, dict)}
868+
)
869+
return row
870+
871+
data = []
865872

866-
data.append(row)
867-
except Exception as e:
868-
logger.error(f"Error parsing results file {filename}: {e}.")
873+
for filename in sorted(os.listdir(folder)):
874+
if not filename.endswith("evaluation_results.json"):
875+
continue
876+
877+
file_path = os.path.join(folder, filename)
878+
try:
879+
with open(file_path, encoding="utf-8") as f:
880+
content = json.load(f)
881+
882+
env_info = content.get("environment_info", {})
883+
row = {
884+
"Timestamp": safe_score(env_info, "timestamp_utc"),
885+
"Model": safe_score(env_info.get("parsed_arguments", {}), "model"),
886+
}
887+
888+
results = content.get("results", {})
889+
890+
extra = (
891+
extract_subset(results, subset, group)
892+
if subset
893+
else extract_all(results)
894+
)
895+
row.update(extra)
896+
data.append(row)
897+
except Exception as e:
898+
logger.error(f"Error parsing results file {filename}: {e}.")
869899

870900
return pd.DataFrame(data).sort_values(by="Timestamp", ascending=True)
871901

872902

903+
def setup_summarization_parser() -> argparse.ArgumentParser:
904+
parser = argparse.ArgumentParser(
905+
formatter_class=argparse.RawTextHelpFormatter,
906+
description="CLI utility for summarizing evaluation results.",
907+
)
908+
909+
parser.add_argument(
910+
"--folder",
911+
"-f",
912+
dest="folder",
913+
type=str,
914+
default=".",
915+
help="Directory containing evaluation results json files. Default: current folder.\n",
916+
)
917+
918+
parser.add_argument(
919+
"--subset",
920+
"-s",
921+
type=str,
922+
dest="subset",
923+
default=None,
924+
help="Subset to filter results by. Default: none.",
925+
)
926+
927+
parser.add_argument(
928+
"--group",
929+
"-g",
930+
type=str,
931+
dest="group",
932+
default=None,
933+
help="Group to filter results to. Requires specifying a subset. Default: first group.",
934+
)
935+
936+
return parser
937+
938+
873939
def summarize_cli():
874-
if len(sys.argv) != 2:
875-
logger.error("Usage: python summarize_cli_results.py <results-directory>")
876-
sys.exit(1)
877-
directory = sys.argv[1]
878-
df = extract_scores(directory)
940+
parser = setup_summarization_parser()
941+
args = parser.parse_args()
879942

943+
df = extract_scores(args.folder, args.subset, args.group)
880944
logger.info(df.to_markdown(index=False))
881945

882946

0 commit comments

Comments
 (0)