Skip to content

Add ruff and isort #578

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jul 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 22 additions & 5 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,23 @@ ci:
autoupdate_schedule: quarterly

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
hooks:
- id: check-json
- id: check-yaml
- id: debug-statements
- id: mixed-line-ending
args: [--fix=lf]

- repo: https://github.com/Lucas-C/pre-commit-hooks
rev: v1.5.5
hooks:
- id: insert-license
files: |
(?x)^(
auto_round/.*(py|yaml|yml|sh)
auto_round/.*(py|yaml|yml|sh)|
auto_round_extension/.*(py|yaml|yml|sh)
)$
args:
[
Expand All @@ -26,7 +36,14 @@ repos:
args: [-w]
additional_dependencies:
- tomli
exclude: |
(?x)^(
examples/.*(txt|patch)
)$

- repo: https://github.com/pycqa/isort
rev: 6.0.1
hooks:
- id: isort

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.12.4
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix, --no-cache]
5 changes: 3 additions & 2 deletions auto_round/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
import sys


def run_eval():
from auto_round.script.llm import setup_eval_parser
args = setup_eval_parser()
Expand Down Expand Up @@ -58,7 +59,7 @@ def run_fast():

def run_mllm():
if "--eval" in sys.argv:
from auto_round.script.llm import setup_eval_parser, eval
from auto_round.script.llm import eval, setup_eval_parser
sys.argv.remove("--eval")
args = setup_eval_parser()
args.mllm = True
Expand All @@ -76,7 +77,7 @@ def run_mllm():

def run_lmms():
# from auto_round.script.lmms_eval import setup_lmms_args, eval
from auto_round.script.mllm import setup_lmms_parser, lmms_eval
from auto_round.script.mllm import lmms_eval, setup_lmms_parser
args = setup_lmms_parser()
lmms_eval(args)

Expand Down
83 changes: 45 additions & 38 deletions auto_round/autoround.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,54 +12,62 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import os
import re
import sys
import time
from typing import Any, Union

import accelerate
import torch
import copy
import time
from typing import Union, Any
from transformers import set_seed
from torch import autocast
from tqdm import tqdm
import accelerate
from transformers import set_seed

from auto_round.data_type import QUANT_FUNC_WITH_DTYPE
from auto_round.export.export_to_gguf.config import GGUF_CONFIG, GGUF_INNER_CONFIG, ModelType
from auto_round.wrapper import WrapperMultiblock, wrapper_block, unwrapper_block, WrapperLinear, unwrapper_layer
from auto_round.low_cpu_mem.utils import get_layers_before_block
from auto_round.utils import (
SUPPORTED_DTYPES,
SUPPORTED_LAYER_TYPES,
TORCH_VERSION_AT_LEAST_2_6,
CpuInfo,
_gguf_args_check,
block_forward,
check_is_cpu,
check_seqlen_compatible,
check_skippable_keywords,
check_to_quantized,
clear_memory,
collect_best_params,
compile_func,
convert_dtype_str2torch,
detect_device,
find_matching_blocks,
flatten_list,
get_block_names,
get_layer_config_by_gguf_format,
get_layer_features,
get_layer_names_in_block,
get_lm_head_name,
get_module,
get_shared_keys,
htcore,
infer_bits_by_data_type,
init_cache,
is_debug_mode,
is_optimum_habana_available,
llm_load_model,
logger,
to_device,
to_dtype,
get_layer_names_in_block,
mv_module_from_gpu,
unsupport_meta_device, clear_memory,
compile_func,
find_matching_blocks, is_debug_mode,
TORCH_VERSION_AT_LEAST_2_6,
SUPPORTED_LAYER_TYPES,
get_layer_features,
set_module,
llm_load_model,
reset_params,
init_cache, check_skippable_keywords, get_shared_keys, SUPPORTED_DTYPES, infer_bits_by_data_type,
_gguf_args_check,
check_seqlen_compatible,
get_layer_config_by_gguf_format, get_lm_head_name, flatten_list
set_module,
to_device,
to_dtype,
unsupport_meta_device,
)
from auto_round.data_type import QUANT_FUNC_WITH_DTYPE
from auto_round.low_cpu_mem.utils import get_layers_before_block
from auto_round.wrapper import WrapperLinear, WrapperMultiblock, unwrapper_block, unwrapper_layer, wrapper_block


class AutoRound(object):
Expand Down Expand Up @@ -232,9 +240,9 @@ def __init__(
self.cache_device = torch.device("cpu") if self.low_gpu_mem_usage else self.device

##activation
self.act_group_size = act_group_size if not (act_group_size is None) else self.group_size
self.act_bits = act_bits if not (act_bits is None) else self.bits
self.act_sym = act_sym if not (act_sym is None) else self.sym
self.act_group_size = act_group_size if act_group_size is not None else self.group_size
self.act_bits = act_bits if act_bits is not None else self.bits
self.act_sym = act_sym if act_sym is not None else self.sym
self.act_dynamic = act_dynamic
self.act_data_type = act_data_type
if self.act_data_type is None:
Expand Down Expand Up @@ -272,15 +280,15 @@ def __init__(

self.enable_torch_compile = enable_torch_compile
if not self.enable_torch_compile and TORCH_VERSION_AT_LEAST_2_6 and self.act_bits > 8 and not is_debug_mode() \
and self.low_cpu_mem_usage != True and "fp8" not in self.data_type and "fp8" not in self.act_data_type:
and not self.low_cpu_mem_usage and "fp8" not in self.data_type and "fp8" not in self.act_data_type:
logger.info("'enable_torch_compile' is set to `False` by default. " \
"Enabling it can reduce tuning cost by 20%, but it might throw an exception.")

if self.act_bits <= 8 and self.enable_torch_compile:
self.enable_torch_compile = False
logger.warning("reset enable_torch_compile to `False` as activation quantization is enabled")

if self.low_cpu_mem_usage == True and self.enable_torch_compile:
if self.low_cpu_mem_usage and self.enable_torch_compile:
self.enable_torch_compile = False
logger.warning("reset enable_torch_compile to `False` as low_cpu_mem_usage is enabled")

Expand Down Expand Up @@ -535,7 +543,7 @@ def parse_format_to_list(self, format: str) -> list:
only_gguf = False
if only_gguf:
self.scale_dtype = torch.float32
logger.info(f"change `scale_dtype` to `torch.float32`")
logger.info("change `scale_dtype` to `torch.float32`")

# Adjust format settings based on compatibility
for index in range(len(formats)):
Expand Down Expand Up @@ -592,7 +600,7 @@ def _check_supported_format(self, format: str) -> bool:
f" group_size={self.group_size}, sym={self.sym}, act_bits={self.act_bits}"
elif format != "fake":
logger.warning(
f"Currently only support to export auto_round format quantized model"
"Currently only support to export auto_round format quantized model"
" with fp8 dtype activation for activation quantization."
" Change format to fake and save."
)
Expand Down Expand Up @@ -1320,7 +1328,7 @@ def quantize(self):
keys = inputs.keys()
input_id_str = [key for key in keys if key.startswith('hidden_state')]
if len(input_id_str) != 1:
raise RuntimeError(f"hidden_states arg mismatch error,"
raise RuntimeError("hidden_states arg mismatch error,"
"please raise an issue in https://github.com/intel/auto-round/issues")
inputs["input_ids"] = inputs.pop(input_id_str[0], None)
if q_inputs is not None:
Expand Down Expand Up @@ -1853,10 +1861,10 @@ def forward(m, hidden_states=None, *positional_inputs, **kwargs):
self.batch_dim = 1
if len(hidden_states.shape) > 1 and hidden_states.shape[1] > self.batch_size:
logger.error(
f"this model has not been supported, "
f"please raise an issue in https://github.com/intel/auto-round/issues"
f" or try to set the `batch_size` to 1 and "
f"`gradient_accumulate_steps` to your current batch size.")
"this model has not been supported, "
"please raise an issue in https://github.com/intel/auto-round/issues"
" or try to set the `batch_size` to 1 and "
"`gradient_accumulate_steps` to your current batch size.")
exit(-1)

if hidden_states is not None:
Expand Down Expand Up @@ -2094,8 +2102,7 @@ def get_act_max_hook(module, input, output):
hook_handles = []

for n, m in model.named_modules():
# for block
if hasattr(m, "act_dynamic") and m.act_dynamic == False and check_to_quantized(m):
if hasattr(m, "act_dynamic") and not m.act_dynamic and check_to_quantized(m):
hook = m.register_forward_hook(get_act_max_hook)
hook_handles.append(hook)
continue
Expand Down Expand Up @@ -2405,7 +2412,7 @@ def quant_blocks(
model_type=model_type)
else:
PACKING_LAYER_WITH_FORMAT[target_backend](tmp_m.tmp_name, self.model, self.formats[0])
pbar.set_description(f"Quantizing done")
pbar.set_description("Quantizing done")
pbar.update(1)
pbar.close()

Expand Down
8 changes: 4 additions & 4 deletions auto_round/calib_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@

import json
import random
import sys

import torch
from datasets import Dataset, IterableDataset, load_dataset, concatenate_datasets
from datasets import Features, Sequence, Value
from datasets import Dataset, Features, IterableDataset, Sequence, Value, concatenate_datasets, load_dataset
from torch.utils.data import DataLoader
import sys

from .utils import is_local_path, logger

CALIB_DATASETS = {}
Expand Down Expand Up @@ -136,7 +136,7 @@ def get_pile_dataset(tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", split
error_message = str(e)
# Check for proxy or SSL error
if "proxy" in error_message.lower() or isinstance(e, ssl.SSLError) or "SSL" in error_message.upper():
logger.error(f"Network error detected, please checking proxy settings." \
logger.error("Network error detected, please checking proxy settings." \
"Error: {error_message}. Or consider using a backup dataset by `pip install modelscope`" \
" and set '--dataset swift/pile-val-backup' in AutoRound API.")
else:
Expand Down
10 changes: 8 additions & 2 deletions auto_round/data_type/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,15 @@
# limitations under the License.

import torch
from auto_round.data_type.utils import get_gaudi_fp8_ste_func, float8_e4m3fn_ste, reshape_pad_tensor_by_group_size, \
revert_tensor_by_pad, float8_e5m2_ste

from auto_round.data_type.register import register_dtype
from auto_round.data_type.utils import (
float8_e4m3fn_ste,
float8_e5m2_ste,
get_gaudi_fp8_ste_func,
reshape_pad_tensor_by_group_size,
revert_tensor_by_pad,
)


@register_dtype(("fp8_sym","fp8","fp8_e4m3"))
Expand Down
5 changes: 3 additions & 2 deletions auto_round/data_type/gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
# limitations under the License.

import torch
from auto_round.data_type.utils import round_ste, reshape_pad_tensor_by_group_size, revert_tensor_by_pad, logger

from auto_round.data_type.register import register_dtype
from auto_round.data_type.utils import logger, reshape_pad_tensor_by_group_size, revert_tensor_by_pad, round_ste
from auto_round.utils import get_reciprocal


Expand Down Expand Up @@ -506,7 +507,7 @@ def quant_tensor_gguf_sym_dq(
Returns:
Quantized and de-quantized tensor, scale, zero-point
"""
from auto_round.export.export_to_gguf.config import QK_K, K_SCALE_SIZE, GGML_QUANT_SIZES
from auto_round.export.export_to_gguf.config import GGML_QUANT_SIZES, K_SCALE_SIZE, QK_K
from auto_round.export.export_to_gguf.packing import make_q3_quants, make_qx_quants

if bits not in [3, 6]:
Expand Down
3 changes: 2 additions & 1 deletion auto_round/data_type/int.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
# limitations under the License.

import torch
from auto_round.data_type.utils import round_ste, reshape_pad_tensor_by_group_size, revert_tensor_by_pad

from auto_round.data_type.register import register_dtype
from auto_round.data_type.utils import reshape_pad_tensor_by_group_size, revert_tensor_by_pad, round_ste


@register_dtype("int_sym")
Expand Down
5 changes: 3 additions & 2 deletions auto_round/data_type/mxfp.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
# limitations under the License.

import torch
from auto_round.data_type.utils import floor_ste, round_ste, reshape_pad_tensor_by_group_size, revert_tensor_by_pad
from auto_round.data_type.register import register_dtype, QUANT_FUNC_WITH_DTYPE

from auto_round.data_type.register import QUANT_FUNC_WITH_DTYPE, register_dtype
from auto_round.data_type.utils import floor_ste, reshape_pad_tensor_by_group_size, revert_tensor_by_pad, round_ste

MXFP_FORMAT_CACHE = {
# data type: ebits, mbits, emax, max_norm, min_norm
Expand Down
2 changes: 1 addition & 1 deletion auto_round/data_type/nvfp.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from auto_round.data_type.fp8 import float8_e4m3fn_ste
from auto_round.data_type.register import register_dtype
from auto_round.data_type.utils import reshape_pad_tensor_by_group_size, revert_tensor_by_pad, logger
from auto_round.data_type.utils import logger, reshape_pad_tensor_by_group_size, revert_tensor_by_pad


# taken from
Expand Down
5 changes: 4 additions & 1 deletion auto_round/data_type/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import lru_cache

import torch

from auto_round.data_type.register import QUANT_FUNC_WITH_DTYPE
from functools import lru_cache
from auto_round.utils import logger


def reshape_pad_tensor_by_group_size(data: torch.Tensor, group_size: int):
"""Reshapes and pads the tensor to ensure that it can be quantized in groups of `group_size`.

Expand Down
3 changes: 1 addition & 2 deletions auto_round/data_type/w4fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@
import torch

from auto_round.data_type.register import register_dtype
from auto_round.data_type.utils import get_gaudi_fp8_ste_func, float8_e4m3fn_ste

from auto_round.data_type.utils import float8_e4m3fn_ste, get_gaudi_fp8_ste_func

# @register_dtype("fp8_gaudi3_to_int_sym")
# def progressive_quant_fp8_int4_gaudi3(
Expand Down
2 changes: 1 addition & 1 deletion auto_round/eval/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from typing import Optional, Union

from lm_eval import simple_evaluate as lm_simple_evaluate
import os

os.environ["TOKENIZERS_PARALLELISM"] = "false"

Expand Down
Loading