Skip to content

Commit 4758343

Browse files
authored
Merge pull request #89 from nyanp/fix/scikit-learn-1.0
Support scikit-learn 1.0
2 parents 32bffe8 + 3ad1cc8 commit 4758343

File tree

2 files changed

+3
-6
lines changed

2 files changed

+3
-6
lines changed

nyaggle/validation/split.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def check_cv(cv: Union[int, Iterable, BaseCrossValidator] = 5,
4646
else:
4747
return KFold(cv, shuffle=True, random_state=random_state)
4848

49-
return model_selection.check_cv(cv, y, stratified)
49+
return model_selection.check_cv(cv, y, classifier=stratified)
5050

5151

5252
class Take(BaseCrossValidator):
@@ -380,8 +380,7 @@ class StratifiedGroupKFold(_BaseKFold):
380380

381381
def __init__(self, n_splits: int = 3, shuffle: bool = False,
382382
random_state: Optional[Union[int, np.random.RandomState]] = None):
383-
super(StratifiedGroupKFold, self).__init__(n_splits, shuffle,
384-
random_state)
383+
super().__init__(n_splits, shuffle=shuffle, random_state=random_state)
385384

386385
def _make_test_folds(self, X, y=None, groups=None):
387386
"""

tests/experiment/test_hyperparameter_tuner.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,7 @@ def _check_parameter_tunes(params, x, y):
1616

1717

1818
def test_regression_problem_parameter_tunes():
19-
dataset = datasets.load_boston()
20-
x = pd.DataFrame(dataset.data, columns=dataset.feature_names)
21-
y = pd.Series(dataset.target)
19+
x, y = datasets.load_diabetes(return_X_y=True, as_frame=True)
2220
params = {
2321
'objective': 'regression',
2422
'metric': 'rmse',

0 commit comments

Comments
 (0)