|
| 1 | +import numpy as np |
| 2 | +import pandas as pd |
| 3 | +import pytest |
| 4 | +from sklearn.ensemble import RandomForestClassifier |
| 5 | + |
| 6 | +from boruta import BorutaPy |
| 7 | + |
| 8 | + |
| 9 | +@pytest.mark.parametrize("tree_n,expected", [(10, 44), (100, 141)]) |
| 10 | +def test_get_tree_num(tree_n, expected): |
| 11 | + rfc = RandomForestClassifier(max_depth=10) |
| 12 | + bt = BorutaPy(rfc) |
| 13 | + assert bt._get_tree_num(tree_n) == expected |
| 14 | + |
| 15 | + |
| 16 | +@pytest.fixture(scope="module") |
| 17 | +def Xy(): |
| 18 | + np.random.seed(42) |
| 19 | + y = np.random.binomial(1, 0.5, 1000) |
| 20 | + X = np.zeros((1000, 10)) |
| 21 | + |
| 22 | + z = (y - np.random.binomial(1, 0.1, 1000) + |
| 23 | + np.random.binomial(1, 0.1, 1000)) |
| 24 | + z[z == -1] = 0 |
| 25 | + z[z == 2] = 1 |
| 26 | + |
| 27 | + # 5 relevant features |
| 28 | + X[:, 0] = z |
| 29 | + X[:, 1] = (y * np.abs(np.random.normal(0, 1, 1000)) |
| 30 | + + np.random.normal(0, 0.1, 1000)) |
| 31 | + X[:, 2] = y + np.random.normal(0, 1, 1000) |
| 32 | + X[:, 3] = y**2 + np.random.normal(0, 1, 1000) |
| 33 | + X[:, 4] = np.sqrt(y) + np.random.binomial(2, 0.1, 1000) |
| 34 | + |
| 35 | + # 5 irrelevant features |
| 36 | + X[:, 5] = np.random.normal(0, 1, 1000) |
| 37 | + X[:, 6] = np.random.poisson(1, 1000) |
| 38 | + X[:, 7] = np.random.binomial(1, 0.3, 1000) |
| 39 | + X[:, 8] = np.random.normal(0, 1, 1000) |
| 40 | + X[:, 9] = np.random.poisson(1, 1000) |
| 41 | + |
| 42 | + return X, y |
| 43 | + |
| 44 | + |
| 45 | +def test_if_boruta_extracts_relevant_features(Xy): |
| 46 | + X, y = Xy |
| 47 | + rfc = RandomForestClassifier() |
| 48 | + bt = BorutaPy(rfc) |
| 49 | + bt.fit(X, y) |
| 50 | + assert list(range(5)) == list(np.where(bt.support_)[0]) |
| 51 | + |
| 52 | + |
| 53 | +def test_if_it_works_with_dataframe_input(Xy): |
| 54 | + X, y = Xy |
| 55 | + X_df, y_df = pd.DataFrame(X), pd.Series(y) |
| 56 | + bt = BorutaPy(RandomForestClassifier()) |
| 57 | + bt.fit(X_df, y_df) |
| 58 | + assert list(range(5)) == list(np.where(bt.support_)[0]) |
| 59 | + |
| 60 | + |
| 61 | +def test_dataframe_is_returned(Xy): |
| 62 | + X, y = Xy |
| 63 | + X_df, y_df = pd.DataFrame(X), pd.Series(y) |
| 64 | + rfc = RandomForestClassifier() |
| 65 | + bt = BorutaPy(rfc) |
| 66 | + bt.fit(X_df, y_df) |
| 67 | + assert isinstance(bt.transform(X_df, return_df=True), pd.DataFrame) |
0 commit comments