Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
AndreFCruz committed Apr 25, 2024
1 parent bf27a8a commit 95b9b1f
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 69 deletions.
13 changes: 11 additions & 2 deletions error_parity/pareto_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def fit_and_evaluate_postprocessing(
seed: int = 42,
y_fit_pred_scores: np.ndarray = None, # pre-computed predictions on the fit data
bootstrap: bool = True,
bootstrap_kwargs: dict = None,
**bootstrap_kwargs: dict,
) -> dict[str, dict]:
"""Fit and evaluate a postprocessing intervention on the given predictor.
Expand Down Expand Up @@ -106,12 +106,16 @@ def _evaluate_on_data(data: tuple):
X, Y, S = data

if bootstrap:
kwargs = bootstrap_kwargs or dict(
# Default kwargs for bootstrapping
kwargs = dict(
confidence_pct=95,
seed=seed,
threshold=0.50,
)

# Update kwargs with any extra bootstrap kwargs
kwargs.update(bootstrap_kwargs)

eval_func = partial(
evaluate_predictions_bootstrap,
**kwargs,
Expand Down Expand Up @@ -197,6 +201,10 @@ def callable_predictor(X) -> np.ndarray:
assert 1 <= len(preds.shape) <= 2, f"Model outputs predictions in shape {preds.shape}"
return preds if len(preds.shape) == 1 else preds[:, -1]

# Pre-compute predictions on the fit data
X_fit, _, _ = fit_data
y_fit_pred_scores = callable_predictor(X_fit)

postproc_template = RelaxedThresholdOptimizer(
predictor=callable_predictor,
constraint=fairness_constraint,
Expand All @@ -211,6 +219,7 @@ def _func_call(tol: float):
fit_data=fit_data,
eval_data=eval_data,
bootstrap=bootstrap,
y_fit_pred_scores=y_fit_pred_scores,
**kwargs)

except Exception as exc:
Expand Down
20 changes: 9 additions & 11 deletions examples/example-with-postprocessing-and-inprocessing.ipynb

Large diffs are not rendered by default.

72 changes: 35 additions & 37 deletions examples/relaxed-equalized-odds.usage-example-folktables.ipynb

Large diffs are not rendered by default.

25 changes: 12 additions & 13 deletions examples/relaxed-equalized-odds.usage-example-synthetic-data.ipynb

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@
"INFO:root:ROC convex hull contains 41.6% of the original points.\n",
"INFO:root:ROC convex hull contains 36.6% of the original points.\n",
"INFO:root:ROC convex hull contains 38.6% of the original points.\n",
"INFO:root:cvxpy solver took 0.000259958s; status is optimal.\n",
"INFO:root:cvxpy solver took 0.000282458s; status is optimal.\n",
"INFO:root:Optimal solution value: 0.15335531408011688\n",
"INFO:root:Variable Global ROC point: value [0.10552007 0.71687162]\n",
"INFO:root:Variable ROC point for group 0: value [0.23852472 0.69338557]\n",
Expand All @@ -270,14 +270,14 @@
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 153 ms, sys: 9.35 ms, total: 162 ms\n",
"Wall time: 785 ms\n"
"CPU times: user 152 ms, sys: 9 ms, total: 161 ms\n",
"Wall time: 242 ms\n"
]
},
{
"data": {
"text/plain": [
"<error_parity.threshold_optimizer.RelaxedThresholdOptimizer at 0x137c42950>"
"<error_parity.threshold_optimizer.RelaxedThresholdOptimizer at 0x12b1a5db0>"
]
},
"execution_count": 10,
Expand Down Expand Up @@ -599,7 +599,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "dc11df709fe3435f9a2aa12917cd5dae",
"model_id": "8c3f25c7c9b34de3b6c636584189fb00",
"version_major": 2,
"version_minor": 0
},
Expand All @@ -621,7 +621,6 @@
" \"fit\": (X, y_true, group),\n",
" },\n",
" fairness_constraint=FAIRNESS_CONSTRAINT,\n",
" y_fit_pred_scores=predictor(X),\n",
" predict_method=\"__call__\", # for callable predictors\n",
" bootstrap=True,\n",
" seed=SEED,\n",
Expand Down

0 comments on commit 95b9b1f

Please sign in to comment.