Skip to content

Commit bd1c950

Browse files
committed
Fixed bug in compute_scores
1 parent b23832c commit bd1c950

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

raid/evaluate.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,8 @@ def compute_scores(df, thresholds, require_complete=True, include_all=True):
106106
scores = []
107107

108108
# Separate human from model data
109-
df = df[df["model"] != "human"]
110109
dfh = df[df["model"] == "human"]
110+
df = df[df["model"] != "human"]
111111

112112
# For each domain, attack, model, and decoding strategy, filter the dataset
113113
for d in get_unique_items(df, "domain", include_all):
@@ -122,6 +122,7 @@ def compute_scores(df, thresholds, require_complete=True, include_all=True):
122122
for r in get_unique_items(df, "repetition_penalty", include_all):
123123
df_filter = dfs[dfs["repetition_penalty"] == r] if r != "all" else dfs
124124

125+
#print(f"Inner loop. DF filer len is {len(df_filter)} dfh filter len is {len(dfh_filter)}")
125126
# If no outputs for this split, continue
126127
if len(df_filter) == 0 or len(dfh_filter) == 0:
127128
continue
@@ -139,6 +140,9 @@ def compute_scores(df, thresholds, require_complete=True, include_all=True):
139140
# For each target FPR value
140141
tprs = {}
141142
for fpr in thresholds.keys():
143+
# Get thresholds for the particular fpr value
144+
fpr_thresholds = thresholds[fpr]
145+
142146
# Initialize predictions
143147
preds = []
144148

@@ -149,7 +153,7 @@ def compute_scores(df, thresholds, require_complete=True, include_all=True):
149153

150154
# Select the domain-specific threshold to use for classification
151155
# (If thresholds is a dict, use the domain-specific threshold)
152-
t = thresholds[domain] if type(thresholds) == dict else thresholds
156+
t = fpr_thresholds[domain] if type(fpr_thresholds) == dict else fpr_thresholds
153157

154158
# Get the 0 to 1 scores for the detector
155159
y_model = df_domain["score"].to_numpy()

0 commit comments

Comments
 (0)