4
4
import dask .dataframe as dd
5
5
import pyarrow as pa
6
6
import pytest
7
- from dask .distributed import Client
8
7
from fondant .component .data_io import DaskDataLoader , DaskDataWriter
9
8
from fondant .core .component_spec import ComponentSpec , OperationSpec
10
9
from fondant .core .manifest import Manifest
21
20
NUMBER_OF_TEST_ROWS = 151
22
21
23
22
24
- @pytest .fixture ()
25
- def dask_client (): # noqa: PT004
26
- client = Client ()
27
- yield
28
- client .close ()
29
-
30
-
31
23
@pytest .fixture ()
32
24
def manifest ():
33
25
return Manifest .from_file (manifest_path )
@@ -121,7 +113,6 @@ def test_write_dataset(
121
113
dataframe ,
122
114
manifest ,
123
115
component_spec ,
124
- dask_client ,
125
116
):
126
117
"""Test writing out subsets."""
127
118
# Dictionary specifying the expected subsets to write and their column names
@@ -134,7 +125,7 @@ def test_write_dataset(
134
125
operation_spec = OperationSpec (component_spec ),
135
126
)
136
127
# write dataframe to temp dir
137
- data_writer .write_dataframe (dataframe , dask_client )
128
+ data_writer .write_dataframe (dataframe )
138
129
# read written data and assert
139
130
dataframe = dd .read_parquet (
140
131
temp_dir
@@ -152,7 +143,6 @@ def test_write_dataset_custom_produces(
152
143
dataframe ,
153
144
manifest ,
154
145
component_spec_produces ,
155
- dask_client ,
156
146
):
157
147
"""Test writing out subsets."""
158
148
produces = {
@@ -175,7 +165,7 @@ def test_write_dataset_custom_produces(
175
165
)
176
166
177
167
# write dataframe to temp dir
178
- data_writer .write_dataframe (dataframe , dask_client )
168
+ data_writer .write_dataframe (dataframe )
179
169
# # read written data and assert
180
170
dataframe = dd .read_parquet (
181
171
temp_dir
@@ -194,7 +184,6 @@ def test_write_reset_index(
194
184
dataframe ,
195
185
manifest ,
196
186
component_spec ,
197
- dask_client ,
198
187
):
199
188
"""Test writing out the index and fields that have no dask index and checking
200
189
if the id index was created.
@@ -207,19 +196,18 @@ def test_write_reset_index(
207
196
manifest = manifest ,
208
197
operation_spec = OperationSpec (component_spec ),
209
198
)
210
- data_writer .write_dataframe (dataframe , dask_client )
199
+ data_writer .write_dataframe (dataframe )
211
200
dataframe = dd .read_parquet (fn )
212
201
assert dataframe .index .name == "id"
213
202
214
203
215
204
@pytest .mark .parametrize ("partitions" , list (range (1 , 5 )))
216
- def test_write_divisions ( # noqa: PLR0913
205
+ def test_write_divisions (
217
206
tmp_path_factory ,
218
207
dataframe ,
219
208
manifest ,
220
209
component_spec ,
221
210
partitions ,
222
- dask_client ,
223
211
):
224
212
"""Test writing out index and subsets and asserting they have the divisions of the dataframe."""
225
213
# repartition the dataframe (default is 3 partitions)
@@ -233,7 +221,7 @@ def test_write_divisions( # noqa: PLR0913
233
221
operation_spec = OperationSpec (component_spec ),
234
222
)
235
223
236
- data_writer .write_dataframe (dataframe , dask_client )
224
+ data_writer .write_dataframe (dataframe )
237
225
238
226
dataframe = dd .read_parquet (fn )
239
227
assert dataframe .index .name == "id"
@@ -245,7 +233,6 @@ def test_write_fields_invalid(
245
233
dataframe ,
246
234
manifest ,
247
235
component_spec ,
248
- dask_client ,
249
236
):
250
237
"""Test writing out fields but the dataframe columns are incomplete."""
251
238
with tmp_path_factory .mktemp ("temp" ) as fn :
@@ -262,15 +249,14 @@ def test_write_fields_invalid(
262
249
r"but not found in dataframe"
263
250
)
264
251
with pytest .raises (ValueError , match = expected_error_msg ):
265
- data_writer .write_dataframe (dataframe , dask_client )
252
+ data_writer .write_dataframe (dataframe )
266
253
267
254
268
255
def test_write_fields_invalid_several_fields_missing (
269
256
tmp_path_factory ,
270
257
dataframe ,
271
258
manifest ,
272
259
component_spec ,
273
- dask_client ,
274
260
):
275
261
"""Test writing out fields but the dataframe columns are incomplete."""
276
262
with tmp_path_factory .mktemp ("temp" ) as fn :
@@ -288,4 +274,4 @@ def test_write_fields_invalid_several_fields_missing(
288
274
r"but not found in dataframe"
289
275
)
290
276
with pytest .raises (ValueError , match = expected_error_msg ):
291
- data_writer .write_dataframe (dataframe , dask_client )
277
+ data_writer .write_dataframe (dataframe )
0 commit comments