Skip to content

Commit c04cb12

Browse files
committed
fix more lint errors
Signed-off-by: Bill Nell <[email protected]>
1 parent 916f902 commit c04cb12

File tree

3 files changed

+27
-30
lines changed

3 files changed

+27
-30
lines changed

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22

33
import importlib
44
import threading
5-
import weakref
65
from abc import abstractmethod
76
from dataclasses import dataclass
87
from enum import Enum
98
from typing import Callable, List, Optional, Tuple
9+
from weakref import WeakValueDictionary
1010

1111
import torch
1212
import torch.nn.functional as F
@@ -266,7 +266,7 @@ def apply(
266266
class AllToAllCache:
267267

268268
def __init__(self):
269-
self._cache = weakref.WeakValueDictionary()
269+
self._cache: WeakValueDictionary = WeakValueDictionary()
270270
self._lock = threading.RLock() # Reentrant lock for thread safety
271271

272272
def get_or_create(self, **kwargs):
@@ -802,7 +802,8 @@ def __init__(
802802
if quant_config is None:
803803
quant_method = UnquantizedFusedMoEMethod(moe)
804804
else:
805-
quant_method = quant_config.get_quant_method(self, prefix)
805+
quant_method = quant_config.get_quant_method(
806+
self, prefix) # type: ignore
806807
assert isinstance(quant_method, FusedMoEMethodBase)
807808

808809
assert quant_method is not None
@@ -812,7 +813,7 @@ def __init__(
812813

813814
if dispatch_combine is not None:
814815
world_size = moe.ep_size
815-
dp_size = moe.ep_size // moe.dp_size
816+
dp_size = int(moe.ep_size // moe.dp_size)
816817
success = self.quant_method.set_dispatch_combine(
817818
dp_size, world_size, dispatch_combine)
818819
if not success:

vllm/model_executor/layers/fused_moe/modular_kernel.py

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -339,27 +339,24 @@ def forward(
339339
a1, a1_scale, a2_scale, topk_weights, topk_ids, global_num_experts,
340340
expert_map, apply_router_weight_on_input)
341341

342-
if True:
343-
fused_out = self.fused_experts.apply(
344-
a1q,
345-
w1,
346-
w2,
347-
topk_ids,
348-
activation=activation,
349-
global_num_experts=global_num_experts,
350-
expert_map=expert_map,
351-
w1_scale=w1_scale,
352-
w2_scale=w2_scale,
353-
w1_zp=w1_zp,
354-
w2_zp=w2_zp,
355-
a1q_scale=a1q_scale,
356-
a2_scale=a2_scale,
357-
workspace13=workspace13,
358-
workspace2=workspace2,
359-
expert_num_tokens=expert_num_tokens,
360-
)
361-
else:
362-
fused_out = torch.empty_like(a1q)
342+
fused_out = self.fused_experts.apply(
343+
a1q,
344+
w1,
345+
w2,
346+
topk_ids,
347+
activation=activation,
348+
global_num_experts=global_num_experts,
349+
expert_map=expert_map,
350+
w1_scale=w1_scale,
351+
w2_scale=w2_scale,
352+
w1_zp=w1_zp,
353+
w2_zp=w2_zp,
354+
a1q_scale=a1q_scale,
355+
a2_scale=a2_scale,
356+
workspace13=workspace13,
357+
workspace2=workspace2,
358+
expert_num_tokens=expert_num_tokens,
359+
)
363360

364361
self.dispatch_combine.combine(output, fused_out, topk_weights,
365362
topk_ids, apply_router_weight_on_input)

vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,10 @@ def __init__(self,
2121
block_m: Optional[int] = None,
2222
allow_deep_gemm: bool = False):
2323
super().__init__()
24-
self.triton_expert = TritonExperts(use_fp8_w8a8, use_int8_w8a8,
25-
use_int4_w4a16, use_int8_w8a16,
26-
per_channel_quant, block_shape,
27-
block_m)
28-
self.deep_gemm_expert = DeepGemmExperts()
24+
self.triton_expert: TritonExperts = TritonExperts(
25+
use_fp8_w8a8, use_int8_w8a8, use_int4_w4a16, use_int8_w8a16,
26+
per_channel_quant, block_shape, block_m)
27+
self.deep_gemm_expert: DeepGemmExperts = DeepGemmExperts()
2928
self.allow_deep_gemm = allow_deep_gemm
3029
self.use_fp8_w8a8 = use_fp8_w8a8
3130

0 commit comments

Comments
 (0)