Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

📈 performance optimisation #3

Merged
merged 13 commits into from
May 13, 2024
3 changes: 1 addition & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,10 @@
!.gitignore

# Finally, exclude anything in the above inclusions that we don't want.
# Exclude common Python files and folders.
*.pyc
*.pyo
*.ipynb
__pycache__/
.pytest_cache/
tests/profiler.py
tests/test_bench.py
tests/test_bench.py
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
## Changelog 🔄
All notable changes to `semchunk` will be documented here. This project adheres to [Keep a Changelog](https://keepachangelog.com/en/1.1.0/) and [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased] - 2024-XX-XX
### Changed
- Improved chunking performance with larger chunk sizes by switching from linear to binary search for the identification of optimal chunk boundaries.

## [0.2.3] - 2024-03-11
### Fixed
- Ensured that memoization does not overwrite `chunk()`'s function signature.
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

`semchunk` is a fast and lightweight pure Python library for splitting text into semantically meaningful chunks.

Owing to its complex yet highly efficient chunking algorithm, `semchunk` is both more semantically accurate than [`langchain.text_splitter.RecursiveCharacterTextSplitter`](https://python.langchain.com/docs/modules/data_connection/document_transformers/text_splitters/recursive_text_splitter) (see [How It Works 🔍](https://github.com/umarbutler/semchunk#how-it-works-)) and is also over 70% faster than [`semantic-text-splitter`](https://pypi.org/project/semantic-text-splitter/) (see the [Benchmarks 📊](https://github.com/umarbutler/semchunk#benchmarks-)).
Owing to its complex yet highly efficient chunking algorithm, `semchunk` is both more semantically accurate than [`langchain.text_splitter.RecursiveCharacterTextSplitter`](https://python.langchain.com/docs/modules/data_connection/document_transformers/text_splitters/recursive_text_splitter) (see [How It Works 🔍](https://github.com/umarbutler/semchunk#how-it-works-)) and is also over 80% faster than [`semantic-text-splitter`](https://pypi.org/project/semantic-text-splitter/) (see the [Benchmarks 📊](https://github.com/umarbutler/semchunk#benchmarks-)).

## Installation 📦
`semchunk` may be installed with `pip`:
Expand Down
62 changes: 41 additions & 21 deletions src/semchunk/semchunk.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
import re

from bisect import bisect_left
from typing import Callable
from functools import cache, wraps
from itertools import accumulate


_memoised_token_counters = {}
"""A map of token counters to their memoised versions."""
Expand Down Expand Up @@ -45,7 +50,33 @@ def _split_text(text: str) -> tuple[str, bool, list[str]]:
# Return the splitter and the split text.
return splitter, splitter_is_whitespace, text.split(splitter)

def chunk(text: str, chunk_size: int, token_counter: callable, memoize: bool=True, _recursion_depth: int = 0) -> list[str]:

def merge_splits(splits: list[str], chunk_size: int, splitter: str, token_counter: Callable) -> tuple[int, str]:
"""Merge splits until a chunk size is reached, returning the index of the last split included in the merged chunk along with the merged chunk itself."""

average = 0.2
low = 0
high = len(splits) + 1
cumulative_lengths = tuple(accumulate(map(len, splits), initial=0))
cumulative_lengths += (cumulative_lengths[-1],)

while low < high:
i = bisect_left(cumulative_lengths[low : high + 1], chunk_size * average)
midpoint = min(i + low, high - 1)

tokens = token_counter(splitter.join(splits[:midpoint]))

average = cumulative_lengths[midpoint] / tokens if cumulative_lengths[midpoint] else average

if tokens > chunk_size:
high = midpoint
else:
low = midpoint + 1

return low - 1, splitter.join(splits[:low - 1])


def chunk(text: str, chunk_size: int, token_counter: Callable, memoize: bool = True, _recursion_depth: int = 0) -> list[str]:
"""Split text into semantically meaningful chunks of a specified size as determined by the provided token counter.

Args:
Expand Down Expand Up @@ -76,35 +107,23 @@ def chunk(text: str, chunk_size: int, token_counter: callable, memoize: bool=Tru

# If the split is over the chunk size, recursively chunk it.
if token_counter(split) > chunk_size:
chunks.extend(chunk(split, chunk_size, token_counter=token_counter, memoize=memoize, _recursion_depth=_recursion_depth+1))
chunks.extend(chunk(split, chunk_size, token_counter = token_counter, memoize = memoize, _recursion_depth = _recursion_depth + 1))

# If the split is equal to or under the chunk size, merge it with all subsequent splits until the chunk size is reached.
# If the split is equal to or under the chunk size, add it and any subsequent splits to a new chunk until the chunk size is reached.
else:
# Initalise the new chunk.
new_chunk = split
# Merge the split with subsequent splits until the chunk size is reached.
final_split_in_chunk_i, new_chunk = merge_splits(splits[i:], chunk_size, splitter, token_counter)

# Iterate through each subsequent split until the chunk size is reached.
for j, next_split in enumerate(splits[i+1:], start=i+1):
# Check whether the next split can be added to the chunk without exceeding the chunk size.
if token_counter(updated_chunk:=new_chunk+splitter+next_split) <= chunk_size:
# Add the next split to the new chunk.
new_chunk = updated_chunk

# Add the index of the next split to the list of indices to skip.
skips.add(j)

# If the next split cannot be added to the chunk without exceeding the chunk size, break.
else:
break
# Mark any splits included in the new chunk for exclusion from future chunks.
skips.update(range(i + 1, i + final_split_in_chunk_i))

# Add the chunk.
chunks.append(new_chunk)

# If the splitter is not whitespace and the split is not the last split, add the splitter to the end of the last chunk if doing so would not cause it to exceed the chunk size otherwise add the splitter as a new chunk.
if not splitter_is_whitespace and not (i == len(splits) - 1 or all(j in skips for j in range(i+1, len(splits)))):
if token_counter(last_chunk_with_splitter:=chunks[-1]+splitter) <= chunk_size:
if not splitter_is_whitespace and not (i == len(splits) - 1 or all(j in skips for j in range(i + 1, len(splits)))):
if token_counter(last_chunk_with_splitter := chunks[-1] + splitter) <= chunk_size:
chunks[-1] = last_chunk_with_splitter

else:
chunks.append(splitter)

Expand All @@ -114,4 +133,5 @@ def chunk(text: str, chunk_size: int, token_counter: callable, memoize: bool=Tru

return chunks


chunk = wraps(chunk)(cache(chunk))
Loading