diff --git a/.github/doxygen/Doxyfile b/.github/doxygen/Doxyfile
index c81da466..d10a6532 100644
--- a/.github/doxygen/Doxyfile
+++ b/.github/doxygen/Doxyfile
@@ -48,7 +48,7 @@ PROJECT_NAME = "LLM for Unity"
# could be handy for archiving the generated documentation or if some version
# control system is used.
-PROJECT_NUMBER = v2.4.0
+PROJECT_NUMBER = v2.4.1
# Using the PROJECT_BRIEF tag one can provide an optional one line description
# for a project that appears at the top of each page and should give viewer a
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 9ad39aaa..64f3ed8a 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,3 +1,13 @@
+## v2.4.1
+#### 🚀 Features
+
+- Static library linking on mobile (fixes iOS signing) (PR: #289)
+
+#### 🐛 Fixes
+
+- Fix support for extras (flash attention, iQ quants) (PR: #292)
+
+
## v2.4.0
#### 🚀 Features
diff --git a/CHANGELOG.release.md b/CHANGELOG.release.md
index 5abea998..8733cb3d 100644
--- a/CHANGELOG.release.md
+++ b/CHANGELOG.release.md
@@ -1,16 +1,8 @@
### 🚀 Features
-- iOS deployment (PR: #267)
-- Improve building process (PR: #282)
-- Add structured output / function calling sample (PR: #281)
-- Update LlamaLib to v1.2.0 (llama.cpp b4218) (PR: #283)
+- Static library linking on mobile (fixes iOS signing) (PR: #289)
### 🐛 Fixes
-- Clear temp build directory before building (PR: #278)
-
-### 📦 General
-
-- Remove support for extras (flash attention, iQ quants) (PR: #284)
-- remove support for LLM base prompt (PR: #285)
+- Fix support for extras (flash attention, iQ quants) (PR: #292)
diff --git a/Editor/LLMBuildProcessor.cs b/Editor/LLMBuildProcessor.cs
index f0638a88..96d0f3a8 100644
--- a/Editor/LLMBuildProcessor.cs
+++ b/Editor/LLMBuildProcessor.cs
@@ -2,6 +2,9 @@
using UnityEditor.Build;
using UnityEditor.Build.Reporting;
using UnityEngine;
+#if UNITY_IOS
+using UnityEditor.iOS.Xcode;
+#endif
namespace LLMUnity
{
@@ -43,9 +46,27 @@ private void OnBuildError(string condition, string stacktrace, LogType type)
if (type == LogType.Error) BuildCompleted();
}
+#if UNITY_IOS
+ ///
+ /// Adds the Accelerate framework (for ios)
+ ///
+ public static void AddAccelerate(string outputPath)
+ {
+ string projPath = PBXProject.GetPBXProjectPath(outputPath);
+ PBXProject proj = new PBXProject();
+ proj.ReadFromFile(projPath);
+ proj.AddFrameworkToProject(proj.GetUnityMainTargetGuid(), "Accelerate.framework", false);
+ proj.AddFrameworkToProject(proj.GetUnityFrameworkTargetGuid(), "Accelerate.framework", false);
+ proj.WriteToFile(projPath);
+ }
+#endif
+
// called after the build
public void OnPostprocessBuild(BuildReport report)
{
+#if UNITY_IOS
+ AddAccelerate(report.summary.outputPath);
+#endif
BuildCompleted();
}
diff --git a/Editor/LLMEditor.cs b/Editor/LLMEditor.cs
index d2b7ca49..5ed29afe 100644
--- a/Editor/LLMEditor.cs
+++ b/Editor/LLMEditor.cs
@@ -111,6 +111,7 @@ public override void AddModelSettings(SerializedObject llmScriptSO)
if (llmScriptSO.FindProperty("advancedOptions").boolValue)
{
attributeClasses.Add(typeof(ModelAdvancedAttribute));
+ if (LLMUnitySetup.FullLlamaLib) attributeClasses.Add(typeof(ModelExtrasAttribute));
}
ShowPropertiesOfClass("", llmScriptSO, attributeClasses, false);
Space();
@@ -444,12 +445,18 @@ private void CopyToClipboard(string text)
te.Copy();
}
+ public void AddExtrasToggle()
+ {
+ if (ToggleButton("Use extras", LLMUnitySetup.FullLlamaLib)) LLMUnitySetup.SetFullLlamaLib(!LLMUnitySetup.FullLlamaLib);
+ }
+
public override void AddOptionsToggles(SerializedObject llmScriptSO)
{
AddDebugModeToggle();
EditorGUILayout.BeginHorizontal();
AddAdvancedOptionsToggle(llmScriptSO);
+ AddExtrasToggle();
EditorGUILayout.EndHorizontal();
Space();
}
diff --git a/README.md b/README.md
index 26a84a59..ad242081 100644
--- a/README.md
+++ b/README.md
@@ -499,7 +499,8 @@ Save the scene, run and enjoy!
### LLM Settings
- `Show/Hide Advanced Options` Toggle to show/hide advanced options from below
-- `Log Level` select how verbose the log messages arequants)
+- `Log Level` select how verbose the log messages are
+- `Use extras` select to install and allow the use of extra features (flash attention and IQ quants)
#### 💻 Setup Settings
@@ -550,6 +551,7 @@ If the user's GPU is not supported, the LLM will fall back to the CPU
- `Chat Template` the chat template being used for the LLM
- `Lora` the path of the LoRAs being used (relative to the Assets/StreamingAssets folder)
- `Lora Weights` the weights of the LoRAs being used
+ - `Flash Attention` click to use flash attention in the model (if `Use extras` is enabled)
@@ -557,6 +559,7 @@ If the user's GPU is not supported, the LLM will fall back to the CPU
- `Show/Hide Advanced Options` Toggle to show/hide advanced options from below
- `Log Level` select how verbose the log messages are
+- `Use extras` select to install and allow the use of extra features (flash attention and IQ quants)
#### 💻 Setup Settings
diff --git a/Runtime/LLM.cs b/Runtime/LLM.cs
index 7b3d2571..3564d5bf 100644
--- a/Runtime/LLM.cs
+++ b/Runtime/LLM.cs
@@ -59,6 +59,8 @@ public class LLM : MonoBehaviour
[ModelAdvanced] public string lora = "";
///
the weights of the LORA models being used.
[ModelAdvanced] public string loraWeights = "";
+ ///
enable use of flash attention
+ [ModelExtras] public bool flashAttention = false;
///
API key to use for the server (optional)
public string APIKey;
@@ -430,6 +432,7 @@ protected virtual string GetLlamaccpArguments()
if (numThreadsToUse > 0) arguments += $" -t {numThreadsToUse}";
arguments += loraArgument;
arguments += $" -ngl {numGPULayers}";
+ if (LLMUnitySetup.FullLlamaLib && flashAttention) arguments += $" --flash-attn";
if (remote)
{
arguments += $" --port {port} --host 0.0.0.0";
@@ -719,7 +722,7 @@ public void ApplyLoras()
json = json.Substring(startIndex, endIndex - startIndex);
IntPtr stringWrapper = llmlib.StringWrapper_Construct();
- llmlib.LLM_Lora_Weight(LLMObject, json, stringWrapper);
+ llmlib.LLM_LoraWeight(LLMObject, json, stringWrapper);
llmlib.StringWrapper_Delete(stringWrapper);
}
diff --git a/Runtime/LLMBuilder.cs b/Runtime/LLMBuilder.cs
index c95835b9..e8a6dd2f 100644
--- a/Runtime/LLMBuilder.cs
+++ b/Runtime/LLMBuilder.cs
@@ -17,6 +17,7 @@ public class LLMBuilder
static List
movedPairs = new List();
public static string BuildTempDir = Path.Combine(Application.temporaryCachePath, "LLMUnityBuild");
public static string androidPluginDir = Path.Combine(Application.dataPath, "Plugins", "Android", "LLMUnity");
+ public static string iOSPluginDir = Path.Combine(Application.dataPath, "Plugins", "iOS", "LLMUnity");
static string movedCache = Path.Combine(BuildTempDir, "moved.json");
[InitializeOnLoadMethod]
@@ -87,7 +88,7 @@ public static void MovePath(string source, string target)
/// path
public static bool DeletePath(string path)
{
- string[] allowedDirs = new string[] { LLMUnitySetup.GetAssetPath(), BuildTempDir, androidPluginDir};
+ string[] allowedDirs = new string[] { LLMUnitySetup.GetAssetPath(), BuildTempDir, androidPluginDir, iOSPluginDir};
bool deleteOK = false;
foreach (string allowedDir in allowedDirs) deleteOK = deleteOK || LLMUnitySetup.IsSubPath(path, allowedDir);
if (!deleteOK)
@@ -148,10 +149,10 @@ static void AddActionAddMeta(string target)
}
///
- /// Hides all the library platforms apart from the target platform by moving out their library folders outside of StreamingAssets
+ /// Moves libraries in the correct place for building
///
/// target platform
- public static void HideLibraryPlatforms(string platform)
+ public static void BuildLibraryPlatforms(string platform)
{
List platforms = new List(){ "windows", "macos", "linux", "android", "ios", "setup" };
platforms.Remove(platform);
@@ -161,6 +162,8 @@ public static void HideLibraryPlatforms(string platform)
foreach (string platformPrefix in platforms)
{
bool move = sourceName.StartsWith(platformPrefix);
+ move = move || (sourceName.Contains("cuda") && !sourceName.Contains("full") && LLMUnitySetup.FullLlamaLib);
+ move = move || (sourceName.Contains("cuda") && sourceName.Contains("full") && !LLMUnitySetup.FullLlamaLib);
if (move)
{
string target = Path.Combine(BuildTempDir, sourceName);
@@ -170,13 +173,14 @@ public static void HideLibraryPlatforms(string platform)
}
}
- if (platform == "android")
+ if (platform == "android" || platform == "ios")
{
- string source = Path.Combine(LLMUnitySetup.libraryPath, "android");
- string target = Path.Combine(androidPluginDir, LLMUnitySetup.libraryName);
+ string pluginDir = platform == "android"? androidPluginDir: iOSPluginDir;
+ string source = Path.Combine(LLMUnitySetup.libraryPath, platform);
+ string target = Path.Combine(pluginDir, LLMUnitySetup.libraryName);
MoveAction(source, target);
MoveAction(source + ".meta", target + ".meta");
- AddActionAddMeta(androidPluginDir);
+ AddActionAddMeta(pluginDir);
}
}
@@ -196,7 +200,7 @@ public static void Build(string platform)
{
DeletePath(BuildTempDir);
Directory.CreateDirectory(BuildTempDir);
- HideLibraryPlatforms(platform);
+ BuildLibraryPlatforms(platform);
BuildModels();
}
diff --git a/Runtime/LLMCaller.cs b/Runtime/LLMCaller.cs
index ac0949a3..2ef586d6 100644
--- a/Runtime/LLMCaller.cs
+++ b/Runtime/LLMCaller.cs
@@ -128,7 +128,11 @@ protected virtual void AssignLLM()
if (remote || llm != null) return;
List validLLMs = new List();
+#if UNITY_6000_0_OR_NEWER
+ foreach (LLM foundllm in FindObjectsByType(typeof(LLM), FindObjectsSortMode.None))
+#else
foreach (LLM foundllm in FindObjectsOfType())
+#endif
{
if (IsValidLLM(foundllm) && IsAutoAssignableLLM(foundllm)) validLLMs.Add(foundllm);
}
@@ -284,7 +288,9 @@ protected virtual async Task PostRequestRemote(string json, strin
}
// Start the request asynchronously
- var asyncOperation = request.SendWebRequest();
+ UnityWebRequestAsyncOperation asyncOperation = request.SendWebRequest();
+ await Task.Yield(); // Wait for the next frame so that asyncOperation is properly registered (especially if not in main thread)
+
float lastProgress = 0f;
// Continue updating progress until the request is completed
while (!asyncOperation.isDone)
diff --git a/Runtime/LLMLib.cs b/Runtime/LLMLib.cs
index 729ee6bd..98dcc78c 100644
--- a/Runtime/LLMLib.cs
+++ b/Runtime/LLMLib.cs
@@ -364,11 +364,98 @@ public static int dlclose(IntPtr handle)
public class LLMLib
{
IntPtr libraryHandle = IntPtr.Zero;
- static readonly object staticLock = new object();
static bool has_avx = false;
static bool has_avx2 = false;
static bool has_avx512 = false;
+
+#if (UNITY_ANDROID || UNITY_IOS) && !UNITY_EDITOR
+
+ public LLMLib(string arch){}
+
+#if UNITY_ANDROID
+ public const string LibraryName = "libundreamai_android";
+#else
+ public const string LibraryName = "__Internal";
+#endif
+
+ [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint="Logging")]
+ public static extern void LoggingStatic(IntPtr stringWrapper);
+ [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint="StopLogging")]
+ public static extern void StopLoggingStatic();
+ [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint="LLM_Construct")]
+ public static extern IntPtr LLM_ConstructStatic(string command);
+ [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint="LLM_Delete")]
+ public static extern void LLM_DeleteStatic(IntPtr LLMObject);
+ [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint="LLM_StartServer")]
+ public static extern void LLM_StartServerStatic(IntPtr LLMObject);
+ [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint="LLM_StopServer")]
+ public static extern void LLM_StopServerStatic(IntPtr LLMObject);
+ [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint="LLM_Start")]
+ public static extern void LLM_StartStatic(IntPtr LLMObject);
+ [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint="LLM_Started")]
+ public static extern bool LLM_StartedStatic(IntPtr LLMObject);
+ [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint="LLM_Stop")]
+ public static extern void LLM_StopStatic(IntPtr LLMObject);
+ [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint="LLM_SetTemplate")]
+ public static extern void LLM_SetTemplateStatic(IntPtr LLMObject, string chatTemplate);
+ [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint="LLM_SetSSL")]
+ public static extern void LLM_SetSSLStatic(IntPtr LLMObject, string SSLCert, string SSLKey);
+ [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint="LLM_Tokenize")]
+ public static extern void LLM_TokenizeStatic(IntPtr LLMObject, string jsonData, IntPtr stringWrapper);
+ [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint="LLM_Detokenize")]
+ public static extern void LLM_DetokenizeStatic(IntPtr LLMObject, string jsonData, IntPtr stringWrapper);
+ [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint="LLM_Embeddings")]
+ public static extern void LLM_EmbeddingsStatic(IntPtr LLMObject, string jsonData, IntPtr stringWrapper);
+ [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint="LLM_Lora_Weight")]
+ public static extern void LLM_LoraWeightStatic(IntPtr LLMObject, string jsonData, IntPtr stringWrapper);
+ [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint="LLM_Lora_List")]
+ public static extern void LLM_LoraListStatic(IntPtr LLMObject, IntPtr stringWrapper);
+ [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint="LLM_Completion")]
+ public static extern void LLM_CompletionStatic(IntPtr LLMObject, string jsonData, IntPtr stringWrapper);
+ [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint="LLM_Slot")]
+ public static extern void LLM_SlotStatic(IntPtr LLMObject, string jsonData, IntPtr stringWrapper);
+ [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint="LLM_Cancel")]
+ public static extern void LLM_CancelStatic(IntPtr LLMObject, int idSlot);
+ [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint="LLM_Status")]
+ public static extern int LLM_StatusStatic(IntPtr LLMObject, IntPtr stringWrapper);
+ [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint="StringWrapper_Construct")]
+ public static extern IntPtr StringWrapper_ConstructStatic();
+ [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint="StringWrapper_Delete")]
+ public static extern void StringWrapper_DeleteStatic(IntPtr instance);
+ [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint="StringWrapper_GetStringSize")]
+ public static extern int StringWrapper_GetStringSizeStatic(IntPtr instance);
+ [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint="StringWrapper_GetString")]
+ public static extern void StringWrapper_GetStringStatic(IntPtr instance, IntPtr buffer, int bufferSize, bool clear = false);
+
+ public void Logging(IntPtr stringWrapper){ LoggingStatic(stringWrapper); }
+ public void StopLogging(){ StopLoggingStatic(); }
+ public IntPtr LLM_Construct(string command){ return LLM_ConstructStatic(command); }
+ public void LLM_Delete(IntPtr LLMObject){ LLM_DeleteStatic(LLMObject); }
+ public void LLM_StartServer(IntPtr LLMObject){ LLM_StartServerStatic(LLMObject); }
+ public void LLM_StopServer(IntPtr LLMObject){ LLM_StopServerStatic(LLMObject); }
+ public void LLM_Start(IntPtr LLMObject){ LLM_StartStatic(LLMObject); }
+ public bool LLM_Started(IntPtr LLMObject){ return LLM_StartedStatic(LLMObject); }
+ public void LLM_Stop(IntPtr LLMObject){ LLM_StopStatic(LLMObject); }
+ public void LLM_SetTemplate(IntPtr LLMObject, string chatTemplate){ LLM_SetTemplateStatic(LLMObject, chatTemplate); }
+ public void LLM_SetSSL(IntPtr LLMObject, string SSLCert, string SSLKey){ LLM_SetSSLStatic(LLMObject, SSLCert, SSLKey); }
+ public void LLM_Tokenize(IntPtr LLMObject, string jsonData, IntPtr stringWrapper){ LLM_TokenizeStatic(LLMObject, jsonData, stringWrapper); }
+ public void LLM_Detokenize(IntPtr LLMObject, string jsonData, IntPtr stringWrapper){ LLM_DetokenizeStatic(LLMObject, jsonData, stringWrapper); }
+ public void LLM_Embeddings(IntPtr LLMObject, string jsonData, IntPtr stringWrapper){ LLM_EmbeddingsStatic(LLMObject, jsonData, stringWrapper); }
+ public void LLM_LoraWeight(IntPtr LLMObject, string jsonData, IntPtr stringWrapper){ LLM_LoraWeightStatic(LLMObject, jsonData, stringWrapper); }
+ public void LLM_LoraList(IntPtr LLMObject, IntPtr stringWrapper){ LLM_LoraListStatic(LLMObject, stringWrapper); }
+ public void LLM_Completion(IntPtr LLMObject, string jsonData, IntPtr stringWrapper){ LLM_CompletionStatic(LLMObject, jsonData, stringWrapper); }
+ public void LLM_Slot(IntPtr LLMObject, string jsonData, IntPtr stringWrapper){ LLM_SlotStatic(LLMObject, jsonData, stringWrapper); }
+ public void LLM_Cancel(IntPtr LLMObject, int idSlot){ LLM_CancelStatic(LLMObject, idSlot); }
+ public int LLM_Status(IntPtr LLMObject, IntPtr stringWrapper){ return LLM_StatusStatic(LLMObject, stringWrapper); }
+ public IntPtr StringWrapper_Construct(){ return StringWrapper_ConstructStatic(); }
+ public void StringWrapper_Delete(IntPtr instance){ StringWrapper_DeleteStatic(instance); }
+ public int StringWrapper_GetStringSize(IntPtr instance){ return StringWrapper_GetStringSizeStatic(instance); }
+ public void StringWrapper_GetString(IntPtr instance, IntPtr buffer, int bufferSize, bool clear = false){ StringWrapper_GetStringStatic(instance, buffer, bufferSize, clear); }
+
+#else
+
static bool has_avx_set = false;
+ static readonly object staticLock = new object();
static LLMLib()
{
@@ -427,7 +514,7 @@ public LLMLib(string arch)
LLM_Tokenize = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Tokenize");
LLM_Detokenize = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Detokenize");
LLM_Embeddings = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Embeddings");
- LLM_Lora_Weight = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Lora_Weight");
+ LLM_LoraWeight = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Lora_Weight");
LLM_LoraList = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Lora_List");
LLM_Completion = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Completion");
LLM_Slot = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Slot");
@@ -441,69 +528,6 @@ public LLMLib(string arch)
StopLogging = LibraryLoader.GetSymbolDelegate(libraryHandle, "StopLogging");
}
- ///
- /// Destroys the LLM library
- ///
- public void Destroy()
- {
- if (libraryHandle != IntPtr.Zero) LibraryLoader.FreeLibrary(libraryHandle);
- }
-
- ///
- /// Identifies the possible architectures that we can use based on the OS and GPU usage
- ///
- /// whether to allow GPU architectures
- /// possible architectures
- public static List PossibleArchitectures(bool gpu = false)
- {
- List architectures = new List();
- if (Application.platform == RuntimePlatform.WindowsEditor || Application.platform == RuntimePlatform.WindowsPlayer || Application.platform == RuntimePlatform.WindowsServer ||
- Application.platform == RuntimePlatform.LinuxEditor || Application.platform == RuntimePlatform.LinuxPlayer || Application.platform == RuntimePlatform.LinuxServer)
- {
- if (gpu)
- {
- architectures.Add("cuda-cu12.2.0");
- architectures.Add("cuda-cu11.7.1");
- architectures.Add("hip");
- architectures.Add("vulkan");
- }
- if (has_avx512) architectures.Add("avx512");
- if (has_avx2) architectures.Add("avx2");
- if (has_avx) architectures.Add("avx");
- architectures.Add("noavx");
- }
- else if (Application.platform == RuntimePlatform.OSXEditor || Application.platform == RuntimePlatform.OSXPlayer)
- {
- string arch = RuntimeInformation.ProcessArchitecture.ToString().ToLower();
- if (arch.Contains("arm"))
- {
- architectures.Add("arm64-acc");
- architectures.Add("arm64-no_acc");
- }
- else
- {
- if (arch != "x86" && arch != "x64") LLMUnitySetup.LogWarning($"Unknown architecture of processor {arch}! Falling back to x86_64");
- architectures.Add("x64-acc");
- architectures.Add("x64-no_acc");
- }
- }
- else if (Application.platform == RuntimePlatform.Android)
- {
- architectures.Add("android");
- }
- else if (Application.platform == RuntimePlatform.IPhonePlayer)
- {
- architectures.Add("ios");
- }
- else
- {
- string error = "Unknown OS";
- LLMUnitySetup.LogError(error);
- throw new Exception(error);
- }
- return architectures;
- }
-
///
/// Gets the path of a library that allows to detect the underlying CPU (Windows / Linux).
///
@@ -546,14 +570,6 @@ public static string GetArchitecturePath(string arch)
{
filename = $"macos-{arch}/libundreamai_macos-{arch}.dylib";
}
- else if (Application.platform == RuntimePlatform.Android)
- {
- return "libundreamai_android.so";
- }
- else if (Application.platform == RuntimePlatform.IPhonePlayer)
- {
- filename = "iOS/libundreamai_iOS.dylib";
- }
else
{
string error = "Unknown OS";
@@ -563,31 +579,6 @@ public static string GetArchitecturePath(string arch)
return Path.Combine(LLMUnitySetup.libraryPath, filename);
}
- ///
- /// Allows to retrieve a string from the library (Unity only allows marshalling of chars)
- ///
- /// string wrapper pointer
- /// retrieved string
- public string GetStringWrapperResult(IntPtr stringWrapper)
- {
- string result = "";
- int bufferSize = StringWrapper_GetStringSize(stringWrapper);
- if (bufferSize > 1)
- {
- IntPtr buffer = Marshal.AllocHGlobal(bufferSize);
- try
- {
- StringWrapper_GetString(stringWrapper, buffer, bufferSize);
- result = Marshal.PtrToStringAnsi(buffer);
- }
- finally
- {
- Marshal.FreeHGlobal(buffer);
- }
- }
- return result;
- }
-
public delegate bool HasArchDelegate();
public delegate void LoggingDelegate(IntPtr stringWrapper);
public delegate void StopLoggingDelegate();
@@ -629,7 +620,7 @@ public string GetStringWrapperResult(IntPtr stringWrapper)
public LLM_DetokenizeDelegate LLM_Detokenize;
public LLM_CompletionDelegate LLM_Completion;
public LLM_EmbeddingsDelegate LLM_Embeddings;
- public LLM_LoraWeightDelegate LLM_Lora_Weight;
+ public LLM_LoraWeightDelegate LLM_LoraWeight;
public LLM_LoraListDelegate LLM_LoraList;
public LLM_SlotDelegate LLM_Slot;
public LLM_CancelDelegate LLM_Cancel;
@@ -638,6 +629,105 @@ public string GetStringWrapperResult(IntPtr stringWrapper)
public StringWrapper_DeleteDelegate StringWrapper_Delete;
public StringWrapper_GetStringSizeDelegate StringWrapper_GetStringSize;
public StringWrapper_GetStringDelegate StringWrapper_GetString;
+
+#endif
+
+ ///
+ /// Identifies the possible architectures that we can use based on the OS and GPU usage
+ ///
+ /// whether to allow GPU architectures
+ /// possible architectures
+ public static List PossibleArchitectures(bool gpu = false)
+ {
+ List architectures = new List();
+ if (Application.platform == RuntimePlatform.WindowsEditor || Application.platform == RuntimePlatform.WindowsPlayer || Application.platform == RuntimePlatform.WindowsServer ||
+ Application.platform == RuntimePlatform.LinuxEditor || Application.platform == RuntimePlatform.LinuxPlayer || Application.platform == RuntimePlatform.LinuxServer)
+ {
+ if (gpu)
+ {
+ if (LLMUnitySetup.FullLlamaLib)
+ {
+ architectures.Add("cuda-cu12.2.0-full");
+ architectures.Add("cuda-cu11.7.1-full");
+ }
+ else
+ {
+ architectures.Add("cuda-cu12.2.0");
+ architectures.Add("cuda-cu11.7.1");
+ }
+ architectures.Add("hip");
+ architectures.Add("vulkan");
+ }
+ if (has_avx512) architectures.Add("avx512");
+ if (has_avx2) architectures.Add("avx2");
+ if (has_avx) architectures.Add("avx");
+ architectures.Add("noavx");
+ }
+ else if (Application.platform == RuntimePlatform.OSXEditor || Application.platform == RuntimePlatform.OSXPlayer)
+ {
+ string arch = RuntimeInformation.ProcessArchitecture.ToString().ToLower();
+ if (arch.Contains("arm"))
+ {
+ architectures.Add("arm64-acc");
+ architectures.Add("arm64-no_acc");
+ }
+ else
+ {
+ if (arch != "x86" && arch != "x64") LLMUnitySetup.LogWarning($"Unknown architecture of processor {arch}! Falling back to x86_64");
+ architectures.Add("x64-acc");
+ architectures.Add("x64-no_acc");
+ }
+ }
+ else if (Application.platform == RuntimePlatform.Android)
+ {
+ architectures.Add("android");
+ }
+ else if (Application.platform == RuntimePlatform.IPhonePlayer)
+ {
+ architectures.Add("ios");
+ }
+ else
+ {
+ string error = "Unknown OS";
+ LLMUnitySetup.LogError(error);
+ throw new Exception(error);
+ }
+ return architectures;
+ }
+
+ ///
+ /// Allows to retrieve a string from the library (Unity only allows marshalling of chars)
+ ///
+ /// string wrapper pointer
+ /// retrieved string
+ public string GetStringWrapperResult(IntPtr stringWrapper)
+ {
+ string result = "";
+ int bufferSize = StringWrapper_GetStringSize(stringWrapper);
+ if (bufferSize > 1)
+ {
+ IntPtr buffer = Marshal.AllocHGlobal(bufferSize);
+ try
+ {
+ StringWrapper_GetString(stringWrapper, buffer, bufferSize);
+ result = Marshal.PtrToStringAnsi(buffer);
+ }
+ finally
+ {
+ Marshal.FreeHGlobal(buffer);
+ }
+ }
+ return result;
+ }
+
+ ///
+ /// Destroys the LLM library
+ ///
+ public void Destroy()
+ {
+ if (libraryHandle != IntPtr.Zero) LibraryLoader.FreeLibrary(libraryHandle);
+ }
}
+
}
/// \endcond
diff --git a/Runtime/LLMUnitySetup.cs b/Runtime/LLMUnitySetup.cs
index adb0d686..0d50e7ff 100644
--- a/Runtime/LLMUnitySetup.cs
+++ b/Runtime/LLMUnitySetup.cs
@@ -59,6 +59,7 @@ public class LocalRemoteAttribute : PropertyAttribute {}
public class RemoteAttribute : PropertyAttribute {}
public class LocalAttribute : PropertyAttribute {}
public class ModelAttribute : PropertyAttribute {}
+ public class ModelExtrasAttribute : PropertyAttribute {}
public class ChatAttribute : PropertyAttribute {}
public class LLMUnityAttribute : PropertyAttribute {}
@@ -100,9 +101,9 @@ public class LLMUnitySetup
{
// DON'T CHANGE! the version is autocompleted with a GitHub action
/// LLM for Unity version
- public static string Version = "v2.4.0";
+ public static string Version = "v2.4.1";
/// LlamaLib version
- public static string LlamaLibVersion = "v1.2.0";
+ public static string LlamaLibVersion = "v1.2.1";
/// LlamaLib release url
public static string LlamaLibReleaseURL = $"https://github.com/undreamai/LlamaLib/releases/download/{LlamaLibVersion}";
/// LlamaLib name
@@ -111,6 +112,8 @@ public class LLMUnitySetup
public static string libraryPath = GetAssetPath(libraryName);
/// LlamaLib url
public static string LlamaLibURL = $"{LlamaLibReleaseURL}/{libraryName}.zip";
+ /// LlamaLib extension url
+ public static string LlamaLibExtensionURL = $"{LlamaLibReleaseURL}/{libraryName}-full.zip";
/// LLMnity store path
public static string LLMUnityStore = Path.Combine(Environment.GetFolderPath(Environment.SpecialFolder.ApplicationData), "LLMUnity");
/// Model download path
@@ -150,6 +153,8 @@ public class LLMUnitySetup
/// \cond HIDE
[LLMUnity] public static DebugModeType DebugMode = DebugModeType.All;
static string DebugModeKey = "DebugMode";
+ public static bool FullLlamaLib = false;
+ static string FullLlamaLibKey = "FullLlamaLib";
static List> errorCallbacks = new List>();
static readonly object lockObject = new object();
static Dictionary androidExtractTasks = new Dictionary();
@@ -184,6 +189,7 @@ public static void LogError(string message)
static void LoadPlayerPrefs()
{
DebugMode = (DebugModeType)PlayerPrefs.GetInt(DebugModeKey, (int)DebugModeType.All);
+ FullLlamaLib = PlayerPrefs.GetInt(FullLlamaLibKey, 0) == 1;
}
public static void SetDebugMode(DebugModeType newDebugMode)
@@ -194,6 +200,18 @@ public static void SetDebugMode(DebugModeType newDebugMode)
PlayerPrefs.Save();
}
+#if UNITY_EDITOR
+ public static void SetFullLlamaLib(bool value)
+ {
+ if (FullLlamaLib == value) return;
+ FullLlamaLib = value;
+ PlayerPrefs.SetInt(FullLlamaLibKey, value ? 1 : 0);
+ PlayerPrefs.Save();
+ _ = DownloadLibrary();
+ }
+
+#endif
+
public static string GetLibraryName(string version)
{
return $"undreamai-{version}-llamacpp";
@@ -436,6 +454,9 @@ static async Task DownloadLibrary()
// setup LlamaLib in StreamingAssets
await DownloadAndExtractInsideDirectory(LlamaLibURL, libraryPath, setupDir);
+
+ // setup LlamaLib extras in StreamingAssets
+ if (FullLlamaLib) await DownloadAndExtractInsideDirectory(LlamaLibExtensionURL, libraryPath, setupDir);
}
catch (Exception e)
{
diff --git a/Samples~/FunctionCalling/Scene.unity b/Samples~/FunctionCalling/Scene.unity
index bf5d7c27..87290fa3 100644
--- a/Samples~/FunctionCalling/Scene.unity
+++ b/Samples~/FunctionCalling/Scene.unity
@@ -13,7 +13,7 @@ OcclusionCullingSettings:
--- !u!104 &2
RenderSettings:
m_ObjectHideFlags: 0
- serializedVersion: 9
+ serializedVersion: 10
m_Fog: 0
m_FogColor: {r: 0.5, g: 0.5, b: 0.5, a: 1}
m_FogMode: 3
@@ -38,13 +38,12 @@ RenderSettings:
m_ReflectionIntensity: 1
m_CustomReflection: {fileID: 0}
m_Sun: {fileID: 0}
- m_IndirectSpecularColor: {r: 0.44657832, g: 0.49641222, b: 0.57481664, a: 1}
m_UseRadianceAmbientProbe: 0
--- !u!157 &3
LightmapSettings:
m_ObjectHideFlags: 0
- serializedVersion: 12
- m_GIWorkflowMode: 1
+ serializedVersion: 13
+ m_BakeOnSceneLoad: 0
m_GISettings:
serializedVersion: 2
m_BounceScale: 1
@@ -67,9 +66,6 @@ LightmapSettings:
m_LightmapParameters: {fileID: 0}
m_LightmapsBakeMode: 1
m_TextureCompression: 1
- m_FinalGather: 0
- m_FinalGatherFiltering: 1
- m_FinalGatherRayCount: 256
m_ReflectionCompression: 2
m_MixedBakeMode: 2
m_BakeBackend: 1
@@ -602,8 +598,8 @@ MonoBehaviour:
m_OnClick:
m_PersistentCalls:
m_Calls:
- - m_Target: {fileID: 0}
- m_TargetAssemblyTypeName: SimpleInteraction, Assembly-CSharp
+ - m_Target: {fileID: 107963747}
+ m_TargetAssemblyTypeName: LLMUnitySamples.FunctionCalling, Assembly-CSharp
m_MethodName: CancelRequests
m_Mode: 1
m_Arguments:
@@ -613,7 +609,7 @@ MonoBehaviour:
m_FloatArgument: 0
m_StringArgument:
m_BoolArgument: 0
- m_CallState: 2
+ m_CallState: 1
--- !u!114 &724531322
MonoBehaviour:
m_ObjectHideFlags: 0
@@ -940,9 +936,8 @@ Light:
m_PrefabAsset: {fileID: 0}
m_GameObject: {fileID: 909474451}
m_Enabled: 1
- serializedVersion: 10
+ serializedVersion: 11
m_Type: 1
- m_Shape: 0
m_Color: {r: 1, g: 0.95686275, b: 0.8392157, a: 1}
m_Intensity: 1
m_Range: 10
@@ -992,8 +987,12 @@ Light:
m_BoundingSphereOverride: {x: 0, y: 0, z: 0, w: 0}
m_UseBoundingSphereOverride: 0
m_UseViewFrustumForShadowCasterCull: 1
+ m_ForceVisible: 0
m_ShadowRadius: 0
m_ShadowAngle: 0
+ m_LightUnit: 1
+ m_LuxAtDistance: 1
+ m_EnableSpotReflector: 1
--- !u!4 &909474453
Transform:
m_ObjectHideFlags: 0
@@ -1048,7 +1047,6 @@ MonoBehaviour:
dontDestroyOnLoad: 1
contextSize: 8192
batchSize: 512
- basePrompt:
model:
chatTemplate: chatml
lora:
diff --git a/Samples~/MobileDemo/Scene.unity b/Samples~/MobileDemo/Scene.unity
index eb7f637e..c1cc66aa 100644
--- a/Samples~/MobileDemo/Scene.unity
+++ b/Samples~/MobileDemo/Scene.unity
@@ -13,7 +13,7 @@ OcclusionCullingSettings:
--- !u!104 &2
RenderSettings:
m_ObjectHideFlags: 0
- serializedVersion: 9
+ serializedVersion: 10
m_Fog: 0
m_FogColor: {r: 0.5, g: 0.5, b: 0.5, a: 1}
m_FogMode: 3
@@ -38,13 +38,12 @@ RenderSettings:
m_ReflectionIntensity: 1
m_CustomReflection: {fileID: 0}
m_Sun: {fileID: 0}
- m_IndirectSpecularColor: {r: 0.44657832, g: 0.49641222, b: 0.57481664, a: 1}
m_UseRadianceAmbientProbe: 0
--- !u!157 &3
LightmapSettings:
m_ObjectHideFlags: 0
- serializedVersion: 12
- m_GIWorkflowMode: 1
+ serializedVersion: 13
+ m_BakeOnSceneLoad: 0
m_GISettings:
serializedVersion: 2
m_BounceScale: 1
@@ -67,9 +66,6 @@ LightmapSettings:
m_LightmapParameters: {fileID: 0}
m_LightmapsBakeMode: 1
m_TextureCompression: 1
- m_FinalGather: 0
- m_FinalGatherFiltering: 1
- m_FinalGatherRayCount: 256
m_ReflectionCompression: 2
m_MixedBakeMode: 2
m_BakeBackend: 1
@@ -899,7 +895,7 @@ MonoBehaviour:
m_FloatArgument: 0
m_StringArgument:
m_BoolArgument: 0
- m_CallState: 0
+ m_CallState: 1
--- !u!114 &724531322
MonoBehaviour:
m_ObjectHideFlags: 0
@@ -1226,9 +1222,8 @@ Light:
m_PrefabAsset: {fileID: 0}
m_GameObject: {fileID: 909474451}
m_Enabled: 1
- serializedVersion: 10
+ serializedVersion: 11
m_Type: 1
- m_Shape: 0
m_Color: {r: 1, g: 0.95686275, b: 0.8392157, a: 1}
m_Intensity: 1
m_Range: 10
@@ -1278,8 +1273,12 @@ Light:
m_BoundingSphereOverride: {x: 0, y: 0, z: 0, w: 0}
m_UseBoundingSphereOverride: 0
m_UseViewFrustumForShadowCasterCull: 1
+ m_ForceVisible: 0
m_ShadowRadius: 0
m_ShadowAngle: 0
+ m_LightUnit: 1
+ m_LuxAtDistance: 1
+ m_EnableSpotReflector: 1
--- !u!4 &909474453
Transform:
m_ObjectHideFlags: 0
@@ -1334,8 +1333,7 @@ MonoBehaviour:
dontDestroyOnLoad: 1
contextSize: 8192
batchSize: 512
- basePrompt:
- model: qwen2-0_5b-instruct-q4_k_m.gguf
+ model:
chatTemplate: chatml
lora:
loraWeights:
@@ -1346,9 +1344,9 @@ MonoBehaviour:
SSLKey:
SSLKeyPath:
minContextLength: 0
- maxContextLength: 32768
+ maxContextLength: 0
embeddingsOnly: 0
- embeddingLength: 896
+ embeddingLength: 0
--- !u!4 &1047848255
Transform:
m_ObjectHideFlags: 0
diff --git a/Tests/Runtime/TestLLM.cs b/Tests/Runtime/TestLLM.cs
index f2a4813d..a54572e3 100644
--- a/Tests/Runtime/TestLLM.cs
+++ b/Tests/Runtime/TestLLM.cs
@@ -7,9 +7,41 @@
using System.Collections;
using System.IO;
using UnityEngine.TestTools;
+using UnityEditor;
+using UnityEditor.TestTools.TestRunner.Api;
namespace LLMUnityTests
{
+ [InitializeOnLoad]
+ public static class TestRunListener
+ {
+ static TestRunListener()
+ {
+ var api = ScriptableObject.CreateInstance();
+ api.RegisterCallbacks(new TestRunCallbacks());
+ }
+ }
+
+ public class TestRunCallbacks : ICallbacks
+ {
+ public void RunStarted(ITestAdaptor testsToRun) { }
+
+ public void RunFinished(ITestResultAdaptor result)
+ {
+ LLMUnitySetup.FullLlamaLib = false;
+ }
+
+ public void TestStarted(ITestAdaptor test)
+ {
+ LLMUnitySetup.FullLlamaLib = test.FullName.Contains("CUDA_full");
+ }
+
+ public void TestFinished(ITestResultAdaptor result)
+ {
+ LLMUnitySetup.FullLlamaLib = false;
+ }
+ }
+
public class TestLLMLoraAssignment
{
[Test]
@@ -231,7 +263,7 @@ public async Task RunTestsTask()
await Tests();
llm.OnDestroy();
}
- catch (Exception e)
+ catch (Exception e)
{
error = e;
}
@@ -268,7 +300,7 @@ public void TestInitParameters(int nkeep, int chats)
public void TestTokens(List tokens)
{
- Assert.AreEqual(tokens, new List {40});
+ Assert.AreEqual(tokens, new List { 40 });
}
public void TestWarmup()
@@ -292,7 +324,7 @@ public void TestEmbeddings(List embeddings)
Assert.That(embeddings.Count == 896);
}
- public virtual void OnDestroy() {}
+ public virtual void OnDestroy() { }
}
public class TestLLM_LLMManager_Load : TestLLM
@@ -459,6 +491,10 @@ public override LLMCharacter CreateLLMCharacter()
LLMCharacter llmCharacter = base.CreateLLMCharacter();
llmCharacter.save = saveName;
llmCharacter.saveCache = true;
+ foreach (string filename in new string[]{
+ llmCharacter.GetJsonSavePath(saveName),
+ llmCharacter.GetCacheSavePath(saveName)
+ }) if (File.Exists(filename)) File.Delete(filename);
return llmCharacter;
}
@@ -492,4 +528,51 @@ public void TestSave()
}
}
}
+
+ public class TestLLM_CUDA : TestLLM
+ {
+ public override LLM CreateLLM()
+ {
+ LLM llm = base.CreateLLM();
+ llm.numGPULayers = 10;
+ return llm;
+ }
+ }
+
+ public class TestLLM_CUDA_full : TestLLM_CUDA
+ {
+ public override void SetParameters()
+ {
+ base.SetParameters();
+ if (Application.platform == RuntimePlatform.WindowsEditor || Application.platform == RuntimePlatform.WindowsPlayer)
+ {
+ reply1 = "To increase your meme production output, you might consider using more modern tools and techniques to generate memes.";
+ reply2 = "To increase your meme production output, you can try using various tools and techniques to generate more content quickly";
+ }
+ else
+ {
+ reply1 = "To increase your meme production output, you might consider using more advanced tools and techniques to generate memes faster";
+ reply2 = "To increase your meme production output, you might consider using more advanced tools and techniques to generate memes faster";
+ }
+ }
+ }
+
+ public class TestLLM_CUDA_full_attention : TestLLM_CUDA_full
+ {
+ public override LLM CreateLLM()
+ {
+ LLM llm = base.CreateLLM();
+ llm.flashAttention = true;
+ return llm;
+ }
+
+ public override void SetParameters()
+ {
+ base.SetParameters();
+ if (Application.platform == RuntimePlatform.LinuxEditor || Application.platform == RuntimePlatform.LinuxPlayer)
+ {
+ reply2 = "To increase your meme production output, you can try using various tools and techniques to generate more memes.";
+ }
+ }
+ }
}
diff --git a/VERSION b/VERSION
index 8721bbc4..a3721209 100644
--- a/VERSION
+++ b/VERSION
@@ -1 +1 @@
-v2.4.0
+v2.4.1
diff --git a/package.json b/package.json
index e3775816..396eddc3 100644
--- a/package.json
+++ b/package.json
@@ -1,6 +1,6 @@
{
"name": "ai.undream.llm",
- "version": "2.4.0",
+ "version": "2.4.1",
"displayName": "LLM for Unity",
"description": "LLM for Unity allows to run and distribute Large Language Models (LLMs) in the Unity engine.",
"unity": "2022.3",