Commit 36e0631
GEMM + Swiglu fused Grouped MLP for MXFP8 (#2769)
* GEMM + Swiglu fused Grouped MLP for MXFP8
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* cleanup/lint
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* Properly cache the alpha tensor
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* nD dummy grad
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* 0 tokens in entire rank
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* tmp downgrade cublas version check
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* delayed wgrad tests pass for basic gl
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* merge everything
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* Rebase into fused_mxfp8_grouped_mlp; unit tests for delayed wgrad working
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* Fix tests being skipped for fusible ops
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* Integrate mxfp8 dbias kernel in group_quantize
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* Add bias/dbias fused support with cute GEMMs
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* Check bias/dbias support
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* Pack biases more efficiently
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* GroupedTensor for biases to avoid concat
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* format
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* Support 1D grouped tensor shape for bias and fix checkpointing
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* Fixes and tests
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* Refactor grouped tensor marking for paged stashing
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* Remove setting logical_shape in mark_grouped_tensor
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* Cleanup logical_shape
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* pass the tests for now
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
* address some review comments
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
* address review comments
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* more cleanups
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* cleanup
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
* refactor wgrad logic
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* Rename argument from single_grouped_parameter to single_grouped_weight
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* Check wgrad store context is not empty for 0 token case.
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* Test only checks for fusion if fused kernel is available
Signed-off-by: Tim Moon <tmoon@nvidia.com>
* fix the tolerance to be of bf16 for the cute gemm
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
* Update transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py
Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
* address further review comments
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* address more review comments
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
* address more review comments + test for zero grouped tensor work case
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
* cublaslt remove zero work gemm avoidance
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* address review comments
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix the wgrad test
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
* split dbias functionality from gq api
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* Format and lint
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* port fixes and add better doc for page stashing war
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* Guard fusion via env
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* Change to trigger CI
Remove unnecessary blank line in docstring.
* To retrigger CI
* Space to trigger the pipeline
* fix zero work cublas gemm
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Varun Thumbe <vthumbe@nvidia.com>
Co-authored-by: Vasudevan Rengasamy <vrengasamy@nvidia.com>
Co-authored-by: Tim Moon <tmoon@nvidia.com>
Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>1 parent b8e17cb commit 36e0631
File tree
30 files changed
+3784
-234
lines changed- qa/L0_pytorch_unittest
- tests
- cpp/operator
- pytorch
- transformer_engine
- common
- gemm
- include/transformer_engine
- util
- pytorch
- csrc
- extensions
- module
- ops
- basic
- fused
- tensor
- storage
30 files changed
+3784
-234
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
41 | 41 | | |
42 | 42 | | |
43 | 43 | | |
44 | | - | |
| 44 | + | |
45 | 45 | | |
46 | 46 | | |
47 | 47 | | |
| |||
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
110 | 110 | | |
111 | 111 | | |
112 | 112 | | |
| 113 | + | |
| 114 | + | |
| 115 | + | |
| 116 | + | |
| 117 | + | |
| 118 | + | |
| 119 | + | |
| 120 | + | |
| 121 | + | |
| 122 | + | |
| 123 | + | |
| 124 | + | |
| 125 | + | |
| 126 | + | |
| 127 | + | |
| 128 | + | |
| 129 | + | |
| 130 | + | |
| 131 | + | |
| 132 | + | |
| 133 | + | |
| 134 | + | |
| 135 | + | |
| 136 | + | |
| 137 | + | |
| 138 | + | |
| 139 | + | |
| 140 | + | |
| 141 | + | |
| 142 | + | |
| 143 | + | |
| 144 | + | |
| 145 | + | |
| 146 | + | |
| 147 | + | |
| 148 | + | |
| 149 | + | |
| 150 | + | |
| 151 | + | |
| 152 | + | |
| 153 | + | |
| 154 | + | |
| 155 | + | |
| 156 | + | |
| 157 | + | |
| 158 | + | |
| 159 | + | |
| 160 | + | |
| 161 | + | |
| 162 | + | |
| 163 | + | |
| 164 | + | |
| 165 | + | |
| 166 | + | |
| 167 | + | |
| 168 | + | |
| 169 | + | |
| 170 | + | |
| 171 | + | |
| 172 | + | |
| 173 | + | |
| 174 | + | |
| 175 | + | |
| 176 | + | |
| 177 | + | |
| 178 | + | |
| 179 | + | |
| 180 | + | |
| 181 | + | |
| 182 | + | |
| 183 | + | |
| 184 | + | |
| 185 | + | |
| 186 | + | |
| 187 | + | |
| 188 | + | |
| 189 | + | |
| 190 | + | |
| 191 | + | |
| 192 | + | |
| 193 | + | |
| 194 | + | |
| 195 | + | |
| 196 | + | |
| 197 | + | |
| 198 | + | |
| 199 | + | |
| 200 | + | |
| 201 | + | |
| 202 | + | |
| 203 | + | |
| 204 | + | |
| 205 | + | |
| 206 | + | |
| 207 | + | |
| 208 | + | |
| 209 | + | |
| 210 | + | |
| 211 | + | |
| 212 | + | |
| 213 | + | |
| 214 | + | |
| 215 | + | |
| 216 | + | |
| 217 | + | |
| 218 | + | |
| 219 | + | |
| 220 | + | |
| 221 | + | |
113 | 222 | | |
114 | 223 | | |
115 | 224 | | |
| |||
126 | 235 | | |
127 | 236 | | |
128 | 237 | | |
| 238 | + | |
| 239 | + | |
| 240 | + | |
| 241 | + | |
| 242 | + | |
| 243 | + | |
| 244 | + | |
| 245 | + | |
| 246 | + | |
| 247 | + | |
| 248 | + | |
| 249 | + | |
| 250 | + | |
| 251 | + | |
| 252 | + | |
| 253 | + | |
| 254 | + | |
| 255 | + | |
| 256 | + | |
| 257 | + | |
| 258 | + | |
| 259 | + | |
| 260 | + | |
| 261 | + | |
| 262 | + | |
| 263 | + | |
| 264 | + | |
| 265 | + | |
| 266 | + | |
| 267 | + | |
| 268 | + | |
| 269 | + | |
| 270 | + | |
| 271 | + | |
| 272 | + | |
129 | 273 | | |
130 | 274 | | |
131 | 275 | | |
| |||
0 commit comments