Skip to content

Commit 8b280cb

Browse files
enakta0xE0F
andauthored
refactor(generators): unify generators to work with any storage backend (argonne-lcf#329)
Every new storage backend required copy-pasting each generator into an _XXX sibling file: npz_generator_s3.py, npy_generator_s3.py and so on. The only difference was whether to write the output locally on disk, directly via numpy/PIL, or via the storage interface. This makes the pattern unsustainable: two duplicated formats today, more with each new backend — incurring a significant maintenance burden. Since all generators already had a storage instance and used it to generate file names, we can leverage it. The only set of generators now can check if the stroage is locally available via `islocalfs` and use some optimisation, if any. If the storage is not local, the sample serializes to io.BytesIO, call buf.getvalue(), and delegate to self.storage.put_data(). All storage backends receive plain bytes as designed by the storage interface, removing type inspection, seek() and getvalue() calls scattered across backends. - FileStorage.put_data was never called, had text-mode open and a double get_uri call (once from the generator, once inside put_data itself). Now it is the default write path for LOCAL_FS, used by almost every workload config. get_data aligned to binary mode ("rb") for consistency. - AIStoreStorage.put_data: remove isinstance dispatch, accept bytes directly. - S3TorchStorage.put_data: remove data.getvalue() — just write data. - generator_factory: removed S3/AIStore branching for NPZ, NPY, JPEG. - factory referenced jpeg_generator_s3.JPEGGeneratorS3 which never existed; JPEG + S3/AIStore would crash at import time. After this patch, adding a new storage backend requires no changes in any generator. Adding a new data format automatically works with all backends. Signed-off-by: Denis Barakhtanov <dbarahtanov@enakta.com> Co-authored-by: Denis Barakhtanov <denis.barahtanov@gmail.com>
1 parent ea53bcf commit 8b280cb

File tree

11 files changed

+48
-164
lines changed

11 files changed

+48
-164
lines changed

dlio_benchmark/data_generator/generator_factory.py

Lines changed: 7 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,7 @@
1414
See the License for the specific language governing permissions and
1515
limitations under the License.
1616
"""
17-
from dlio_benchmark.utils.config import ConfigArguments
18-
19-
from dlio_benchmark.common.enumerations import FormatType, StorageType
17+
from dlio_benchmark.common.enumerations import FormatType
2018
from dlio_benchmark.common.error_code import ErrorCodes
2119

2220
class GeneratorFactory(object):
@@ -25,7 +23,6 @@ def __init__(self):
2523

2624
@staticmethod
2725
def get_generator(type):
28-
_args = ConfigArguments.get_instance()
2926
if type == FormatType.TFRECORD:
3027
from dlio_benchmark.data_generator.tf_generator import TFRecordGenerator
3128
return TFRecordGenerator()
@@ -36,29 +33,14 @@ def get_generator(type):
3633
from dlio_benchmark.data_generator.csv_generator import CSVGenerator
3734
return CSVGenerator()
3835
elif type == FormatType.NPZ:
39-
# Use S3 generators for both S3 and AIStore
40-
if _args.storage_type in (StorageType.S3, StorageType.AISTORE):
41-
from dlio_benchmark.data_generator.npz_generator_s3 import NPZGeneratorS3
42-
return NPZGeneratorS3()
43-
else:
44-
from dlio_benchmark.data_generator.npz_generator import NPZGenerator
45-
return NPZGenerator()
36+
from dlio_benchmark.data_generator.npz_generator import NPZGenerator
37+
return NPZGenerator()
4638
elif type == FormatType.NPY:
47-
# Use S3 generators for both S3 and AIStore
48-
if _args.storage_type in (StorageType.S3, StorageType.AISTORE):
49-
from dlio_benchmark.data_generator.npy_generator_s3 import NPYGeneratorS3
50-
return NPYGeneratorS3()
51-
else:
52-
from dlio_benchmark.data_generator.npy_generator import NPYGenerator
53-
return NPYGenerator()
39+
from dlio_benchmark.data_generator.npy_generator import NPYGenerator
40+
return NPYGenerator()
5441
elif type == FormatType.JPEG:
55-
# Use S3 generators for both S3 and AIStore
56-
if _args.storage_type in (StorageType.S3, StorageType.AISTORE):
57-
from dlio_benchmark.data_generator.jpeg_generator_s3 import JPEGGeneratorS3
58-
return JPEGGeneratorS3()
59-
else:
60-
from dlio_benchmark.data_generator.jpeg_generator import JPEGGenerator
61-
return JPEGGenerator()
42+
from dlio_benchmark.data_generator.jpeg_generator import JPEGGenerator
43+
return JPEGGenerator()
6244
elif type == FormatType.PNG:
6345
from dlio_benchmark.data_generator.png_generator import PNGGenerator
6446
return PNGGenerator()

dlio_benchmark/data_generator/jpeg_generator.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
See the License for the specific language governing permissions and
1515
limitations under the License.
1616
"""
17+
import io
1718
import numpy as np
1819
import PIL.Image as im
1920

@@ -53,5 +54,8 @@ def generate(self):
5354
self.logger.info(f"Generated file {i}/{self.total_files_to_generate}")
5455
out_path_spec = self.storage.get_uri(self._file_list[i])
5556
progress(i+1, self.total_files_to_generate, "Generating JPEG Data")
56-
img.save(out_path_spec, format='JPEG', bits=8)
57+
output = out_path_spec if self.storage.islocalfs() else io.BytesIO()
58+
img.save(output, format='JPEG', bits=8)
59+
if not self.storage.islocalfs():
60+
self.storage.put_data(out_path_spec, output.getvalue())
5761
np.random.seed()

dlio_benchmark/data_generator/npy_generator.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
See the License for the specific language governing permissions and
1515
limitations under the License.
1616
"""
17+
import io
1718
import numpy as np
1819

1920
from dlio_benchmark.data_generator.data_generator import DataGenerator
@@ -49,5 +50,8 @@ def generate(self):
4950

5051
out_path_spec = self.storage.get_uri(self._file_list[i])
5152
progress(i+1, self.total_files_to_generate, "Generating NPY Data")
52-
np.save(out_path_spec, records)
53+
output = out_path_spec if self.storage.islocalfs() else io.BytesIO()
54+
np.save(output, records)
55+
if not self.storage.islocalfs():
56+
self.storage.put_data(out_path_spec, output.getvalue())
5357
np.random.seed()

dlio_benchmark/data_generator/npy_generator_s3.py

Lines changed: 0 additions & 57 deletions
This file was deleted.

dlio_benchmark/data_generator/npz_generator.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
See the License for the specific language governing permissions and
1515
limitations under the License.
1616
"""
17+
import io
1718
import numpy as np
1819

1920
from dlio_benchmark.common.enumerations import Compression
@@ -48,8 +49,11 @@ def generate(self):
4849
records = gen_random_tensor(shape=(dim_, dim[2*i+1], self.num_samples), dtype=self._args.record_element_dtype, rng=rng)
4950
out_path_spec = self.storage.get_uri(self._file_list[i])
5051
progress(i+1, self.total_files_to_generate, "Generating NPZ Data")
52+
output = out_path_spec if self.storage.islocalfs() else io.BytesIO()
5153
if self.compression != Compression.ZIP:
52-
np.savez(out_path_spec, x=records, y=record_labels)
54+
np.savez(output, x=records, y=record_labels)
5355
else:
54-
np.savez_compressed(out_path_spec, x=records, y=record_labels)
56+
np.savez_compressed(output, x=records, y=record_labels)
57+
if not self.storage.islocalfs():
58+
self.storage.put_data(out_path_spec, output.getvalue())
5559
np.random.seed()

dlio_benchmark/data_generator/npz_generator_s3.py

Lines changed: 0 additions & 59 deletions
This file was deleted.

dlio_benchmark/data_generator/png_generator.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
See the License for the specific language governing permissions and
1515
limitations under the License.
1616
"""
17+
import io
1718
import numpy as np
1819
import PIL.Image as im
1920

@@ -49,5 +50,8 @@ def generate(self):
4950
self.logger.info(f"Generated file {i}/{self.total_files_to_generate}")
5051
out_path_spec = self.storage.get_uri(self._file_list[i])
5152
progress(i+1, self.total_files_to_generate, "Generating PNG Data")
52-
img.save(out_path_spec, format='PNG', bits=8)
53+
output = out_path_spec if self.storage.islocalfs() else io.BytesIO()
54+
img.save(output, format='PNG', bits=8)
55+
if not self.storage.islocalfs():
56+
self.storage.put_data(out_path_spec, output.getvalue())
5357
np.random.seed()

dlio_benchmark/storage/aistore_storage.py

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
"""
1717

1818
import os
19-
import io
2019
import logging
2120

2221
try:
@@ -89,12 +88,12 @@ def _clean_key(self, id):
8988
Extract the object key from a full S3/AIS URI.
9089
9190
Why this is needed:
92-
- S3 generators (NPYGeneratorS3, NPYReaderS3) pass full URIs like:
91+
- Generators call storage.get_uri(file_list[i]) which pass full URIs like:
9392
"s3://dlio-benchmark-native/train/img_08_of_16.npy"
9493
or "ais://dlio-benchmark-native/train/img_08_of_16.npy"
9594
- AIStore SDK expects just the object key:
9695
"train/img_08_of_16.npy"
97-
- This method strips the "s3://" or "ais://" prefix and bucket name
96+
- This method strips the scheme and bucket name from the URI
9897
9998
Handles:
10099
s3://bucket/path/file.ext -> path/file.ext
@@ -226,22 +225,12 @@ def put_data(self, id, data, offset=None, length=None):
226225

227226
key = self._clean_key(id)
228227

229-
# Convert data to bytes
230-
if isinstance(data, io.BytesIO):
231-
data.seek(0)
232-
body = data.read()
233-
elif isinstance(data, bytes):
234-
body = data
235-
else:
236-
body = bytes(data)
237-
238-
# Put object
239228
obj = self.bucket.object(key)
240-
obj.get_writer().put_content(body)
229+
obj.get_writer().put_content(data)
241230

242231
# TODO: add offset and length support
243232

244-
logging.debug(f"Successfully uploaded: {key} ({len(body)} bytes)")
233+
logging.debug(f"Successfully uploaded: {key} ({len(data)} bytes)")
245234
return True
246235
except Exception as e:
247236
logging.error(f"Error putting data to {id}: {e}")

dlio_benchmark/storage/file_storage.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,12 +90,16 @@ def delete_node(self, id):
9090
# TODO Handle partial read and writes
9191
@dlp.log
9292
def put_data(self, id, data, offset=None, length=None):
93-
with open(self.get_uri(id), "w") as fd:
93+
# id is the fully-resolved path (callers call get_uri() before put_data).
94+
# Do NOT call self.get_uri(id) here — that would double-prefix the namespace.
95+
with open(id, "wb") as fd:
9496
fd.write(data)
9597

9698
@dlp.log
9799
def get_data(self, id, data, offset=None, length=None):
98-
with open(self.get_uri(id), "r") as fd:
100+
# id is the fully-resolved path (callers call get_uri() before put_data).
101+
# Do NOT call self.get_uri(id) here — that would double-prefix the namespace.
102+
with open(id, "rb") as fd:
99103
data = fd.read()
100104
return data
101105

@@ -105,3 +109,6 @@ def isfile(self, id):
105109

106110
def get_basename(self, id):
107111
return os.path.basename(id)
112+
113+
def islocalfs(self):
114+
return True

dlio_benchmark/storage/s3_torch_storage.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def put_data(self, id, data, offset=None, length=None):
131131

132132
bucket_name = parsed.netloc
133133
writer = self.s3_client.put_object(bucket_name, id)
134-
writer.write(data.getvalue())
134+
writer.write(data)
135135
writer.close()
136136
return None
137137

0 commit comments

Comments
 (0)