|
2 | 2 |
|
3 | 3 | from sklearn import clone, datasets |
4 | 4 |
|
| 5 | +from sklearn.ensemble import RandomForestClassifier as SklearnRF |
5 | 6 | import dislib as ds |
6 | 7 | from dislib.classification import CascadeSVM, RandomForestClassifier |
7 | 8 | from dislib.cluster import DBSCAN, KMeans, GaussianMixture |
@@ -79,6 +80,43 @@ def test_fit(self): |
79 | 80 | self.assertTrue(hasattr(searcher, 'scorer_')) |
80 | 81 | self.assertEqual(searcher.n_splits_, 5) |
81 | 82 |
|
| 83 | + def test_fit_sk(self): |
| 84 | + """Tests GridSearchCV fit().""" |
| 85 | + x_np, y_np = datasets.load_iris(return_X_y=True) |
| 86 | + x = ds.array(x_np, (30, 4)) |
| 87 | + y = ds.array(y_np[:, np.newaxis], (30, 1)) |
| 88 | + |
| 89 | + param_grid = {'n_estimators': (2, 4), |
| 90 | + 'max_depth': range(3, 5)} |
| 91 | + rf = SklearnRF() |
| 92 | + print("ESTIMATOR TYPE") |
| 93 | + print(str(type(rf))) |
| 94 | + |
| 95 | + searcher = GridSearchCV(rf, param_grid) |
| 96 | + searcher.fit(x, y) |
| 97 | + |
| 98 | + expected_keys = {'param_max_depth', 'param_n_estimators', 'params', |
| 99 | + 'mean_test_score', 'std_test_score', |
| 100 | + 'rank_test_score'} |
| 101 | + split_keys = {'split%d_test_score' % i for i in range(5)} |
| 102 | + expected_keys.update(split_keys) |
| 103 | + self.assertSetEqual(set(searcher.cv_results_.keys()), expected_keys) |
| 104 | + |
| 105 | + expected_params = [(3, 2), (3, 4), (4, 2), (4, 4)] |
| 106 | + for params in searcher.cv_results_['params']: |
| 107 | + m = params['max_depth'] |
| 108 | + n = params['n_estimators'] |
| 109 | + self.assertIn((m, n), expected_params) |
| 110 | + expected_params.remove((m, n)) |
| 111 | + self.assertEqual(len(expected_params), 0) |
| 112 | + |
| 113 | + self.assertTrue(hasattr(searcher, 'best_estimator_')) |
| 114 | + self.assertTrue(hasattr(searcher, 'best_score_')) |
| 115 | + self.assertTrue(hasattr(searcher, 'best_params_')) |
| 116 | + self.assertTrue(hasattr(searcher, 'best_index_')) |
| 117 | + self.assertTrue(hasattr(searcher, 'scorer_')) |
| 118 | + self.assertEqual(searcher.n_splits_, 5) |
| 119 | + |
82 | 120 | def test_fit_2(self): |
83 | 121 | """Tests GridSearchCV fit() with different data.""" |
84 | 122 | x_np, y_np = datasets.load_breast_cancer(return_X_y=True) |
|
0 commit comments