Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Return token_strings in Client.tokenize() when offline=True #494

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/cohere/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,8 @@ def tokenize(
if offline:
try:
tokens = local_tokenizers.local_tokenize(self, text=text, model=model)
return TokenizeResponse(tokens=tokens, token_strings=[])
token_strings = local_tokenizers.local_tokens_to_token_strings(self, tokens=tokens, model=model)
return TokenizeResponse(tokens=tokens, token_strings=token_strings)
except Exception:
# Fallback to calling the API.
opts["additional_headers"] = opts.get("additional_headers", {})
Expand Down Expand Up @@ -392,7 +393,8 @@ async def tokenize(
if offline:
try:
tokens = await local_tokenizers.async_local_tokenize(self, model=model, text=text)
return TokenizeResponse(tokens=tokens, token_strings=[])
token_strings = await local_tokenizers.async_local_tokens_to_token_strings(self, tokens=tokens, model=model)
return TokenizeResponse(tokens=tokens, token_strings=token_strings)
except Exception:
opts["additional_headers"] = opts.get("additional_headers", {})
opts["additional_headers"]["sdk-api-warning-message"] = "offline_tokenizer_failed"
Expand Down
12 changes: 12 additions & 0 deletions src/cohere/manually_maintained/tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ def local_tokenize(co: "Client", model: str, text: str) -> typing.List[int]:
return tokenizer.encode(text, add_special_tokens=False).ids


def local_tokens_to_token_strings(co: "Client", model: str, tokens: typing.Sequence[int]) -> typing.List[str]:
"""Decodes a list of token ints to a list of token strings."""
tokenizer = get_hf_tokenizer(co, model)
return tokenizer.decode_batch([[token] for token in tokens], skip_special_tokens=False)


def local_detokenize(co: "Client", model: str, tokens: typing.Sequence[int]) -> str:
"""Decodes a given list of tokens using a local tokenizer."""
tokenizer = get_hf_tokenizer(co, model)
Expand Down Expand Up @@ -73,6 +79,12 @@ async def async_local_tokenize(co: "AsyncClient", model: str, text: str) -> typi
return tokenizer.encode(text, add_special_tokens=False).ids


async def async_local_tokens_to_token_strings(co: "Client", model: str, tokens: typing.Sequence[int]) -> typing.List[str]:
"""Decodes a list of token ints to a list of token strings."""
tokenizer = await async_get_hf_tokenizer(co, model)
return tokenizer.decode_batch([[token] for token in tokens], skip_special_tokens=False)


async def async_local_detokenize(co: "AsyncClient", model: str, tokens: typing.Sequence[int]) -> str:
"""Decodes a given list of tokens using a local tokenizer."""
tokenizer = await async_get_hf_tokenizer(co, model)
Expand Down
3 changes: 2 additions & 1 deletion tests/test_async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,8 @@ async def test_local_tokenize(self) -> None:
model="command",
text="tokenize me! :D"
)
print(response)
self.assertEqual(response.tokens, [10002, 2261, 2012, 8, 2792, 43])
self.assertEqual(response.token_strings, ["token", "ize", " me", "!", " :", "D"])

async def test_local_detokenize(self) -> None:
response = await self.co.detokenize(
Expand Down
3 changes: 2 additions & 1 deletion tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,8 @@ def test_local_tokenize(self) -> None:
model="command",
text="tokenize me! :D"
)
print(response)
self.assertEqual(response.tokens, [10002, 2261, 2012, 8, 2792, 43])
self.assertEqual(response.token_strings, ["token", "ize", " me", "!", " :", "D"])

def test_local_detokenize(self) -> None:
response = co.detokenize(
Expand Down