Skip to content

Commit 715f0f4

Browse files
committed
cat_cols parameter to adv-validatoin
1 parent 1f011a4 commit 715f0f4

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

nyaggle/validation/adversarial_validate.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def adversarial_validate(X_train: pd.DataFrame,
1919
X_test: pd.DataFrame,
2020
importance_type: str = 'gain',
2121
estimator: Optional[BaseEstimator] = None,
22+
cat_cols = None,
2223
cv = None) -> ADVResult:
2324
"""
2425
Perform adversarial validation between X_train and X_test.
@@ -74,8 +75,12 @@ def adversarial_validate(X_train: pd.DataFrame,
7475
if cv is None:
7576
cv = Take(1, KFold(5, shuffle=True, random_state=0))
7677

78+
fit_params = {'verbose': -1}
79+
if cat_cols:
80+
fit_params['categorical_feature'] = cat_cols
81+
7782
result = cross_validate(estimator, concat, y, None, cv=cv, predict_proba=True,
78-
eval_func=roc_auc_score, fit_params={'verbose': -1}, importance_type=importance_type)
83+
eval_func=roc_auc_score, fit_params=fit_params, importance_type=importance_type)
7984

8085
importance = pd.concat(result.importance)
8186
importance = importance.groupby('feature')['importance'].mean().reset_index()

0 commit comments

Comments
 (0)