Skip to content

Commit

Permalink
fix imports in tutorials
Browse files Browse the repository at this point in the history
  • Loading branch information
reuvenp committed Jan 15, 2025
1 parent 47f5298 commit 387c58d
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -218,19 +218,18 @@
"cell_type": "code",
"source": [
"from mct_quantizers import QuantizationMethod\n",
"from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import schema, TargetPlatformCapabilities, Signedness, \\\n",
" AttributeQuantizationConfig, OpQuantizationConfig\n",
"from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import schema\n",
"\n",
"simd_size = 1\n",
"\n",
"def get_tpc():\n",
" # Define the default weight attribute configuration\n",
" default_weight_attr_config = AttributeQuantizationConfig(\n",
" default_weight_attr_config = schema.AttributeQuantizationConfig(\n",
" weights_quantization_method=QuantizationMethod.UNIFORM,\n",
" )\n",
"\n",
" # Define the OpQuantizationConfig\n",
" default_config = OpQuantizationConfig(\n",
" default_config = schema.OpQuantizationConfig(\n",
" default_weight_attr_config=default_weight_attr_config,\n",
" attr_weights_configs_mapping={},\n",
" activation_quantization_method=QuantizationMethod.UNIFORM,\n",
Expand All @@ -249,11 +248,11 @@
"\n",
" # Create the quantization configuration options and model\n",
" default_configuration_options = schema.QuantizationConfigOptions(quantization_configurations=tuple([default_config]))\n",
" tpc = TargetPlatformCapabilities(default_qco=default_configuration_options,\n",
" tpc_minor_version=1,\n",
" tpc_patch_version=0,\n",
" tpc_platform_type=\"custom_pruning_notebook_tpc\",\n",
" operator_set=tuple(operator_set))\n",
" tpc = schema.TargetPlatformCapabilities(default_qco=default_configuration_options,\n",
" tpc_minor_version=1,\n",
" tpc_patch_version=0,\n",
" tpc_platform_type=\"custom_pruning_notebook_tpc\",\n",
" operator_set=tuple(operator_set))\n",
" return tpc\n"
],
"metadata": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,13 +189,9 @@
},
"outputs": [],
"source": [
"from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import Signedness, AttributeQuantizationConfig\n",
"from model_compression_toolkit import DefaultDict\n",
"from model_compression_toolkit.constants import FLOAT_BITWIDTH\n",
"from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR, KERAS_KERNEL, BIAS_ATTR, BIAS\n",
"from mct_quantizers import QuantizationMethod\n",
"from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import schema, TargetPlatformCapabilities, Signedness, \\\n",
" AttributeQuantizationConfig, OpQuantizationConfig, QuantizationConfigOptions\n",
"from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR, BIAS_ATTR\n",
"from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import schema\n",
"\n",
"def get_tpc():\n",
" \"\"\"\n",
Expand All @@ -210,23 +206,23 @@
" \"\"\"\n",
"\n",
" # define a default quantization config for all non-specified weights attributes.\n",
" default_weight_attr_config = AttributeQuantizationConfig(\n",
" default_weight_attr_config = schema.AttributeQuantizationConfig(\n",
" weights_quantization_method=QuantizationMethod.POWER_OF_TWO,\n",
" weights_n_bits=8,\n",
" weights_per_channel_threshold=False,\n",
" enable_weights_quantization=False,\n",
" lut_values_bitwidth=None)\n",
"\n",
" # define a quantization config to quantize the kernel (for layers where there is a kernel attribute).\n",
" kernel_base_config = AttributeQuantizationConfig(\n",
" kernel_base_config = schema.AttributeQuantizationConfig(\n",
" weights_quantization_method=QuantizationMethod.SYMMETRIC,\n",
" weights_n_bits=2,\n",
" weights_per_channel_threshold=True,\n",
" enable_weights_quantization=True,\n",
" lut_values_bitwidth=None)\n",
"\n",
" # define a quantization config to quantize the bias (for layers where there is a bias attribute).\n",
" bias_config = AttributeQuantizationConfig(\n",
" bias_config = schema.AttributeQuantizationConfig(\n",
" weights_quantization_method=QuantizationMethod.POWER_OF_TWO,\n",
" weights_n_bits=FLOAT_BITWIDTH,\n",
" weights_per_channel_threshold=False,\n",
Expand All @@ -237,7 +233,7 @@
" # AttributeQuantizationConfig for weights with no specific AttributeQuantizationConfig.\n",
" # MCT will compress a layer's kernel and bias according to the configurations that are\n",
" # set in KERNEL_ATTR and BIAS_ATTR that are passed in attr_weights_configs_mapping.\n",
" default_config = OpQuantizationConfig(\n",
" default_config = schema.OpQuantizationConfig(\n",
" default_weight_attr_config=default_weight_attr_config,\n",
" attr_weights_configs_mapping={KERNEL_ATTR: kernel_base_config,\n",
" BIAS_ATTR: bias_config},\n",
Expand All @@ -253,7 +249,7 @@
"\n",
" # Set default QuantizationConfigOptions in new TargetPlatformCapabilities to be used when no other\n",
" # QuantizationConfigOptions is set for an OperatorsSet.\n",
" default_configuration_options = QuantizationConfigOptions(quantization_configurations=[default_config])\n",
" default_configuration_options = schema.QuantizationConfigOptions(quantization_configurations=[default_config])\n",
" no_quantization_config = (default_configuration_options.clone_and_edit(enable_activation_quantization=False)\n",
" .clone_and_edit_weight_attribute(enable_weights_quantization=False))\n",
"\n",
Expand All @@ -263,11 +259,11 @@
" operator_set.append(schema.OperatorsSet(name=schema.OperatorSetNames.FLATTEN, qc_options=no_quantization_config))\n",
"\n",
"\n",
" tpc = TargetPlatformCapabilities(default_qco=default_configuration_options,\n",
" tpc_minor_version=1,\n",
" tpc_patch_version=0,\n",
" tpc_platform_type=\"custom_qat_notebook_tpc\",\n",
" operator_set=tuple(operator_set))\n",
" tpc = schema.TargetPlatformCapabilities(default_qco=default_configuration_options,\n",
" tpc_minor_version=1,\n",
" tpc_patch_version=0,\n",
" tpc_platform_type=\"custom_qat_notebook_tpc\",\n",
" operator_set=tuple(operator_set))\n",
" return tpc\n"
]
},
Expand Down

0 comments on commit 387c58d

Please sign in to comment.