Skip to content

Commit

Permalink
Implemented demographic parity constraint (#7)
Browse files Browse the repository at this point in the history
* generalizing test functions to test all constraints

* DP constraints seems to be working

* testing DP constraint

* fixed bug with true negative rate parity constraints

* updated example notebooks

* updated notebooks with examples on other fairness metrics

* readme update
  • Loading branch information
AndreFCruz authored Mar 20, 2024
1 parent 3ee8c7f commit ea22982
Show file tree
Hide file tree
Showing 11 changed files with 378 additions and 167 deletions.
11 changes: 5 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ from error_parity import RelaxedThresholdOptimizer
# Given any trained model that outputs real-valued scores
fair_clf = RelaxedThresholdOptimizer(
predictor=lambda X: model.predict_proba(X)[:, -1], # for sklearn API
# predictor=model, # use this for a callable model
constraint="equalized_odds",
tolerance=0.05, # fairness constraint tolerance
# predictor=model, # use this for a callable model
constraint="equalized_odds", # other constraints are available
tolerance=0.05, # fairness constraint tolerance
)

# Fit the fairness adjustment on some data
Expand Down Expand Up @@ -84,10 +84,9 @@ Currently implemented fairness constraints:
- [x] predictive equality;
- i.e., equal group-specific FPR;
- use `constraint="false_positive_rate_parity"`;

Road-map:
- [ ] demographic parity;
- [x] demographic parity;
- i.e., equal group-specific predicted prevalence;
- use `constraint="demographic_parity"`;


## Citing
Expand Down
2 changes: 1 addition & 1 deletion error_parity/_version.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
"""File to keep the package version in one place."""
__version__ = "0.3.8"
__version__ = "0.3.9"
__version_info__ = tuple(__version__.split("."))
55 changes: 41 additions & 14 deletions error_parity/cvxpy_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@

# Set of all fairness constraints with a cvxpy LP implementation
ALL_CONSTRAINTS = {
"equalized_odds",
"true_positive_rate_parity",
"false_positive_rate_parity",
"true_negative_rate_parity",
"false_negative_rate_parity",
"equalized_odds", # equal TPR and equal FPR across groups
"true_positive_rate_parity", # TPR parity, same as FNR parity
"false_positive_rate_parity", # FPR parity, same as TNR parity
"true_negative_rate_parity", # TNR parity, same as FPR parity
"false_negative_rate_parity", # FNR parity, same as TPR parity
"demographic_parity", # equal positive prediction rates across groups
}

NOT_SUPPORTED_CONSTRAINTS_ERROR_MESSAGE = (
Expand Down Expand Up @@ -232,6 +233,7 @@ def compute_fair_optimum(
groupwise_roc_hulls: dict[int, np.ndarray],
group_sizes_label_pos: np.ndarray,
group_sizes_label_neg: np.ndarray,
groupwise_prevalence: np.ndarray,
global_prevalence: float,
false_positive_cost: float = 1.0,
false_negative_cost: float = 1.0,
Expand Down Expand Up @@ -286,11 +288,12 @@ def compute_fair_optimum(
n_groups = len(groupwise_roc_hulls)
if n_groups != len(group_sizes_label_neg) or n_groups != len(group_sizes_label_pos):
raise ValueError(
f"Invalid arguments; all of the following should have the same "
f"length: groupwise_roc_hulls, group_sizes_label_neg, group_sizes_label_pos;"
"Invalid arguments; all of the following should have the same "
"length: groupwise_roc_hulls, group_sizes_label_neg, group_sizes_label_pos;"
f"got: {len(groupwise_roc_hulls)}, {len(group_sizes_label_neg)}, {len(group_sizes_label_pos)};"
)

# Group-wise ROC points
# Group-wise ROC points --- in the form (FPR, TPR)
groupwise_roc_points_vars = [
cp.Variable(shape=2, name=f"ROC point for group {i}", nonneg=True)
for i in range(n_groups)
Expand All @@ -307,7 +310,9 @@ def compute_fair_optimum(
== group_sizes_label_pos @ np.array([p[1] for p in groupwise_roc_points_vars]),
]

### APPLY FAIRNESS CONSTRAINTS
# ** APPLY FAIRNESS CONSTRAINTS **
# NOTE: feature request: compatibility with multiple constraints simultaneously

# If "equalized_odds"
# > i.e., constrain l-inf distance between any two groups' ROCs being less than `tolerance`
if fairness_constraint == "equalized_odds":
Expand All @@ -323,19 +328,19 @@ def compute_fair_optimum(
elif fairness_constraint.endswith("rate_parity"):
roc_idx_of_interest: int
if (
fairness_constraint == "true_positive_rate_parity"
or fairness_constraint == "false_negative_rate_parity"
fairness_constraint == "true_positive_rate_parity" # TPR
or fairness_constraint == "false_negative_rate_parity" # FNR
):
roc_idx_of_interest = 1

elif (
fairness_constraint == "false_positive_rate_parity"
or fairness_constraint == "false_negative_rate_parity"
fairness_constraint == "false_positive_rate_parity" # FPR
or fairness_constraint == "true_negative_rate_parity" # TNR
):
roc_idx_of_interest = 0

else:
# This point should never be reached as fairness constraint was previously validated
# This point should never be reached as fairness_constraint was previously validated
raise ValueError(NOT_SUPPORTED_CONSTRAINTS_ERROR_MESSAGE)

constraints += [
Expand All @@ -348,6 +353,28 @@ def compute_fair_optimum(
if i < j
]

# If demographic parity, i.e., equal positive prediction rates across groups
# note: this ignores the labels Y and only considers predictions Y_hat
elif fairness_constraint == "demographic_parity":

# NOTE: PPR = TPR * prevalence + FPR * (1 - prevalence)
def group_positive_prediction_rate(group_idx: int):
"""Computes group-wise PPR as a function of the ROC cvxpy vars."""
group_prevalence = groupwise_prevalence[group_idx]
group_tpr = groupwise_roc_points_vars[group_idx][1]
group_fpr = groupwise_roc_points_vars[group_idx][0]

return group_tpr * group_prevalence + group_fpr * (1 - group_prevalence)

# Add constraints on the absolute difference between group-wos
constraints += [
cp.abs(
group_positive_prediction_rate(i) - group_positive_prediction_rate(j)
) <= tolerance
for i, j in product(range(n_groups), range(n_groups))
if i < j
]

# TODO: implement other constraints here
else:
raise NotImplementedError(NOT_SUPPORTED_CONSTRAINTS_ERROR_MESSAGE)
Expand Down
100 changes: 80 additions & 20 deletions error_parity/threshold_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,29 +88,40 @@ def __init__(
self._groupwise_roc_data: dict = None
self._groupwise_roc_hulls: dict = None
self._groupwise_roc_points: np.ndarray = None
self._groupwise_prevalence: np.ndarray = None
self._global_roc_point: np.ndarray = None
self._global_prevalence: float = None
self._realized_classifier: EnsembleGroupwiseClassifiers = None

@property
def groupwise_roc_data(self) -> dict:
"""Group-specific ROC data containing (FPR, TPR, threshold) triplets."""
return self._groupwise_roc_data

@property
def groupwise_roc_hulls(self) -> dict:
"""Group-specific ROC convex hulls achieved by underlying predictor."""
return self._groupwise_roc_hulls

@property
def groupwise_roc_points(self) -> np.ndarray:
"""Group-specific ROC points achieved by solution."""
return self._groupwise_roc_points

@property
def groupwise_prevalence(self) -> np.ndarray:
"""Group-specific prevalence, i.e., P(Y=1|A=a)"""
return self._groupwise_prevalence

@property
def global_roc_point(self) -> np.ndarray:
"""Global ROC point achieved by solution."""
return self._global_roc_point

@property
def groupwise_roc_hulls(self) -> dict:
"""Group-specific ROC convex hulls achieved by underlying predictor."""
return self._groupwise_roc_hulls

@property
def groupwise_roc_data(self) -> dict:
"""Group-specific ROC data containing (FPR, TPR, threshold) triplets."""
return self._groupwise_roc_data
def global_prevalence(self) -> np.ndarray:
"""Global prevalence, i.e., P(Y=1)."""
return self._global_prevalence

def cost(
self,
Expand Down Expand Up @@ -147,8 +158,14 @@ def cost(
false_neg_cost=false_neg_cost or self.false_neg_cost,
)

def constraint_violation(self) -> float:
"""Constraint violation of the LP solution found.
def constraint_violation(self, constraint_name: str = None) -> float:
"""Theoretical constraint violation of the LP solution found.
Parameters
----------
constraint_name : str, optional
Optionally, may provide another constraint name that will be used
instead of this classifier's self.constraint;
Returns
-------
Expand All @@ -157,13 +174,21 @@ def constraint_violation(self) -> float:
"""
self._check_fit_status()

if self.constraint not in ALL_CONSTRAINTS:
if constraint_name is not None:
logging.warning(
f"Calculating constraint violation for {constraint_name} constraint;\n"
f"Note: this classifier was fitted with a {self.constraint} constraint;"
)
else:
constraint_name = self.constraint

if constraint_name not in ALL_CONSTRAINTS:
raise ValueError(NOT_SUPPORTED_CONSTRAINTS_ERROR_MESSAGE)

if self.constraint == "equalized_odds":
if constraint_name == "equalized_odds":
return self.equalized_odds_violation()

elif self.constraint.endswith("rate_parity"):
elif constraint_name.endswith("rate_parity"):
constraint_to_error_type = {
"true_positive_rate_parity": "fn",
"false_positive_rate_parity": "fp",
Expand All @@ -172,13 +197,16 @@ def constraint_violation(self) -> float:
}

return self.error_rate_parity_constraint_violation(
error_type=constraint_to_error_type[self.constraint],
error_type=constraint_to_error_type[constraint_name],
)

elif constraint_name == "demographic_parity":
return self.demographic_parity_violation()

else:
raise NotImplementedError(
f"Standalone constraint violation not yet computed for "
f"constraint='{self.constraint}'."
f"constraint='{constraint_name}'."
)

def error_rate_parity_constraint_violation(self, error_type: str) -> float:
Expand Down Expand Up @@ -208,7 +236,9 @@ def error_rate_parity_constraint_violation(self, error_type: str) -> float:

return self._max_l_inf_between_points(
points=[
np.reshape(roc_point[roc_idx_of_interest], newshape=(1,))
np.reshape( # NOTE: must pass an array object, not scalars
roc_point[roc_idx_of_interest], # use only FPR or TPR (whichever was constrained)
newshape=(1,))
for roc_point in self.groupwise_roc_points
],
)
Expand All @@ -230,6 +260,31 @@ def equalized_odds_violation(self) -> float:
points=self.groupwise_roc_points,
)

def demographic_parity_violation(self) -> float:
"""Computes the theoretical violation of the demographic parity constraint.
That is, the maximum distance between groups' PPR (positive prediction
rate).
Returns
-------
float
The demographic parity constraint violation.
"""
self._check_fit_status()

# Compute groups' PPR (positive prediction rate)
return self._max_l_inf_between_points( # TODO: check
points=[
# NOTE: must pass an array object, not scalars
np.reshape(
group_tpr * group_prev + group_fpr * (1 - group_prev),
newshape=(1,),
)
for (group_fpr, group_tpr), group_prev in zip(self.groupwise_roc_points, self.groupwise_prevalence)
],
)

@staticmethod
def _max_l_inf_between_points(points: list[float | np.ndarray]) -> float:
# Number of points (should correspond to the number of groups)
Expand Down Expand Up @@ -300,7 +355,7 @@ def fit(
group_sizes_label_pos = np.array([np.sum(y[group == g]) for g in unique_groups])

if np.sum(group_sizes_label_neg) + np.sum(group_sizes_label_pos) != len(y):
raise RuntimeError(f"Failed sanity check. Are you using non-binary labels?")
raise RuntimeError("Failed sanity check. Are you using non-binary labels?")

# Convert to relative sizes
group_sizes_label_neg = group_sizes_label_neg.astype(float) / np.sum(
Expand All @@ -310,6 +365,11 @@ def fit(
group_sizes_label_pos
)

# Compute group-wise prevalence rates
self._groupwise_prevalence = np.array(
[np.mean(y[group == g]) for g in unique_groups]
)

# Compute group-wise ROC curves
if y_scores is None:
y_scores = self.predictor(X)
Expand Down Expand Up @@ -363,7 +423,8 @@ def fit(
groupwise_roc_hulls=self._groupwise_roc_hulls,
group_sizes_label_pos=group_sizes_label_pos,
group_sizes_label_neg=group_sizes_label_neg,
global_prevalence=self._global_prevalence,
groupwise_prevalence=self.groupwise_prevalence,
global_prevalence=self.global_prevalence,
false_positive_cost=self.false_pos_cost,
false_negative_cost=self.false_neg_cost,
)
Expand Down Expand Up @@ -431,7 +492,6 @@ def _check_fit_status(self, raise_error: bool = True) -> bool:
return False

raise RuntimeError(
"This classifier has not yet been fitted to any data."
)
"This classifier has not yet been fitted to any data.")

return True
2 changes: 1 addition & 1 deletion examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@
| [relaxed-equalized-odds.usage-example-folktables.ipynb](relaxed-equalized-odds.usage-example-folktables.ipynb) | equalized odds | ACSIncome | Example usage of `RelaxedThresholdOptimizer` to map Pareto frontier of attainable fairness-accuracy trade-offs for a given predictor. |
| [parse-folktables-datasets.ipynb](parse-folktables-datasets.ipynb) | - | ACSIncome / folktables | Notebook that downloads and parses folktables datasets (required to run the folktables/ACSIncome examples). |
| [relaxed-equalized-odds.usage-example-synthetic-data.ipynb](relaxed-equalized-odds.usage-example-synthetic-data.ipynb) | equalized odds | synthetic (no downloads necessary) | Stand-alone example on synthetic data. |
| [relaxed-equal-opportunity.usage-example-synthetic-data.ipynb](relaxed-equal-opportunity.usage-example-synthetic-data.ipynb) | equal opportunity | synthetic (no downloads) | Stand-alone example for equal opportunity. |
| [usage-example-for-other-constraints.synthetic-data.ipynb](usage-example-for-other-constraints.synthetic-data.ipynb) | TPR equality, FPR equality, demographic parity | synthetic (no downloads) | Stand-alone example with other available fairness metrics (based on TPR, FPR, or PPR). |
| [example-with-postprocessing-and-inprocessing.ipynb](example-with-postprocessing-and-inprocessing.ipynb) | equalized odds | synthetic (no downloads) | Example of using relaxed postprocessing with an in-processing fairness algorithm. |
| [brute-force-example_equalized-odds-thresholding.ipynb](brute-force-example_equalized-odds-thresholding.ipynb) | equalized odds | synthetic (no downloads) | Comparison between using the `RelaxedThresholdOptimizer` and a brute-force solver (out of curiosity). |
Loading

0 comments on commit ea22982

Please sign in to comment.