Skip to content

Commit

Permalink
completion: true to register completion models, refs #284
Browse files Browse the repository at this point in the history
  • Loading branch information
simonw committed Sep 19, 2023
1 parent 9c7792d commit fcff36c
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 2 deletions.
4 changes: 4 additions & 0 deletions docs/other-models.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ Let's say OpenAI have just released the `gpt-3.5-turbo-0613` model and you want
```
The `model_id` is the identifier that will be recorded in the LLM logs. You can use this to specify the model, or you can optionally include a list of aliases for that model.

If the model is a completion model (such as `gpt-3.5-turbo-instruct`) add `completion: true` to the configuration.

With this configuration in place, the following command should run a prompt against the new model:

```bash
Expand Down Expand Up @@ -87,6 +89,8 @@ If the `api_base` is set, the existing configured `openai` API key will not be s

You can set `api_key_name` to the name of a key stored using the {ref}`api-keys` feature.

Add `completion: true` if the model is a completion model that uses a `/completion` as opposed to a `/completion/chat` endpoint.

Having configured the model like this, run `llm models` to check that it installed correctly. You can then run prompts against it like so:

```bash
Expand Down
6 changes: 5 additions & 1 deletion llm/default_plugins/openai_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,11 @@ def register_models(register):
api_version = extra_model.get("api_version")
api_engine = extra_model.get("api_engine")
headers = extra_model.get("headers")
chat_model = Chat(
if extra_model.get("completion"):
klass = Completion
else:
klass = Chat
chat_model = klass(
model_id,
model_name=model_name,
api_base=api_base,
Expand Down
12 changes: 11 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ def mocked_openai_completion_logprobs(requests_mock):

@pytest.fixture
def mocked_localai(requests_mock):
return requests_mock.post(
requests_mock.post(
"http://localai.localhost/chat/completions",
json={
"model": "orca",
Expand All @@ -327,6 +327,16 @@ def mocked_localai(requests_mock):
},
headers={"Content-Type": "application/json"},
)
requests_mock.post(
"http://localai.localhost/completions",
json={
"model": "completion-babbage",
"usage": {},
"choices": [{"text": "Hello"}],
},
headers={"Content-Type": "application/json"},
)
return requests_mock


@pytest.fixture
Expand Down
13 changes: 13 additions & 0 deletions tests/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,10 @@ def test_openai_completion_logprobs_nostream(
- model_id: orca
model_name: orca-mini-3b
api_base: "http://localai.localhost"
- model_id: completion-babbage
model_name: babbage
api_base: "http://localai.localhost"
completion: 1
"""


Expand All @@ -458,6 +462,15 @@ def test_openai_localai_configuration(mocked_localai, user_path):
"messages": [{"role": "user", "content": "three names \nfor a pet pelican"}],
"stream": False,
}
# And check the completion model too
result2 = runner.invoke(cli, ["--no-stream", "--model", "completion-babbage", "hi"])
assert result2.exit_code == 0
assert result2.output == "Hello\n"
assert json.loads(mocked_localai.last_request.text) == {
"model": "babbage",
"prompt": "hi",
"stream": False,
}


EXPECTED_OPTIONS = """
Expand Down

0 comments on commit fcff36c

Please sign in to comment.