Skip to content

Commit 2fcf03e

Browse files
PromoterAI benchmarks (#16)
* Add promoterai benchmark dataset processing * Add promoter ai benchmark results * Refactor sat mut mpra
1 parent eb9d04b commit 2fcf03e

File tree

6 files changed

+209
-41
lines changed

6 files changed

+209
-41
lines changed

experiments/evals/config/config.yaml

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,42 @@ sat_mut_mpra_promoter:
1111
- PKLR
1212
- TERT
1313

14+
# PromoterAI benchmark datasets from GitHub
15+
promoterai_benchmarks:
16+
promoterai_gtex_outlier: GTEx_outlier.tsv
17+
promoterai_cagi5_saturation: CAGI5_saturation.tsv
18+
promoterai_mpra_saturation: MPRA_saturation.tsv
19+
promoterai_gtex_eqtl: GTEx_eQTL.tsv
20+
promoterai_mpra_eqtl: MPRA_eQTL.tsv
21+
promoterai_ukbb_proteome: UKBB_proteome.tsv
22+
promoterai_gel_rna: GEL_RNA.tsv
23+
24+
# Combined dataset groups for efficient batch inference
25+
combined_dataset_groups:
26+
promoterai_combined:
27+
datasets:
28+
- promoterai_gtex_outlier
29+
- promoterai_cagi5_saturation
30+
- promoterai_mpra_saturation
31+
- promoterai_gtex_eqtl
32+
- promoterai_mpra_eqtl
33+
- promoterai_ukbb_proteome
34+
- promoterai_gel_rna
35+
sat_mut_mpra_combined:
36+
datasets:
37+
- sat_mut_mpra_promoter_F9
38+
- sat_mut_mpra_promoter_GP1BA
39+
- sat_mut_mpra_promoter_HBB
40+
- sat_mut_mpra_promoter_HBG1
41+
- sat_mut_mpra_promoter_HNF4A
42+
- sat_mut_mpra_promoter_LDLR
43+
- sat_mut_mpra_promoter_MSMB
44+
- sat_mut_mpra_promoter_PKLR
45+
- sat_mut_mpra_promoter_TERT
46+
1447
context_size: 512
15-
per_device_batch_size: 128
16-
torch_compile: False # overhead not worth it for small datasets and fast models
48+
per_device_batch_size: 512
49+
torch_compile: True # consider if overhead is worth it for small datasets and fast models
1750

1851
# first part run for 370k steps, second part run for 130k steps
1952
models:
@@ -49,3 +82,8 @@ dataset_configs:
4982
# This applies to all promoter-specific datasets
5083
metrics: [Spearman]
5184
scorings: [absLLR.plus.score]
85+
86+
promoterai_benchmark:
87+
# This applies to all PromoterAI benchmark datasets
88+
metrics: [AUPRC]
89+
scorings: [absLLR.plus.score]

experiments/evals/workflow/Snakefile

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,31 @@ configfile: "config/config.yaml"
22

33

44
def get_all_datasets():
5-
"""Get list of all dataset names for wildcard constraints."""
5+
"""Get list of all dataset names that have metrics computed (individual benchmarks only)."""
66
datasets = []
77
for dataset in config["dataset_configs"].keys():
88
if dataset == "sat_mut_mpra_promoter":
99
# Expand for each promoter
1010
for promoter in config["sat_mut_mpra_promoter"]:
1111
datasets.append(f"sat_mut_mpra_promoter_{promoter}")
12+
elif dataset == "promoterai_benchmark":
13+
# Expand for each PromoterAI benchmark
14+
for benchmark in config["promoterai_benchmarks"].keys():
15+
datasets.append(benchmark)
1216
else:
1317
datasets.append(dataset)
1418
return datasets
1519

1620

21+
def get_all_datasets_including_combined():
22+
"""Get all datasets including combined groups for intermediate processing."""
23+
datasets = get_all_datasets()
24+
# Add combined dataset groups
25+
for group_name in config.get("combined_dataset_groups", {}).keys():
26+
datasets.append(group_name)
27+
return datasets
28+
29+
1730
def get_all_metric_files():
1831
"""Generate list of all metric files based on dataset_configs."""
1932
files = []
@@ -29,6 +42,15 @@ def get_all_metric_files():
2942
files.append(
3043
f"results/metrics/{dataset_name}/{metric}/{model}_{scoring}.tsv"
3144
)
45+
elif dataset == "promoterai_benchmark":
46+
# Handle promoterai_benchmark - expand for each benchmark
47+
for benchmark in config["promoterai_benchmarks"].keys():
48+
for metric in cfg["metrics"]:
49+
for model in config["models"].keys():
50+
for scoring in cfg["scorings"]:
51+
files.append(
52+
f"results/metrics/{benchmark}/{metric}/{model}_{scoring}.tsv"
53+
)
3254
else:
3355
# Regular datasets
3456
for metric in cfg["metrics"]:
@@ -63,11 +85,11 @@ include: "rules/common.smk"
6385
include: "rules/gnomad.smk"
6486
include: "rules/metrics.smk"
6587
include: "rules/model.smk"
88+
include: "rules/promoterai_benchmarks.smk"
6689
include: "rules/sat_mut_mpra.smk"
6790
include: "rules/traitgym.smk"
6891

6992

7093
rule all:
7194
input:
72-
get_all_metric_files(),
73-
get_all_correlation_files()
95+
get_all_correlation_files(),

experiments/evals/workflow/rules/metrics.smk

Lines changed: 73 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,85 @@
1-
rule metrics_AUPRC:
2-
input:
3-
"results/dataset/{dataset}.parquet",
4-
"results/prediction/{dataset}/{model}.parquet",
5-
output:
6-
"results/metrics/{dataset}/AUPRC/{model}.tsv",
7-
wildcard_constraints:
8-
dataset="|".join(get_all_datasets()),
9-
run:
10-
y_true = pd.read_parquet(input[0], columns=["label"]).label
11-
y_pred = pd.read_parquet(input[1], columns=["score"]).score
12-
AUPRC = average_precision_score(y_true, y_pred)
13-
pd.DataFrame({"AUPRC": [AUPRC]}).to_csv(output[0], sep="\t", index=False, float_format="%.3f")
1+
# Metric function definitions
2+
def metric_auprc(y_true, y_pred):
3+
"""Compute Area Under Precision-Recall Curve."""
4+
return average_precision_score(y_true, y_pred)
145

156

16-
rule metrics_AUROC:
17-
input:
18-
"results/dataset/{dataset}.parquet",
19-
"results/prediction/{dataset}/{model}.parquet",
20-
output:
21-
"results/metrics/{dataset}/AUROC/{model}.tsv",
22-
wildcard_constraints:
23-
dataset="|".join(get_all_datasets()),
24-
run:
25-
y_true = pd.read_parquet(input[0], columns=["label"]).label
26-
y_pred = pd.read_parquet(input[1], columns=["score"]).score
27-
AUROC = roc_auc_score(y_true, y_pred)
28-
pd.DataFrame({"AUROC": [AUROC]}).to_csv(output[0], sep="\t", index=False, float_format="%.3f")
7+
def metric_auroc(y_true, y_pred):
8+
"""Compute Area Under ROC Curve."""
9+
return roc_auc_score(y_true, y_pred)
10+
2911

12+
def metric_spearman(y_true, y_pred):
13+
"""Compute Spearman correlation coefficient."""
14+
return spearmanr(y_true, y_pred)[0]
3015

31-
rule metrics_Spearman:
16+
17+
# Metric registry - maps metric names to functions
18+
METRIC_FUNCTIONS = {
19+
"AUPRC": metric_auprc,
20+
"AUROC": metric_auroc,
21+
"Spearman": metric_spearman,
22+
}
23+
24+
25+
def get_combined_group(dataset_name):
26+
"""Return combined group name if dataset belongs to one, else None."""
27+
for group_name, group_config in config.get("combined_dataset_groups", {}).items():
28+
if dataset_name in group_config["datasets"]:
29+
return group_name
30+
return None
31+
32+
33+
def get_dataset_input(wildcards):
34+
"""Get dataset input path - combined or individual."""
35+
combined_group = get_combined_group(wildcards.dataset)
36+
if combined_group:
37+
return f"results/dataset/{combined_group}.parquet"
38+
else:
39+
return f"results/dataset/{wildcards.dataset}.parquet"
40+
41+
42+
def get_prediction_input(wildcards):
43+
"""Get prediction input path - combined or individual."""
44+
combined_group = get_combined_group(wildcards.dataset)
45+
if combined_group:
46+
return f"results/prediction/{combined_group}/{wildcards.model}.parquet"
47+
else:
48+
return f"results/prediction/{wildcards.dataset}/{wildcards.model}.parquet"
49+
50+
51+
rule metrics:
52+
"""Unified metrics rule - handles AUPRC, AUROC, Spearman for all datasets."""
3253
input:
33-
"results/dataset/{dataset}.parquet",
34-
"results/prediction/{dataset}/{model}.parquet",
54+
dataset=get_dataset_input,
55+
prediction=get_prediction_input,
3556
output:
36-
"results/metrics/{dataset}/Spearman/{model}.tsv",
57+
"results/metrics/{dataset}/{metric}/{model}.tsv",
3758
wildcard_constraints:
3859
dataset="|".join(get_all_datasets()),
3960
run:
40-
y_true = pd.read_parquet(input[0], columns=["label"]).label
41-
y_pred = pd.read_parquet(input[1], columns=["score"]).score
42-
Spearman = spearmanr(y_true, y_pred)[0]
43-
pd.DataFrame({"Spearman": [Spearman]}).to_csv(output[0], sep="\t", index=False, float_format="%.3f")
61+
# Load data
62+
df_dataset = pd.read_parquet(input.dataset)
63+
df_pred = pd.read_parquet(input.prediction)
64+
65+
# Filter to specific benchmark if using combined dataset (positional filtering)
66+
if 'dataset' in df_dataset.columns:
67+
mask = df_dataset['dataset'] == wildcards.dataset
68+
df_dataset = df_dataset[mask]
69+
df_pred = df_pred[mask] # Apply same positional mask
70+
71+
# Extract labels and scores
72+
y_true = df_dataset["label"]
73+
y_pred = df_pred["score"]
74+
75+
# Compute metric using registry
76+
metric_func = METRIC_FUNCTIONS[wildcards.metric]
77+
value = metric_func(y_true, y_pred)
78+
79+
# Save result
80+
pd.DataFrame({wildcards.metric: [value]}).to_csv(
81+
output[0], sep="\t", index=False, float_format="%.3f"
82+
)
4483

4584

4685
rule aggregate_metrics:

experiments/evals/workflow/rules/model.smk

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ rule model_llr:
2121
output:
2222
"results/features/{dataset}/{model}_LLR.parquet",
2323
wildcard_constraints:
24-
dataset="|".join(get_all_datasets()),
24+
dataset="|".join(get_all_datasets_including_combined()),
2525
model="|".join(config["models"].keys()),
2626
threads:
2727
workflow.cores
@@ -57,7 +57,7 @@ rule model_abs_llr:
5757
output:
5858
"results/features/{dataset}/{model}_absLLR.parquet",
5959
wildcard_constraints:
60-
dataset="|".join(get_all_datasets()),
60+
dataset="|".join(get_all_datasets_including_combined()),
6161
model="|".join(config["models"].keys()),
6262
run:
6363
df = pd.read_parquet(input[0])
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# PromoterAI benchmark datasets from GitHub
2+
# Downloads and processes all benchmarks from: https://github.com/Illumina/PromoterAI/tree/master/data/benchmark
3+
4+
rule promoterai_benchmark:
5+
output:
6+
"results/dataset/{dataset}.parquet",
7+
wildcard_constraints:
8+
dataset="|".join(config["promoterai_benchmarks"].keys()),
9+
params:
10+
filename=lambda wildcards: config["promoterai_benchmarks"][wildcards.dataset],
11+
run:
12+
url = f"https://raw.githubusercontent.com/Illumina/PromoterAI/master/data/benchmark/{params.filename}"
13+
V = pd.read_csv(url, sep="\t")
14+
V["chrom"] = V["chrom"].str.replace("^chr", "", regex=True)
15+
# Group by coordinates (variants near multiple genes)
16+
# Label is True if consequence is not "none" for ANY gene
17+
V_grouped = V.groupby(COORDINATES, as_index=False).agg({
18+
"consequence": lambda x: (x != "none").any(),
19+
})
20+
V_grouped = V_grouped.rename(columns={"consequence": "label"})
21+
V_grouped.to_parquet(output[0], index=False)
22+
23+
24+
rule combine_promoterai_datasets:
25+
"""Combine all promoterai benchmarks into a single dataset for efficient batch inference."""
26+
input:
27+
lambda wildcards: expand(
28+
"results/dataset/{dataset}.parquet",
29+
dataset=config["combined_dataset_groups"][wildcards.combined_group]["datasets"]
30+
)
31+
output:
32+
"results/dataset/{combined_group}.parquet"
33+
wildcard_constraints:
34+
combined_group="promoterai_combined"
35+
run:
36+
datasets = config["combined_dataset_groups"][wildcards.combined_group]["datasets"]
37+
dfs = []
38+
for dataset_name, dataset_path in zip(datasets, input):
39+
df = pd.read_parquet(dataset_path)
40+
df["dataset"] = dataset_name # Add dataset identifier column
41+
dfs.append(df)
42+
43+
# Concatenate all datasets
44+
combined = pd.concat(dfs, ignore_index=True)
45+
combined.to_parquet(output[0], index=False)

experiments/evals/workflow/rules/sat_mut_mpra.smk

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,27 @@ rule sat_mut_mpra_promoter_dataset:
66
V["label"] = V["label"].abs() # abs(LFC)
77
for promoter, path in zip(config["sat_mut_mpra_promoter"], output):
88
V[V["element"] == promoter].to_parquet(path, index=False)
9+
10+
11+
rule combine_sat_mut_mpra_datasets:
12+
"""Combine all sat_mut_mpra promoter datasets into a single dataset for efficient batch inference."""
13+
input:
14+
lambda wildcards: expand(
15+
"results/dataset/{dataset}.parquet",
16+
dataset=config["combined_dataset_groups"][wildcards.combined_group]["datasets"]
17+
)
18+
output:
19+
"results/dataset/{combined_group}.parquet"
20+
wildcard_constraints:
21+
combined_group="sat_mut_mpra_combined"
22+
run:
23+
datasets = config["combined_dataset_groups"][wildcards.combined_group]["datasets"]
24+
dfs = []
25+
for dataset_name, dataset_path in zip(datasets, input):
26+
df = pd.read_parquet(dataset_path)
27+
df["dataset"] = dataset_name # Add dataset identifier column
28+
dfs.append(df)
29+
30+
# Concatenate all datasets
31+
combined = pd.concat(dfs, ignore_index=True)
32+
combined.to_parquet(output[0], index=False)

0 commit comments

Comments
 (0)