Skip to content

Commit b8b08ba

Browse files
chrispylmbaak
authored andcommitted
added test for default case of dtype
1 parent a68b069 commit b8b08ba

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

tests/integration/test_pandas_em.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,10 +136,18 @@ def test_pandas_tfidf(dtype):
136136
np.testing.assert_allclose(actual_value, exp_value, rtol=0, atol=0.001)
137137

138138

139+
def test_pandas_tfidf_default_dtype():
140+
pandas_t = PandasNormalizedTfidfVectorizer()
141+
unique_names = [str(uuid.uuid4()) for i in range(100)]
142+
gt_names = pd.Series(unique_names)
143+
pandas_t.fit(gt_names)
144+
assert pandas_t.idf_.dtype == np.float32
145+
146+
139147
@pytest.mark.parametrize(
140148
("dtype", "data_size"), [(np.float32, 100), (np.float64, 100), (np.float32, 1000000), (np.float64, 1000000)]
141149
)
142-
def test_pandas_tfidf_dtype(dtype, data_size):
150+
def test_pandas_tfidf_dtype_for_different_input_sizes(dtype, data_size):
143151
pandas_t = PandasNormalizedTfidfVectorizer(dtype=dtype)
144152
unique_names = [str(uuid.uuid4()) for i in range(data_size)]
145153
gt_names = pd.Series(unique_names)

0 commit comments

Comments
 (0)