Skip to content

Commit

Permalink
Allow remote Ollama
Browse files Browse the repository at this point in the history
  • Loading branch information
ccreutzi committed Jul 22, 2024
1 parent 4f65c2e commit 7c3268f
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 7 deletions.
6 changes: 5 additions & 1 deletion +llms/+internal/callOllamaChatAPI.m
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,13 @@
nvp.Seed
nvp.TimeOut
nvp.StreamFun
nvp.Endpoint
end

URL = "http://localhost:11434/api/chat";
URL = nvp.Endpoint + "/api/chat";
if ~startsWith(URL,"http")
URL = "http://" + URL;
end

% The JSON for StopSequences must have an array, and cannot say "stop": "foo".
% The easiest way to ensure that is to never pass in a scalar …
Expand Down
7 changes: 6 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,23 @@ jobs:
run: |
# Run the background, there is no way to daemonise at the moment
ollama serve &
# Run a second server to test different endpoint
OLLAMA_HOST=127.0.0.1:11435 OLLAMA_MODELS=/tmp/ollama/models ollama serve &
# A short pause is required before the HTTP port is opened
sleep 5
# This endpoint blocks until ready
time curl -i http://localhost:11434
time curl -i http://localhost:11435
# For debugging, record Ollama version
ollama --version
- name: Pull mistral model
- name: Pull models
run: |
ollama pull mistral
OLLAMA_HOST=127.0.0.1:11435 ollama pull qwen2:0.5b
- name: Set up MATLAB
uses: matlab-actions/setup-matlab@v2
with:
Expand All @@ -39,6 +43,7 @@ jobs:
AZURE_OPENAI_DEPLOYMENT: ${{ secrets.AZURE_DEPLOYMENT }}
AZURE_OPENAI_ENDPOINT: ${{ secrets.AZURE_ENDPOINT }}
AZURE_OPENAI_API_KEY: ${{ secrets.AZURE_KEY }}
SECOND_OLLAMA_ENDPOINT: 127.0.0.1:11435
uses: matlab-actions/run-tests@v2
with:
test-results-junit: test-results/results.xml
Expand Down
7 changes: 7 additions & 0 deletions doc/Ollama.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,10 @@ chat = ollamaChat("mistral", StreamFun=sf);
txt = generate(chat,"What is Model-Based Design and how is it related to Digital Twin?");
% Should stream the response token by token
```

## Establishing a connection to remote LLMs using Ollama

To connect to a remote Ollama server, use the `Endpoint` parameter. Include the server name and port number (Ollama starts on 11434 by default):
```matlab
chat = ollamaChat("mistral",Endpoint="ollamaServer:11434");
```
10 changes: 7 additions & 3 deletions ollamaChat.m
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,9 @@
% Copyright 2024 The MathWorks, Inc.

properties
Model (1,1) string
TopK (1,1) {mustBeReal,mustBePositive} = Inf
Model (1,1) string
Endpoint (1,1) string
TopK (1,1) {mustBeReal,mustBePositive} = Inf
TailFreeSamplingZ (1,1) {mustBeReal} = 1
end

Expand All @@ -82,6 +83,7 @@
nvp.TimeOut (1,1) {mustBeReal,mustBePositive} = 120
nvp.TailFreeSamplingZ (1,1) {mustBeReal} = 1
nvp.StreamFun (1,1) {mustBeA(nvp.StreamFun,'function_handle')}
nvp.Endpoint (1,1) string = "127.0.0.1:11434"
end

if isfield(nvp,"StreamFun")
Expand All @@ -105,6 +107,7 @@
this.TailFreeSamplingZ = nvp.TailFreeSamplingZ;
this.StopSequences = nvp.StopSequences;
this.TimeOut = nvp.TimeOut;
this.Endpoint = nvp.Endpoint;
end

function [text, message, response] = generate(this, messages, nvp)
Expand Down Expand Up @@ -147,7 +150,8 @@
TailFreeSamplingZ=this.TailFreeSamplingZ,...
StopSequences=this.StopSequences, MaxNumTokens=nvp.MaxNumTokens, ...
ResponseFormat=this.ResponseFormat,Seed=nvp.Seed, ...
TimeOut=this.TimeOut, StreamFun=this.StreamFun);
TimeOut=this.TimeOut, StreamFun=this.StreamFun, ...
Endpoint=this.Endpoint);

if isfield(response.Body.Data,"error")
err = response.Body.Data.error;
Expand Down
13 changes: 11 additions & 2 deletions tests/tollamaChat.m
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ function seedFixesResult(testCase)
testCase.verifyEqual(response1,response2);
end


function streamFunc(testCase)
function seen = sf(str)
persistent data;
Expand All @@ -118,6 +117,17 @@ function streamFunc(testCase)
testCase.verifyGreaterThan(numel(sf("")), 1);
end

function reactToEndpoint(testCase)
testCase.assumeTrue(isenv("SECOND_OLLAMA_ENDPOINT"),...
"Test point assumes a second Ollama server is running " + ...
"and $SECOND_OLLAMA_ENDPOINT points to it.");
chat = ollamaChat("qwen2:0.5b",Endpoint=getenv("SECOND_OLLAMA_ENDPOINT"));
testCase.verifyWarningFree(@() generate(chat,"dummy"));
% also make sure "http://" can be included
chat = ollamaChat("qwen2:0.5b",Endpoint="http://" + getenv("SECOND_OLLAMA_ENDPOINT"));
testCase.verifyWarningFree(@() generate(chat,"dummy"));
end

function doReturnErrors(testCase)
testCase.assumeFalse( ...
any(startsWith(ollamaChat.models,"abcdefghijklmnop")), ...
Expand All @@ -126,7 +136,6 @@ function doReturnErrors(testCase)
testCase.verifyError(@() generate(chat,"hi!"), "llms:apiReturnedError");
end


function invalidInputsConstructor(testCase, InvalidConstructorInput)
testCase.verifyError(@() ollamaChat("mistral", InvalidConstructorInput.Input{:}), InvalidConstructorInput.Error);
end
Expand Down

0 comments on commit 7c3268f

Please sign in to comment.