diff --git a/.github/doxygen/Doxyfile b/.github/doxygen/Doxyfile index c81da466..d10a6532 100644 --- a/.github/doxygen/Doxyfile +++ b/.github/doxygen/Doxyfile @@ -48,7 +48,7 @@ PROJECT_NAME = "LLM for Unity" # could be handy for archiving the generated documentation or if some version # control system is used. -PROJECT_NUMBER = v2.4.0 +PROJECT_NUMBER = v2.4.1 # Using the PROJECT_BRIEF tag one can provide an optional one line description # for a project that appears at the top of each page and should give viewer a diff --git a/CHANGELOG.md b/CHANGELOG.md index 9ad39aaa..64f3ed8a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,13 @@ +## v2.4.1 +#### 🚀 Features + +- Static library linking on mobile (fixes iOS signing) (PR: #289) + +#### 🐛 Fixes + +- Fix support for extras (flash attention, iQ quants) (PR: #292) + + ## v2.4.0 #### 🚀 Features diff --git a/CHANGELOG.release.md b/CHANGELOG.release.md index 5abea998..8733cb3d 100644 --- a/CHANGELOG.release.md +++ b/CHANGELOG.release.md @@ -1,16 +1,8 @@ ### 🚀 Features -- iOS deployment (PR: #267) -- Improve building process (PR: #282) -- Add structured output / function calling sample (PR: #281) -- Update LlamaLib to v1.2.0 (llama.cpp b4218) (PR: #283) +- Static library linking on mobile (fixes iOS signing) (PR: #289) ### 🐛 Fixes -- Clear temp build directory before building (PR: #278) - -### 📦 General - -- Remove support for extras (flash attention, iQ quants) (PR: #284) -- remove support for LLM base prompt (PR: #285) +- Fix support for extras (flash attention, iQ quants) (PR: #292) diff --git a/Editor/LLMBuildProcessor.cs b/Editor/LLMBuildProcessor.cs index f0638a88..96d0f3a8 100644 --- a/Editor/LLMBuildProcessor.cs +++ b/Editor/LLMBuildProcessor.cs @@ -2,6 +2,9 @@ using UnityEditor.Build; using UnityEditor.Build.Reporting; using UnityEngine; +#if UNITY_IOS +using UnityEditor.iOS.Xcode; +#endif namespace LLMUnity { @@ -43,9 +46,27 @@ private void OnBuildError(string condition, string stacktrace, LogType type) if (type == LogType.Error) BuildCompleted(); } +#if UNITY_IOS + /// + /// Adds the Accelerate framework (for ios) + /// + public static void AddAccelerate(string outputPath) + { + string projPath = PBXProject.GetPBXProjectPath(outputPath); + PBXProject proj = new PBXProject(); + proj.ReadFromFile(projPath); + proj.AddFrameworkToProject(proj.GetUnityMainTargetGuid(), "Accelerate.framework", false); + proj.AddFrameworkToProject(proj.GetUnityFrameworkTargetGuid(), "Accelerate.framework", false); + proj.WriteToFile(projPath); + } +#endif + // called after the build public void OnPostprocessBuild(BuildReport report) { +#if UNITY_IOS + AddAccelerate(report.summary.outputPath); +#endif BuildCompleted(); } diff --git a/Editor/LLMEditor.cs b/Editor/LLMEditor.cs index d2b7ca49..5ed29afe 100644 --- a/Editor/LLMEditor.cs +++ b/Editor/LLMEditor.cs @@ -111,6 +111,7 @@ public override void AddModelSettings(SerializedObject llmScriptSO) if (llmScriptSO.FindProperty("advancedOptions").boolValue) { attributeClasses.Add(typeof(ModelAdvancedAttribute)); + if (LLMUnitySetup.FullLlamaLib) attributeClasses.Add(typeof(ModelExtrasAttribute)); } ShowPropertiesOfClass("", llmScriptSO, attributeClasses, false); Space(); @@ -444,12 +445,18 @@ private void CopyToClipboard(string text) te.Copy(); } + public void AddExtrasToggle() + { + if (ToggleButton("Use extras", LLMUnitySetup.FullLlamaLib)) LLMUnitySetup.SetFullLlamaLib(!LLMUnitySetup.FullLlamaLib); + } + public override void AddOptionsToggles(SerializedObject llmScriptSO) { AddDebugModeToggle(); EditorGUILayout.BeginHorizontal(); AddAdvancedOptionsToggle(llmScriptSO); + AddExtrasToggle(); EditorGUILayout.EndHorizontal(); Space(); } diff --git a/README.md b/README.md index 26a84a59..ad242081 100644 --- a/README.md +++ b/README.md @@ -499,7 +499,8 @@ Save the scene, run and enjoy! ### LLM Settings - `Show/Hide Advanced Options` Toggle to show/hide advanced options from below -- `Log Level` select how verbose the log messages arequants) +- `Log Level` select how verbose the log messages are +- `Use extras` select to install and allow the use of extra features (flash attention and IQ quants) #### 💻 Setup Settings @@ -550,6 +551,7 @@ If the user's GPU is not supported, the LLM will fall back to the CPU - `Chat Template` the chat template being used for the LLM - `Lora` the path of the LoRAs being used (relative to the Assets/StreamingAssets folder) - `Lora Weights` the weights of the LoRAs being used + - `Flash Attention` click to use flash attention in the model (if `Use extras` is enabled) @@ -557,6 +559,7 @@ If the user's GPU is not supported, the LLM will fall back to the CPU - `Show/Hide Advanced Options` Toggle to show/hide advanced options from below - `Log Level` select how verbose the log messages are +- `Use extras` select to install and allow the use of extra features (flash attention and IQ quants) #### 💻 Setup Settings
diff --git a/Runtime/LLM.cs b/Runtime/LLM.cs index 7b3d2571..3564d5bf 100644 --- a/Runtime/LLM.cs +++ b/Runtime/LLM.cs @@ -59,6 +59,8 @@ public class LLM : MonoBehaviour [ModelAdvanced] public string lora = ""; /// the weights of the LORA models being used. [ModelAdvanced] public string loraWeights = ""; + /// enable use of flash attention + [ModelExtras] public bool flashAttention = false; /// API key to use for the server (optional) public string APIKey; @@ -430,6 +432,7 @@ protected virtual string GetLlamaccpArguments() if (numThreadsToUse > 0) arguments += $" -t {numThreadsToUse}"; arguments += loraArgument; arguments += $" -ngl {numGPULayers}"; + if (LLMUnitySetup.FullLlamaLib && flashAttention) arguments += $" --flash-attn"; if (remote) { arguments += $" --port {port} --host 0.0.0.0"; @@ -719,7 +722,7 @@ public void ApplyLoras() json = json.Substring(startIndex, endIndex - startIndex); IntPtr stringWrapper = llmlib.StringWrapper_Construct(); - llmlib.LLM_Lora_Weight(LLMObject, json, stringWrapper); + llmlib.LLM_LoraWeight(LLMObject, json, stringWrapper); llmlib.StringWrapper_Delete(stringWrapper); } diff --git a/Runtime/LLMBuilder.cs b/Runtime/LLMBuilder.cs index c95835b9..e8a6dd2f 100644 --- a/Runtime/LLMBuilder.cs +++ b/Runtime/LLMBuilder.cs @@ -17,6 +17,7 @@ public class LLMBuilder static List movedPairs = new List(); public static string BuildTempDir = Path.Combine(Application.temporaryCachePath, "LLMUnityBuild"); public static string androidPluginDir = Path.Combine(Application.dataPath, "Plugins", "Android", "LLMUnity"); + public static string iOSPluginDir = Path.Combine(Application.dataPath, "Plugins", "iOS", "LLMUnity"); static string movedCache = Path.Combine(BuildTempDir, "moved.json"); [InitializeOnLoadMethod] @@ -87,7 +88,7 @@ public static void MovePath(string source, string target) /// path public static bool DeletePath(string path) { - string[] allowedDirs = new string[] { LLMUnitySetup.GetAssetPath(), BuildTempDir, androidPluginDir}; + string[] allowedDirs = new string[] { LLMUnitySetup.GetAssetPath(), BuildTempDir, androidPluginDir, iOSPluginDir}; bool deleteOK = false; foreach (string allowedDir in allowedDirs) deleteOK = deleteOK || LLMUnitySetup.IsSubPath(path, allowedDir); if (!deleteOK) @@ -148,10 +149,10 @@ static void AddActionAddMeta(string target) } /// - /// Hides all the library platforms apart from the target platform by moving out their library folders outside of StreamingAssets + /// Moves libraries in the correct place for building /// /// target platform - public static void HideLibraryPlatforms(string platform) + public static void BuildLibraryPlatforms(string platform) { List platforms = new List(){ "windows", "macos", "linux", "android", "ios", "setup" }; platforms.Remove(platform); @@ -161,6 +162,8 @@ public static void HideLibraryPlatforms(string platform) foreach (string platformPrefix in platforms) { bool move = sourceName.StartsWith(platformPrefix); + move = move || (sourceName.Contains("cuda") && !sourceName.Contains("full") && LLMUnitySetup.FullLlamaLib); + move = move || (sourceName.Contains("cuda") && sourceName.Contains("full") && !LLMUnitySetup.FullLlamaLib); if (move) { string target = Path.Combine(BuildTempDir, sourceName); @@ -170,13 +173,14 @@ public static void HideLibraryPlatforms(string platform) } } - if (platform == "android") + if (platform == "android" || platform == "ios") { - string source = Path.Combine(LLMUnitySetup.libraryPath, "android"); - string target = Path.Combine(androidPluginDir, LLMUnitySetup.libraryName); + string pluginDir = platform == "android"? androidPluginDir: iOSPluginDir; + string source = Path.Combine(LLMUnitySetup.libraryPath, platform); + string target = Path.Combine(pluginDir, LLMUnitySetup.libraryName); MoveAction(source, target); MoveAction(source + ".meta", target + ".meta"); - AddActionAddMeta(androidPluginDir); + AddActionAddMeta(pluginDir); } } @@ -196,7 +200,7 @@ public static void Build(string platform) { DeletePath(BuildTempDir); Directory.CreateDirectory(BuildTempDir); - HideLibraryPlatforms(platform); + BuildLibraryPlatforms(platform); BuildModels(); } diff --git a/Runtime/LLMCaller.cs b/Runtime/LLMCaller.cs index ac0949a3..2ef586d6 100644 --- a/Runtime/LLMCaller.cs +++ b/Runtime/LLMCaller.cs @@ -128,7 +128,11 @@ protected virtual void AssignLLM() if (remote || llm != null) return; List validLLMs = new List(); +#if UNITY_6000_0_OR_NEWER + foreach (LLM foundllm in FindObjectsByType(typeof(LLM), FindObjectsSortMode.None)) +#else foreach (LLM foundllm in FindObjectsOfType()) +#endif { if (IsValidLLM(foundllm) && IsAutoAssignableLLM(foundllm)) validLLMs.Add(foundllm); } @@ -284,7 +288,9 @@ protected virtual async Task PostRequestRemote(string json, strin } // Start the request asynchronously - var asyncOperation = request.SendWebRequest(); + UnityWebRequestAsyncOperation asyncOperation = request.SendWebRequest(); + await Task.Yield(); // Wait for the next frame so that asyncOperation is properly registered (especially if not in main thread) + float lastProgress = 0f; // Continue updating progress until the request is completed while (!asyncOperation.isDone) diff --git a/Runtime/LLMLib.cs b/Runtime/LLMLib.cs index 729ee6bd..98dcc78c 100644 --- a/Runtime/LLMLib.cs +++ b/Runtime/LLMLib.cs @@ -364,11 +364,98 @@ public static int dlclose(IntPtr handle) public class LLMLib { IntPtr libraryHandle = IntPtr.Zero; - static readonly object staticLock = new object(); static bool has_avx = false; static bool has_avx2 = false; static bool has_avx512 = false; + +#if (UNITY_ANDROID || UNITY_IOS) && !UNITY_EDITOR + + public LLMLib(string arch){} + +#if UNITY_ANDROID + public const string LibraryName = "libundreamai_android"; +#else + public const string LibraryName = "__Internal"; +#endif + + [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint="Logging")] + public static extern void LoggingStatic(IntPtr stringWrapper); + [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint="StopLogging")] + public static extern void StopLoggingStatic(); + [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint="LLM_Construct")] + public static extern IntPtr LLM_ConstructStatic(string command); + [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint="LLM_Delete")] + public static extern void LLM_DeleteStatic(IntPtr LLMObject); + [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint="LLM_StartServer")] + public static extern void LLM_StartServerStatic(IntPtr LLMObject); + [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint="LLM_StopServer")] + public static extern void LLM_StopServerStatic(IntPtr LLMObject); + [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint="LLM_Start")] + public static extern void LLM_StartStatic(IntPtr LLMObject); + [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint="LLM_Started")] + public static extern bool LLM_StartedStatic(IntPtr LLMObject); + [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint="LLM_Stop")] + public static extern void LLM_StopStatic(IntPtr LLMObject); + [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint="LLM_SetTemplate")] + public static extern void LLM_SetTemplateStatic(IntPtr LLMObject, string chatTemplate); + [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint="LLM_SetSSL")] + public static extern void LLM_SetSSLStatic(IntPtr LLMObject, string SSLCert, string SSLKey); + [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint="LLM_Tokenize")] + public static extern void LLM_TokenizeStatic(IntPtr LLMObject, string jsonData, IntPtr stringWrapper); + [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint="LLM_Detokenize")] + public static extern void LLM_DetokenizeStatic(IntPtr LLMObject, string jsonData, IntPtr stringWrapper); + [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint="LLM_Embeddings")] + public static extern void LLM_EmbeddingsStatic(IntPtr LLMObject, string jsonData, IntPtr stringWrapper); + [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint="LLM_Lora_Weight")] + public static extern void LLM_LoraWeightStatic(IntPtr LLMObject, string jsonData, IntPtr stringWrapper); + [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint="LLM_Lora_List")] + public static extern void LLM_LoraListStatic(IntPtr LLMObject, IntPtr stringWrapper); + [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint="LLM_Completion")] + public static extern void LLM_CompletionStatic(IntPtr LLMObject, string jsonData, IntPtr stringWrapper); + [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint="LLM_Slot")] + public static extern void LLM_SlotStatic(IntPtr LLMObject, string jsonData, IntPtr stringWrapper); + [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint="LLM_Cancel")] + public static extern void LLM_CancelStatic(IntPtr LLMObject, int idSlot); + [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint="LLM_Status")] + public static extern int LLM_StatusStatic(IntPtr LLMObject, IntPtr stringWrapper); + [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint="StringWrapper_Construct")] + public static extern IntPtr StringWrapper_ConstructStatic(); + [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint="StringWrapper_Delete")] + public static extern void StringWrapper_DeleteStatic(IntPtr instance); + [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint="StringWrapper_GetStringSize")] + public static extern int StringWrapper_GetStringSizeStatic(IntPtr instance); + [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint="StringWrapper_GetString")] + public static extern void StringWrapper_GetStringStatic(IntPtr instance, IntPtr buffer, int bufferSize, bool clear = false); + + public void Logging(IntPtr stringWrapper){ LoggingStatic(stringWrapper); } + public void StopLogging(){ StopLoggingStatic(); } + public IntPtr LLM_Construct(string command){ return LLM_ConstructStatic(command); } + public void LLM_Delete(IntPtr LLMObject){ LLM_DeleteStatic(LLMObject); } + public void LLM_StartServer(IntPtr LLMObject){ LLM_StartServerStatic(LLMObject); } + public void LLM_StopServer(IntPtr LLMObject){ LLM_StopServerStatic(LLMObject); } + public void LLM_Start(IntPtr LLMObject){ LLM_StartStatic(LLMObject); } + public bool LLM_Started(IntPtr LLMObject){ return LLM_StartedStatic(LLMObject); } + public void LLM_Stop(IntPtr LLMObject){ LLM_StopStatic(LLMObject); } + public void LLM_SetTemplate(IntPtr LLMObject, string chatTemplate){ LLM_SetTemplateStatic(LLMObject, chatTemplate); } + public void LLM_SetSSL(IntPtr LLMObject, string SSLCert, string SSLKey){ LLM_SetSSLStatic(LLMObject, SSLCert, SSLKey); } + public void LLM_Tokenize(IntPtr LLMObject, string jsonData, IntPtr stringWrapper){ LLM_TokenizeStatic(LLMObject, jsonData, stringWrapper); } + public void LLM_Detokenize(IntPtr LLMObject, string jsonData, IntPtr stringWrapper){ LLM_DetokenizeStatic(LLMObject, jsonData, stringWrapper); } + public void LLM_Embeddings(IntPtr LLMObject, string jsonData, IntPtr stringWrapper){ LLM_EmbeddingsStatic(LLMObject, jsonData, stringWrapper); } + public void LLM_LoraWeight(IntPtr LLMObject, string jsonData, IntPtr stringWrapper){ LLM_LoraWeightStatic(LLMObject, jsonData, stringWrapper); } + public void LLM_LoraList(IntPtr LLMObject, IntPtr stringWrapper){ LLM_LoraListStatic(LLMObject, stringWrapper); } + public void LLM_Completion(IntPtr LLMObject, string jsonData, IntPtr stringWrapper){ LLM_CompletionStatic(LLMObject, jsonData, stringWrapper); } + public void LLM_Slot(IntPtr LLMObject, string jsonData, IntPtr stringWrapper){ LLM_SlotStatic(LLMObject, jsonData, stringWrapper); } + public void LLM_Cancel(IntPtr LLMObject, int idSlot){ LLM_CancelStatic(LLMObject, idSlot); } + public int LLM_Status(IntPtr LLMObject, IntPtr stringWrapper){ return LLM_StatusStatic(LLMObject, stringWrapper); } + public IntPtr StringWrapper_Construct(){ return StringWrapper_ConstructStatic(); } + public void StringWrapper_Delete(IntPtr instance){ StringWrapper_DeleteStatic(instance); } + public int StringWrapper_GetStringSize(IntPtr instance){ return StringWrapper_GetStringSizeStatic(instance); } + public void StringWrapper_GetString(IntPtr instance, IntPtr buffer, int bufferSize, bool clear = false){ StringWrapper_GetStringStatic(instance, buffer, bufferSize, clear); } + +#else + static bool has_avx_set = false; + static readonly object staticLock = new object(); static LLMLib() { @@ -427,7 +514,7 @@ public LLMLib(string arch) LLM_Tokenize = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Tokenize"); LLM_Detokenize = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Detokenize"); LLM_Embeddings = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Embeddings"); - LLM_Lora_Weight = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Lora_Weight"); + LLM_LoraWeight = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Lora_Weight"); LLM_LoraList = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Lora_List"); LLM_Completion = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Completion"); LLM_Slot = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Slot"); @@ -441,69 +528,6 @@ public LLMLib(string arch) StopLogging = LibraryLoader.GetSymbolDelegate(libraryHandle, "StopLogging"); } - /// - /// Destroys the LLM library - /// - public void Destroy() - { - if (libraryHandle != IntPtr.Zero) LibraryLoader.FreeLibrary(libraryHandle); - } - - /// - /// Identifies the possible architectures that we can use based on the OS and GPU usage - /// - /// whether to allow GPU architectures - /// possible architectures - public static List PossibleArchitectures(bool gpu = false) - { - List architectures = new List(); - if (Application.platform == RuntimePlatform.WindowsEditor || Application.platform == RuntimePlatform.WindowsPlayer || Application.platform == RuntimePlatform.WindowsServer || - Application.platform == RuntimePlatform.LinuxEditor || Application.platform == RuntimePlatform.LinuxPlayer || Application.platform == RuntimePlatform.LinuxServer) - { - if (gpu) - { - architectures.Add("cuda-cu12.2.0"); - architectures.Add("cuda-cu11.7.1"); - architectures.Add("hip"); - architectures.Add("vulkan"); - } - if (has_avx512) architectures.Add("avx512"); - if (has_avx2) architectures.Add("avx2"); - if (has_avx) architectures.Add("avx"); - architectures.Add("noavx"); - } - else if (Application.platform == RuntimePlatform.OSXEditor || Application.platform == RuntimePlatform.OSXPlayer) - { - string arch = RuntimeInformation.ProcessArchitecture.ToString().ToLower(); - if (arch.Contains("arm")) - { - architectures.Add("arm64-acc"); - architectures.Add("arm64-no_acc"); - } - else - { - if (arch != "x86" && arch != "x64") LLMUnitySetup.LogWarning($"Unknown architecture of processor {arch}! Falling back to x86_64"); - architectures.Add("x64-acc"); - architectures.Add("x64-no_acc"); - } - } - else if (Application.platform == RuntimePlatform.Android) - { - architectures.Add("android"); - } - else if (Application.platform == RuntimePlatform.IPhonePlayer) - { - architectures.Add("ios"); - } - else - { - string error = "Unknown OS"; - LLMUnitySetup.LogError(error); - throw new Exception(error); - } - return architectures; - } - /// /// Gets the path of a library that allows to detect the underlying CPU (Windows / Linux). /// @@ -546,14 +570,6 @@ public static string GetArchitecturePath(string arch) { filename = $"macos-{arch}/libundreamai_macos-{arch}.dylib"; } - else if (Application.platform == RuntimePlatform.Android) - { - return "libundreamai_android.so"; - } - else if (Application.platform == RuntimePlatform.IPhonePlayer) - { - filename = "iOS/libundreamai_iOS.dylib"; - } else { string error = "Unknown OS"; @@ -563,31 +579,6 @@ public static string GetArchitecturePath(string arch) return Path.Combine(LLMUnitySetup.libraryPath, filename); } - /// - /// Allows to retrieve a string from the library (Unity only allows marshalling of chars) - /// - /// string wrapper pointer - /// retrieved string - public string GetStringWrapperResult(IntPtr stringWrapper) - { - string result = ""; - int bufferSize = StringWrapper_GetStringSize(stringWrapper); - if (bufferSize > 1) - { - IntPtr buffer = Marshal.AllocHGlobal(bufferSize); - try - { - StringWrapper_GetString(stringWrapper, buffer, bufferSize); - result = Marshal.PtrToStringAnsi(buffer); - } - finally - { - Marshal.FreeHGlobal(buffer); - } - } - return result; - } - public delegate bool HasArchDelegate(); public delegate void LoggingDelegate(IntPtr stringWrapper); public delegate void StopLoggingDelegate(); @@ -629,7 +620,7 @@ public string GetStringWrapperResult(IntPtr stringWrapper) public LLM_DetokenizeDelegate LLM_Detokenize; public LLM_CompletionDelegate LLM_Completion; public LLM_EmbeddingsDelegate LLM_Embeddings; - public LLM_LoraWeightDelegate LLM_Lora_Weight; + public LLM_LoraWeightDelegate LLM_LoraWeight; public LLM_LoraListDelegate LLM_LoraList; public LLM_SlotDelegate LLM_Slot; public LLM_CancelDelegate LLM_Cancel; @@ -638,6 +629,105 @@ public string GetStringWrapperResult(IntPtr stringWrapper) public StringWrapper_DeleteDelegate StringWrapper_Delete; public StringWrapper_GetStringSizeDelegate StringWrapper_GetStringSize; public StringWrapper_GetStringDelegate StringWrapper_GetString; + +#endif + + /// + /// Identifies the possible architectures that we can use based on the OS and GPU usage + /// + /// whether to allow GPU architectures + /// possible architectures + public static List PossibleArchitectures(bool gpu = false) + { + List architectures = new List(); + if (Application.platform == RuntimePlatform.WindowsEditor || Application.platform == RuntimePlatform.WindowsPlayer || Application.platform == RuntimePlatform.WindowsServer || + Application.platform == RuntimePlatform.LinuxEditor || Application.platform == RuntimePlatform.LinuxPlayer || Application.platform == RuntimePlatform.LinuxServer) + { + if (gpu) + { + if (LLMUnitySetup.FullLlamaLib) + { + architectures.Add("cuda-cu12.2.0-full"); + architectures.Add("cuda-cu11.7.1-full"); + } + else + { + architectures.Add("cuda-cu12.2.0"); + architectures.Add("cuda-cu11.7.1"); + } + architectures.Add("hip"); + architectures.Add("vulkan"); + } + if (has_avx512) architectures.Add("avx512"); + if (has_avx2) architectures.Add("avx2"); + if (has_avx) architectures.Add("avx"); + architectures.Add("noavx"); + } + else if (Application.platform == RuntimePlatform.OSXEditor || Application.platform == RuntimePlatform.OSXPlayer) + { + string arch = RuntimeInformation.ProcessArchitecture.ToString().ToLower(); + if (arch.Contains("arm")) + { + architectures.Add("arm64-acc"); + architectures.Add("arm64-no_acc"); + } + else + { + if (arch != "x86" && arch != "x64") LLMUnitySetup.LogWarning($"Unknown architecture of processor {arch}! Falling back to x86_64"); + architectures.Add("x64-acc"); + architectures.Add("x64-no_acc"); + } + } + else if (Application.platform == RuntimePlatform.Android) + { + architectures.Add("android"); + } + else if (Application.platform == RuntimePlatform.IPhonePlayer) + { + architectures.Add("ios"); + } + else + { + string error = "Unknown OS"; + LLMUnitySetup.LogError(error); + throw new Exception(error); + } + return architectures; + } + + /// + /// Allows to retrieve a string from the library (Unity only allows marshalling of chars) + /// + /// string wrapper pointer + /// retrieved string + public string GetStringWrapperResult(IntPtr stringWrapper) + { + string result = ""; + int bufferSize = StringWrapper_GetStringSize(stringWrapper); + if (bufferSize > 1) + { + IntPtr buffer = Marshal.AllocHGlobal(bufferSize); + try + { + StringWrapper_GetString(stringWrapper, buffer, bufferSize); + result = Marshal.PtrToStringAnsi(buffer); + } + finally + { + Marshal.FreeHGlobal(buffer); + } + } + return result; + } + + /// + /// Destroys the LLM library + /// + public void Destroy() + { + if (libraryHandle != IntPtr.Zero) LibraryLoader.FreeLibrary(libraryHandle); + } } + } /// \endcond diff --git a/Runtime/LLMUnitySetup.cs b/Runtime/LLMUnitySetup.cs index adb0d686..0d50e7ff 100644 --- a/Runtime/LLMUnitySetup.cs +++ b/Runtime/LLMUnitySetup.cs @@ -59,6 +59,7 @@ public class LocalRemoteAttribute : PropertyAttribute {} public class RemoteAttribute : PropertyAttribute {} public class LocalAttribute : PropertyAttribute {} public class ModelAttribute : PropertyAttribute {} + public class ModelExtrasAttribute : PropertyAttribute {} public class ChatAttribute : PropertyAttribute {} public class LLMUnityAttribute : PropertyAttribute {} @@ -100,9 +101,9 @@ public class LLMUnitySetup { // DON'T CHANGE! the version is autocompleted with a GitHub action /// LLM for Unity version - public static string Version = "v2.4.0"; + public static string Version = "v2.4.1"; /// LlamaLib version - public static string LlamaLibVersion = "v1.2.0"; + public static string LlamaLibVersion = "v1.2.1"; /// LlamaLib release url public static string LlamaLibReleaseURL = $"https://github.com/undreamai/LlamaLib/releases/download/{LlamaLibVersion}"; /// LlamaLib name @@ -111,6 +112,8 @@ public class LLMUnitySetup public static string libraryPath = GetAssetPath(libraryName); /// LlamaLib url public static string LlamaLibURL = $"{LlamaLibReleaseURL}/{libraryName}.zip"; + /// LlamaLib extension url + public static string LlamaLibExtensionURL = $"{LlamaLibReleaseURL}/{libraryName}-full.zip"; /// LLMnity store path public static string LLMUnityStore = Path.Combine(Environment.GetFolderPath(Environment.SpecialFolder.ApplicationData), "LLMUnity"); /// Model download path @@ -150,6 +153,8 @@ public class LLMUnitySetup /// \cond HIDE [LLMUnity] public static DebugModeType DebugMode = DebugModeType.All; static string DebugModeKey = "DebugMode"; + public static bool FullLlamaLib = false; + static string FullLlamaLibKey = "FullLlamaLib"; static List> errorCallbacks = new List>(); static readonly object lockObject = new object(); static Dictionary androidExtractTasks = new Dictionary(); @@ -184,6 +189,7 @@ public static void LogError(string message) static void LoadPlayerPrefs() { DebugMode = (DebugModeType)PlayerPrefs.GetInt(DebugModeKey, (int)DebugModeType.All); + FullLlamaLib = PlayerPrefs.GetInt(FullLlamaLibKey, 0) == 1; } public static void SetDebugMode(DebugModeType newDebugMode) @@ -194,6 +200,18 @@ public static void SetDebugMode(DebugModeType newDebugMode) PlayerPrefs.Save(); } +#if UNITY_EDITOR + public static void SetFullLlamaLib(bool value) + { + if (FullLlamaLib == value) return; + FullLlamaLib = value; + PlayerPrefs.SetInt(FullLlamaLibKey, value ? 1 : 0); + PlayerPrefs.Save(); + _ = DownloadLibrary(); + } + +#endif + public static string GetLibraryName(string version) { return $"undreamai-{version}-llamacpp"; @@ -436,6 +454,9 @@ static async Task DownloadLibrary() // setup LlamaLib in StreamingAssets await DownloadAndExtractInsideDirectory(LlamaLibURL, libraryPath, setupDir); + + // setup LlamaLib extras in StreamingAssets + if (FullLlamaLib) await DownloadAndExtractInsideDirectory(LlamaLibExtensionURL, libraryPath, setupDir); } catch (Exception e) { diff --git a/Samples~/FunctionCalling/Scene.unity b/Samples~/FunctionCalling/Scene.unity index bf5d7c27..87290fa3 100644 --- a/Samples~/FunctionCalling/Scene.unity +++ b/Samples~/FunctionCalling/Scene.unity @@ -13,7 +13,7 @@ OcclusionCullingSettings: --- !u!104 &2 RenderSettings: m_ObjectHideFlags: 0 - serializedVersion: 9 + serializedVersion: 10 m_Fog: 0 m_FogColor: {r: 0.5, g: 0.5, b: 0.5, a: 1} m_FogMode: 3 @@ -38,13 +38,12 @@ RenderSettings: m_ReflectionIntensity: 1 m_CustomReflection: {fileID: 0} m_Sun: {fileID: 0} - m_IndirectSpecularColor: {r: 0.44657832, g: 0.49641222, b: 0.57481664, a: 1} m_UseRadianceAmbientProbe: 0 --- !u!157 &3 LightmapSettings: m_ObjectHideFlags: 0 - serializedVersion: 12 - m_GIWorkflowMode: 1 + serializedVersion: 13 + m_BakeOnSceneLoad: 0 m_GISettings: serializedVersion: 2 m_BounceScale: 1 @@ -67,9 +66,6 @@ LightmapSettings: m_LightmapParameters: {fileID: 0} m_LightmapsBakeMode: 1 m_TextureCompression: 1 - m_FinalGather: 0 - m_FinalGatherFiltering: 1 - m_FinalGatherRayCount: 256 m_ReflectionCompression: 2 m_MixedBakeMode: 2 m_BakeBackend: 1 @@ -602,8 +598,8 @@ MonoBehaviour: m_OnClick: m_PersistentCalls: m_Calls: - - m_Target: {fileID: 0} - m_TargetAssemblyTypeName: SimpleInteraction, Assembly-CSharp + - m_Target: {fileID: 107963747} + m_TargetAssemblyTypeName: LLMUnitySamples.FunctionCalling, Assembly-CSharp m_MethodName: CancelRequests m_Mode: 1 m_Arguments: @@ -613,7 +609,7 @@ MonoBehaviour: m_FloatArgument: 0 m_StringArgument: m_BoolArgument: 0 - m_CallState: 2 + m_CallState: 1 --- !u!114 &724531322 MonoBehaviour: m_ObjectHideFlags: 0 @@ -940,9 +936,8 @@ Light: m_PrefabAsset: {fileID: 0} m_GameObject: {fileID: 909474451} m_Enabled: 1 - serializedVersion: 10 + serializedVersion: 11 m_Type: 1 - m_Shape: 0 m_Color: {r: 1, g: 0.95686275, b: 0.8392157, a: 1} m_Intensity: 1 m_Range: 10 @@ -992,8 +987,12 @@ Light: m_BoundingSphereOverride: {x: 0, y: 0, z: 0, w: 0} m_UseBoundingSphereOverride: 0 m_UseViewFrustumForShadowCasterCull: 1 + m_ForceVisible: 0 m_ShadowRadius: 0 m_ShadowAngle: 0 + m_LightUnit: 1 + m_LuxAtDistance: 1 + m_EnableSpotReflector: 1 --- !u!4 &909474453 Transform: m_ObjectHideFlags: 0 @@ -1048,7 +1047,6 @@ MonoBehaviour: dontDestroyOnLoad: 1 contextSize: 8192 batchSize: 512 - basePrompt: model: chatTemplate: chatml lora: diff --git a/Samples~/MobileDemo/Scene.unity b/Samples~/MobileDemo/Scene.unity index eb7f637e..c1cc66aa 100644 --- a/Samples~/MobileDemo/Scene.unity +++ b/Samples~/MobileDemo/Scene.unity @@ -13,7 +13,7 @@ OcclusionCullingSettings: --- !u!104 &2 RenderSettings: m_ObjectHideFlags: 0 - serializedVersion: 9 + serializedVersion: 10 m_Fog: 0 m_FogColor: {r: 0.5, g: 0.5, b: 0.5, a: 1} m_FogMode: 3 @@ -38,13 +38,12 @@ RenderSettings: m_ReflectionIntensity: 1 m_CustomReflection: {fileID: 0} m_Sun: {fileID: 0} - m_IndirectSpecularColor: {r: 0.44657832, g: 0.49641222, b: 0.57481664, a: 1} m_UseRadianceAmbientProbe: 0 --- !u!157 &3 LightmapSettings: m_ObjectHideFlags: 0 - serializedVersion: 12 - m_GIWorkflowMode: 1 + serializedVersion: 13 + m_BakeOnSceneLoad: 0 m_GISettings: serializedVersion: 2 m_BounceScale: 1 @@ -67,9 +66,6 @@ LightmapSettings: m_LightmapParameters: {fileID: 0} m_LightmapsBakeMode: 1 m_TextureCompression: 1 - m_FinalGather: 0 - m_FinalGatherFiltering: 1 - m_FinalGatherRayCount: 256 m_ReflectionCompression: 2 m_MixedBakeMode: 2 m_BakeBackend: 1 @@ -899,7 +895,7 @@ MonoBehaviour: m_FloatArgument: 0 m_StringArgument: m_BoolArgument: 0 - m_CallState: 0 + m_CallState: 1 --- !u!114 &724531322 MonoBehaviour: m_ObjectHideFlags: 0 @@ -1226,9 +1222,8 @@ Light: m_PrefabAsset: {fileID: 0} m_GameObject: {fileID: 909474451} m_Enabled: 1 - serializedVersion: 10 + serializedVersion: 11 m_Type: 1 - m_Shape: 0 m_Color: {r: 1, g: 0.95686275, b: 0.8392157, a: 1} m_Intensity: 1 m_Range: 10 @@ -1278,8 +1273,12 @@ Light: m_BoundingSphereOverride: {x: 0, y: 0, z: 0, w: 0} m_UseBoundingSphereOverride: 0 m_UseViewFrustumForShadowCasterCull: 1 + m_ForceVisible: 0 m_ShadowRadius: 0 m_ShadowAngle: 0 + m_LightUnit: 1 + m_LuxAtDistance: 1 + m_EnableSpotReflector: 1 --- !u!4 &909474453 Transform: m_ObjectHideFlags: 0 @@ -1334,8 +1333,7 @@ MonoBehaviour: dontDestroyOnLoad: 1 contextSize: 8192 batchSize: 512 - basePrompt: - model: qwen2-0_5b-instruct-q4_k_m.gguf + model: chatTemplate: chatml lora: loraWeights: @@ -1346,9 +1344,9 @@ MonoBehaviour: SSLKey: SSLKeyPath: minContextLength: 0 - maxContextLength: 32768 + maxContextLength: 0 embeddingsOnly: 0 - embeddingLength: 896 + embeddingLength: 0 --- !u!4 &1047848255 Transform: m_ObjectHideFlags: 0 diff --git a/Tests/Runtime/TestLLM.cs b/Tests/Runtime/TestLLM.cs index f2a4813d..a54572e3 100644 --- a/Tests/Runtime/TestLLM.cs +++ b/Tests/Runtime/TestLLM.cs @@ -7,9 +7,41 @@ using System.Collections; using System.IO; using UnityEngine.TestTools; +using UnityEditor; +using UnityEditor.TestTools.TestRunner.Api; namespace LLMUnityTests { + [InitializeOnLoad] + public static class TestRunListener + { + static TestRunListener() + { + var api = ScriptableObject.CreateInstance(); + api.RegisterCallbacks(new TestRunCallbacks()); + } + } + + public class TestRunCallbacks : ICallbacks + { + public void RunStarted(ITestAdaptor testsToRun) { } + + public void RunFinished(ITestResultAdaptor result) + { + LLMUnitySetup.FullLlamaLib = false; + } + + public void TestStarted(ITestAdaptor test) + { + LLMUnitySetup.FullLlamaLib = test.FullName.Contains("CUDA_full"); + } + + public void TestFinished(ITestResultAdaptor result) + { + LLMUnitySetup.FullLlamaLib = false; + } + } + public class TestLLMLoraAssignment { [Test] @@ -231,7 +263,7 @@ public async Task RunTestsTask() await Tests(); llm.OnDestroy(); } - catch (Exception e) + catch (Exception e) { error = e; } @@ -268,7 +300,7 @@ public void TestInitParameters(int nkeep, int chats) public void TestTokens(List tokens) { - Assert.AreEqual(tokens, new List {40}); + Assert.AreEqual(tokens, new List { 40 }); } public void TestWarmup() @@ -292,7 +324,7 @@ public void TestEmbeddings(List embeddings) Assert.That(embeddings.Count == 896); } - public virtual void OnDestroy() {} + public virtual void OnDestroy() { } } public class TestLLM_LLMManager_Load : TestLLM @@ -459,6 +491,10 @@ public override LLMCharacter CreateLLMCharacter() LLMCharacter llmCharacter = base.CreateLLMCharacter(); llmCharacter.save = saveName; llmCharacter.saveCache = true; + foreach (string filename in new string[]{ + llmCharacter.GetJsonSavePath(saveName), + llmCharacter.GetCacheSavePath(saveName) + }) if (File.Exists(filename)) File.Delete(filename); return llmCharacter; } @@ -492,4 +528,51 @@ public void TestSave() } } } + + public class TestLLM_CUDA : TestLLM + { + public override LLM CreateLLM() + { + LLM llm = base.CreateLLM(); + llm.numGPULayers = 10; + return llm; + } + } + + public class TestLLM_CUDA_full : TestLLM_CUDA + { + public override void SetParameters() + { + base.SetParameters(); + if (Application.platform == RuntimePlatform.WindowsEditor || Application.platform == RuntimePlatform.WindowsPlayer) + { + reply1 = "To increase your meme production output, you might consider using more modern tools and techniques to generate memes."; + reply2 = "To increase your meme production output, you can try using various tools and techniques to generate more content quickly"; + } + else + { + reply1 = "To increase your meme production output, you might consider using more advanced tools and techniques to generate memes faster"; + reply2 = "To increase your meme production output, you might consider using more advanced tools and techniques to generate memes faster"; + } + } + } + + public class TestLLM_CUDA_full_attention : TestLLM_CUDA_full + { + public override LLM CreateLLM() + { + LLM llm = base.CreateLLM(); + llm.flashAttention = true; + return llm; + } + + public override void SetParameters() + { + base.SetParameters(); + if (Application.platform == RuntimePlatform.LinuxEditor || Application.platform == RuntimePlatform.LinuxPlayer) + { + reply2 = "To increase your meme production output, you can try using various tools and techniques to generate more memes."; + } + } + } } diff --git a/VERSION b/VERSION index 8721bbc4..a3721209 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -v2.4.0 +v2.4.1 diff --git a/package.json b/package.json index e3775816..396eddc3 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "ai.undream.llm", - "version": "2.4.0", + "version": "2.4.1", "displayName": "LLM for Unity", "description": "LLM for Unity allows to run and distribute Large Language Models (LLMs) in the Unity engine.", "unity": "2022.3",