Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit b92fd9a

Browse files
ArjunSubramonianArjun SubramonianArjun SubramonianArjun SubramonianArjun Subramonian
authored
Contextualized bias mitigation (#5176)
* added linear and hard debiasers * worked on documentation * committing changes before branch switch * committing changes before switching branch * finished bias direction, linear and hard debiasers, need to write tests * finished bias direction test * Commiting changes before switching branch * finished hard and linear debiasers * finished OSCaR * bias mitigators tests and bias metrics remaining * added bias mitigator tests * added bias mitigator tests * finished tests for bias mitigation methods * fixed gpu issues * fixed gpu issues * fixed gpu issues * resolve issue with count_nonzero not being differentiable * added more references * fairness during finetuning * finished bias mitigator wrapper * added reference * updated CHANGELOG and fixed minor docs issues * move id tensors to embedding device * fixed to use predetermined bias direction * fixed minor doc errors * snli reader registration issue * fixed _pretrained from params issue * fixed device issues * evaluate bias mitigation initial commit * finished evaluate bias mitigation * handles multiline prediction files * fixed minor bugs * fixed minor bugs * improved prediction diff JSON format * forgot to resolve a conflict * Refactored evaluate bias mitigation to use NLI metric * Added SNLIPredictionsDiff class * ensured dataloader is same for bias mitigated and baseline models * finished evaluate bias mitigation * Update CHANGELOG.md * Replaced local data files with github raw content links * Update allennlp/fairness/bias_mitigator_applicator.py Co-authored-by: Pete <[email protected]> * deleted evaluate_bias_mitigation from git tracking * removed evaluate-bias-mitigation instances from rest of repo * addressed Akshita's comments * moved bias mitigator applicator test to allennlp-models * removed unnecessary files Co-authored-by: Arjun Subramonian <[email protected]> Co-authored-by: Arjun Subramonian <[email protected]> Co-authored-by: Arjun Subramonian <[email protected]> Co-authored-by: Arjun Subramonian <[email protected]> Co-authored-by: Akshita Bhagia <[email protected]> Co-authored-by: Pete <[email protected]>
1 parent aa52a9a commit b92fd9a

File tree

12 files changed

+2557
-4
lines changed

12 files changed

+2557
-4
lines changed

CHANGELOG.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2424
### Added
2525

2626
- Added `TaskSuite` base class and command line functionality for running [`checklist`](https://github.com/marcotcr/checklist) test suites, along with implementations for `SentimentAnalysisSuite`, `QuestionAnsweringSuite`, and `TextualEntailmentSuite`. These can be found in the `allennlp.confidence_checks.task_checklists` module.
27+
- Added `BiasMitigatorApplicator`, which wraps any Model and mitigates biases by finetuning
28+
on a downstream task.
2729
- Added `allennlp diff` command to compute a diff on model checkpoints, analogous to what `git diff` does on two files.
2830
- Meta data defined by the class `allennlp.common.meta.Meta` is now saved in the serialization directory and archive file
2931
when training models from the command line. This is also now part of the `Archive` named tuple that's returned from `load_archive()`.
@@ -54,7 +56,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
5456
- Fixed `wandb` callback to work in distributed training.
5557
- Fixed `tqdm` logging into multiple files with `allennlp-optuna`.
5658

57-
5859
## [v2.4.0](https://github.com/allenai/allennlp/releases/tag/v2.4.0) - 2021-04-22
5960

6061
### Added
@@ -80,8 +81,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
8081
- Add new dimension to the `interpret` module: influence functions via the `InfluenceInterpreter` base class, along with a concrete implementation: `SimpleInfluence`.
8182
- Added a `quiet` parameter to the `MultiProcessDataLoading` that disables `Tqdm` progress bars.
8283
- The test for distributed metrics now takes a parameter specifying how often you want to run it.
83-
- Created the fairness module and added four fairness metrics: `Independence`, `Separation`, and `Sufficiency`.
84-
- Added three bias metrics to the fairness module: `WordEmbeddingAssociationTest`, `EmbeddingCoherenceTest`, `NaturalLanguageInference`, and `AssociationWithoutGroundTruth`.
84+
- Created the fairness module and added three fairness metrics: `Independence`, `Separation`, and `Sufficiency`.
85+
- Added four bias metrics to the fairness module: `WordEmbeddingAssociationTest`, `EmbeddingCoherenceTest`, `NaturalLanguageInference`, and `AssociationWithoutGroundTruth`.
8586
- Added four bias direction methods (`PCABiasDirection`, `PairedPCABiasDirection`, `TwoMeansBiasDirection`, `ClassificationNormalBiasDirection`) and four bias mitigation methods (`LinearBiasMitigator`, `HardBiasMitigator`, `INLPBiasMitigator`, `OSCaRBiasMitigator`).
8687

8788
### Changed

allennlp/fairness/__init__.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
44
1. measure the fairness of models according to multiple definitions of fairness
55
2. measure bias amplification
6-
3. debias embeddings during training time and post-processing
6+
3. mitigate bias in static and contextualized embeddings during training time and
7+
post-processing
78
"""
89

910
from allennlp.fairness.fairness_metrics import Independence, Separation, Sufficiency
@@ -25,3 +26,17 @@
2526
INLPBiasMitigator,
2627
OSCaRBiasMitigator,
2728
)
29+
from allennlp.fairness.bias_utils import load_words, load_word_pairs
30+
from allennlp.fairness.bias_mitigator_applicator import BiasMitigatorApplicator
31+
from allennlp.fairness.bias_mitigator_wrappers import (
32+
HardBiasMitigatorWrapper,
33+
LinearBiasMitigatorWrapper,
34+
INLPBiasMitigatorWrapper,
35+
OSCaRBiasMitigatorWrapper,
36+
)
37+
from allennlp.fairness.bias_direction_wrappers import (
38+
PCABiasDirectionWrapper,
39+
PairedPCABiasDirectionWrapper,
40+
TwoMeansBiasDirectionWrapper,
41+
ClassificationNormalBiasDirectionWrapper,
42+
)
Lines changed: 269 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,269 @@
1+
import torch
2+
from typing import Union, Optional
3+
from os import PathLike
4+
5+
from allennlp.fairness.bias_direction import (
6+
BiasDirection,
7+
PCABiasDirection,
8+
PairedPCABiasDirection,
9+
TwoMeansBiasDirection,
10+
ClassificationNormalBiasDirection,
11+
)
12+
from allennlp.fairness.bias_utils import load_word_pairs, load_words
13+
14+
from allennlp.common import Registrable
15+
from allennlp.data.tokenizers.tokenizer import Tokenizer
16+
from allennlp.data import Vocabulary
17+
18+
19+
class BiasDirectionWrapper(Registrable):
20+
"""
21+
Parent class for bias direction wrappers.
22+
"""
23+
24+
def __init__(self):
25+
self.direction: BiasDirection = None
26+
self.noise: float = None
27+
28+
def __call__(self, module):
29+
raise NotImplementedError
30+
31+
def train(self, mode: bool = True):
32+
"""
33+
34+
# Parameters
35+
36+
mode : `bool`, optional (default=`True`)
37+
Sets `requires_grad` to value of `mode` for bias direction.
38+
"""
39+
self.direction.requires_grad = mode
40+
41+
def add_noise(self, t: torch.Tensor):
42+
"""
43+
44+
# Parameters
45+
46+
t : `torch.Tensor`
47+
Tensor to which to add small amount of Gaussian noise.
48+
"""
49+
return t + self.noise * torch.randn(t.size(), device=t.device)
50+
51+
52+
@BiasDirectionWrapper.register("pca")
53+
class PCABiasDirectionWrapper(BiasDirectionWrapper):
54+
"""
55+
56+
# Parameters
57+
58+
seed_words_file : `Union[PathLike, str]`
59+
Path of file containing seed words.
60+
tokenizer : `Tokenizer`
61+
Tokenizer used to tokenize seed words.
62+
direction_vocab : `Vocabulary`, optional (default=`None`)
63+
Vocabulary of tokenizer. If `None`, assumes tokenizer is of
64+
type `PreTrainedTokenizer` and uses tokenizer's `vocab` attribute.
65+
namespace : `str`, optional (default=`"tokens"`)
66+
Namespace of direction_vocab to use when tokenizing.
67+
Disregarded when direction_vocab is `None`.
68+
requires_grad : `bool`, optional (default=`False`)
69+
Option to enable gradient calculation for bias direction.
70+
noise : `float`, optional (default=`1e-10`)
71+
To avoid numerical instability if embeddings are initialized uniformly.
72+
"""
73+
74+
def __init__(
75+
self,
76+
seed_words_file: Union[PathLike, str],
77+
tokenizer: Tokenizer,
78+
direction_vocab: Optional[Vocabulary] = None,
79+
namespace: str = "tokens",
80+
requires_grad: bool = False,
81+
noise: float = 1e-10,
82+
):
83+
self.ids = load_words(seed_words_file, tokenizer, direction_vocab, namespace)
84+
self.direction = PCABiasDirection(requires_grad=requires_grad)
85+
self.noise = noise
86+
87+
def __call__(self, module):
88+
# embed subword token IDs and mean pool to get
89+
# embedding of original word
90+
ids_embeddings = []
91+
for i in self.ids:
92+
i = i.to(module.weight.device)
93+
ids_embeddings.append(torch.mean(module.forward(i), dim=0, keepdim=True))
94+
ids_embeddings = torch.cat(ids_embeddings)
95+
96+
# adding trivial amount of noise
97+
# to eliminate linear dependence amongst all embeddings
98+
# when training first starts
99+
ids_embeddings = self.add_noise(ids_embeddings)
100+
101+
return self.direction(ids_embeddings)
102+
103+
104+
@BiasDirectionWrapper.register("paired_pca")
105+
class PairedPCABiasDirectionWrapper(BiasDirectionWrapper):
106+
"""
107+
108+
# Parameters
109+
110+
seed_word_pairs_file : `Union[PathLike, str]`
111+
Path of file containing seed word pairs.
112+
tokenizer : `Tokenizer`
113+
Tokenizer used to tokenize seed words.
114+
direction_vocab : `Vocabulary`, optional (default=`None`)
115+
Vocabulary of tokenizer. If `None`, assumes tokenizer is of
116+
type `PreTrainedTokenizer` and uses tokenizer's `vocab` attribute.
117+
namespace : `str`, optional (default=`"tokens"`)
118+
Namespace of direction_vocab to use when tokenizing.
119+
Disregarded when direction_vocab is `None`.
120+
requires_grad : `bool`, optional (default=`False`)
121+
Option to enable gradient calculation for bias direction.
122+
noise : `float`, optional (default=`1e-10`)
123+
To avoid numerical instability if embeddings are initialized uniformly.
124+
"""
125+
126+
def __init__(
127+
self,
128+
seed_word_pairs_file: Union[PathLike, str],
129+
tokenizer: Tokenizer,
130+
direction_vocab: Optional[Vocabulary] = None,
131+
namespace: str = "tokens",
132+
requires_grad: bool = False,
133+
noise: float = 1e-10,
134+
):
135+
self.ids1, self.ids2 = load_word_pairs(
136+
seed_word_pairs_file, tokenizer, direction_vocab, namespace
137+
)
138+
self.direction = PairedPCABiasDirection(requires_grad=requires_grad)
139+
self.noise = noise
140+
141+
def __call__(self, module):
142+
# embed subword token IDs and mean pool to get
143+
# embedding of original word
144+
ids1_embeddings = []
145+
for i in self.ids1:
146+
i = i.to(module.weight.device)
147+
ids1_embeddings.append(torch.mean(module.forward(i), dim=0, keepdim=True))
148+
ids2_embeddings = []
149+
for i in self.ids2:
150+
i = i.to(module.weight.device)
151+
ids2_embeddings.append(torch.mean(module.forward(i), dim=0, keepdim=True))
152+
ids1_embeddings = torch.cat(ids1_embeddings)
153+
ids2_embeddings = torch.cat(ids2_embeddings)
154+
155+
ids1_embeddings = self.add_noise(ids1_embeddings)
156+
ids2_embeddings = self.add_noise(ids2_embeddings)
157+
158+
return self.direction(ids1_embeddings, ids2_embeddings)
159+
160+
161+
@BiasDirectionWrapper.register("two_means")
162+
class TwoMeansBiasDirectionWrapper(BiasDirectionWrapper):
163+
"""
164+
165+
# Parameters
166+
167+
seed_word_pairs_file : `Union[PathLike, str]`
168+
Path of file containing seed word pairs.
169+
tokenizer : `Tokenizer`
170+
Tokenizer used to tokenize seed words.
171+
direction_vocab : `Vocabulary`, optional (default=`None`)
172+
Vocabulary of tokenizer. If `None`, assumes tokenizer is of
173+
type `PreTrainedTokenizer` and uses tokenizer's `vocab` attribute.
174+
namespace : `str`, optional (default=`"tokens"`)
175+
Namespace of direction_vocab to use when tokenizing.
176+
Disregarded when direction_vocab is `None`.
177+
requires_grad : `bool`, optional (default=`False`)
178+
Option to enable gradient calculation for bias direction.
179+
noise : `float`, optional (default=`1e-10`)
180+
To avoid numerical instability if embeddings are initialized uniformly.
181+
"""
182+
183+
def __init__(
184+
self,
185+
seed_word_pairs_file: Union[PathLike, str],
186+
tokenizer: Tokenizer,
187+
direction_vocab: Optional[Vocabulary] = None,
188+
namespace: str = "tokens",
189+
requires_grad: bool = False,
190+
noise: float = 1e-10,
191+
):
192+
self.ids1, self.ids2 = load_word_pairs(
193+
seed_word_pairs_file, tokenizer, direction_vocab, namespace
194+
)
195+
self.direction = TwoMeansBiasDirection(requires_grad=requires_grad)
196+
self.noise = noise
197+
198+
def __call__(self, module):
199+
# embed subword token IDs and mean pool to get
200+
# embedding of original word
201+
ids1_embeddings = []
202+
for i in self.ids1:
203+
i = i.to(module.weight.device)
204+
ids1_embeddings.append(torch.mean(module.forward(i), dim=0, keepdim=True))
205+
ids2_embeddings = []
206+
for i in self.ids2:
207+
i = i.to(module.weight.device)
208+
ids2_embeddings.append(torch.mean(module.forward(i), dim=0, keepdim=True))
209+
ids1_embeddings = torch.cat(ids1_embeddings)
210+
ids2_embeddings = torch.cat(ids2_embeddings)
211+
212+
ids1_embeddings = self.add_noise(ids1_embeddings)
213+
ids2_embeddings = self.add_noise(ids2_embeddings)
214+
215+
return self.direction(ids1_embeddings, ids2_embeddings)
216+
217+
218+
@BiasDirectionWrapper.register("classification_normal")
219+
class ClassificationNormalBiasDirectionWrapper(BiasDirectionWrapper):
220+
"""
221+
222+
# Parameters
223+
224+
seed_word_pairs_file : `Union[PathLike, str]`
225+
Path of file containing seed word pairs.
226+
tokenizer : `Tokenizer`
227+
Tokenizer used to tokenize seed words.
228+
direction_vocab : `Vocabulary`, optional (default=`None`)
229+
Vocabulary of tokenizer. If `None`, assumes tokenizer is of
230+
type `PreTrainedTokenizer` and uses tokenizer's `vocab` attribute.
231+
namespace : `str`, optional (default=`"tokens"`)
232+
Namespace of direction_vocab to use when tokenizing.
233+
Disregarded when direction_vocab is `None`.
234+
noise : `float`, optional (default=`1e-10`)
235+
To avoid numerical instability if embeddings are initialized uniformly.
236+
"""
237+
238+
def __init__(
239+
self,
240+
seed_word_pairs_file: Union[PathLike, str],
241+
tokenizer: Tokenizer,
242+
direction_vocab: Optional[Vocabulary] = None,
243+
namespace: str = "tokens",
244+
noise: float = 1e-10,
245+
):
246+
self.ids1, self.ids2 = load_word_pairs(
247+
seed_word_pairs_file, tokenizer, direction_vocab, namespace
248+
)
249+
self.direction = ClassificationNormalBiasDirection()
250+
self.noise = noise
251+
252+
def __call__(self, module):
253+
# embed subword token IDs and mean pool to get
254+
# embedding of original word
255+
ids1_embeddings = []
256+
for i in self.ids1:
257+
i = i.to(module.weight.device)
258+
ids1_embeddings.append(torch.mean(module.forward(i), dim=0, keepdim=True))
259+
ids2_embeddings = []
260+
for i in self.ids2:
261+
i = i.to(module.weight.device)
262+
ids2_embeddings.append(torch.mean(module.forward(i), dim=0, keepdim=True))
263+
ids1_embeddings = torch.cat(ids1_embeddings)
264+
ids2_embeddings = torch.cat(ids2_embeddings)
265+
266+
ids1_embeddings = self.add_noise(ids1_embeddings)
267+
ids2_embeddings = self.add_noise(ids2_embeddings)
268+
269+
return self.direction(ids1_embeddings, ids2_embeddings)

allennlp/fairness/bias_metrics.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,8 @@ class NaturalLanguageInference(Metric):
258258
3. Threshold:tau (T:tau): A parameterized measure that reports the fraction
259259
of examples whose probability of neutral is above tau.
260260
261+
# Parameters
262+
261263
neutral_label : `int`, optional (default=`2`)
262264
The discrete integer label corresponding to a neutral entailment prediction.
263265
taus : `List[float]`, optional (default=`[0.5, 0.7]`)

0 commit comments

Comments
 (0)