Skip to content

Commit b351557

Browse files
committed
feat: add batch invariant mean op
1 parent f7e2366 commit b351557

3 files changed

Lines changed: 216 additions & 0 deletions

File tree

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
import torch
2+
import triton
3+
import triton.language as tl
4+
import triton.runtime.driver as driver
5+
6+
def get_npu_properties():
7+
device = torch.npu.current_device()
8+
return driver.active.utils.get_device_properties(device)
9+
10+
@triton.jit
11+
def mean_kernel(
12+
input_ptr,
13+
output_ptr,
14+
input_stride0,
15+
input_stride1,
16+
input_stride2,
17+
output_stride0,
18+
output_stride1,
19+
M, # size before reduction dim
20+
N, # size of reduction dim
21+
K, # size after reduction dim
22+
BLOCK_SIZE,
23+
SUB_BLOCK_SIZE: tl.constexpr,
24+
):
25+
"""
26+
Kernel for computing mean along a single dimension.
27+
Input is viewed as (M, K, N) where N is the dimension being reduced.
28+
"""
29+
# Program ID gives us which output element we're computing
30+
pid = tl.program_id(0)
31+
32+
# Compute output indices of the first sub-block
33+
m_idx = pid * BLOCK_SIZE // K
34+
k_idx = pid * BLOCK_SIZE % K
35+
36+
# Bounds check
37+
if m_idx >= M or k_idx >= K:
38+
return
39+
40+
# Accumulate sum across reduction dimension
41+
for i in range(0, BLOCK_SIZE):
42+
if m_idx >= M:
43+
pass
44+
acc = 0.0
45+
for n_start in range(0, N, SUB_BLOCK_SIZE):
46+
n_offsets = n_start + tl.arange(0, SUB_BLOCK_SIZE)
47+
mask = n_offsets < N
48+
49+
# Calculate input indices
50+
input_idx = (
51+
m_idx * input_stride0 + k_idx * input_stride1 + n_offsets * input_stride2
52+
)
53+
54+
# Load and accumulate
55+
vals = tl.load(input_ptr + input_idx, mask = mask, other = 0.0)
56+
acc += tl.sum(vals)
57+
58+
# Compute mean and store
59+
mean_val = acc / N
60+
output_idx = m_idx * output_stride0 + k_idx * output_stride1
61+
tl.store(output_ptr + output_idx, mean_val)
62+
63+
# Update indices for next iteration
64+
k_idx += 1
65+
if k_idx >= K:
66+
k_idx = 0
67+
m_idx += 1
68+
69+
def mean_dim(
70+
input:torch.Tensor,
71+
dim: int,
72+
keepdim: bool = False,
73+
dtype: torch.dtype | None = None,
74+
) -> torch.Tensor:
75+
"""
76+
Triton implementation of torch.mean with single dimension reduction.
77+
78+
Args:
79+
input: Input tensor
80+
dim: Single dimension along which to compute mean
81+
keepdim: Whether to keep the reduced dimension
82+
dtype: Output dtype. If None, uses input dtype (or float32 for integer inputs)
83+
84+
Returns:
85+
Tensor with mean values along specified dimension
86+
"""
87+
# Validate inputs
88+
assert "npu" in str(input.device).lower(), "Input must be a npu tensor"
89+
assert (
90+
-input.ndim <= dim < input.ndim
91+
), f"Invalid dimension {dim} for tensor with {input.ndim} dimensions"
92+
93+
# Handle negative dim
94+
if dim < 0:
95+
dim = dim + input.ndim
96+
97+
# Handle dtype
98+
if dtype is None:
99+
if input.dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
100+
dtype = torch.float32
101+
else:
102+
dtype = input.dtype
103+
104+
# Convert input to appropriate dtype if needed
105+
if input.dtype != dtype:
106+
input = input.to(dtype)
107+
108+
# Get input shape and strides
109+
shape = list(input.shape)
110+
111+
# Calculate dimensions for kernel
112+
M = 1
113+
for i in range(dim):
114+
M *= shape[i]
115+
116+
N = shape[dim]
117+
118+
K = 1
119+
for i in range(dim + 1, len(shape)):
120+
K *= shape[i]
121+
122+
# Reshape input to 3D view (M, K, N)
123+
input_3d = input.reshape(M, N, K)
124+
input_3d = input_3d.transpose(1, 2).contiguous()
125+
126+
# Create output shape
127+
if keepdim:
128+
output_shape = shape.copy()
129+
output_shape[dim] = 1
130+
else:
131+
output_shape = shape[:dim] + shape[dim + 1 :]
132+
133+
# Create output tensor
134+
output = torch.empty(output_shape, dtype=dtype, device=input.device)
135+
136+
# Reshape output for kernel
137+
if keepdim:
138+
output_2d = output.reshape(M, 1, K).squeeze(1)
139+
else:
140+
output_2d = output.reshape(M, K)
141+
142+
# Launch kernel
143+
num_core = get_npu_properties()["num_vectorcore"]
144+
grid = (num_core,)
145+
BLOCK_SIZE = triton.cdiv(M * K, num_core)
146+
SUB_BLOCK_SIZE = 4096
147+
148+
mean_kernel[grid](
149+
input_3d,
150+
output_2d,
151+
input_3d.stride(0),
152+
input_3d.stride(1),
153+
input_3d.stride(2),
154+
output_2d.stride(0),
155+
output_2d.stride(1) if output_2d.ndim > 1 else 0,
156+
M,
157+
N,
158+
K,
159+
BLOCK_SIZE,
160+
SUB_BLOCK_SIZE,
161+
)
162+
163+
return output
164+
165+
def mean_batch_invariant(input, dim, keepdim=False, dtype: torch.dtype | None = None):
166+
assert dtype is None or dtype == torch.float32, f"unsupported dtype: {dtype}"
167+
if len(dim) == 1:
168+
return mean_dim(input, dim[0], keepdim=keepdim)
169+
else:
170+
assert input.dtype in {
171+
torch.float16,
172+
torch.bfloat16,
173+
torch.float32,
174+
}, "only float types supported for now"
175+
n_elems = 1
176+
for d in dim:
177+
n_elems *= input.shape[d]
178+
return torch.sum(input, dim=dim, keepdim=keepdim, dtype=torch.float32) / n_elems

tests/test_batch_invariant_mean.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import torch
2+
from batch_invariant_ops import set_batch_invariant_mode
3+
device_type = getattr(torch.accelerator.current_accelerator(), "type", "cpu")
4+
torch.set_default_device(device_type)
5+
6+
with set_batch_invariant_mode(True):
7+
pass
8+
9+
def test_batch_invariance_mean(dtype=torch.float32):
10+
B, D, K = 2048, 4096,64
11+
a = torch.linspace(-100, 100,B*D*K, dtype=dtype).reshape(B, D, K)
12+
13+
out1 = torch.mean(a[:1], dim=1)
14+
out2 = torch.mean(a, dim=1)[:1]
15+
16+
# Check if results are identical
17+
diff = (out1 - out2).abs().max()
18+
return diff.item == 0, diff
19+
20+
def run_iters(iters=10):
21+
for dtype in [torch.float32, torch.float16]:
22+
is_deterministic = True
23+
difflist = []
24+
for i in range(iters):
25+
isd, df = test_batch_invariance_mean(dtype)
26+
is_deterministic = is_deterministic and isd
27+
difflist.append(df)
28+
print(f"Batch Deterministic: {is_deterministic} run-to-run max/min/diff {max(difflist)}/{min(difflist)}/{max(difflist)-min(difflist)} for {dtype} in {iters} iterations")
29+
30+
# Test with standard PyTorch
31+
print("Standard PyTorch:")
32+
with set_batch_invariant_mode(False):
33+
run_iters()
34+
35+
# Test with batch-invariant operations
36+
print("\nBatch-Invariant Mode:")
37+
with set_batch_invariant_mode(True):
38+
run_iters()

0 commit comments

Comments
 (0)