Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added JSON import and export capabilities to TargetPlatformModel #1311

Merged
merged 2 commits into from
Jan 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from pathlib import Path
from typing import Union

from model_compression_toolkit.logger import Logger
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformModel
import json


def load_target_platform_model(tp_model_or_path: Union[TargetPlatformModel, str]) -> TargetPlatformModel:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change tp model to tpc everywhere in this function (function, variables names, cimments...)
because we know that we are going to change the terminology, but this won't be changed automatically with the refactor to the class name

"""
Parses the tp_model input, which can be either a TargetPlatformModel object
or a string path to a JSON file.

Parameters:
tp_model_or_path (Union[TargetPlatformModel, str]): Input target platform model or path to .JSON file.

Returns:
TargetPlatformModel: The parsed TargetPlatformModel.

Raises:
FileNotFoundError: If the JSON file does not exist.
ValueError: If the JSON content is invalid or cannot initialize the TargetPlatformModel.
TypeError: If the input is neither a TargetPlatformModel nor a valid JSON file path.
"""
if isinstance(tp_model_or_path, TargetPlatformModel):
return tp_model_or_path

if isinstance(tp_model_or_path, str):
path = Path(tp_model_or_path)

if not path.exists() or not path.is_file():
raise FileNotFoundError(f"The path '{tp_model_or_path}' is not a valid file.")
# Verify that the file has a .json extension
if path.suffix.lower() != '.json':
raise ValueError(f"The file '{path}' does not have a '.json' extension.")
try:
with path.open('r', encoding='utf-8') as file:
data = file.read()
except OSError as e:
raise ValueError(f"Error reading the file '{tp_model_or_path}': {e.strerror}.") from e

try:
return TargetPlatformModel.parse_raw(data)
except ValueError as e:
raise ValueError(f"Invalid JSON for loading TargetPlatformModel in '{tp_model_or_path}': {e}.") from e
except Exception as e:
raise ValueError(f"Unexpected error while initializing TargetPlatformModel: {e}.") from e

raise TypeError(
f"tp_model_or_path must be either a TargetPlatformModel instance or a string path to a JSON file, "
f"but received type '{type(tp_model_or_path).__name__}'."
)


def export_target_platform_model(model: TargetPlatformModel, export_path: Union[str, Path]) -> None:
"""
Exports a TargetPlatformModel instance to a JSON file.

Parameters:
model (TargetPlatformModel): The TargetPlatformModel instance to export.
export_path (Union[str, Path]): The file path to export the model to.

Raises:
ValueError: If the model is not an instance of TargetPlatformModel.
OSError: If there is an issue writing to the file.
"""
if not isinstance(model, TargetPlatformModel):
raise ValueError("The provided model is not a valid TargetPlatformModel instance.")

path = Path(export_path)
try:
# Ensure the parent directory exists
path.parent.mkdir(parents=True, exist_ok=True)

# Export the model to JSON and write to the file
with path.open('w', encoding='utf-8') as file:
file.write(model.json(indent=4))
except OSError as e:
raise OSError(f"Failed to write to file '{export_path}': {e.strerror}") from e
145 changes: 115 additions & 30 deletions tests/common_tests/test_tp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR
from model_compression_toolkit.target_platform_capabilities.schema.schema_functions import \
get_config_options_by_operators_set, is_opset_in_model
from model_compression_toolkit.target_platform_capabilities.tpc_io_handler import load_target_platform_model, \
export_target_platform_model
from tests.common_tests.helpers.generate_test_tp_model import generate_test_attr_configs, generate_test_op_qc

tp = mct.target_platform
Expand All @@ -30,42 +32,125 @@
TEST_QCO = schema.QuantizationConfigOptions(quantization_configurations=tuple([TEST_QC]))


class TargetPlatformModelingTest(unittest.TestCase):
def cleanup_file(self, file_path):
if os.path.exists(file_path):
os.remove(file_path)
print(f"Cleaned up: {file_path}")
class TPModelInputOutputTests(unittest.TestCase):

def test_dump_to_json(self):
def setUp(self):
# Setup reusable resources or configurations for tests
self.valid_export_path = "exported_model.json"
self.invalid_export_path = "/invalid/path/exported_model.json"
self.invalid_json_content = '{"field1": "value1", "field2": ' # Incomplete JSON
self.invalid_json_file = "invalid_model.json"
self.nonexistent_file = "nonexistent.json"
op1 = schema.OperatorsSet(name="opset1")
op2 = schema.OperatorsSet(name="opset2")
op3 = schema.OperatorsSet(name="opset3")
op12 = schema.OperatorSetConcat(operators_set=[op1, op2])
model = schema.TargetPlatformModel(default_qco=TEST_QCO,
operator_set=(op1, op2, op3),
fusing_patterns=(schema.Fusing(operator_groups=(op12, op3)),
schema.Fusing(operator_groups=(op1, op2))),
tpc_minor_version=1,
tpc_patch_version=0,
tpc_platform_type="dump_to_json",
add_metadata=False)
json_str = model.json()
# Define the output file path
file_path = "target_platform_model.json"
# Register cleanup to delete the file if it exists
self.addCleanup(self.cleanup_file, file_path)

# Write the JSON string to the file
with open(file_path, "w") as f:
f.write(json_str)

with open(file_path, "r") as f:
json_content = f.read()

loaded_target_model = schema.TargetPlatformModel.parse_raw(json_content)
self.assertEqual(model, loaded_target_model)

self.tp_model = schema.TargetPlatformModel(default_qco=TEST_QCO,
operator_set=(op1, op2, op3),
fusing_patterns=(schema.Fusing(operator_groups=(op12, op3)),
schema.Fusing(operator_groups=(op1, op2))),
tpc_minor_version=1,
tpc_patch_version=0,
tpc_platform_type="dump_to_json",
add_metadata=False)

# Create invalid JSON file
with open(self.invalid_json_file, "w") as file:
file.write(self.invalid_json_content)

def tearDown(self):
# Cleanup files created during tests
for file in [self.valid_export_path, self.invalid_json_file]:
if os.path.exists(file):
os.remove(file)

def test_valid_model_object(self):
"""Test that a valid TargetPlatformModel object is returned unchanged."""
result = load_target_platform_model(self.tp_model)
self.assertEqual(self.tp_model, result)

def test_invalid_json_parsing(self):
"""Test that invalid JSON content raises a ValueError."""
with self.assertRaises(ValueError) as context:
load_target_platform_model(self.invalid_json_file)
self.assertIn("Invalid JSON for loading TargetPlatformModel in", str(context.exception))

def test_nonexistent_file(self):
"""Test that a nonexistent file raises FileNotFoundError."""
with self.assertRaises(FileNotFoundError) as context:
load_target_platform_model(self.nonexistent_file)
self.assertIn("is not a valid file", str(context.exception))

def test_non_json_extension(self):
"""Test that a file with a non-JSON extension raises ValueError."""
non_json_file = "test_model.txt"
try:
with open(non_json_file, "w") as file:
file.write(self.invalid_json_content)
with self.assertRaises(ValueError) as context:
load_target_platform_model(non_json_file)
self.assertIn("does not have a '.json' extension", str(context.exception))
finally:
os.remove(non_json_file)

def test_invalid_input_type(self):
"""Test that an unsupported input type raises TypeError."""
invalid_input = 123 # Not a string or TargetPlatformModel
with self.assertRaises(TypeError) as context:
load_target_platform_model(invalid_input)
self.assertIn("must be either a TargetPlatformModel instance or a string path", str(context.exception))

def test_valid_export(self):
"""Test exporting a valid TargetPlatformModel instance to a file."""
export_target_platform_model(self.tp_model, self.valid_export_path)
# Verify the file exists
self.assertTrue(os.path.exists(self.valid_export_path))

# Verify the contents match the model's JSON representation
with open(self.valid_export_path, "r", encoding="utf-8") as file:
content = file.read()
self.assertEqual(content, self.tp_model.json(indent=4))

def test_export_with_invalid_model(self):
"""Test that exporting an invalid model raises a ValueError."""
with self.assertRaises(ValueError) as context:
export_target_platform_model("not_a_model", self.valid_export_path)
self.assertIn("not a valid TargetPlatformModel instance", str(context.exception))

def test_export_with_invalid_path(self):
"""Test that exporting to an invalid path raises an OSError."""
with self.assertRaises(OSError) as context:
export_target_platform_model(self.tp_model, self.invalid_export_path)
self.assertIn("Failed to write to file", str(context.exception))

def test_export_creates_parent_directories(self):
"""Test that exporting creates missing parent directories."""
nested_path = "nested/directory/exported_model.json"
try:
export_target_platform_model(self.tp_model, nested_path)
# Verify the file exists
self.assertTrue(os.path.exists(nested_path))

# Verify the contents match the model's JSON representation
with open(nested_path, "r", encoding="utf-8") as file:
content = file.read()
self.assertEqual(content, self.tp_model.json(indent=4))
finally:
# Cleanup created directories
if os.path.exists(nested_path):
os.remove(nested_path)
if os.path.exists("nested/directory"):
os.rmdir("nested/directory")
if os.path.exists("nested"):
os.rmdir("nested")

def test_export_then_import(self):
"""Test that a model exported and then imported is identical."""
export_target_platform_model(self.tp_model, self.valid_export_path)
imported_model = load_target_platform_model(self.valid_export_path)
self.assertEqual(self.tp_model, imported_model)

class TargetPlatformModelingTest(unittest.TestCase):
def test_immutable_tp(self):

with self.assertRaises(Exception) as e:
Expand Down
4 changes: 3 additions & 1 deletion tests/test_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@
from tests.common_tests.function_tests.test_resource_utilization_object import TestResourceUtilizationObject
from tests.common_tests.function_tests.test_threshold_selection import TestThresholdSelection
from tests.common_tests.test_doc_examples import TestCommonDocsExamples
from tests.common_tests.test_tp_model import TargetPlatformModelingTest, OpsetTest, QCOptionsTest, FusingTest
from tests.common_tests.test_tp_model import TargetPlatformModelingTest, OpsetTest, QCOptionsTest, FusingTest, \
TPModelInputOutputTests

found_tf = importlib.util.find_spec("tensorflow") is not None
if found_tf:
Expand Down Expand Up @@ -116,6 +117,7 @@
suiteList.append(unittest.TestLoader().loadTestsFromTestCase(TestHistogramCollector))
suiteList.append(unittest.TestLoader().loadTestsFromTestCase(TestCollectorsManipulations))
suiteList.append(unittest.TestLoader().loadTestsFromTestCase(TestThresholdSelection))
suiteList.append(unittest.TestLoader().loadTestsFromTestCase(TPModelInputOutputTests))
suiteList.append(unittest.TestLoader().loadTestsFromTestCase(TargetPlatformModelingTest))
suiteList.append(unittest.TestLoader().loadTestsFromTestCase(OpsetTest))
suiteList.append(unittest.TestLoader().loadTestsFromTestCase(QCOptionsTest))
Expand Down
Loading