Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 00f0273

Browse files
committedNov 27, 2024·
Merge remote-tracking branch 'origin/develop' into kk/rms_norm_opt-regression-fix
2 parents ba1cbee + 529cefe commit 00f0273

File tree

2 files changed

+17
-14
lines changed

2 files changed

+17
-14
lines changed
 

‎.github/workflows/clang-format.yml

+2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ on:
66
push:
77
branches:
88
- main
9+
- develop
910
paths:
1011
- '**/*.h'
1112
- '**/*.cpp'
@@ -15,6 +16,7 @@ on:
1516
pull_request:
1617
branches:
1718
- main
19+
- develop
1820
paths:
1921
- '**/*.h'
2022
- '**/*.cpp'

‎csrc/activation_kernels.cu

+15-14
Original file line numberDiff line numberDiff line change
@@ -93,20 +93,21 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
9393

9494
// Launch activation and gating kernel.
9595
#ifdef USE_ROCM
96-
#define LAUNCH_SCALED_ACTIVATION_GATE_KERNEL(KERNEL) \
97-
int d = input.size(-1) / 2; \
98-
int64_t num_tokens = input.numel() / input.size(-1); \
99-
dim3 grid(num_tokens); \
100-
dim3 block(std::min(d, 1024)); \
101-
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
102-
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
103-
VLLM_DISPATCH_FLOATING_TYPES( \
104-
input.scalar_type(), "scaled_act_and_mul_kernel", [&] { \
105-
vllm::scaled_act_and_mul_kernel<scalar_t, KERNEL<scalar_t>> \
106-
<<<grid, block, 0, stream>>>(out.data_ptr<c10::Float8_e4m3fnuz>(), \
107-
input.data_ptr<scalar_t>(), d, \
108-
1.0 / (*scale.data_ptr<float>())); \
109-
});
96+
#define LAUNCH_SCALED_ACTIVATION_GATE_KERNEL(KERNEL) \
97+
int d = input.size(-1) / 2; \
98+
int64_t num_tokens = input.numel() / input.size(-1); \
99+
dim3 grid(num_tokens); \
100+
dim3 block(std::min(d, 1024)); \
101+
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
102+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
103+
VLLM_DISPATCH_FLOATING_TYPES( \
104+
input.scalar_type(), "scaled_act_and_mul_kernel", [&] { \
105+
vllm::scaled_act_and_mul_kernel<scalar_t, KERNEL<scalar_t>> \
106+
<<<grid, block, 0, stream>>>( \
107+
out.data_ptr<c10::Float8_e4m3fnuz>(), \
108+
input.data_ptr<scalar_t>(), d, \
109+
1.0 / (*scale.data_ptr<float>())); \
110+
});
110111
#endif
111112

112113
void silu_and_mul(torch::Tensor& out, // [..., d]

0 commit comments

Comments
 (0)
Please sign in to comment.