@@ -117,22 +117,30 @@ def map_sample(
117117 # "error": str(e)
118118 # })
119119 pass
120-
120+
121+ def default_feature_extractor (sample ):
122+ return {
123+ "url" : sample ["url" ],
124+ "caption" : sample ["caption" ],
125+ }
121126
122127def map_batch (
123128 batch , num_threads = 256 , image_shape = (256 , 256 ),
124129 min_image_shape = (128 , 128 ),
125130 timeout = 15 , retries = 3 , image_processor = default_image_processor ,
126131 upscale_interpolation = cv2 .INTER_CUBIC ,
127132 downscale_interpolation = cv2 .INTER_AREA ,
133+ feature_extractor = default_feature_extractor ,
128134):
129135 try :
130136 map_sample_fn = partial (map_sample , image_shape = image_shape , min_image_shape = min_image_shape ,
131137 timeout = timeout , retries = retries , image_processor = image_processor ,
132138 upscale_interpolation = upscale_interpolation ,
133139 downscale_interpolation = downscale_interpolation )
134140 with ThreadPoolExecutor (max_workers = num_threads ) as executor :
135- executor .map (map_sample_fn , batch ["url" ], batch ['caption' ])
141+ features = feature_extractor (batch )
142+ url , caption = features ["url" ], features ["caption" ]
143+ executor .map (map_sample_fn , url , caption )
136144 except Exception as e :
137145 print (f"Error maping batch" , e )
138146 traceback .print_exc ()
@@ -149,12 +157,14 @@ def parallel_image_loader(
149157 num_threads = 256 , timeout = 15 , retries = 3 , image_processor = default_image_processor ,
150158 upscale_interpolation = cv2 .INTER_CUBIC ,
151159 downscale_interpolation = cv2 .INTER_AREA ,
160+ feature_extractor = default_feature_extractor ,
152161):
153162 map_batch_fn = partial (map_batch , num_threads = num_threads , image_shape = image_shape ,
154163 min_image_shape = min_image_shape ,
155164 timeout = timeout , retries = retries , image_processor = image_processor ,
156165 upscale_interpolation = upscale_interpolation ,
157- downscale_interpolation = downscale_interpolation )
166+ downscale_interpolation = downscale_interpolation ,
167+ feature_extractor = feature_extractor )
158168 shard_len = len (dataset ) // num_workers
159169 print (f"Local Shard lengths: { shard_len } " )
160170 with multiprocessing .Pool (num_workers ) as pool :
@@ -181,6 +191,7 @@ def __init__(
181191 image_processor = default_image_processor ,
182192 upscale_interpolation = cv2 .INTER_CUBIC ,
183193 downscale_interpolation = cv2 .INTER_AREA ,
194+ feature_extractor = default_feature_extractor ,
184195 ):
185196 self .dataset = dataset
186197 self .num_workers = num_workers
@@ -191,7 +202,8 @@ def __init__(
191202 num_workers = num_workers ,
192203 timeout = timeout , retries = retries , image_processor = image_processor ,
193204 upscale_interpolation = upscale_interpolation ,
194- downscale_interpolation = downscale_interpolation )
205+ downscale_interpolation = downscale_interpolation ,
206+ feature_extractor = feature_extractor )
195207 self .thread = threading .Thread (target = loader , args = (dataset ,))
196208 self .thread .start ()
197209
@@ -256,6 +268,7 @@ def __init__(
256268 image_processor = default_image_processor ,
257269 upscale_interpolation = cv2 .INTER_CUBIC ,
258270 downscale_interpolation = cv2 .INTER_AREA ,
271+ feature_extractor = default_feature_extractor ,
259272 ):
260273 if isinstance (dataset , str ):
261274 dataset_path = dataset
@@ -281,7 +294,8 @@ def __init__(
281294 num_workers = num_workers , batch_size = batch_size , num_threads = num_threads ,
282295 timeout = timeout , retries = retries , image_processor = image_processor ,
283296 upscale_interpolation = upscale_interpolation ,
284- downscale_interpolation = downscale_interpolation )
297+ downscale_interpolation = downscale_interpolation ,
298+ feature_extractor = feature_extractor )
285299 self .batch_size = batch_size
286300
287301 # Launch a thread to load batches in the background
0 commit comments