Skip to content

Commit

Permalink
updated example notebooks for compatibility with l_p_norm kwarg
Browse files Browse the repository at this point in the history
  • Loading branch information
AndreFCruz committed Apr 25, 2024
1 parent 0db5357 commit b1809d0
Show file tree
Hide file tree
Showing 10 changed files with 248 additions and 225 deletions.
5 changes: 4 additions & 1 deletion error_parity/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def eval_accuracy_and_equalized_odds(
y_true: np.ndarray,
y_pred_binary: np.ndarray,
sensitive_attr: np.ndarray,
l_p_norm: int = np.inf,
display: bool = False,
) -> tuple[float, float]:
"""Evaluate accuracy and equalized odds of the given predictions.
Expand All @@ -48,6 +49,8 @@ def eval_accuracy_and_equalized_odds(
The predicted class labels.
sensitive_attr : np.ndarray
The sensitive attribute data.
l_p_norm : int, optional
The norm to use for the constraint violation, by default np.inf.
display : bool, optional
Whether to print results or not, by default False.
Expand All @@ -68,7 +71,7 @@ def eval_accuracy_and_equalized_odds(
roc_points = np.vstack(roc_points)

linf_constraint_violation = [
np.linalg.norm(roc_points[i] - roc_points[j], ord=np.inf)
np.linalg.norm(roc_points[i] - roc_points[j], ord=l_p_norm)
for i, j in product(range(n_groups), range(n_groups))
if i < j
]
Expand Down
67 changes: 34 additions & 33 deletions error_parity/pareto_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from __future__ import annotations

import os
import copy
import logging
import traceback
from functools import partial
Expand All @@ -25,42 +26,31 @@


def fit_and_evaluate_postprocessing(
predictor: callable,
postproc_template: RelaxedThresholdOptimizer,
tolerance: float,
fit_data: tuple,
eval_data: tuple | dict[tuple],
fairness_constraint: str = "equalized_odds",
false_pos_cost: float = 1.,
false_neg_cost: float = 1.,
max_roc_ticks: int = 200,
seed: int = 42,
y_fit_pred_scores: np.ndarray = None, # pre-computed predictions on the fit data
bootstrap: bool = True,
bootstrap_kwargs: dict = None,
**bootstrap_kwargs: dict,
) -> dict[str, dict]:
"""Fit and evaluate a postprocessing intervention on the given predictor.
Parameters
----------
predictor : callable
The callable predictor to fit postprocessing on.
postproc_template: RelaxedThresholdOptimizer
An object that serves as the template to copy when creating the
postprocessing optimizer.
tolerance : float
The tolerance (or slack) for fairness constraint fulfillment.
The tolerance (or slack) for fairness constraint fulfillment. This value
will override the `tolerance` attribute of the `postproc_template` object.
fit_data : tuple
The data used to fit postprocessing.
eval_data : tuple or dict[tuple]
The data or sequence of data to evaluate postprocessing on.
If a tuple is provided, will call it "eval" data in the returned results
dictionary; if a dict is provided, will assume {<key_1>: <data_1>, ...}.
fairness_constraint : str, optional
The name of the fairness constraint to use, by default "equalized_odds".
false_pos_cost : float, optional
The cost of a false positive error, by default 1.
false_neg_cost : float, optional
The cost of a false negative error, by default 1.
max_roc_ticks : int, optional
The maximum number of ticks (precision) to use when computing
group-specific ROC curves, by default 200.
seed : int, optional
The random seed, by default 42
y_fit_pred_scores : np.ndarray, optional
Expand All @@ -85,15 +75,8 @@ def fit_and_evaluate_postprocessing(
>>> "test": {"accuracy": 0.65, "...": "..."},
>>> }
"""
clf = RelaxedThresholdOptimizer(
predictor=predictor,
constraint=fairness_constraint,
tolerance=tolerance,
false_pos_cost=false_pos_cost,
false_neg_cost=false_neg_cost,
max_roc_ticks=max_roc_ticks,
seed=seed,
)
clf = copy.copy(postproc_template)
clf.tolerance = tolerance

# Unpack data
X_fit, y_fit, s_fit = fit_data
Expand All @@ -105,7 +88,7 @@ def fit_and_evaluate_postprocessing(
# (Theoretical) fit results
results["fit-theoretical"] = {
"accuracy": 1 - clf.cost(1.0, 1.0),
fairness_constraint: clf.constraint_violation(),
clf.constraint: clf.constraint_violation(),
}

ALLOWED_ABS_ERROR = 1e-5
Expand All @@ -123,12 +106,16 @@ def _evaluate_on_data(data: tuple):
X, Y, S = data

if bootstrap:
kwargs = bootstrap_kwargs or dict(
# Default kwargs for bootstrapping
kwargs = dict(
confidence_pct=95,
seed=seed,
threshold=0.50,
)

# Update kwargs with any extra bootstrap kwargs
kwargs.update(bootstrap_kwargs)

eval_func = partial(
evaluate_predictions_bootstrap,
**kwargs,
Expand Down Expand Up @@ -156,8 +143,9 @@ def _evaluate_on_data(data: tuple):
def compute_postprocessing_curve(
model: object,
fit_data: tuple,
eval_data: tuple or dict[tuple],
eval_data: tuple | dict[tuple],
fairness_constraint: str = "equalized_odds",
l_p_norm: int = np.inf,
bootstrap: bool = True,
tolerance_ticks: list = DEFAULT_TOLERANCE_TICKS,
tolerance_tick_step: float = None,
Expand All @@ -180,7 +168,10 @@ def compute_postprocessing_curve(
format as `fit_data`), or a dictionary of <data_name>-><data_triplet>
containing multiple datasets to evaluate on.
fairness_constraint : str, optional
_description_, by default "equalized_odds"
The fairness constraint to use , by default "equalized_odds".
l_p_norm : int, optional
The norm to use when computing the fairness constraint, by default np.inf.
Note: only compatible with the "equalized odds" constraint.
bootstrap : bool, optional
Whether to compute uncertainty estimates via bootstrapping, by default
False.
Expand Down Expand Up @@ -210,15 +201,25 @@ def callable_predictor(X) -> np.ndarray:
assert 1 <= len(preds.shape) <= 2, f"Model outputs predictions in shape {preds.shape}"
return preds if len(preds.shape) == 1 else preds[:, -1]

# Pre-compute predictions on the fit data
X_fit, _, _ = fit_data
y_fit_pred_scores = callable_predictor(X_fit)

postproc_template = RelaxedThresholdOptimizer(
predictor=callable_predictor,
constraint=fairness_constraint,
l_p_norm=l_p_norm,
)

def _func_call(tol: float):
try:
return fit_and_evaluate_postprocessing(
predictor=callable_predictor,
postproc_template=postproc_template,
tolerance=tol,
fit_data=fit_data,
eval_data=eval_data,
fairness_constraint=fairness_constraint,
bootstrap=bootstrap,
y_fit_pred_scores=y_fit_pred_scores,
**kwargs)

except Exception as exc:
Expand Down
7 changes: 6 additions & 1 deletion error_parity/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,9 +190,14 @@ def plot_postprocessing_solution(
)

# Set axis settings
fairness_constr_str = postprocessed_clf.constraint.replace("_", " ")
if postprocessed_clf.constraint == "equalized_odds":
l_p_norm = postprocessed_clf.l_p_norm if postprocessed_clf.l_p_norm != np.inf else r"\infty"
fairness_constr_str += f" $\\ell_{l_p_norm}$"

plt.suptitle(f"Solution to {postprocessed_clf.tolerance}-relaxed optimum", y=0.96)
plt.title(
f"(fairness constraint: {postprocessed_clf.constraint.replace('_', ' ')})",
f"(fairness constraint: {fairness_constr_str})",
fontsize="small",
)

Expand Down
18 changes: 17 additions & 1 deletion error_parity/threshold_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,7 @@ def predict(self, X: np.ndarray, *, group: np.ndarray) -> np.ndarray:
return self(X, group=group)

def _check_fit_status(self, raise_error: bool = True) -> bool:
"""Checks whether this classifier has been fit on some data.
"""Check whether this classifier has been fit on some data.
Parameters
----------
Expand All @@ -546,3 +546,19 @@ def _check_fit_status(self, raise_error: bool = True) -> bool:
"This classifier has not yet been fitted to any data.")

return True

def __copy__(self):
"""Create a shallow copy of this object.
The returned copy is in a blank state, i.e., it has not been fit to any
data.
"""
return self.__class__(
predictor=self.predictor,
constraint=self.constraint,
tolerance=self.tolerance,
false_pos_cost=self.false_pos_cost,
false_neg_cost=self.false_neg_cost,
l_p_norm=self.l_p_norm,
max_roc_ticks=self.max_roc_ticks,
seed=self.seed,
)
Loading

0 comments on commit b1809d0

Please sign in to comment.