-
Notifications
You must be signed in to change notification settings - Fork 45
Refactor(linear): split LinearBackward kernel into 3 independent kernels #142
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 6 commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
b27f3cb
Refactor(linear): split LinearBackward kernel into 3 independent kernels
chen2021673 8b94ba8
refactor(matmul): split MatmulBackward kernel into 2 independent kernels
chen2021673 97dabe4
refactor(gemm): extract shared GemmCuda primitive; split matmul kerne…
chen2021673 252e6cd
refactor(sgemv): extract shared SgemvCuda primitive; add sgemv branch…
chen2021673 a7e1b99
refactor: fix Matmul nullptr safety and convert GemmParams/SgemvParam…
chen2021673 15be0d2
refactor(gemm): remove blas_handle from GemmParams/SgemvParams; add d…
chen2021673 d6e3899
refactor(gemm): move gemm.cuh/gemm.cu to src/kernels/cuda/common/
chen2021673 05d4153
style(cuda): separate include groups with blank line in linear/matmul…
chen2021673 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,80 @@ | ||
| #pragma once | ||
|
|
||
| #include <cublas_v2.h> | ||
|
|
||
| #include "infini_train/include/datatype.h" | ||
| #include "infini_train/include/device.h" | ||
|
|
||
| namespace infini_train::kernels::cuda { | ||
|
|
||
| /** | ||
| * Parameter bundle for a single GEMM call: | ||
| * C = alpha * op(A) * op(B) + beta * C | ||
| * | ||
| * batch_count == 1 → non-batched path (cublasGemmEx) | ||
| * batch_count > 1 → strided-batched (cublasGemmStridedBatchedEx) | ||
| * | ||
| * When batch_count == 1, stride_a/b/c are unused and must be left at 0. | ||
| */ | ||
| struct GemmParams { | ||
| cublasOperation_t trans_a = CUBLAS_OP_N; | ||
| cublasOperation_t trans_b = CUBLAS_OP_N; | ||
|
|
||
| int m = 0; // rows of op(A) and C | ||
| int n = 0; // cols of op(B) and C | ||
| int k = 0; // cols of op(A) == rows of op(B) | ||
|
|
||
| const void *A = nullptr; | ||
| int lda = 0; | ||
| const void *B = nullptr; | ||
| int ldb = 0; | ||
| void *C = nullptr; | ||
| int ldc = 0; | ||
|
|
||
| float alpha = 1.0f; | ||
| float beta = 0.0f; | ||
|
|
||
| // batch_count=1: non-batched (Linear path); stride_a/b/c must be 0 | ||
| // batch_count>1: strided-batched (Matmul path) | ||
| int batch_count = 1; | ||
| long long stride_a = 0; | ||
| long long stride_b = 0; | ||
| long long stride_c = 0; | ||
|
|
||
| DataType input_dtype; // dtype of A and B | ||
| DataType output_dtype; // dtype of C (may differ, e.g. bf16 in → fp32 out) | ||
| }; | ||
|
|
||
| /** | ||
| * Execute the GEMM described by `p` via cuBLAS. | ||
| * Dispatches to cublasGemmEx (batch_count==1) or | ||
| * cublasGemmStridedBatchedEx (batch_count>1). | ||
| * Uses CUBLAS_COMPUTE_32F for all input dtypes to ensure precision. | ||
| * Aborts on cuBLAS error (via CUBLAS_CHECK / LOG(FATAL)). | ||
| */ | ||
| void GemmCuda(const Device &device, const GemmParams &p); | ||
|
|
||
| /** | ||
| * Parameter bundle for a single SGEMV call (fp32 only): | ||
| * y = alpha * op(A) * x + beta * y | ||
| * | ||
| * op(A) is m_phys-by-n_phys when trans==N, or n_phys-by-m_phys when trans==T, | ||
| * where m_phys and n_phys are the physical (pre-transpose) row/col counts of A. | ||
| */ | ||
| struct SgemvParams { | ||
| cublasOperation_t trans = CUBLAS_OP_N; | ||
| int m = 0; | ||
| int n = 0; | ||
| const float *A = nullptr; | ||
| int lda = 0; | ||
| const float *x = nullptr; | ||
| int incx = 1; | ||
| float *y = nullptr; | ||
| int incy = 1; | ||
| float alpha = 1.0f; | ||
| float beta = 0.0f; | ||
| }; | ||
|
|
||
| void SgemvCuda(const Device &device, const SgemvParams &p); | ||
|
|
||
| } // namespace infini_train::kernels::cuda |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.