Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ jobs:
commands: |
echo "::add-matcher::.github/workflows/matchers/pylint.json"
tox -e lint
- name: "mypy"
commands: |
echo "::add-matcher::.github/workflows/matchers/mypy.json"
tox -e mypy

steps:
- name: "Harden Runner"
Expand Down
16 changes: 16 additions & 0 deletions .github/workflows/matchers/mypy.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
{
"problemMatcher": [
{
"owner": "mypy",
"pattern": [
{
"regexp": "^(.+):(\\d+):\\s(error|warning):\\s(.+)$",
"file": 1,
"line": 2,
"severity": 3,
"message": 4
}
]
}
]
}
4 changes: 2 additions & 2 deletions fms_mo/aiu_addons/i8i8/i8i8_aiu_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"""Implement FMS adapter for INT8xINT8 checkpoints"""

# Standard
from typing import Mapping
from typing import Mapping, MutableMapping

# Third Party
from fms.utils import serialization
Expand Down Expand Up @@ -47,7 +47,7 @@ def _int8_qparams_aiu(


def _add_defaults_and_concat(
new_sd: dict[str, torch.Tensor],
new_sd: MutableMapping[str, torch.Tensor],
modules_seen: set[str],
) -> None:
"""
Expand Down
3 changes: 2 additions & 1 deletion fms_mo/calib.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,8 @@ def qmodel_calib(
return model

DPorDDPdevices = None
if "qmodel_prep" not in sys._getframe().f_back.f_code.co_name:
f_back = sys._getframe().f_back
if f_back and "qmodel_prep" not in f_back.f_code.co_name:
model.to(currDev)
qcfg["wasDPmodel"] = qcfg.get("wasDPmodel", isinstance(model, nn.DataParallel))
qcfg["wasDDPmodel"] = qcfg.get(
Expand Down
6 changes: 3 additions & 3 deletions fms_mo/custom_ext_kernels/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,10 @@
# Third Party
import torch.library as lib

reg_op = partial(lib.custom_op, mutates_args=())
reg_op = partial(lib.custom_op, mutates_args=()) # type: ignore[attr-defined]
reg_op_func = lib.define # NOTE this is func, not decorator
kernel_impl = lib.register_kernel
reg_fake = lib.register_fake
kernel_impl = lib.register_kernel # type: ignore[attr-defined]
reg_fake = lib.register_fake # type: ignore[attr-defined]

else:
raise RuntimeError("Custom Op registration only works for >PT2.1")
Expand Down
8 changes: 5 additions & 3 deletions fms_mo/quant/ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -2631,8 +2631,10 @@ def reset_bn(module: nn.BatchNorm2d):
Function not currently used.
"""
if module.track_running_stats:
module.running_mean.zero_()
module.running_var.fill_(1 - module.eps)
if running_mean := module.running_mean:
running_mean.zero_()
if running_var := module.running_var:
running_var.fill_(1 - module.eps)
# we do not reset numer of tracked batches here
if module.affine:
nn.init.ones_(module.weight)
Expand All @@ -2651,7 +2653,7 @@ def reset_bn(module: nn.BatchNorm2d):
bn_affine = True # FrozenBN doesn't have .affine property
except:
BNofInteret = (nn.BatchNorm2d, nn.BatchNorm1d)
AbsorbLayers = (nn.Conv2d, nn.Linear)
AbsorbLayers = (nn.Conv2d, nn.Linear) # type: ignore[assignment]


def search_fold_and_remove_bn(model, mod_folded):
Expand Down
13 changes: 7 additions & 6 deletions fms_mo/quant/quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3895,7 +3895,8 @@
self.delta = torch.nn.Parameter(delta)
else:
delta, zero_point = self.init_quantization_scale(x, self.channel_wise)
self.delta.fill_(delta)
if self_data := self.delta:
self_data.fill_(delta)
self.zero_point.fill_(zero_point)
self.inited = True

Expand All @@ -3906,7 +3907,7 @@
return x_dequant

def init_quantization_scale(self, x: torch.Tensor, channel_wise: bool = False):
delta, zero_point = None, None
delta, zero_point = 1.0, 0 # init seems unnecessary, at least avoid None causing type chk err
if channel_wise:
x_clone = x.clone().detach()
n_channels = x_clone.shape[0]
Expand All @@ -3914,23 +3915,23 @@
x_max = x_clone.abs().max(dim=-1)[0].max(dim=-1)[0].max(dim=-1)[0]
else:
x_max = x_clone.abs().max(dim=-1)[0]
delta = x_max.clone()

Check failure on line 3918 in fms_mo/quant/quantizers.py

View workflow job for this annotation

GitHub Actions / lint: mypy

Incompatible types in assignment (expression has type "Tensor", variable has type "float") [assignment]
zero_point = x_max.clone()

Check failure on line 3919 in fms_mo/quant/quantizers.py

View workflow job for this annotation

GitHub Actions / lint: mypy

Incompatible types in assignment (expression has type "Tensor", variable has type "int") [assignment]
# determine the scale and zero point channel-by-channel
for c in range(n_channels):
delta[c], zero_point[c] = self.init_quantization_scale(

Check failure on line 3922 in fms_mo/quant/quantizers.py

View workflow job for this annotation

GitHub Actions / lint: mypy

Unsupported target for indexed assignment ("float") [index]

Check failure on line 3922 in fms_mo/quant/quantizers.py

View workflow job for this annotation

GitHub Actions / lint: mypy

Unsupported target for indexed assignment ("int") [index]
x_clone[c], channel_wise=False
)
if len(x.shape) == 4:
delta = delta.view(-1, 1, 1, 1)

Check failure on line 3926 in fms_mo/quant/quantizers.py

View workflow job for this annotation

GitHub Actions / lint: mypy

"float" has no attribute "view" [attr-defined]
zero_point = zero_point.view(-1, 1, 1, 1)

Check failure on line 3927 in fms_mo/quant/quantizers.py

View workflow job for this annotation

GitHub Actions / lint: mypy

"int" has no attribute "view" [attr-defined]
else:
delta = delta.view(-1, 1)

Check failure on line 3929 in fms_mo/quant/quantizers.py

View workflow job for this annotation

GitHub Actions / lint: mypy

"float" has no attribute "view" [attr-defined]
zero_point = zero_point.view(-1, 1)

Check failure on line 3930 in fms_mo/quant/quantizers.py

View workflow job for this annotation

GitHub Actions / lint: mypy

"int" has no attribute "view" [attr-defined]
else:
if "max" in self.scale_method:
x_min = min(x.min().item(), 0)
x_max = max(x.max().item(), 0)

Check failure on line 3934 in fms_mo/quant/quantizers.py

View workflow job for this annotation

GitHub Actions / lint: mypy

Incompatible types in assignment (expression has type "int | float", variable has type "Tensor") [assignment]
if "scale" in self.scale_method:
x_min = x_min * (self.n_bits + 2) / 8
x_max = x_max * (self.n_bits + 2) / 8
Expand Down Expand Up @@ -3960,7 +3961,7 @@
if score < best_score:
best_score = score
delta = (new_max - new_min) / (2**self.n_bits - 1)
zero_point = (-new_min / delta).round()
zero_point = (-new_min / delta).round() # type: ignore[union-attr]
else:
raise NotImplementedError

Expand Down Expand Up @@ -4035,8 +4036,8 @@
self.reset_ReSig_param(multimodal)

self.beta = 2 / 3
self.Wshape = None
self.reshape2 = None
self.Wshape: list[int] = list()
self.reshape2: list[Any] = list()

def forward(self, x):
if self.useSAWB:
Expand Down Expand Up @@ -5389,7 +5390,7 @@
if "e4m3" in q_mode:
self.float8_dtype = torch.float8_e4m3fn
elif "e5m2" in q_mode:
self.float8_dtype = torch.float8_e5m2G
self.float8_dtype = torch.float8_e5m2
else:
raise ValueError("FP8 only supports e4m3 and e5m2")
self.emulate = emulate
Expand Down
4 changes: 2 additions & 2 deletions fms_mo/utils/qconfig_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

# Standard
from pathlib import Path
from typing import Any
from typing import Any, Dict
import json
import logging
import os
Expand Down Expand Up @@ -149,7 +149,7 @@ def qconfig_init(recipe: str = None, args: Any = None):
otherwise use constantLR as default
"""

qcfg = {}
qcfg: Dict[str, Any] = {}
# 1. create a dict with default values
qcfg["mapping"] = {
nn.Conv2d: {"from": nn.Conv2d, "to": QConv2d, "otherwise": QConv2d},
Expand Down
51 changes: 30 additions & 21 deletions fms_mo/utils/torchscript_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ def parse_operation(op_str: str):
operands = op_str[
last_open_parenthesis_index + 1 : last_close_parenthesis_index
].split(",")
operands = [operand.strip() for operand in operands] if operands != [""] else None
# pylint: disable=line-too-long
operands = [operand.strip() for operand in operands] if operands != [""] else None # type: ignore[assignment]
return operator, operands


Expand Down Expand Up @@ -178,9 +179,14 @@ def __init__(self, node_input, dictionary_of_nodes: dict):
)
operator, operands = parse_operation(op_str)
if "aten::_conv" in op_str:
self.ch_in = list(native_torchscript_node.inputs())[0].type().sizes()
# NOTE: Needed for finding shortcut convolutions later
self.ch_out = list(native_torchscript_node.outputs())[0].type().sizes()
if native_torchscript_node:
self.ch_in = (
list(native_torchscript_node.inputs())[0].type().sizes()
)
# NOTE: Needed for finding shortcut convolutions later
self.ch_out = (
list(native_torchscript_node.outputs())[0].type().sizes()
)
else:
node_def = node_input_repr
op_str, operator, operands = None, None, None
Expand All @@ -200,31 +206,34 @@ def __init__(self, node_input, dictionary_of_nodes: dict):
working_str = node_input_repr[start_index:end_index]
start_index = end_index + 2

node_instance.name, node_instance.obj = working_str.split(" : ")
node_instance.name = node_instance.name.strip()
# pylint: disable=line-too-long
node_instance.name, node_instance.obj = working_str.split(" : ") # type: ignore[attr-defined]
node_instance.name = node_instance.name.strip() # type: ignore[attr-defined]
if native_torchscript_outputs:
if node_instance.name not in native_torchscript_outputs:
# pylint: disable=line-too-long
if node_instance.name not in native_torchscript_outputs: # type: ignore[attr-defined]
# pylint: disable=line-too-long
logger.error(
f"Node def {node_instance.name} not in nativeTSoutputs "
f"Node def {node_instance.name} not in nativeTSoutputs " # type: ignore[attr-defined]
f"{native_torchscript_outputs}"
)
node_instance.Op = op_str
node_instance.Op = op_str # type: ignore[attr-defined]
if node_def_in_one_line > 1:
node_instance.unpackIdx = node_index
node_instance.unpackIdx = node_index # type: ignore[attr-defined]
if line_number:
node_instance.lineno = line_number
node_instance.operator = operator
node_instance.lineno = line_number # type: ignore[attr-defined]
node_instance.operator = operator # type: ignore[attr-defined]
# This is the name of parents, not the pointer to the parent nodes
node_instance.parents = operands
node_instance.parents_ptr = []
node_instance.scope = scope_repr
node_instance.modname = module_name
node_instance.children = []
node_instance.children_ptr = []
node_instance.TSparents = native_torchscript_parents
node_instance.TSoutputs = native_torchscript_outputs
node_instance.parents = operands # type: ignore[attr-defined]
node_instance.parents_ptr = [] # type: ignore[attr-defined]
node_instance.scope = scope_repr # type: ignore[attr-defined]
node_instance.modname = module_name # type: ignore[attr-defined]
node_instance.children = [] # type: ignore[attr-defined]
node_instance.children_ptr = [] # type: ignore[attr-defined]
node_instance.TSparents = native_torchscript_parents # type: ignore[attr-defined]
node_instance.TSoutputs = native_torchscript_outputs # type: ignore[attr-defined]
# graph.dictionary_of_nodes will keep a record of all the nodes
dictionary_of_nodes[node_instance.name] = node_instance
dictionary_of_nodes[node_instance.name] = node_instance # type: ignore[attr-defined]

def __repr__(self):
return f"{self.name} "
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ known-local-folder=["fms_mo","tests"]
[tool.mypy]
mypy_path = [""]
packages = ["fms_mo", "tests"]
disable_error_code = []
disable_error_code = ["import-not-found", "import-untyped"]
# TODO: tighten MyPy checks by enabling these checks over time.
check_untyped_defs = false
disallow_incomplete_defs = false
Expand Down
Loading