Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove separate NO_CUBLASLT build. #1103

Conversation

matthewdouglas
Copy link
Member

This PR removes the build option NO_CUBLASLT. It additionally removes the runtime check to load the separate nocublaslt variants of the library.

Reasoning:

  • Having separate library builds adds complexity and extra build time
  • Since CUDA 11, libcublas actually takes a dependency on libcublasLt already
  • We have runtime checks against compute capability to avoid calling library functions that would be unsupported

So far I've only tested this on RTX 3060. I do have access to a machine with a GTX 1660, so I'll try to test on that too.

Copy link

github-actions bot commented Mar 4, 2024

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Comment on lines +364 to +384
// TODO: Check overhead. Maybe not worth it; just check in Python lib once,
// and avoid calling lib functions w/o support for them.
// TODO: Address GTX 1660, any other 7.5 devices maybe not supported.
inline bool igemmlt_supported() {
int device;
int ccMajor;

CUDA_CHECK_RETURN(cudaGetDevice(&device));
CUDA_CHECK_RETURN(cudaDeviceGetAttribute(&ccMajor, cudaDevAttrComputeCapabilityMajor, device));

if (ccMajor >= 8)
return true;

if (ccMajor < 7)
return false;

int ccMinor;
CUDA_CHECK_RETURN(cudaDeviceGetAttribute(&ccMinor, cudaDevAttrComputeCapabilityMinor, device));

return ccMinor >= 5;
}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using this as more of a sanity check right now. I'd expect we wouldn't be calling transform, spmm_coo, or igemmlt with devices that don't support it, but I haven't verified this. In particular the spmm_coo function is one that I am not so sure about.

@akx
Copy link
Contributor

akx commented Mar 5, 2024

If you want, I have a GTX 1070 (under WSL2, works surprisingly well) I can test on:

+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 545.36                 Driver Version: 546.33       CUDA Version: 12.3     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA GeForce GTX 1070        On  | 00000000:2D:00.0  On |                  N/A |
|  0%   62C    P0              37W / 185W |   1204MiB /  8192MiB |      1%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+

@Titus-von-Koeller
Copy link
Collaborator

Hey @matthewdouglas @akx,

Tim mentioned that it's probably not safe to remove this.

What's your opinion on this? How can be certain what's what? Currently, I'm not sure how to best proceed.

@akx
Copy link
Contributor

akx commented Apr 9, 2024

Did Tim say why it's "probably not safe"? Do we know of an actual situation where cublaslt isn't available? Is such a situation something we want to support?

@matthewdouglas
Copy link
Member Author

matthewdouglas commented Apr 9, 2024

Did Tim say why it's "probably not safe"? Do we know of an actual situation where cublaslt isn't available? Is such a situation something we want to support?

I'm curious too, but I think there might also be just a naming issue here since the cublasLt has shipped with the CUDA Toolkit since v10.1. It could have been placed in some unusual spots but by the time toolkit 11.0 comes around it's not an issue and we should always be able to link to it. PyTorch binaries ship with it. And if I'm not mistaken, libcublas.so itself depends on libcublaslt.so these days.

The main differentiator here is support for int8 tensor cores (e.g. the check for compute capability >= 7.5). So we would have to make sure to not call F.igemmlt() for such devices. But linking in the libcublaslt code IMO shouldn't be a problem if we're not trying to run the unsupported matmul ops. And there's already a path in MatMul8bitLt for that. We can guard some device code with __CUDA_ARCH__ too.

Some places where cublasLt is used:

  • int igemmlt<int, int, int>(cublasLtHandle_t ltHandle, int m, int n, int k, ...)
  • cublasLtOrder_t get_order<int>()
  • void transform<T, SRC, TARGET, transpose, DTYPE>(cublasLtHandle_t, T *A, T *out, int dim1, int dim2)

Separately, I believe I remember reading somewhere that there would be intent to actually deprecate the int8 matmul path that does not use tensor cores too (F.igemm, MatMul8bit, and also F.vectorwise_quant, F.vectorwise_mm_dequant).

@Titus-von-Koeller
Copy link
Collaborator

I'll try to get in touch with Tim to get more info from him and relay the new info you provided. Unfortunately, he didn't give any reasoning at the time.

He's quite unavailable atm, so it might take a few days.

Thanks @matthewdouglas for this thorough and knowledgable analysis, this was once again very helpful!

@matthewdouglas
Copy link
Member Author

Superseded by #1401.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants