diff --git a/src/semchunk/semchunk.py b/src/semchunk/semchunk.py index 291122e..914e92c 100644 --- a/src/semchunk/semchunk.py +++ b/src/semchunk/semchunk.py @@ -1,5 +1,7 @@ import re +from bisect import bisect_left from functools import cache, wraps +from itertools import accumulate _memoised_token_counters = {} """A map of token counters to their memoised versions.""" @@ -114,41 +116,31 @@ def chunk_legacy(text: str, chunk_size: int, token_counter: callable, memoize: b return chunks -def count(text: str, max_size: int, counter: callable) -> int: - """Counts the number of tokens in a text, with a heuristic to accelerate long texts""" - heuritistic = 6*max_size - - # There is a rare failure case for the below heuristic where superfluous tokens - # may be added from a longer, existing token being split before it was finished. - # e.g. Australia -> 1 token - # Australi -> 3 token - # - # We mitigate this failure case by adding the len(longest token)-1 such that - # any ongoing token will be able to finish - # - # Using the cl100k tokenset, the length of the longest non-symbol token is 42 - # See: https://gist.github.com/Yardanico/623b3092d0b707119f8c7d90a3596afe - max_token = 42 - 1 - - if len(text) > heuritistic and counter(text[:heuritistic+max_token]) > max_size: - return max_size+1 - return counter(text) def find_split(splits: list[str], max_size: int, splitter: str, counter: callable) -> tuple[int, str]: """Binary search for the optimal split point where the accumulated_token_count < max_size.""" - low, high = 0, len(splits) + 1 + + # Start avg low for fast calc of first real avg + avg, low, high = 0.2, 0, len(splits) + 1 + sums = list(accumulate(map(len, splits), initial=0)) + sums.append(sums[-1]) + while low < high: - # As the main performance hit comes from running the token_counter on long texts - # we can bias the binary search to favour guessing towards shorter sequences. - # This is done below by using > 2 as the divisor - mid = low + (high - low) // 8 - if count(splitter.join(splits[:mid]), max_size, counter) > max_size: + idx = bisect_left(sums[low:high + 1], max_size * avg) + mid = min(idx + low, high - 1) + + tokens = counter(splitter.join(splits[:mid])) + + avg = sums[mid]/tokens if sums[mid] else avg + + if tokens > max_size: high = mid else: low = mid + 1 return low-1, splitter.join(splits[:low-1]) + def chunk(text: str, chunk_size: int, token_counter: callable, memoize: bool=True, _recursion_depth: int = 0) -> list[str]: """Split text into semantically meaningful chunks of a specified size as determined by the provided token counter. @@ -179,7 +171,7 @@ def chunk(text: str, chunk_size: int, token_counter: callable, memoize: bool=Tru continue # If the split is over the chunk size, recursively chunk it. - if count(split, chunk_size, token_counter) > chunk_size: + if token_counter(split) > chunk_size: chunks.extend(chunk(split, chunk_size, token_counter, memoize, _recursion_depth+1)) # If the split is equal to or under the chunk size, merge it with all subsequent splits until the chunk size is reached. @@ -195,11 +187,8 @@ def chunk(text: str, chunk_size: int, token_counter: callable, memoize: bool=Tru # If the splitter is not whitespace and the split is not the last split, add the splitter to the end of the last chunk if doing so would not cause it to exceed the chunk size otherwise add the splitter as a new chunk. if not splitter_is_whitespace and not (i == len(splits) - 1 or all(j in skips for j in range(i+1, len(splits)))): - # We seperately add tokens(prior chunk) and tokens(splitter) to ensure O(1) - (both will be in cache). - # There is a failure case where tokens(get_last_token(prior_chunk) + splitter) == 1 however this is - # quite uncommon and leads to a negligible impact - if token_counter(chunks[-1]) + token_counter(splitter) <= chunk_size: - chunks[-1] += splitter + if token_counter(last_chunk_with_splitter:=chunks[-1]+splitter) <= chunk_size: + chunks[-1] = last_chunk_with_splitter else: chunks.append(splitter) diff --git a/tests/bench.py b/tests/bench.py index 9660cf6..6e386ec 100644 --- a/tests/bench.py +++ b/tests/bench.py @@ -6,7 +6,8 @@ import semchunk -chunk_sizes = [8,16,32,64,128,256,512,1024] +chunk_sizes = [8,16,32,64,128,256,512,1024,2048,4096,8192] + semantic_text_splitter_chunker = TextSplitter.from_tiktoken_model('gpt-4') encoder = tiktoken.encoding_for_model('gpt-4') @@ -31,7 +32,7 @@ def bench_semantic_text_splitter(text: str, chunk_size: int) -> None: libraries = { 'semchunk': bench_semchunkv1, 'semchunkv2': bench_semchunkv2, - 'semantic_text_splitter': bench_semantic_text_splitter, + # 'semantic_text_splitter': bench_semantic_text_splitter, } def bench() -> dict[str, float]: @@ -41,10 +42,13 @@ def bench() -> dict[str, float]: semchunk.semchunk._memoised_token_counters = {} for fileid in test_semchunk.gutenberg.fileids(): sample = test_semchunk.gutenberg.raw(fileid) + results = [] for library, function in libraries.items(): start = time.time() - function(sample, chunk_size) + results.append(function(sample, chunk_size)) benchmarks[library][i] += time.time() - start + if len(results) > 1: + assert results[-1] == results[-2] return benchmarks