diff --git a/plm_interpretability/logistic_regression_probe/all_latents.py b/plm_interpretability/logistic_regression_probe/all_latents.py index d20c6bc..dace4dc 100644 --- a/plm_interpretability/logistic_regression_probe/all_latents.py +++ b/plm_interpretability/logistic_regression_probe/all_latents.py @@ -100,6 +100,7 @@ def all_latents( plm_model=plm_model, sae_model=sae_model, plm_layer=plm_layer, + pool_over_annotation=False, ) with warnings.catch_warnings(): # LogisticRegression throws warnings when it can't converge.