Skip to content

Commit

Permalink
Performance, reliably this time
Browse files Browse the repository at this point in the history
  • Loading branch information
R0bk committed Apr 22, 2024
1 parent eca19be commit 69ff066
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 34 deletions.
51 changes: 20 additions & 31 deletions src/semchunk/semchunk.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import re
from bisect import bisect_left
from functools import cache, wraps
from itertools import accumulate

_memoised_token_counters = {}
"""A map of token counters to their memoised versions."""
Expand Down Expand Up @@ -114,41 +116,31 @@ def chunk_legacy(text: str, chunk_size: int, token_counter: callable, memoize: b

return chunks

def count(text: str, max_size: int, counter: callable) -> int:
"""Counts the number of tokens in a text, with a heuristic to accelerate long texts"""
heuritistic = 6*max_size

# There is a rare failure case for the below heuristic where superfluous tokens
# may be added from a longer, existing token being split before it was finished.
# e.g. Australia -> 1 token
# Australi -> 3 token
#
# We mitigate this failure case by adding the len(longest token)-1 such that
# any ongoing token will be able to finish
#
# Using the cl100k tokenset, the length of the longest non-symbol token is 42
# See: https://gist.github.com/Yardanico/623b3092d0b707119f8c7d90a3596afe
max_token = 42 - 1

if len(text) > heuritistic and counter(text[:heuritistic+max_token]) > max_size:
return max_size+1
return counter(text)

def find_split(splits: list[str], max_size: int, splitter: str, counter: callable) -> tuple[int, str]:
"""Binary search for the optimal split point where the accumulated_token_count < max_size."""
low, high = 0, len(splits) + 1

# Start avg low for fast calc of first real avg
avg, low, high = 0.2, 0, len(splits) + 1
sums = list(accumulate(map(len, splits), initial=0))
sums.append(sums[-1])

while low < high:
# As the main performance hit comes from running the token_counter on long texts
# we can bias the binary search to favour guessing towards shorter sequences.
# This is done below by using > 2 as the divisor
mid = low + (high - low) // 8
if count(splitter.join(splits[:mid]), max_size, counter) > max_size:
idx = bisect_left(sums[low:high + 1], max_size * avg)
mid = min(idx + low, high - 1)

tokens = counter(splitter.join(splits[:mid]))

avg = sums[mid]/tokens if sums[mid] else avg

if tokens > max_size:
high = mid
else:
low = mid + 1

return low-1, splitter.join(splits[:low-1])


def chunk(text: str, chunk_size: int, token_counter: callable, memoize: bool=True, _recursion_depth: int = 0) -> list[str]:
"""Split text into semantically meaningful chunks of a specified size as determined by the provided token counter.
Expand Down Expand Up @@ -179,7 +171,7 @@ def chunk(text: str, chunk_size: int, token_counter: callable, memoize: bool=Tru
continue

# If the split is over the chunk size, recursively chunk it.
if count(split, chunk_size, token_counter) > chunk_size:
if token_counter(split) > chunk_size:
chunks.extend(chunk(split, chunk_size, token_counter, memoize, _recursion_depth+1))

# If the split is equal to or under the chunk size, merge it with all subsequent splits until the chunk size is reached.
Expand All @@ -195,11 +187,8 @@ def chunk(text: str, chunk_size: int, token_counter: callable, memoize: bool=Tru

# If the splitter is not whitespace and the split is not the last split, add the splitter to the end of the last chunk if doing so would not cause it to exceed the chunk size otherwise add the splitter as a new chunk.
if not splitter_is_whitespace and not (i == len(splits) - 1 or all(j in skips for j in range(i+1, len(splits)))):
# We seperately add tokens(prior chunk) and tokens(splitter) to ensure O(1) - (both will be in cache).
# There is a failure case where tokens(get_last_token(prior_chunk) + splitter) == 1 however this is
# quite uncommon and leads to a negligible impact
if token_counter(chunks[-1]) + token_counter(splitter) <= chunk_size:
chunks[-1] += splitter
if token_counter(last_chunk_with_splitter:=chunks[-1]+splitter) <= chunk_size:
chunks[-1] = last_chunk_with_splitter
else:
chunks.append(splitter)

Expand Down
10 changes: 7 additions & 3 deletions tests/bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

import semchunk

chunk_sizes = [8,16,32,64,128,256,512,1024]
chunk_sizes = [8,16,32,64,128,256,512,1024,2048,4096,8192]

semantic_text_splitter_chunker = TextSplitter.from_tiktoken_model('gpt-4')

encoder = tiktoken.encoding_for_model('gpt-4')
Expand All @@ -31,7 +32,7 @@ def bench_semantic_text_splitter(text: str, chunk_size: int) -> None:
libraries = {
'semchunk': bench_semchunkv1,
'semchunkv2': bench_semchunkv2,
'semantic_text_splitter': bench_semantic_text_splitter,
# 'semantic_text_splitter': bench_semantic_text_splitter,
}

def bench() -> dict[str, float]:
Expand All @@ -41,10 +42,13 @@ def bench() -> dict[str, float]:
semchunk.semchunk._memoised_token_counters = {}
for fileid in test_semchunk.gutenberg.fileids():
sample = test_semchunk.gutenberg.raw(fileid)
results = []
for library, function in libraries.items():
start = time.time()
function(sample, chunk_size)
results.append(function(sample, chunk_size))
benchmarks[library][i] += time.time() - start
if len(results) > 1:
assert results[-1] == results[-2]

return benchmarks

Expand Down

0 comments on commit 69ff066

Please sign in to comment.