Skip to content

Commit

Permalink
Refactor hessian service and torch sample layer attention. (#1242)
Browse files Browse the repository at this point in the history
* refactor hessian service
* add data util
* update gptq trainers
* update qparams computation for hessian service api change
* add keras pytest to workflow, fix pytest discovery
  • Loading branch information
irenaby authored Oct 22, 2024
1 parent de1f973 commit cfb313c
Show file tree
Hide file tree
Showing 40 changed files with 1,376 additions and 977 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/run_keras_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,15 @@ jobs:
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install tensorflow==${{ inputs.tf-version }} sony-custom-layers
pip install pytest
- name: Run unittests
# Some tests are sensitive to memory because we use tf gradients on a multi-thread/process
# CPU environment (https://github.com/tensorflow/tensorflow/issues/41718).
# For this reason, if we run them in such an environment, we need to run them first non-parallel separately.
run: |
python -m unittest discover tests/keras_tests -v
- name: Run pytest
run: |
pytest tests_pytest/keras
2 changes: 2 additions & 0 deletions .github/workflows/run_pytorch_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,7 @@ jobs:
- name: Run unittests
run: |
python -m unittest discover tests/pytorch_tests -v
- name: Run pytest
run: |
pytest tests_pytest/pytorch
5 changes: 5 additions & 0 deletions .github/workflows/run_tests_suite_coverage.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ jobs:
run: |
source tf_env/bin/activate
coverage run --parallel-mode -m --omit "*__init__.py" --include "model_compression_toolkit/**/*.py" unittest tests/test_suite.py -v
- name: Run TensorFlow pytest
run: |
source tf_env/bin/activate
coverage run --parallel-mode -m --omit "*__init__.py" --include "model_compression_toolkit/**/*.py" pytest tests_pytest/keras
- name: Set up Pytorch environment
run: |
Expand Down
72 changes: 43 additions & 29 deletions model_compression_toolkit/core/common/framework_implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
# ==============================================================================
from abc import ABC, abstractmethod
from typing import Callable, Any, List, Tuple, Dict
from typing import Callable, Any, List, Tuple, Dict, Generator

import numpy as np

Expand Down Expand Up @@ -46,7 +46,7 @@ def constants(self):
Returns: Module of the framework constants.
"""
raise NotImplemented(f'{self.__class__.__name__} did not supply a constants module.') # pragma: no cover
raise NotImplementedError(f'{self.__class__.__name__} did not supply a constants module.') # pragma: no cover

@abstractmethod
def get_hessian_scores_calculator(self,
Expand All @@ -64,7 +64,7 @@ def get_hessian_scores_calculator(self,
Returns: HessianScoresCalculator to use for the hessian approximation scores computation for this request.
"""
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
f'framework\'s get_hessian_scores_calculator method.') # pragma: no cover

@abstractmethod
Expand All @@ -77,7 +77,7 @@ def to_numpy(self, tensor: Any) -> np.ndarray:
Returns:
Numpy array converted from the input tensor.
"""
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
f'framework\'s to_numpy method.') # pragma: no cover

@abstractmethod
Expand All @@ -90,7 +90,7 @@ def to_tensor(self, tensor: np.ndarray) -> Any:
Returns:
Framework's tensor converted from the input Numpy array.
"""
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
f'framework\'s to_tensor method.') # pragma: no cover

@abstractmethod
Expand All @@ -106,7 +106,7 @@ def model_reader(self,
Returns:
Graph representing the input model.
"""
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
f'framework\'s model_reader method.') # pragma: no cover

@abstractmethod
Expand All @@ -131,7 +131,7 @@ def model_builder(self,
Returns:
A tuple with the model and additional relevant supporting objects.
"""
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
f'framework\'s model_builder method.') # pragma: no cover

@abstractmethod
Expand All @@ -148,7 +148,7 @@ def run_model_inference(self,
Returns:
The frameworks model's output.
"""
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
f'framework\'s run_model_inference method.') # pragma: no cover

@abstractmethod
Expand All @@ -167,7 +167,7 @@ def shift_negative_correction(self,
Returns:
Graph after SNC.
"""
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
f'framework\'s apply_shift_negative_correction method.') # pragma: no cover

@abstractmethod
Expand All @@ -184,7 +184,7 @@ def get_substitutions_channel_equalization(self,
Returns:
A list of the framework substitutions used after we collect statistics.
"""
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
f'framework\'s get_substitutions_channel_equalization method.') # pragma: no cover

@abstractmethod
Expand All @@ -194,7 +194,7 @@ def get_substitutions_prepare_graph(self, fw_info: FrameworkInfo = None) -> List
Returns: A list of the framework substitutions used to prepare the graph.
"""
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
f'framework\'s get_substitutions_prepare_graph method.') # pragma: no cover

@abstractmethod
Expand All @@ -208,23 +208,23 @@ def get_substitutions_pre_statistics_collection(self, quant_config: Quantization
Returns: A list of the framework substitutions used before we collect statistics.
"""
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
f'framework\'s get_substitutions_pre_statistics_collection method.') # pragma: no cover

@abstractmethod
def get_linear_collapsing_substitution(self) -> common.BaseSubstitution:
"""
Returns: linear collapsing substitution
"""
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
f'framework\'s get_linear_collapsing_substitution method.') # pragma: no cover

@abstractmethod
def get_op2d_add_const_collapsing_substitution(self) -> common.BaseSubstitution:
"""
Returns: conv2d add const collapsing substitution
"""
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
f'framework\'s get_op2d_add_const_collapsing_substitution method.') # pragma: no cover

@abstractmethod
Expand All @@ -239,15 +239,15 @@ def get_substitutions_statistics_correction(self, quant_config: QuantizationConf
Returns:
A list of the framework substitutions used for statistics correction.
"""
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
f'framework\'s get_substitutions_statistics_correction method.') # pragma: no cover

@abstractmethod
def get_residual_collapsing_substitution(self) -> List[common.BaseSubstitution]:
"""
Returns: A list of the framework substitutions used for residual collapsing
"""
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
f'framework\'s get_residual_collapsing_substitution method.') # pragma: no cover


Expand All @@ -263,7 +263,7 @@ def get_substitutions_post_statistics_collection(self, quant_config: Quantizatio
Returns:
A list of the framework substitutions used after we collect statistics.
"""
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
f'framework\'s get_substitutions_post_statistics_collection method.') # pragma: no cover

@abstractmethod
Expand All @@ -272,7 +272,7 @@ def get_substitutions_virtual_weights_activation_coupling(self) -> List[common.B
Returns: A list of Keras substitutions used to build a virtual graph with composed activation-weights pairs.
"""

raise NotImplemented(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
f'framework\'s get_substitutions_virtual_weights_activation_coupling '
f'method.') # pragma: no cover

Expand All @@ -288,7 +288,7 @@ def get_substitutions_after_second_moment_correction(self, quant_config: Quantiz
Returns:
A list of the framework substitutions used after we apply second moment statistics.
"""
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
f'framework\'s get_substitutions_after_second_moment_correction '
f'method.') # pragma: no cover

Expand Down Expand Up @@ -316,7 +316,7 @@ def get_sensitivity_evaluator(self,
A function that computes the metric.
"""

raise NotImplemented(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
f'framework\'s get_sensitivity_evaluator method.') # pragma: no cover

def get_node_prior_info(self, node: BaseNode,
Expand All @@ -334,7 +334,7 @@ def get_node_prior_info(self, node: BaseNode,
NodePriorInfo with information about the node.
"""

raise NotImplemented(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
f'framework\'s get_node_prior_info method.') # pragma: no cover

def count_node_for_mixed_precision_interest_points(self, node: BaseNode) -> bool:
Expand All @@ -345,7 +345,7 @@ def count_node_for_mixed_precision_interest_points(self, node: BaseNode) -> bool
Returns: True if the node should be considered an interest point, False otherwise.
"""

raise NotImplemented(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
f'framework\'s count_node_for_mixed_precision_interest_points method.') # pragma: no cover

def get_mp_node_distance_fn(self, n: BaseNode,
Expand All @@ -364,7 +364,7 @@ def get_mp_node_distance_fn(self, n: BaseNode,
Returns: A distance function between two tensors and a axis on which the distance is computed (if exists).
"""

raise NotImplemented(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
f'framework\'s get_mp_node_distance_fn method.') # pragma: no cover


Expand All @@ -381,7 +381,7 @@ def is_output_node_compatible_for_hessian_score_computation(self,
"""

raise NotImplemented(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
f'framework\'s is_output_node_compatible_for_hessian_score_computation method.') # pragma: no cover

@abstractmethod
Expand All @@ -398,7 +398,7 @@ def get_node_mac_operations(self,
Returns: The MAC count of the operation
"""

raise NotImplemented(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
f'framework\'s get_node_mac_operations method.') # pragma: no cover

@abstractmethod
Expand All @@ -419,7 +419,7 @@ def apply_second_moment_correction(self,
Returns:
A Graph after second moment correction.
"""
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
f'framework\'s apply_second_moment_correction method.') # pragma: no cover

@abstractmethod
Expand All @@ -436,7 +436,7 @@ def sensitivity_eval_inference(self,
Returns:
The output of the model inference on the given input.
"""
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
f'framework\'s sensitivity_eval_inference method.') # pragma: no cover

def get_inferable_quantizers(self, node: BaseNode):
Expand All @@ -452,5 +452,19 @@ def get_inferable_quantizers(self, node: BaseNode):
"""

raise NotImplemented(f'{self.__class__.__name__} have to implement the '
f'framework\'s get_inferable_quantizers method.') # pragma: no cover
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
f'framework\'s get_inferable_quantizers method.') # pragma: no cover

@staticmethod
def convert_data_gen_to_dataloader(data_gen_fn: Callable[[], Generator], batch_size: int):
"""
Create DataLoader based on samples yielded by data_gen.
Args:
data_gen_fn: data generator factory.
batch_size: target batch size.
Returns:
Framework dataloader.
"""
raise NotImplementedError() # pragma: no cover
Loading

0 comments on commit cfb313c

Please sign in to comment.