1
- use std:: env;
1
+ use std:: env:: { self , VarError } ;
2
2
use std:: fs:: { read_dir, File } ;
3
3
use std:: io:: Write ;
4
4
use std:: path:: { Path , PathBuf } ;
5
5
use std:: process:: Command ;
6
+ use std:: str:: FromStr ;
6
7
7
8
use cc:: Build ;
8
9
use once_cell:: sync:: Lazy ;
10
+ use glob:: glob;
9
11
10
12
// This build file is based on:
11
13
// https://github.com/mdrokz/rust-llama.cpp/blob/master/build.rs
@@ -365,23 +367,16 @@ fn compile_blis(cx: &mut Build) {
365
367
}
366
368
367
369
fn compile_hipblas ( cx : & mut Build , cxx : & mut Build , mut hip : Build ) -> & ' static str {
368
- const DEFAULT_ROCM_PATH_STR : & str = " /opt/rocm/";
370
+ let rocm_path_str = env :: var ( "ROCM_PATH" ) . or ( Ok :: < String , VarError > ( String :: from_str ( " /opt/rocm/") . unwrap ( ) ) ) . unwrap ( ) ;
369
371
370
- let rocm_path_str = env:: var ( "ROCM_PATH" )
371
- . map_err ( |_| DEFAULT_ROCM_PATH_STR . to_string ( ) )
372
- . unwrap ( ) ;
373
- println ! ( "Compiling HIPBLAS GGML. Using ROCm from {rocm_path_str}" ) ;
372
+ println ! ( "Compiling hipBLAS GGML. Using ROCm from {rocm_path_str}" ) ;
374
373
375
374
let rocm_path = PathBuf :: from ( rocm_path_str) ;
376
375
let rocm_include = rocm_path. join ( "include" ) ;
377
376
let rocm_lib = rocm_path. join ( "lib" ) ;
378
377
let rocm_hip_bin = rocm_path. join ( "bin/hipcc" ) ;
379
378
380
- let cuda_lib = "ggml-cuda" ;
381
- let cuda_file = cuda_lib. to_string ( ) + ".cu" ;
382
- let cuda_header = cuda_lib. to_string ( ) + ".h" ;
383
-
384
- let defines = [ "GGML_USE_HIPBLAS" , "GGML_USE_CUBLAS" ] ;
379
+ let defines = [ "GGML_USE_HIPBLAS" , "GGML_USE_CUDA" ] ;
385
380
for def in defines {
386
381
cx. define ( def, None ) ;
387
382
cxx. define ( def, None ) ;
@@ -390,24 +385,39 @@ fn compile_hipblas(cx: &mut Build, cxx: &mut Build, mut hip: Build) -> &'static
390
385
cx. include ( & rocm_include) ;
391
386
cxx. include ( & rocm_include) ;
392
387
388
+ let ggml_cuda = glob ( LLAMA_PATH . join ( "ggml-cuda" ) . join ( "*.cu" ) . to_str ( ) . unwrap ( ) )
389
+ . unwrap ( ) . filter_map ( Result :: ok) . collect :: < Vec < _ > > ( ) ;
390
+ let ggml_template_fattn = glob ( LLAMA_PATH . join ( "ggml-cuda" ) . join ( "template-instances" ) . join ( "fattn-vec*.cu" ) . to_str ( ) . unwrap ( ) )
391
+ . unwrap ( ) . filter_map ( Result :: ok) . collect :: < Vec < _ > > ( ) ;
392
+ let ggml_template_wmma = glob ( LLAMA_PATH . join ( "ggml-cuda" ) . join ( "template-instances" ) . join ( "fattn-wmma*.cu" ) . to_str ( ) . unwrap ( ) )
393
+ . unwrap ( ) . filter_map ( Result :: ok) . collect :: < Vec < _ > > ( ) ;
394
+ let ggml_template_mmq = glob ( LLAMA_PATH . join ( "ggml-cuda" ) . join ( "template-instances" ) . join ( "mmq*.cu" ) . to_str ( ) . unwrap ( ) )
395
+ . unwrap ( ) . filter_map ( Result :: ok) . collect :: < Vec < _ > > ( ) ;
396
+
393
397
hip. compiler ( rocm_hip_bin)
394
398
. std ( "c++11" )
395
- . file ( LLAMA_PATH . join ( cuda_file) )
396
- . include ( LLAMA_PATH . join ( cuda_header) )
399
+ . define ( "LLAMA_CUDA_DMMV_X" , Some ( "32" ) )
400
+ . define ( "LLAMA_CUDA_MMV_Y" , Some ( "1" ) )
401
+ . define ( "LLAMA_CUDA_KQUANTS_ITER" , Some ( "2" ) )
402
+ . file ( LLAMA_PATH . join ( "ggml-cuda.cu" ) )
403
+ . files ( ggml_cuda)
404
+ . files ( ggml_template_fattn)
405
+ . files ( ggml_template_wmma)
406
+ . files ( ggml_template_mmq)
407
+ . include ( LLAMA_PATH . join ( "" ) )
408
+ . include ( LLAMA_PATH . join ( "ggml-cuda" ) )
397
409
. define ( "GGML_USE_HIPBLAS" , None )
398
- . compile ( cuda_lib) ;
410
+ . define ( "GGML_USE_CUDA" , None )
411
+ . compile ( "ggml-cuda" ) ;
399
412
400
- println ! (
401
- "cargo:rustc-link-search=native={}" ,
402
- rocm_lib. to_string_lossy( )
403
- ) ;
413
+ println ! ( "cargo:rustc-link-search=native={}" , rocm_lib. to_string_lossy( ) ) ;
404
414
405
415
let rocm_libs = [ "hipblas" , "rocblas" , "amdhip64" ] ;
406
416
for lib in rocm_libs {
407
417
println ! ( "cargo:rustc-link-lib={lib}" ) ;
408
418
}
409
419
410
- cuda_lib
420
+ "ggml-cuda"
411
421
}
412
422
413
423
fn compile_cuda ( cx : & mut Build , cxx : & mut Build , featless_cxx : Build ) -> & ' static str {
0 commit comments