Skip to content

Commit 489809a

Browse files
committed
feat: added support for passing custom feature extractors to online dataset loader
1 parent 3c22222 commit 489809a

2 files changed

Lines changed: 20 additions & 6 deletions

File tree

flaxdiff/data/online_loader.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -117,22 +117,30 @@ def map_sample(
117117
# "error": str(e)
118118
# })
119119
pass
120-
120+
121+
def default_feature_extractor(sample):
122+
return {
123+
"url": sample["url"],
124+
"caption": sample["caption"],
125+
}
121126

122127
def map_batch(
123128
batch, num_threads=256, image_shape=(256, 256),
124129
min_image_shape=(128, 128),
125130
timeout=15, retries=3, image_processor=default_image_processor,
126131
upscale_interpolation=cv2.INTER_CUBIC,
127132
downscale_interpolation=cv2.INTER_AREA,
133+
feature_extractor=default_feature_extractor,
128134
):
129135
try:
130136
map_sample_fn = partial(map_sample, image_shape=image_shape, min_image_shape=min_image_shape,
131137
timeout=timeout, retries=retries, image_processor=image_processor,
132138
upscale_interpolation=upscale_interpolation,
133139
downscale_interpolation=downscale_interpolation)
134140
with ThreadPoolExecutor(max_workers=num_threads) as executor:
135-
executor.map(map_sample_fn, batch["url"], batch['caption'])
141+
features = feature_extractor(batch)
142+
url, caption = features["url"], features["caption"]
143+
executor.map(map_sample_fn, url, caption)
136144
except Exception as e:
137145
print(f"Error maping batch", e)
138146
traceback.print_exc()
@@ -149,12 +157,14 @@ def parallel_image_loader(
149157
num_threads=256, timeout=15, retries=3, image_processor=default_image_processor,
150158
upscale_interpolation=cv2.INTER_CUBIC,
151159
downscale_interpolation=cv2.INTER_AREA,
160+
feature_extractor=default_feature_extractor,
152161
):
153162
map_batch_fn = partial(map_batch, num_threads=num_threads, image_shape=image_shape,
154163
min_image_shape=min_image_shape,
155164
timeout=timeout, retries=retries, image_processor=image_processor,
156165
upscale_interpolation=upscale_interpolation,
157-
downscale_interpolation=downscale_interpolation)
166+
downscale_interpolation=downscale_interpolation,
167+
feature_extractor=feature_extractor)
158168
shard_len = len(dataset) // num_workers
159169
print(f"Local Shard lengths: {shard_len}")
160170
with multiprocessing.Pool(num_workers) as pool:
@@ -181,6 +191,7 @@ def __init__(
181191
image_processor=default_image_processor,
182192
upscale_interpolation=cv2.INTER_CUBIC,
183193
downscale_interpolation=cv2.INTER_AREA,
194+
feature_extractor=default_feature_extractor,
184195
):
185196
self.dataset = dataset
186197
self.num_workers = num_workers
@@ -191,7 +202,8 @@ def __init__(
191202
num_workers=num_workers,
192203
timeout=timeout, retries=retries, image_processor=image_processor,
193204
upscale_interpolation=upscale_interpolation,
194-
downscale_interpolation=downscale_interpolation)
205+
downscale_interpolation=downscale_interpolation,
206+
feature_extractor=feature_extractor)
195207
self.thread = threading.Thread(target=loader, args=(dataset,))
196208
self.thread.start()
197209

@@ -256,6 +268,7 @@ def __init__(
256268
image_processor=default_image_processor,
257269
upscale_interpolation=cv2.INTER_CUBIC,
258270
downscale_interpolation=cv2.INTER_AREA,
271+
feature_extractor=default_feature_extractor,
259272
):
260273
if isinstance(dataset, str):
261274
dataset_path = dataset
@@ -281,7 +294,8 @@ def __init__(
281294
num_workers=num_workers, batch_size=batch_size, num_threads=num_threads,
282295
timeout=timeout, retries=retries, image_processor=image_processor,
283296
upscale_interpolation=upscale_interpolation,
284-
downscale_interpolation=downscale_interpolation)
297+
downscale_interpolation=downscale_interpolation,
298+
feature_extractor=feature_extractor)
285299
self.batch_size = batch_size
286300

287301
# Launch a thread to load batches in the background

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
setup(
1212
name='flaxdiff',
1313
packages=find_packages(),
14-
version='0.1.31',
14+
version='0.1.32',
1515
description='A versatile and easy to understand Diffusion library',
1616
long_description=open('README.md').read(),
1717
long_description_content_type='text/markdown',

0 commit comments

Comments
 (0)