Skip to content

Commit cf06956

Browse files
committed
Add hipBLAS feature and fix build script
Attempts to copy the CMake build steps as much as possible. GPU is detected when running but on my setup appears to segfault when loading the model. I'm at a bit of a loss on the segfault but this is definitely better than before.
1 parent 41077c2 commit cf06956

File tree

5 files changed

+34
-21
lines changed

5 files changed

+34
-21
lines changed

Cargo.lock

+1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

llama-cpp-2/Cargo.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,10 @@ tracing = { workspace = true }
1616
[features]
1717
cuda = ["llama-cpp-sys-2/cuda"]
1818
metal = ["llama-cpp-sys-2/metal"]
19+
hipblas = ["llama-cpp-sys-2/hipblas"]
1920
sampler = []
2021

21-
[target.'cfg(all(target_os = "macos", any(target_arch = "aarch64", target_arch = "arm64")))'.dependencies]
22+
[target.'cfg(all(target_os = "macos", any(target_arch = "aarch64", target_arch = "arm64")))'.dependencies]
2223
llama-cpp-sys-2 = { path = "../llama-cpp-sys-2", features=["metal"], version = "0.1.48" }
2324

2425
[lints]

llama-cpp-2/src/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
//! # Feature Flags
1313
//!
1414
//! - `cuda` enables CUDA gpu support.
15+
//! - `hipblas` enables hipBLAS (ROCm) gpu support (experimental).
1516
//! - `sampler` adds the [`context::sample::sampler`] struct for a more rusty way of sampling.
1617
use std::ffi::NulError;
1718
use std::fmt::Debug;

llama-cpp-sys-2/Cargo.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,9 @@ include = [
4949
bindgen = { workspace = true }
5050
cc = { workspace = true, features = ["parallel"] }
5151
once_cell = "1.19.0"
52+
glob = "0.3.1"
5253

5354
[features]
5455
cuda = []
5556
metal = []
56-
57+
hipblas = []

llama-cpp-sys-2/build.rs

+28-19
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
1-
use std::env;
1+
use std::env::{self, VarError};
22
use std::fs::{read_dir, File};
33
use std::io::Write;
44
use std::path::{Path, PathBuf};
55
use std::process::Command;
6+
use std::str::FromStr;
67

78
use cc::Build;
89
use once_cell::sync::Lazy;
10+
use glob::glob;
911

1012
// This build file is based on:
1113
// https://github.com/mdrokz/rust-llama.cpp/blob/master/build.rs
@@ -365,23 +367,16 @@ fn compile_blis(cx: &mut Build) {
365367
}
366368

367369
fn compile_hipblas(cx: &mut Build, cxx: &mut Build, mut hip: Build) -> &'static str {
368-
const DEFAULT_ROCM_PATH_STR: &str = "/opt/rocm/";
370+
let rocm_path_str = env::var("ROCM_PATH").or(Ok::<String, VarError>(String::from_str("/opt/rocm/").unwrap())).unwrap();
369371

370-
let rocm_path_str = env::var("ROCM_PATH")
371-
.map_err(|_| DEFAULT_ROCM_PATH_STR.to_string())
372-
.unwrap();
373-
println!("Compiling HIPBLAS GGML. Using ROCm from {rocm_path_str}");
372+
println!("Compiling hipBLAS GGML. Using ROCm from {rocm_path_str}");
374373

375374
let rocm_path = PathBuf::from(rocm_path_str);
376375
let rocm_include = rocm_path.join("include");
377376
let rocm_lib = rocm_path.join("lib");
378377
let rocm_hip_bin = rocm_path.join("bin/hipcc");
379378

380-
let cuda_lib = "ggml-cuda";
381-
let cuda_file = cuda_lib.to_string() + ".cu";
382-
let cuda_header = cuda_lib.to_string() + ".h";
383-
384-
let defines = ["GGML_USE_HIPBLAS", "GGML_USE_CUBLAS"];
379+
let defines = ["GGML_USE_HIPBLAS", "GGML_USE_CUDA"];
385380
for def in defines {
386381
cx.define(def, None);
387382
cxx.define(def, None);
@@ -390,24 +385,38 @@ fn compile_hipblas(cx: &mut Build, cxx: &mut Build, mut hip: Build) -> &'static
390385
cx.include(&rocm_include);
391386
cxx.include(&rocm_include);
392387

388+
let ggml_cuda = glob(LLAMA_PATH.join("ggml-cuda").join("*.cu").to_str().unwrap())
389+
.unwrap().filter_map(Result::ok).collect::<Vec<_>>();
390+
let ggml_template_fattn = glob(LLAMA_PATH.join("ggml-cuda").join("template-instances").join("fattn-vec*.cu").to_str().unwrap())
391+
.unwrap().filter_map(Result::ok).collect::<Vec<_>>();
392+
let ggml_template_wmma = glob(LLAMA_PATH.join("ggml-cuda").join("template-instances").join("fattn-wmma*.cu").to_str().unwrap())
393+
.unwrap().filter_map(Result::ok).collect::<Vec<_>>();
394+
let ggml_template_mmq = glob(LLAMA_PATH.join("ggml-cuda").join("template-instances").join("mmq*.cu").to_str().unwrap())
395+
.unwrap().filter_map(Result::ok).collect::<Vec<_>>();
396+
393397
hip.compiler(rocm_hip_bin)
394398
.std("c++11")
395-
.file(LLAMA_PATH.join(cuda_file))
396-
.include(LLAMA_PATH.join(cuda_header))
399+
.define("LLAMA_CUDA_DMMV_X", Some("32"))
400+
.define("LLAMA_CUDA_MMV_Y", Some("1"))
401+
.define("LLAMA_CUDA_KQUANTS_ITER", Some("2"))
402+
.file(LLAMA_PATH.join("ggml-cuda.cu"))
403+
.files(ggml_cuda)
404+
.files(ggml_template_fattn)
405+
.files(ggml_template_wmma)
406+
.files(ggml_template_mmq)
407+
.include(LLAMA_PATH.join(""))
408+
.include(LLAMA_PATH.join("ggml-cuda"))
397409
.define("GGML_USE_HIPBLAS", None)
398-
.compile(cuda_lib);
410+
.compile("ggml-cuda");
399411

400-
println!(
401-
"cargo:rustc-link-search=native={}",
402-
rocm_lib.to_string_lossy()
403-
);
412+
println!("cargo:rustc-link-search=native={}", rocm_lib.to_string_lossy());
404413

405414
let rocm_libs = ["hipblas", "rocblas", "amdhip64"];
406415
for lib in rocm_libs {
407416
println!("cargo:rustc-link-lib={lib}");
408417
}
409418

410-
cuda_lib
419+
"ggml-cuda"
411420
}
412421

413422
fn compile_cuda(cx: &mut Build, cxx: &mut Build, featless_cxx: Build) -> &'static str {

0 commit comments

Comments
 (0)