@@ -93,20 +93,21 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
93
93
94
94
// Launch activation and gating kernel.
95
95
#ifdef USE_ROCM
96
- #define LAUNCH_SCALED_ACTIVATION_GATE_KERNEL (KERNEL ) \
97
- int d = input.size(-1 ) / 2 ; \
98
- int64_t num_tokens = input.numel() / input.size(-1 ); \
99
- dim3 grid (num_tokens); \
100
- dim3 block (std::min(d, 1024 )); \
101
- const at::cuda::OptionalCUDAGuard device_guard (device_of(input)); \
102
- const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
103
- VLLM_DISPATCH_FLOATING_TYPES ( \
104
- input.scalar_type(), "scaled_act_and_mul_kernel", [&] { \
105
- vllm::scaled_act_and_mul_kernel<scalar_t , KERNEL<scalar_t >> \
106
- <<<grid, block, 0 , stream>>> (out.data_ptr <c10::Float8_e4m3fnuz>(), \
107
- input.data_ptr <scalar_t >(), d, \
108
- 1.0 / (*scale.data_ptr <float >())); \
109
- });
96
+ #define LAUNCH_SCALED_ACTIVATION_GATE_KERNEL (KERNEL ) \
97
+ int d = input.size(-1 ) / 2 ; \
98
+ int64_t num_tokens = input.numel() / input.size(-1 ); \
99
+ dim3 grid (num_tokens); \
100
+ dim3 block (std::min(d, 1024 )); \
101
+ const at::cuda::OptionalCUDAGuard device_guard (device_of(input)); \
102
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
103
+ VLLM_DISPATCH_FLOATING_TYPES ( \
104
+ input.scalar_type(), "scaled_act_and_mul_kernel", [&] { \
105
+ vllm::scaled_act_and_mul_kernel<scalar_t , KERNEL<scalar_t >> \
106
+ <<<grid, block, 0 , stream>>> ( \
107
+ out.data_ptr <c10::Float8_e4m3fnuz>(), \
108
+ input.data_ptr <scalar_t >(), d, \
109
+ 1.0 / (*scale.data_ptr <float >())); \
110
+ });
110
111
#endif
111
112
112
113
void silu_and_mul (torch::Tensor& out, // [..., d]
0 commit comments