Skip to content

Commit

Permalink
Merge pull request #25 from jepler/gpt-4-turbo
Browse files Browse the repository at this point in the history
Add gpt-4-1106-preview (gpt-4-turbo) to model list
  • Loading branch information
jepler authored Nov 9, 2023
2 parents 4454071 + 08b221d commit 03e0adf
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 29 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

.PHONY: mypy
mypy: venv/bin/mypy
venv/bin/mypy --strict -p chap
venv/bin/mypy --strict --no-warn-unused-ignores -p chap

venv/bin/mypy:
python -mvenv venv
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,5 @@ write_to = "src/chap/__version__.py"
[tool.setuptools.dynamic]
readme = {file = ["README.md"], content-type="text/markdown"}
dependencies = {file = "requirements.txt"}
[tool.setuptools.package-data]
"pkgname" = ["py.typed"]
45 changes: 19 additions & 26 deletions src/chap/backends/openai_chatgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import functools
import json
import warnings
from dataclasses import dataclass
from typing import AsyncGenerator, cast

Expand All @@ -20,54 +21,46 @@ class EncodingMeta:
encoding: tiktoken.Encoding
tokens_per_message: int
tokens_per_name: int
tokens_overhead: int

@functools.lru_cache()
def encode(self, s: str) -> list[int]:
return self.encoding.encode(s)

def num_tokens_for_message(self, message: Message) -> int:
# n.b. chap doesn't use message.name yet
return len(self.encode(message.role)) + len(self.encode(message.content))
return (
len(self.encode(message.role))
+ len(self.encode(message.content))
+ self.tokens_per_message
)

def num_tokens_for_messages(self, messages: Session) -> int:
return sum(self.num_tokens_for_message(message) for message in messages) + 3
return (
sum(self.num_tokens_for_message(message) for message in messages)
+ self.tokens_overhead
)

@classmethod
@functools.cache
def from_model(cls, model: str) -> "EncodingMeta":
if model == "gpt-3.5-turbo":
# print("Warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.")
model = "gpt-3.5-turbo-0613"
if model == "gpt-4":
# print("Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.")
model = "gpt-4-0613"

try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
print("Warning: model not found. Using cl100k_base encoding.")
warnings.warn("Warning: model not found. Using cl100k_base encoding.")
encoding = tiktoken.get_encoding("cl100k_base")

if model in {
"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k-0613",
"gpt-4-0314",
"gpt-4-32k-0314",
"gpt-4-0613",
"gpt-4-32k-0613",
}:
tokens_per_message = 3
tokens_per_name = 1
elif model == "gpt-3.5-turbo-0301":
tokens_per_message = 3
tokens_per_name = 1
tokens_overhead = 3

if model == "gpt-3.5-turbo-0301":
tokens_per_message = (
4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
)
tokens_per_name = -1 # if there's a name, the role is omitted
else:
raise NotImplementedError(
f"""EncodingMeta is not implemented for model {model}. See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens."""
)
return cls(encoding, tokens_per_message, tokens_per_name)

return cls(encoding, tokens_per_message, tokens_per_name, tokens_overhead)


class ChatGPT:
Expand Down
11 changes: 9 additions & 2 deletions src/chap/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import pkgutil
import subprocess
from dataclasses import MISSING, dataclass, fields
from types import UnionType
from typing import Any, AsyncGenerator, Callable, cast

import click
Expand Down Expand Up @@ -171,7 +172,10 @@ def format_backend_help(api: Backend, formatter: click.HelpFormatter) -> None:
if doc:
doc += " "
doc += f"(Default: {default!r})"
typename = f.type.__name__
f_type = f.type
if isinstance(f_type, UnionType):
f_type = f_type.__args__[0]
typename = f_type.__name__
rows.append((f"-B {name}:{typename.upper()}", doc))
formatter.write_dl(rows)

Expand All @@ -191,8 +195,11 @@ def set_one_backend_option(kv: tuple[str, str]) -> None:
field = all_fields.get(name)
if field is None:
raise click.BadParameter(f"Invalid parameter {name}")
f_type = field.type
if isinstance(f_type, UnionType):
f_type = f_type.__args__[0]
try:
tv = field.type(value)
tv = f_type(value)
except ValueError as e:
raise click.BadParameter(
f"Invalid value for {name} with value {value}: {e}"
Expand Down
Empty file added src/chap/py.typed
Empty file.

0 comments on commit 03e0adf

Please sign in to comment.