diff --git a/src/cohere/client.py b/src/cohere/client.py index 87cd2a159..150295393 100644 --- a/src/cohere/client.py +++ b/src/cohere/client.py @@ -212,7 +212,8 @@ def tokenize( if offline: try: tokens = local_tokenizers.local_tokenize(self, text=text, model=model) - return TokenizeResponse(tokens=tokens, token_strings=[]) + token_strings = local_tokenizers.local_tokens_to_token_strings(self, tokens=tokens, model=model) + return TokenizeResponse(tokens=tokens, token_strings=token_strings) except Exception: # Fallback to calling the API. opts["additional_headers"] = opts.get("additional_headers", {}) @@ -392,7 +393,8 @@ async def tokenize( if offline: try: tokens = await local_tokenizers.async_local_tokenize(self, model=model, text=text) - return TokenizeResponse(tokens=tokens, token_strings=[]) + token_strings = await local_tokenizers.async_local_tokens_to_token_strings(self, tokens=tokens, model=model) + return TokenizeResponse(tokens=tokens, token_strings=token_strings) except Exception: opts["additional_headers"] = opts.get("additional_headers", {}) opts["additional_headers"]["sdk-api-warning-message"] = "offline_tokenizer_failed" diff --git a/src/cohere/manually_maintained/tokenizers.py b/src/cohere/manually_maintained/tokenizers.py index 61e4e5e6c..8766078a6 100644 --- a/src/cohere/manually_maintained/tokenizers.py +++ b/src/cohere/manually_maintained/tokenizers.py @@ -41,6 +41,12 @@ def local_tokenize(co: "Client", model: str, text: str) -> typing.List[int]: return tokenizer.encode(text, add_special_tokens=False).ids +def local_tokens_to_token_strings(co: "Client", model: str, tokens: typing.Sequence[int]) -> typing.List[str]: + """Decodes a list of token ints to a list of token strings.""" + tokenizer = get_hf_tokenizer(co, model) + return tokenizer.decode_batch([[token] for token in tokens], skip_special_tokens=False) + + def local_detokenize(co: "Client", model: str, tokens: typing.Sequence[int]) -> str: """Decodes a given list of tokens using a local tokenizer.""" tokenizer = get_hf_tokenizer(co, model) @@ -73,6 +79,12 @@ async def async_local_tokenize(co: "AsyncClient", model: str, text: str) -> typi return tokenizer.encode(text, add_special_tokens=False).ids +async def async_local_tokens_to_token_strings(co: "Client", model: str, tokens: typing.Sequence[int]) -> typing.List[str]: + """Decodes a list of token ints to a list of token strings.""" + tokenizer = await async_get_hf_tokenizer(co, model) + return tokenizer.decode_batch([[token] for token in tokens], skip_special_tokens=False) + + async def async_local_detokenize(co: "AsyncClient", model: str, tokens: typing.Sequence[int]) -> str: """Decodes a given list of tokens using a local tokenizer.""" tokenizer = await async_get_hf_tokenizer(co, model) diff --git a/tests/test_async_client.py b/tests/test_async_client.py index a4cec71b8..7b92f92e9 100644 --- a/tests/test_async_client.py +++ b/tests/test_async_client.py @@ -383,7 +383,8 @@ async def test_local_tokenize(self) -> None: model="command", text="tokenize me! :D" ) - print(response) + self.assertEqual(response.tokens, [10002, 2261, 2012, 8, 2792, 43]) + self.assertEqual(response.token_strings, ["token", "ize", " me", "!", " :", "D"]) async def test_local_detokenize(self) -> None: response = await self.co.detokenize( diff --git a/tests/test_client.py b/tests/test_client.py index 05f2b0911..1bd60b2d3 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -397,7 +397,8 @@ def test_local_tokenize(self) -> None: model="command", text="tokenize me! :D" ) - print(response) + self.assertEqual(response.tokens, [10002, 2261, 2012, 8, 2792, 43]) + self.assertEqual(response.token_strings, ["token", "ize", " me", "!", " :", "D"]) def test_local_detokenize(self) -> None: response = co.detokenize(