Skip to content

Commit

Permalink
Get AMD GPU support working on Windows
Browse files Browse the repository at this point in the history
Compilation of the ggml-cuda.cu module will happen automatically for AMD
users when the $HIP_PATH environment variable is defined pointing to the
HIP SDK which lets us link hipBLAS and rocBLAS. This change also lets us
bundle a prebuilt DLL for Windows users that will work on stock installs
however its batched performance is much slower. Linux support might work
however it hasn't been tested yet.

See #122
  • Loading branch information
jart committed Jan 3, 2024
1 parent 04d6e93 commit 1f1c53f
Show file tree
Hide file tree
Showing 15 changed files with 659 additions and 138 deletions.
5 changes: 5 additions & 0 deletions llama.cpp/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,8 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
params.unsecure = true;
} else if (arg == "--nocompile") {
FLAG_nocompile = true;
} else if (arg == "--recompile") {
FLAG_recompile = true;
} else if (arg == "--tinyblas") {
FLAG_tinyblas = true; // undocumented
} else if (arg == "--gpu") {
Expand Down Expand Up @@ -560,6 +562,9 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
break;
}
params.n_gpu_layers = std::stoi(argv[i]);
if (params.n_gpu_layers == 0) {
FLAG_gpu = LLAMAFILE_GPU_DISABLED;
}
} else if (arg == "--gpu-layers-draft" || arg == "-ngld" || arg == "--n-gpu-layers-draft") {
passed_gpu_flags = true;
if (++i >= argc) {
Expand Down
93 changes: 83 additions & 10 deletions llama.cpp/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,10 @@
#error "you need to use a 64-bit compiler for llamafile"
#endif

#if defined(GGML_USE_HIPBLAS)
#if defined(GGML_USE_TINYBLAS) && defined(GGML_USE_HIPBLAS)
#include <hip/hip_runtime.h>
#include <hipblas/hipblas.h>
#include <hip/hip_fp16.h>
#ifdef __HIP_PLATFORM_AMD__
// for rocblas_initialize()
#include "rocblas/rocblas.h"
#endif // __HIP_PLATFORM_AMD__
#define CUBLAS_COMPUTE_16F HIPBLAS_R_16F
#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
Expand All @@ -33,19 +29,18 @@
#define CUBLAS_OP_N HIPBLAS_OP_N
#define CUBLAS_OP_T HIPBLAS_OP_T
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
#define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED
#define CUBLAS_TF32_TENSOR_OP_MATH 0
#define CUDA_R_16F HIPBLAS_R_16F
#define CUDA_R_32F HIPBLAS_R_32F
#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
#define cublasGemmAlgo_t hipblasGemmAlgo_t
#define cublasOperation_t hipblasOperation_t
#define cublasComputeType_t hipblasDatatype_t //deprecated, new hipblasComputeType_t not in 5.6
#define cublasCreate hipblasCreate
#define cublasGemmEx hipblasGemmEx
#define cublasGemmBatchedEx hipblasGemmBatchedEx
#define cublasGemmStridedBatchedEx hipblasGemmStridedBatchedEx
#define cublasHandle_t hipblasHandle_t
#define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS
#define cublasSetStream hipblasSetStream
#define cublasSgemm hipblasSgemm
#define cublasStatus_t hipblasStatus_t
#define cudaDataType_t hipblasDatatype_t //deprecated, new hipblasDatatype not in 5.6
#define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer
Expand Down Expand Up @@ -86,16 +81,92 @@
#define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags)
#define cudaStream_t hipStream_t
#define cudaSuccess hipSuccess
#include "tinyblas.cu"
#define cublasSgemm tinyblasSgemm
#define cublasGemmEx tinyblasGemmEx
#define cublasGemmBatchedEx tinyblasGemmBatchedEx
#define cublasGemmStridedBatchedEx tinyblasGemmStridedBatchedEx
#define cublasGetStatusString(x) "REDACTED!cublasGetStatusString"

#elif defined(GGML_USE_TINYBLAS)

#include "tinyblas.cu"
#define cublasHandle_t cudaStream_t
#define cublasSgemm tinyblasSgemm
#define cublasGemmEx tinyblasGemmEx
#define cublasGemmBatchedEx tinyblasGemmBatchedEx
#define cublasGemmStridedBatchedEx tinyblasGemmStridedBatchedEx
#define cublasGetStatusString(x) "REDACTED!cublasGetStatusString"

#elif defined(GGML_USE_HIPBLAS)
#include <hip/hip_runtime.h>
#include <hipblas/hipblas.h>
#include <hip/hip_fp16.h>
#ifdef __HIP_PLATFORM_AMD__
// for rocblas_initialize()
#include "rocblas/rocblas.h"
#endif // __HIP_PLATFORM_AMD__
#define CUBLAS_COMPUTE_16F HIPBLAS_R_16F
#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT
#define CUBLAS_OP_N HIPBLAS_OP_N
#define CUBLAS_OP_T HIPBLAS_OP_T
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
#define CUBLAS_TF32_TENSOR_OP_MATH 0
#define CUDA_R_16F HIPBLAS_R_16F
#define CUDA_R_32F HIPBLAS_R_32F
#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
#define cublasComputeType_t hipblasDatatype_t //deprecated, new hipblasComputeType_t not in 5.6
#define cublasCreate hipblasCreate
#define cublasGemmEx hipblasGemmEx
#define cublasGemmBatchedEx hipblasGemmBatchedEx
#define cublasGemmStridedBatchedEx hipblasGemmStridedBatchedEx
#define cublasHandle_t hipblasHandle_t
#define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS
#define cublasSetStream hipblasSetStream
#define cublasSgemm hipblasSgemm
#define cublasStatus_t hipblasStatus_t
#define cudaDataType_t hipblasDatatype_t //deprecated, new hipblasDatatype not in 5.6
#define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer
#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
#define cudaDeviceProp hipDeviceProp_t
#define cudaDeviceSynchronize hipDeviceSynchronize
#define cudaError_t hipError_t
#define cudaEventCreateWithFlags hipEventCreateWithFlags
#define cudaEventDisableTiming hipEventDisableTiming
#define cudaEventRecord hipEventRecord
#define cudaEvent_t hipEvent_t
#define cudaEventDestroy hipEventDestroy
#define cudaFree hipFree
#define cudaFreeHost hipHostFree
#define cudaGetDevice hipGetDevice
#define cudaGetDeviceCount hipGetDeviceCount
#define cudaGetDeviceProperties hipGetDeviceProperties
#define cudaGetErrorString hipGetErrorString
#define cudaGetLastError hipGetLastError
#define cudaMalloc hipMalloc
#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)
#define cudaMemcpy hipMemcpy
#define cudaMemcpy2DAsync hipMemcpy2DAsync
#define cudaMemcpyAsync hipMemcpyAsync
#define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice
#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost
#define cudaMemcpyHostToDevice hipMemcpyHostToDevice
#define cudaMemcpyKind hipMemcpyKind
#define cudaMemset hipMemset
#define cudaMemsetAsync hipMemsetAsync
#define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize
#define cudaSetDevice hipSetDevice
#define cudaStreamCreateWithFlags hipStreamCreateWithFlags
#define cudaStreamFireAndForget hipStreamFireAndForget
#define cudaStreamNonBlocking hipStreamNonBlocking
#define cudaStreamSynchronize hipStreamSynchronize
#define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags)
#define cudaStream_t hipStream_t
#define cudaSuccess hipSuccess

#else
#include <cuda_runtime.h>
#include <cublas_v2.h>
Expand Down Expand Up @@ -6961,9 +7032,11 @@ void ggml_init_cublas() {
if (!initialized) {

#ifdef __HIP_PLATFORM_AMD__
#ifndef GGML_USE_TINYBLAS
// Workaround for a rocBLAS bug when using multiple graphics cards:
// https://github.com/ROCmSoftwarePlatform/rocBLAS/issues/1346
rocblas_initialize();
#endif
CUDA_CHECK(cudaDeviceSynchronize());
#endif

Expand Down
5 changes: 3 additions & 2 deletions llama.cpp/llava/llava-cli.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -231,8 +231,9 @@ int llava_cli(int argc, char ** argv) {
show_additional_info(argc, argv);
return 1;
}
if (params.mmproj.empty() || (params.image.empty() && !prompt_contains_image(params.prompt))) {
fprintf(stderr, "%s: fatal error: --image flag missing\n", argv[0]);

if (params.mmproj.empty()) {
fprintf(stderr, "%s: fatal error: --mmproj must also be passed when an --image is specified in cli mode\n", argv[0]);
return 1;
}

Expand Down
63 changes: 41 additions & 22 deletions llama.cpp/main/main.1
Original file line number Diff line number Diff line change
Expand Up @@ -353,15 +353,19 @@ Force system to keep model in RAM rather than swapping or compressing.
Do not memory-map model (slower load but may reduce pageouts if not using mlock).
.It Fl Fl numa
Attempt optimizations that help on some NUMA systems if run without this previously, it is recommended to drop the system page cache before using this. See https://github.com/ggerganov/llama.cpp/issues/1437.
.It Fl Fl recompile
Force GPU support to be recompiled at runtime if possible.
.It Fl Fl nocompile
Never compile GPU support at runtime.
.Pp
If
.Pa ~/.llamafile/ggml-cuda.dll
already exists on the file system (or .so for UNIX and .dylib for
MacOS), then it'll be linked as-is without question. Otherwise,
If the appropriate DSO file already exists under
.Pa ~/.llamafile/
then it'll be linked as-is without question. If a prebuilt DSO is
present in the PKZIP content of the executable, then it'll be extracted
and linked if possible. Otherwise,
.Nm
will fall back to CPU inference.
will skip any attempt to compile GPU support and simply fall back to
using CPU inference.
.It Fl Fl gpu Ar GPU
Specifies which brand of GPU should be used. Valid choices are:
.Pp
Expand All @@ -370,25 +374,39 @@ Specifies which brand of GPU should be used. Valid choices are:
.Ar AUTO :
Use any GPU if possible, otherwise fall back to CPU inference (default)
.It
.Ar AMD :
Use AMD GPU. The AMD ROCm SDK must be installed and the HIP_PATH
environment variable must be defined. If an AMD GPU could not be used
for any reason, then a fatal error will be raised.
.It
.Ar APPLE :
Use Apple Metal GPU. This is only available on MacOS ARM64. If Metal
could not be used for any reason, then a fatal error will be raised.
.It
.Ar AMD :
Use AMD GPUs. The AMD HIP ROCm SDK should be installed in which case we
assume the HIP_PATH environment variable has been defined. The set of
gfx microarchitectures needed to run on the host machine is determined
automatically based on the output of the hipInfo command. On Windows,
.Nm
release binaries are distributed with a tinyBLAS DLL so it'll work out
of the box without requiring the HIP SDK to be installed. However,
tinyBLAS is slower than rocBLAS for batch and image processing, so it's
recommended that the SDK be installed anyway. If an AMD GPU could not be
used for any reason, then a fatal error will be raised.
.It
.Ar NVIDIA :
Use NVIDIA GPU. If an NVIDIA GPU could not be used for any reason, a
Use NVIDIA GPUs. If an NVIDIA GPU could not be used for any reason, a
fatal error will be raised. On Windows, NVIDIA GPU support will use our
tinyBLAS library, since it works on stock Windows installs. If both MSVC
and CUDA are installed beforehand, and
tinyBLAS library, since it works on stock Windows installs. However,
tinyBLAS goes slower for batch and image processing. It's possible to
use NVIDIA's closed-source cuBLAS library instead. To do that, both MSVC
and CUDA need to be installed and the
.Nm
is run for the first time on the x64 command prompt, then llamafile will
use NVIDIA's faster cuBLAS library instead. On Linux and other systems,
the CUDA SDK must always be installed, so that native GPU support can be
compiled on the fly.
command should be run once from the x64 MSVC command prompt with the
.Fl Fl recompile
flag passed. The GGML library will then be compiled and saved to
.Pa ~/.llamafile/
so the special process only needs to happen a single time.
.It
.Ar DISABLED :
Never use GPU and instead use CPU inference. This setting is implied by
.Fl ngl Ar 0 .
.El
.Pp
.It Fl ngl Ar N , Fl Fl n-gpu-layers Ar N
Expand Down Expand Up @@ -588,8 +606,7 @@ llama.cpp command line interface, utilizing WizardCoder-Python-13B
weights:
.Bd -literal -offset indent
llamafile \[rs]
-m wizardcoder-python-13b-v1.0.Q8_0.gguf \[rs]
--temp 0 -r '}\[rs]n' -r '\`\`\`\[rs]n' \[rs]
-m wizardcoder-python-13b-v1.0.Q8_0.gguf --temp 0 -r '}\[rs]n' -r '\`\`\`\[rs]n' \[rs]
-e -p '\`\`\`c\[rs]nvoid *memcpy(void *dst, const void *src, size_t size) {\[rs]n'
.Ed
.Pp
Expand Down Expand Up @@ -692,10 +709,12 @@ work to be a production worthy component of a public-facing service. For
example, C++ exceptions caused by JSON parsing errors will make it abort
and print a backtrace.
.Sh PROTIP
NVIDIA users need to pass the
The
.Fl ngl Ar 35
flag to enable GPU acceleration. It's not enabled by default since it
sometimes needs to be tuned for system hardware and model architecture.
flag needs to be passed in order to use GPUs made by NVIDIA and AMD.
It's not enabled by default since it sometimes needs to be tuned based
on the system hardware and model architecture, in order to achieve
optimal performance, and avoid compromising a shared display.
.Sh SEE ALSO
.Xr llamafile-quantize 1 ,
.Xr llamafile-perplexity 1 ,
Expand Down
3 changes: 3 additions & 0 deletions llama.cpp/main/main.1.asc
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,9 @@ OOPPTTIIOONNSS
page cache before using this. See
https://github.com/ggerganov/llama.cpp/issues/1437.

----rreeccoommppiillee
Force GPU support to be recompiled at runtime if possible.

----nnooccoommppiillee
Never compile GPU support at runtime.

Expand Down
6 changes: 1 addition & 5 deletions llama.cpp/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ int main(int argc, char ** argv) {
return server_cli(argc, argv);
}

if (has_argument(argc, argv, "--mmproj")) {
if (has_argument(argc, argv, "--image")) {
return llava_cli(argc, argv);
}

Expand All @@ -142,10 +142,6 @@ int main(int argc, char ** argv) {
if (!gpt_params_parse(argc, argv, params)) {
return 1;
}
if (!params.image.empty()) {
fprintf(stderr, "%s: fatal error: --mmproj must also be passed if --image is passed\n", argv[0]);
return 1;
}
llama_sampling_params & sparams = params.sparams;

#ifndef LOG_DISABLE_LOGS
Expand Down
4 changes: 4 additions & 0 deletions llama.cpp/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2331,6 +2331,10 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
{
FLAG_nocompile = true;
}
else if (arg == "--recompile")
{
FLAG_recompile = true;
}
else if (arg == "--gpu")
{
if (++i >= argc)
Expand Down
1 change: 1 addition & 0 deletions llamafile/copy.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ scp llama.cpp/ggml-cuda.cu \
llamafile/tinyblas.h \
llamafile/tinyblas.cu \
llamafile/llamafile.h \
llamafile/rocm.bat \
llamafile/cuda.bat \
llamafile/cuda.sh \
$HOST:lfbuild/
Loading

0 comments on commit 1f1c53f

Please sign in to comment.