Skip to content

Commit

Permalink
Bump default gpt-3.5-turbo-instruct max tokens to 256, refs #284
Browse files Browse the repository at this point in the history
  • Loading branch information
simonw committed Sep 19, 2023
1 parent 4d46eba commit 4d18da4
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 1 deletion.
10 changes: 9 additions & 1 deletion llm/default_plugins/openai_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def register_models(register):
register(Chat("gpt-4"), aliases=("4", "gpt4"))
register(Chat("gpt-4-32k"), aliases=("4-32k",))
register(
Completion("gpt-3.5-turbo-instruct"),
Completion("gpt-3.5-turbo-instruct", default_max_tokens=256),
aliases=("3.5-instruct", "chatgpt-instruct"),
)
# Load extra models
Expand Down Expand Up @@ -126,6 +126,8 @@ class Chat(Model):
key_env_var = "OPENAI_API_KEY"
can_stream: bool = True

default_max_tokens = None

class Options(llm.Options):
temperature: Optional[float] = Field(
description=(
Expand Down Expand Up @@ -280,6 +282,8 @@ def execute(self, prompt, stream, response, conversation=None):

def build_kwargs(self, prompt):
kwargs = dict(not_nulls(prompt.options))
if "max_tokens" not in kwargs and self.default_max_tokens is not None:
kwargs["max_tokens"] = self.default_max_tokens
if self.api_base:
kwargs["api_base"] = self.api_base
if self.api_type:
Expand All @@ -301,6 +305,10 @@ def build_kwargs(self, prompt):


class Completion(Chat):
def __init__(self, *args, default_max_tokens=None, **kwargs):
super().__init__(*args, **kwargs)
self.default_max_tokens = default_max_tokens

def __str__(self):
return "OpenAI Completion: {}".format(self.model_id)

Expand Down
9 changes: 9 additions & 0 deletions tests/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,15 @@ def test_openai_completion(mocked_openai_completion, user_path):
)
assert result.exit_code == 0
assert result.output == "\n\nThis is indeed a test\n"

# Should have requested 256 tokens
assert json.loads(mocked_openai_completion.last_request.text) == {
"model": "gpt-3.5-turbo-instruct",
"prompt": "Say this is a test",
"stream": False,
"max_tokens": 256,
}

# Check it was logged
rows = list(log_db["responses"].rows)
assert len(rows) == 1
Expand Down

0 comments on commit 4d18da4

Please sign in to comment.