diff --git a/amazon/ionbenchmark/Format.py b/amazon/ionbenchmark/Format.py index f7cecf85e..3500bc16e 100644 --- a/amazon/ionbenchmark/Format.py +++ b/amazon/ionbenchmark/Format.py @@ -13,28 +13,26 @@ def _file_is_ion_text(file): def format_is_ion(format_option): - return (format_option == Format.ION_BINARY.value) or (format_option == Format.ION_TEXT.value) + return format_option in (Format.ION_BINARY, Format.ION_TEXT) def format_is_json(format_option): - return (format_option == Format.JSON.value) or (format_option == Format.SIMPLEJSON.value) \ - or (format_option == Format.UJSON.value) or (format_option == Format.RAPIDJSON.value) + return format_option in (Format.JSON, Format.SIMPLEJSON, Format.UJSON, Format.RAPIDJSON) def format_is_cbor(format_option): - return (format_option == Format.CBOR.value) or (format_option == Format.CBOR2.value) + return format_option in (Format.CBOR, Format.CBOR2) def format_is_binary(format_option): - return format_is_cbor(format_option) or (format_option == Format.ION_BINARY.value) \ - or (format_option == Format.PROTOBUF.value) or (format_option == Format.SD_PROTOBUF.value) + return format_option in (Format.ION_BINARY, Format.PROTOBUF, Format.SD_PROTOBUF, Format.CBOR, Format.CBOR2) def rewrite_file_to_format(file, format_option): temp_file_name_base = 'temp_' + os.path.splitext(os.path.basename(file))[0] - if format_option == Format.ION_BINARY.value: + if format_option is Format.ION_BINARY: temp_file_name_suffix = '.10n' - elif format_option == Format.ION_TEXT.value: + elif format_option is Format.ION_TEXT: temp_file_name_suffix = '.ion' else: temp_file_name_suffix = '' @@ -45,16 +43,16 @@ def rewrite_file_to_format(file, format_option): if format_is_ion(format_option): # Write data if a conversion is required - if (format_option == Format.ION_BINARY.value and _file_is_ion_text(file)) \ - or (format_option == Format.ION_TEXT.value and _file_is_ion_binary(file)): + if (format_option is Format.ION_BINARY and _file_is_ion_text(file)) \ + or (format_option is Format.ION_TEXT and _file_is_ion_binary(file)): # Load data with open(file, 'br') as fp: obj = simpleion.load(fp, single_value=False) with open(temp_file_name, 'bw') as fp: - if format_option == Format.ION_BINARY.value: - simpleion.dump(obj, fp, binary=True) + if format_option is Format.ION_BINARY: + simpleion.dump(obj, fp, binary=True, sequence_as_stream=True) else: - simpleion.dump(obj, fp, binary=False) + simpleion.dump(obj, fp, binary=False, sequence_as_stream=True) else: shutil.copy(file, temp_file_name) else: @@ -75,3 +73,10 @@ class Format(Enum): CBOR2 = 'cbor2' PROTOBUF = 'protobuf' SD_PROTOBUF = 'self_describing_protobuf' + + @staticmethod + def by_value(value): + for e in Format: + if e.value == value: + return e + raise ValueError(f"No enum constant with value {value}") diff --git a/amazon/ionbenchmark/benchmark_spec.py b/amazon/ionbenchmark/benchmark_spec.py index 4ab97f151..26fce3779 100644 --- a/amazon/ionbenchmark/benchmark_spec.py +++ b/amazon/ionbenchmark/benchmark_spec.py @@ -9,6 +9,7 @@ from amazon.ion.simpleion import IonPyValueModel from amazon.ion.symbols import SymbolToken +from amazon.ionbenchmark.Format import Format # Global defaults for CLI test specs _tool_defaults = { @@ -115,7 +116,7 @@ def __init__(self, params: dict, user_overrides: dict = None, user_defaults: dic raise ValueError(f"Missing required parameter '{k}'") if 'name' not in self: - self['name'] = f'({self.get_format()},{self.derive_operation_name()},{path.basename(self.get_input_file())})' + self['name'] = f'({self.get_format().value},{self.derive_operation_name()},{path.basename(self.get_input_file())})' def __missing__(self, key): # Instead of raising a KeyError like a usual dict, just return None. @@ -139,7 +140,7 @@ def get_name(self): return self["name"] def get_format(self): - return self["format"] + return Format.by_value(self["format"]) def get_input_file(self): return self["input_file"] @@ -200,33 +201,33 @@ def _get_model_flags(self): def _get_loader_dumper(self): data_format = self.get_format() - if data_format == 'ion_binary': + if data_format is Format.ION_BINARY: return _ion_load_dump.IonLoadDump(binary=True, c_ext=self['py_c_extension'], value_model=self._get_model_flags()) - elif data_format == 'ion_text': + elif data_format is Format.ION_TEXT: return _ion_load_dump.IonLoadDump(binary=False, c_ext=self['py_c_extension'], value_model=self._get_model_flags()) - elif data_format == 'json': + elif data_format is Format.JSON: import json return json - elif data_format == 'ujson': + elif data_format is Format.UJSON: import ujson return ujson - elif data_format == 'simplejson': + elif data_format is Format.SIMPLEJSON: import simplejson return simplejson - elif data_format == 'rapidjson': + elif data_format is Format.RAPIDJSON: import rapidjson return rapidjson - elif data_format == 'cbor': + elif data_format is Format.CBOR: import cbor return cbor - elif data_format == 'cbor2': + elif data_format is Format.CBOR2: import cbor2 return cbor2 - elif data_format == 'self_describing_protobuf': + elif data_format is Format.SD_PROTOBUF: from self_describing_proto import SelfDescribingProtoSerde # TODO: Consider making the cache option configurable from the spec file return SelfDescribingProtoSerde(cache_type_info=True) - elif data_format == 'protobuf': + elif data_format is Format.PROTOBUF: import proto type_name = self['protobuf_type'] if not type_name: diff --git a/tests/test_benchmark_cli.py b/tests/test_benchmark_cli.py index 3aa1cb198..bc49ed62c 100644 --- a/tests/test_benchmark_cli.py +++ b/tests/test_benchmark_cli.py @@ -118,7 +118,7 @@ def test_write_multi_duplicated_format(file=generate_test_path('integers.ion')): @parametrize( - *tuple((f.value for f in Format.Format if Format.format_is_json(f.value))) + *tuple((f.value for f in Format.Format if Format.format_is_json(f))) ) def test_write_json_format(f): (error_code, _, _) = run_cli(['write', generate_test_path('json/object.json'), '--format', f'{f}']) @@ -126,7 +126,7 @@ def test_write_json_format(f): @parametrize( - *tuple((f.value for f in Format.Format if Format.format_is_json(f.value))) + *tuple((f.value for f in Format.Format if Format.format_is_json(f))) ) def test_read_json_format(f): (error_code, _, _) = run_cli(['read', generate_test_path('json/object.json'), '--format', f'{f}']) @@ -134,7 +134,7 @@ def test_read_json_format(f): @parametrize( - *tuple((f.value for f in Format.Format if Format.format_is_cbor(f.value))) + *tuple((f.value for f in Format.Format if Format.format_is_cbor(f))) ) def test_write_cbor_format(f): (error_code, _, _) = run_cli(['write', generate_test_path('cbor/sample'), '--format', f'{f}']) @@ -167,7 +167,7 @@ def test_read_io_type(f): *tuple((Format.Format.ION_TEXT, Format.Format.ION_BINARY)) ) def test_format_is_ion(f): - assert format_is_ion(f.value) is True + assert format_is_ion(f) is True @parametrize( @@ -178,7 +178,7 @@ def test_format_is_ion(f): )) ) def test_format_is_json(f): - assert format_is_json(f.value) is True + assert format_is_json(f) is True @parametrize( @@ -186,7 +186,7 @@ def test_format_is_json(f): Format.Format.CBOR2 ) def test_format_is_cbor(f): - assert format_is_cbor(f.value) is True + assert format_is_cbor(f) is True def assert_ion_string_equals(act, exp): @@ -208,13 +208,28 @@ def test_compare_with_large_regression(): assert error_code -def test_format_conversion_ion_binary_to_ion_text(): - rewrite_file_to_format(generate_test_path('integers.ion'), Format.Format.ION_BINARY.value) - assert os.path.exists('temp_integers.10n') - os.remove('temp_integers.10n') - - -def test_format_conversion_ion_text_to_ion_binary(): - rewrite_file_to_format(generate_test_path('integers.10n'), Format.Format.ION_TEXT.value) - assert os.path.exists('temp_integers.ion') - os.remove('temp_integers.ion') +@parametrize( + Format.Format.ION_TEXT, + Format.Format.ION_BINARY) +def test_ion_format_conversion(target_format): + + if target_format is Format.Format.ION_BINARY: + source_name = 'integers.ion' + target_name = 'temp_integers.10n' + else: + source_name = 'integers.10n' + target_name = 'temp_integers.ion' + + with open(generate_test_path(source_name), 'rb') as source_file: + source_values = simpleion.load(source_file, single_value=False) + + rewrite_file_to_format(generate_test_path(source_name), target_format) + + assert os.path.exists(target_name) + with open(target_name, 'rb') as target_file: + target_values = simpleion.load(target_file, single_value=False, parse_eagerly=True) + assert len(target_values) == len(source_values) + for (s, c) in zip(source_values, target_values): + assert s == c + + os.remove(target_name) diff --git a/tests/test_benchmark_spec.py b/tests/test_benchmark_spec.py index ab398da21..b2a302355 100644 --- a/tests/test_benchmark_spec.py +++ b/tests/test_benchmark_spec.py @@ -3,6 +3,7 @@ from amazon.ion.simpleion import IonPyValueModel from amazon.ion.symbols import SymbolToken +from amazon.ionbenchmark.Format import Format from amazon.ionbenchmark.benchmark_spec import BenchmarkSpec @@ -21,7 +22,7 @@ def test_get_input_file_size(): def test_get_format(): - assert _minimal_spec.get_format() == 'ion_text' + assert _minimal_spec.get_format() is Format.ION_TEXT def test_get_command():