diff --git a/Runtime/LLMLib.cs b/Runtime/LLMLib.cs index 0c630927..3cf08add 100644 --- a/Runtime/LLMLib.cs +++ b/Runtime/LLMLib.cs @@ -367,6 +367,7 @@ public class LLMLib static bool has_avx = false; static bool has_avx2 = false; static bool has_avx512 = false; + List dependencyHandles = new List(); #if (UNITY_ANDROID || UNITY_IOS) && !UNITY_EDITOR @@ -496,6 +497,12 @@ static LLMLib() /// public LLMLib(string arch) { + foreach (string dependency in GetArchitectureDependencies(arch)) + { + LLMUnitySetup.Log($"Loading {dependency}"); + dependencyHandles.Add(LibraryLoader.LoadLibrary(dependency)); + } + libraryHandle = LibraryLoader.LoadLibrary(GetArchitecturePath(arch)); if (libraryHandle == IntPtr.Zero) { @@ -550,6 +557,35 @@ public static string GetArchitectureCheckerPath() return Path.Combine(LLMUnitySetup.libraryPath, filename); } + /// + /// Gets additional dependencies for the specified architecture. + /// + /// architecture + /// paths of dependency dlls + public static List GetArchitectureDependencies(string arch) + { + List dependencies = new List(); + if (arch == "cuda-cu12.2.0-full") + { + if (Application.platform == RuntimePlatform.WindowsEditor || Application.platform == RuntimePlatform.WindowsPlayer || Application.platform == RuntimePlatform.WindowsServer) + { + dependencies.Add(Path.Combine(LLMUnitySetup.libraryPath, $"windows-{arch}/cudart64_12.dll")); + dependencies.Add(Path.Combine(LLMUnitySetup.libraryPath, $"windows-{arch}/cublasLt64_12.dll")); + dependencies.Add(Path.Combine(LLMUnitySetup.libraryPath, $"windows-{arch}/cublas64_12.dll")); + } + } else if (arch == "vulkan") { + if (Application.platform == RuntimePlatform.WindowsEditor || Application.platform == RuntimePlatform.WindowsPlayer || Application.platform == RuntimePlatform.WindowsServer) + { + dependencies.Add(Path.Combine(LLMUnitySetup.libraryPath, $"windows-{arch}/vulkan-1.dll")); + } + else if (Application.platform == RuntimePlatform.LinuxEditor || Application.platform == RuntimePlatform.LinuxPlayer || Application.platform == RuntimePlatform.LinuxServer) + { + dependencies.Add(Path.Combine(LLMUnitySetup.libraryPath, $"linux-{arch}/libvulkan.so.1")); + } + } + return dependencies; + } + /// /// Gets the path of the llama.cpp library for the specified architecture. /// @@ -724,6 +760,7 @@ public string GetStringWrapperResult(IntPtr stringWrapper) public void Destroy() { if (libraryHandle != IntPtr.Zero) LibraryLoader.FreeLibrary(libraryHandle); + foreach (IntPtr dependencyHandle in dependencyHandles) LibraryLoader.FreeLibrary(dependencyHandle); } } }