diff --git a/Makefile b/Makefile index 94b117d..ae5bd44 100644 --- a/Makefile +++ b/Makefile @@ -4,7 +4,7 @@ .PHONY: mypy mypy: venv/bin/mypy - venv/bin/mypy --strict -p chap + venv/bin/mypy --strict --no-warn-unused-ignores -p chap venv/bin/mypy: python -mvenv venv diff --git a/pyproject.toml b/pyproject.toml index 2bb0e55..6311380 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,3 +42,5 @@ write_to = "src/chap/__version__.py" [tool.setuptools.dynamic] readme = {file = ["README.md"], content-type="text/markdown"} dependencies = {file = "requirements.txt"} +[tool.setuptools.package-data] +"pkgname" = ["py.typed"] diff --git a/src/chap/backends/openai_chatgpt.py b/src/chap/backends/openai_chatgpt.py index a6d8fc7..6dd6e5a 100644 --- a/src/chap/backends/openai_chatgpt.py +++ b/src/chap/backends/openai_chatgpt.py @@ -4,6 +4,7 @@ import functools import json +import warnings from dataclasses import dataclass from typing import AsyncGenerator, cast @@ -20,6 +21,7 @@ class EncodingMeta: encoding: tiktoken.Encoding tokens_per_message: int tokens_per_name: int + tokens_overhead: int @functools.lru_cache() def encode(self, s: str) -> list[int]: @@ -27,47 +29,38 @@ def encode(self, s: str) -> list[int]: def num_tokens_for_message(self, message: Message) -> int: # n.b. chap doesn't use message.name yet - return len(self.encode(message.role)) + len(self.encode(message.content)) + return ( + len(self.encode(message.role)) + + len(self.encode(message.content)) + + self.tokens_per_message + ) def num_tokens_for_messages(self, messages: Session) -> int: - return sum(self.num_tokens_for_message(message) for message in messages) + 3 + return ( + sum(self.num_tokens_for_message(message) for message in messages) + + self.tokens_overhead + ) @classmethod @functools.cache def from_model(cls, model: str) -> "EncodingMeta": - if model == "gpt-3.5-turbo": - # print("Warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.") - model = "gpt-3.5-turbo-0613" - if model == "gpt-4": - # print("Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.") - model = "gpt-4-0613" - try: encoding = tiktoken.encoding_for_model(model) except KeyError: - print("Warning: model not found. Using cl100k_base encoding.") + warnings.warn("Warning: model not found. Using cl100k_base encoding.") encoding = tiktoken.get_encoding("cl100k_base") - if model in { - "gpt-3.5-turbo-0613", - "gpt-3.5-turbo-16k-0613", - "gpt-4-0314", - "gpt-4-32k-0314", - "gpt-4-0613", - "gpt-4-32k-0613", - }: - tokens_per_message = 3 - tokens_per_name = 1 - elif model == "gpt-3.5-turbo-0301": + tokens_per_message = 3 + tokens_per_name = 1 + tokens_overhead = 3 + + if model == "gpt-3.5-turbo-0301": tokens_per_message = ( 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n ) tokens_per_name = -1 # if there's a name, the role is omitted - else: - raise NotImplementedError( - f"""EncodingMeta is not implemented for model {model}. See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens.""" - ) - return cls(encoding, tokens_per_message, tokens_per_name) + + return cls(encoding, tokens_per_message, tokens_per_name, tokens_overhead) class ChatGPT: diff --git a/src/chap/core.py b/src/chap/core.py index 5f4e6a3..df05cf4 100644 --- a/src/chap/core.py +++ b/src/chap/core.py @@ -11,6 +11,7 @@ import pkgutil import subprocess from dataclasses import MISSING, dataclass, fields +from types import UnionType from typing import Any, AsyncGenerator, Callable, cast import click @@ -171,7 +172,10 @@ def format_backend_help(api: Backend, formatter: click.HelpFormatter) -> None: if doc: doc += " " doc += f"(Default: {default!r})" - typename = f.type.__name__ + f_type = f.type + if isinstance(f_type, UnionType): + f_type = f_type.__args__[0] + typename = f_type.__name__ rows.append((f"-B {name}:{typename.upper()}", doc)) formatter.write_dl(rows) @@ -191,8 +195,11 @@ def set_one_backend_option(kv: tuple[str, str]) -> None: field = all_fields.get(name) if field is None: raise click.BadParameter(f"Invalid parameter {name}") + f_type = field.type + if isinstance(f_type, UnionType): + f_type = f_type.__args__[0] try: - tv = field.type(value) + tv = f_type(value) except ValueError as e: raise click.BadParameter( f"Invalid value for {name} with value {value}: {e}" diff --git a/src/chap/py.typed b/src/chap/py.typed new file mode 100644 index 0000000..e69de29