Skip to content

Commit

Permalink
updated all examples with latest error-parity version
Browse files Browse the repository at this point in the history
  • Loading branch information
AndreFCruz committed Nov 15, 2023
1 parent b564f1c commit d7f1e3b
Show file tree
Hide file tree
Showing 8 changed files with 337 additions and 269 deletions.
22 changes: 22 additions & 0 deletions error_parity/_commons.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,28 @@
from scipy.spatial import qhull, ConvexHull


def arrays_are_equal(*arrays: list[np.ndarray]) -> bool:
"""Compares two or more arrays and returns whether they are equal."""
assert len(arrays) >= 2, \
f"At least two arguments must be provided, got {len(arrays)}."

# Reference array
ref_array = arrays[0]
ref_array_np = np.array(ref_array)

for curr_arr in arrays[1:]:
curr_arr_np = np.array(curr_arr)

# Check shape and contents
if (ref_array_np.shape != curr_arr_np.shape
or not np.allclose(ref_array_np, curr_arr_np)
):
return False # arrays are not equal

# All checks passed, return True (arrays are equal)
return True


def join_dictionaries(*dicts) -> dict:
"""Joins a sequence of dictionaries into a single dictionary."""
return reduce(operator.or_, dicts)
Expand Down
4 changes: 2 additions & 2 deletions error_parity/classifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __call__(self, X: np.ndarray, group: np.ndarray = None) -> np.ndarray:
y_pred_binary : np.ndarray[int]
The predicted class for each input sample.
"""
return (self.score_predictor(X) >= self.threshold).astype(int)
return (self.score_predictor(X).ravel() >= self.threshold).astype(int)


class BinaryClassifierAtROCDiagonal(Classifier):
Expand Down Expand Up @@ -125,7 +125,7 @@ def __call__(self, X: np.ndarray, group: np.ndarray) -> np.ndarray:
to a group-specific classifier for that sample.
"""
if len(X) != len(group):
raise ValueError(f"Invalid input sizes len(X) != len(group)")
raise ValueError(f"Invalid input sizes: len(X) != len(group), {len(X)} != {len(group)}.")

# Array to store predictions
num_samples = len(X)
Expand Down
17 changes: 14 additions & 3 deletions error_parity/pareto_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import os
import logging
import traceback
from functools import partial
from concurrent.futures import ThreadPoolExecutor

Expand All @@ -19,7 +20,7 @@

from .threshold_optimizer import RelaxedThresholdOptimizer
from .evaluation import evaluate_predictions, evaluate_predictions_bootstrap
from ._commons import join_dictionaries, get_cost_envelope
from ._commons import join_dictionaries, get_cost_envelope, arrays_are_equal


DEFAULT_TOLERANCE_TICKS = np.hstack((
Expand Down Expand Up @@ -220,7 +221,9 @@ def _func_call(tol: float):
**kwargs)

except Exception as exc:
logging.error(f"FAILED fit_relaxed_postprocessing with `tolerance={tol}`: {exc}")
logging.error(
f"FAILED `fit_and_evaluate_postprocessing(.)` with `tolerance={tol}`; "
f"{''.join(traceback.TracebackException.from_exception(exc).format())}")

return {} # return empty dictionary

Expand All @@ -234,7 +237,12 @@ def _func_call(tol: float):
if tolerance_tick_step is not None:
tolerances = np.arange(0.0, 1.0, tolerance_tick_step)

if tolerance_ticks is not None and tolerance_ticks != DEFAULT_TOLERANCE_TICKS:
if (
# > `tolerance_ticks` was provided
tolerance_ticks is not None
# > and `tolerance_ticks` was set to a non-default value
and not arrays_are_equal(tolerance_ticks, DEFAULT_TOLERANCE_TICKS)
):
logging.error("Please provide only one of `tolerance_ticks` and `tolerance_tick_step`.")

logging.warning("Use of `tolerance_tick_step` overrides the use of `tolerance_ticks`.")
Expand All @@ -243,6 +251,9 @@ def _func_call(tol: float):
else:
tolerances = tolerance_ticks

# Log tolerances used
logging.info(f"Computing postprocessing for the following constraint tolerances: {tolerances}.")

with ThreadPoolExecutor(max_workers=n_jobs) as executor:
func_call_results = list(
tqdm(
Expand Down
8 changes: 6 additions & 2 deletions error_parity/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,16 +240,20 @@ def plot_postprocessing_frontier(
# Get relevant column names
perf_col = f"{perf_metric}_mean_{show_data_type}"
disp_col = f"{disp_metric}_mean_{show_data_type}"

# Check if bootstrap means are available
has_bootstrap_results = perf_col in postproc_results_df.columns

if not has_bootstrap_results:
perf_col = f"{perf_metric}_{show_data_type}"
disp_col = f"{disp_metric}_{show_data_type}"

assert perf_col in postproc_results_df.columns, (
f"Could not find the column '{perf_col}' for the perf. metric '{perf_metric}'.")
f"Could not find the column '{perf_col}' for the perf. metric "
f"'{perf_metric}' on data type '{show_data_type}'.")
assert disp_col in postproc_results_df.columns, (
f"Could not find the column '{disp_col}' for the disp. metric '{disp_metric}'.")
f"Could not find the column '{disp_col}' for the disp. metric "
f"'{disp_metric}' on data type '{show_data_type}'.")

# Get envelope of postprocessing adjustment frontier
postproc_frontier = get_envelope_of_postprocessing_frontier(
Expand Down
6 changes: 5 additions & 1 deletion error_parity/threshold_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def error_rate_parity_constraint_violation(self, error_type: str) -> float:

return self._max_l_inf_between_points(
points=[
roc_point[roc_idx_of_interest]
np.reshape(roc_point[roc_idx_of_interest], newshape=(1,))
for roc_point in self.groupwise_roc_points
],
)
Expand Down Expand Up @@ -314,6 +314,10 @@ def fit(
if y_scores is None:
y_scores = self.predictor(X)

# Flatten y_scores array if needed
if isinstance(y_scores, np.ndarray) and len(y_scores.shape) > 1:
y_scores = y_scores.ravel()

self._groupwise_roc_data = dict()
for g in unique_groups:
group_filter = group == g
Expand Down
160 changes: 111 additions & 49 deletions examples/relaxed-equal-opportunity.usage-example-synthetic-data.ipynb

Large diffs are not rendered by default.

191 changes: 139 additions & 52 deletions examples/relaxed-equalized-odds.usage-example-folktables.ipynb

Large diffs are not rendered by default.

198 changes: 38 additions & 160 deletions examples/relaxed-equalized-odds.usage-example-synthetic-data.ipynb

Large diffs are not rendered by default.

0 comments on commit d7f1e3b

Please sign in to comment.