Skip to content

Commit

Permalink
added tests for unprocessing
Browse files Browse the repository at this point in the history
  • Loading branch information
AndreFCruz committed Apr 2, 2024
1 parent 13b1b62 commit a921842
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 50 deletions.
3 changes: 3 additions & 0 deletions .flake8
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[flake8]
max-complexity=10
max-line-length=127
15 changes: 4 additions & 11 deletions .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,20 +35,13 @@ jobs:
python -m pip install ".[all]"
# ^ install local package with all extras
- name: Lint errors with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
flake8 error_parity --count --select=E9,F63,F7,F82 --show-source --statistics
flake8 tests --count --select=E9,F63,F7,F82 --show-source --statistics
- name: Lint warnings with flake8
- name: Lint with flake8
continue-on-error: true
run: |
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 error_parity --count --max-complexity=10 --max-line-length=127 --statistics
flake8 tests --count --max-complexity=10 --max-line-length=127 --statistics
flake8 error_parity --count --statistics
flake8 tests --count --statistics
- name: Test with pytest
run: |
coverage run -m pytest tests && coverage report -m
# pytest tests
coverage run -m pytest tests && coverage report -m --fail-under=75
78 changes: 39 additions & 39 deletions tests/test_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,42 +223,42 @@ def predictor(idx):
)


# def test_unprocessing(
# y_true: np.ndarray,
# y_pred_scores: np.ndarray,
# sensitive_attribute: np.ndarray,
# random_seed: int,
# ):
# """Tests that unprocessing strictly increases accuracy.
# """
# # Predictor function
# # # > predicts the generated scores from the sample indices
# def predictor(idx):
# return y_pred_scores[idx]

# # Hence, for this example, the features are the sample indices
# num_samples = len(y_true)
# X_features = np.arange(num_samples)

# clf = RelaxedThresholdOptimizer(
# predictor=predictor,
# tolerance=1,
# false_pos_cost=1,
# false_neg_cost=1,
# seed=random_seed,
# )

# # Fit postprocessing to data
# clf.fit(X=X_features, y=y_true, group=sensitive_attribute)

# # Optimal binarized predictions
# y_pred_binary = clf(X_features, group=sensitive_attribute)

# # Original accuracy (using group-blind thresholds)
# original_acc = accuracy_score(y_true, (y_pred_scores >= 0.5).astype(int))

# # Unprocessed accuracy (using group-dependent thresholds)
# unprocessed_acc = accuracy_score(y_true, y_pred_binary)

# # Assert that unprocessing always improves (or maintains) accuracy
# assert unprocessed_acc >= original_acc
def test_unprocessing(
y_true: np.ndarray,
y_pred_scores: np.ndarray,
sensitive_attribute: np.ndarray,
random_seed: int,
):
"""Tests that unprocessing strictly increases accuracy.
"""
# Predictor function
# # > predicts the generated scores from the sample indices
def predictor(idx):
return y_pred_scores[idx]

# Hence, for this example, the features are the sample indices
num_samples = len(y_true)
X_features = np.arange(num_samples)

clf = RelaxedThresholdOptimizer(
predictor=predictor,
tolerance=1,
false_pos_cost=1,
false_neg_cost=1,
seed=random_seed,
)

# Fit postprocessing to data
clf.fit(X=X_features, y=y_true, group=sensitive_attribute)

# Optimal binarized predictions
y_pred_binary = clf(X_features, group=sensitive_attribute)

# Original accuracy (using group-blind thresholds)
original_acc = accuracy_score(y_true, (y_pred_scores >= 0.5).astype(int))

# Unprocessed accuracy (using group-dependent thresholds)
unprocessed_acc = accuracy_score(y_true, y_pred_binary)

# Assert that unprocessing always improves (or maintains) accuracy
assert unprocessed_acc >= original_acc

0 comments on commit a921842

Please sign in to comment.