Commit e879bf8
[PyTorch][Fused Attn] Add support for cuDNN to return Softmax
* cudnn now returns Stats always and Max only with `return_max_logit=true`
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
* fix a typo that caused a bug
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
* update doc strings
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix more docs
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
* fixes from the feedback
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
* update cudnn-frontend to v1.19.1
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
* update the cudnn frontend
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
* fix a wrong omission
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>Stats always and Max when return_max_logit=True (#2677)1 parent 4ead776 commit e879bf8
File tree
5 files changed
+41
-59
lines changed- transformer_engine
- common
- fused_attn
- include/transformer_engine
- pytorch
- cpp_extensions
- csrc/extensions
5 files changed
+41
-59
lines changedLines changed: 22 additions & 42 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
112 | 112 | | |
113 | 113 | | |
114 | 114 | | |
115 | | - | |
| 115 | + | |
116 | 116 | | |
117 | 117 | | |
118 | 118 | | |
| |||
343 | 343 | | |
344 | 344 | | |
345 | 345 | | |
346 | | - | |
| 346 | + | |
347 | 347 | | |
348 | 348 | | |
349 | 349 | | |
| |||
357 | 357 | | |
358 | 358 | | |
359 | 359 | | |
360 | | - | |
361 | | - | |
362 | | - | |
363 | | - | |
364 | 360 | | |
365 | 361 | | |
366 | | - | |
367 | 362 | | |
368 | 363 | | |
369 | | - | |
370 | 364 | | |
371 | 365 | | |
372 | | - | |
373 | 366 | | |
374 | 367 | | |
375 | 368 | | |
| |||
387 | 380 | | |
388 | 381 | | |
389 | 382 | | |
390 | | - | |
391 | | - | |
392 | | - | |
393 | | - | |
394 | | - | |
395 | | - | |
396 | | - | |
| 383 | + | |
| 384 | + | |
| 385 | + | |
| 386 | + | |
| 387 | + | |
397 | 388 | | |
398 | 389 | | |
399 | 390 | | |
| |||
403 | 394 | | |
404 | 395 | | |
405 | 396 | | |
406 | | - | |
| 397 | + | |
407 | 398 | | |
408 | 399 | | |
409 | 400 | | |
| |||
1137 | 1128 | | |
1138 | 1129 | | |
1139 | 1130 | | |
| 1131 | + | |
| 1132 | + | |
| 1133 | + | |
| 1134 | + | |
| 1135 | + | |
| 1136 | + | |
| 1137 | + | |
| 1138 | + | |
| 1139 | + | |
| 1140 | + | |
1140 | 1141 | | |
1141 | 1142 | | |
1142 | 1143 | | |
| |||
1147 | 1148 | | |
1148 | 1149 | | |
1149 | 1150 | | |
1150 | | - | |
1151 | | - | |
1152 | | - | |
1153 | | - | |
1154 | | - | |
1155 | | - | |
1156 | | - | |
1157 | | - | |
1158 | | - | |
1159 | | - | |
1160 | | - | |
1161 | | - | |
1162 | | - | |
1163 | | - | |
1164 | | - | |
1165 | | - | |
1166 | | - | |
1167 | | - | |
1168 | | - | |
1169 | 1151 | | |
1170 | 1152 | | |
1171 | 1153 | | |
| |||
1189 | 1171 | | |
1190 | 1172 | | |
1191 | 1173 | | |
| 1174 | + | |
| 1175 | + | |
| 1176 | + | |
1192 | 1177 | | |
1193 | 1178 | | |
1194 | | - | |
1195 | | - | |
1196 | | - | |
1197 | | - | |
1198 | | - | |
1199 | | - | |
| 1179 | + | |
1200 | 1180 | | |
1201 | 1181 | | |
1202 | 1182 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
118 | 118 | | |
119 | 119 | | |
120 | 120 | | |
121 | | - | |
| 121 | + | |
122 | 122 | | |
123 | 123 | | |
124 | 124 | | |
125 | 125 | | |
126 | 126 | | |
127 | 127 | | |
128 | 128 | | |
129 | | - | |
| 129 | + | |
130 | 130 | | |
131 | 131 | | |
132 | 132 | | |
133 | 133 | | |
134 | 134 | | |
135 | 135 | | |
136 | 136 | | |
137 | | - | |
| 137 | + | |
138 | 138 | | |
139 | 139 | | |
140 | 140 | | |
| |||
Lines changed: 2 additions & 2 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
206 | 206 | | |
207 | 207 | | |
208 | 208 | | |
209 | | - | |
| 209 | + | |
210 | 210 | | |
211 | 211 | | |
212 | 212 | | |
| |||
269 | 269 | | |
270 | 270 | | |
271 | 271 | | |
272 | | - | |
| 272 | + | |
273 | 273 | | |
274 | 274 | | |
275 | 275 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
353 | 353 | | |
354 | 354 | | |
355 | 355 | | |
356 | | - | |
357 | | - | |
358 | | - | |
359 | | - | |
360 | | - | |
361 | | - | |
| 356 | + | |
| 357 | + | |
| 358 | + | |
| 359 | + | |
| 360 | + | |
| 361 | + | |
| 362 | + | |
| 363 | + | |
| 364 | + | |
| 365 | + | |
362 | 366 | | |
363 | 367 | | |
364 | 368 | | |
365 | 369 | | |
366 | 370 | | |
367 | 371 | | |
368 | 372 | | |
369 | | - | |
| 373 | + | |
370 | 374 | | |
371 | 375 | | |
372 | | - | |
373 | | - | |
374 | 376 | | |
375 | 377 | | |
376 | 378 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
259 | 259 | | |
260 | 260 | | |
261 | 261 | | |
262 | | - | |
| 262 | + | |
263 | 263 | | |
264 | 264 | | |
265 | 265 | | |
266 | | - | |
| 266 | + | |
267 | 267 | | |
268 | 268 | | |
269 | 269 | | |
270 | 270 | | |
271 | | - | |
| 271 | + | |
272 | 272 | | |
273 | 273 | | |
274 | 274 | | |
| |||
0 commit comments