Skip to content

Commit a9dba80

Browse files
authored
migrate from unittest to pytest (#140)
* migrate unittest to pytest * fix typo in test name
1 parent 34c6a60 commit a9dba80

File tree

3 files changed

+74
-57
lines changed

3 files changed

+74
-57
lines changed

boruta/test/test_boruta.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import numpy as np
2+
import pandas as pd
3+
import pytest
4+
from sklearn.ensemble import RandomForestClassifier
5+
6+
from boruta import BorutaPy
7+
8+
9+
@pytest.mark.parametrize("tree_n,expected", [(10, 44), (100, 141)])
10+
def test_get_tree_num(tree_n, expected):
11+
rfc = RandomForestClassifier(max_depth=10)
12+
bt = BorutaPy(rfc)
13+
assert bt._get_tree_num(tree_n) == expected
14+
15+
16+
@pytest.fixture(scope="module")
17+
def Xy():
18+
np.random.seed(42)
19+
y = np.random.binomial(1, 0.5, 1000)
20+
X = np.zeros((1000, 10))
21+
22+
z = (y - np.random.binomial(1, 0.1, 1000) +
23+
np.random.binomial(1, 0.1, 1000))
24+
z[z == -1] = 0
25+
z[z == 2] = 1
26+
27+
# 5 relevant features
28+
X[:, 0] = z
29+
X[:, 1] = (y * np.abs(np.random.normal(0, 1, 1000))
30+
+ np.random.normal(0, 0.1, 1000))
31+
X[:, 2] = y + np.random.normal(0, 1, 1000)
32+
X[:, 3] = y**2 + np.random.normal(0, 1, 1000)
33+
X[:, 4] = np.sqrt(y) + np.random.binomial(2, 0.1, 1000)
34+
35+
# 5 irrelevant features
36+
X[:, 5] = np.random.normal(0, 1, 1000)
37+
X[:, 6] = np.random.poisson(1, 1000)
38+
X[:, 7] = np.random.binomial(1, 0.3, 1000)
39+
X[:, 8] = np.random.normal(0, 1, 1000)
40+
X[:, 9] = np.random.poisson(1, 1000)
41+
42+
return X, y
43+
44+
45+
def test_if_boruta_extracts_relevant_features(Xy):
46+
X, y = Xy
47+
rfc = RandomForestClassifier()
48+
bt = BorutaPy(rfc)
49+
bt.fit(X, y)
50+
assert list(range(5)) == list(np.where(bt.support_)[0])
51+
52+
53+
def test_if_it_works_with_dataframe_input(Xy):
54+
X, y = Xy
55+
X_df, y_df = pd.DataFrame(X), pd.Series(y)
56+
bt = BorutaPy(RandomForestClassifier())
57+
bt.fit(X_df, y_df)
58+
assert list(range(5)) == list(np.where(bt.support_)[0])
59+
60+
61+
def test_dataframe_is_returned(Xy):
62+
X, y = Xy
63+
X_df, y_df = pd.DataFrame(X), pd.Series(y)
64+
rfc = RandomForestClassifier()
65+
bt = BorutaPy(rfc)
66+
bt.fit(X_df, y_df)
67+
assert isinstance(bt.transform(X_df, return_df=True), pd.DataFrame)

boruta/test/unit_tests.py

Lines changed: 0 additions & 57 deletions
This file was deleted.

test_requirements.txt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
-r requirements.txt
2+
pytest>=5.4.1
3+
4+
# repo maintenance tooling
5+
black>=21.5b1
6+
flake8>=3.9.2
7+
isort>=5.8.0

0 commit comments

Comments
 (0)