diff --git a/boruta/boruta_py.py b/boruta/boruta_py.py index 1f97bd8..4b463cc 100644 --- a/boruta/boruta_py.py +++ b/boruta/boruta_py.py @@ -326,7 +326,7 @@ def _fit(self, X, y): # set n_estimators if self.n_estimators != 'auto': - self.estimator.set_params(n_estimators=self.n_estimators) + self._set_n_estimators(self.n_estimators) # main feature selection loop while np.any(dec_reg == 0) and _iter < self.max_iter: @@ -335,7 +335,7 @@ def _fit(self, X, y): # number of features that aren't rejected not_rejected = np.where(dec_reg >= 0)[0].shape[0] n_tree = self._get_tree_num(not_rejected) - self.estimator.set_params(n_estimators=n_tree) + self._set_n_estimators(n_estimators=n_tree) # make sure we start with a new tree in each iteration if self._is_lightgbm: @@ -454,6 +454,17 @@ def _transform(self, X, weak=False, return_df=False): X = X[:, indices] return X + def _set_n_estimators(self, n_estimators): + try: + self.estimator.set_params(n_estimators=n_estimators) + except ValueError: + raise ValueError( + f"The estimator {self.estimator} does not take the parameter " + "n_estimators. Use Random Forests or gradient boosting machines " + "instead." + ) + return self + def _get_tree_num(self, n_feat): depth = None try: diff --git a/boruta/test/test_boruta.py b/boruta/test/test_boruta.py index 27d42e6..3ad05ec 100644 --- a/boruta/test/test_boruta.py +++ b/boruta/test/test_boruta.py @@ -2,6 +2,7 @@ import pandas as pd import pytest from sklearn.ensemble import RandomForestClassifier +from sklearn.tree import DecisionTreeClassifier, ExtraTreeClassifier from boruta import BorutaPy @@ -26,8 +27,8 @@ def Xy(): # 5 relevant features X[:, 0] = z - X[:, 1] = (y * np.abs(np.random.normal(0, 1, 1000)) - + np.random.normal(0, 0.1, 1000)) + X[:, 1] = (y * np.abs(np.random.normal(0, 1, 1000)) + + np.random.normal(0, 0.1, 1000)) X[:, 2] = y + np.random.normal(0, 1, 1000) X[:, 3] = y**2 + np.random.normal(0, 1, 1000) X[:, 4] = np.sqrt(y) + np.random.binomial(2, 0.1, 1000) @@ -65,3 +66,18 @@ def test_dataframe_is_returned(Xy): bt = BorutaPy(rfc) bt.fit(X_df, y_df) assert isinstance(bt.transform(X_df, return_df=True), pd.DataFrame) + + +@pytest.mark.parametrize("tree", [ExtraTreeClassifier(), DecisionTreeClassifier()]) +def test_boruta_with_decision_trees(tree, Xy): + msg = ( + f"The estimator {tree} does not take the parameter " + "n_estimators. Use Random Forests or gradient boosting machines " + "instead." + ) + X, y = Xy + bt = BorutaPy(tree) + with pytest.raises(ValueError) as record: + bt.fit(X, y) + + assert str(record.value) == msg \ No newline at end of file