Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Removing unnecessary chunking and merging of long texts #575

Merged
merged 1 commit into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ dependencies = [
"haliax>=1.4.dev291",
"equinox>=0.11.4",
"jaxtyping>=0.2.20",
"tokenizers>=0.15.2",
"transformers>=4.39.3",
"optax>=0.1.9",
"wandb~=0.16.6",
Expand Down
89 changes: 0 additions & 89 deletions src/levanter/data/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,9 +312,6 @@ def _maybe_force_tokenizer_parallelism(tokenizer: PreTrainedTokenizerBase):
os.environ["TOKENIZERS_PARALLELISM"] = "true"


LONG_STRING_WORKAROUND = 100_000


ws = regex.compile(r"\s")


Expand All @@ -332,7 +329,6 @@ def __init__(
*,
batch_size=128,
override_resources=None,
_workaround_len=LONG_STRING_WORKAROUND,
return_attention_mask=False,
padding=False,
max_length=None,
Expand Down Expand Up @@ -368,7 +364,6 @@ def __init__(

self._need_to_add_eos = should_append_eos
self._need_to_add_bos = should_append_bos
self._workaround_len = _workaround_len

def __call__(self, batch: Sequence[str]) -> BatchEncoding:
orig_lengths = [len(d) for d in batch]
Expand All @@ -378,97 +373,13 @@ def __call__(self, batch: Sequence[str]) -> BatchEncoding:
if self._need_to_add_eos:
batch = [d + " " + self.tokenizer.eos_token for d in batch]

if self._needs_long_sequence_workaround:
# break any strings that are longer than 50K characters into smaller chunks
orig_batch = batch
batch = []
needs_merge = []
for i, d in enumerate(orig_batch):
needs_merge.append(False)
orig_len = orig_lengths[i]
while len(d) > self._workaround_len:
# we'd rather break strings at whitespace, so find the first whitespace
match = ws.search(d, self._workaround_len)
# this is vanishingly unlikely, but if we can't find a whitespace, just break it at the limit
if match is None:
split = len(d)
else:
split = match.start()

batch.append(d[:split])
needs_merge.append(True)

d = d[split:]
orig_len -= split

batch.append(d)
else:
needs_merge = []

if self.padding is not False:
encoding = self.tokenizer(batch, return_attention_mask=self.return_attention_mask, verbose=False, padding=self.padding, max_length=self.max_length, truncation=True) # type: ignore
else:
encoding = self.tokenizer(batch, return_attention_mask=self.return_attention_mask, verbose=False) # type: ignore

if needs_merge:
new_encoding = self._merge_split_encodings(batch, encoding, needs_merge)
encoding = BatchEncoding(new_encoding)

return encoding

@staticmethod
def _merge_split_encodings(batch, encoding, needs_merge):
# merge the encodings back together
# we might need to merge multiple encodings together
# needs merge marks the first n-1 encodings that need to be merged for each document
new_encoding = {}
for k, v in encoding.items():
if len(v) == 0:
continue
if isinstance(v[0], np.ndarray):
assert len(v) == len(batch)
v_out = []
vs_to_merge = []
for i in range(len(batch)):
if not needs_merge[i]:
v_out.append(np.concatenate(vs_to_merge))
vs_to_merge = []
vs_to_merge.append(v[i])

if len(vs_to_merge) > 0:
v_out.append(np.concatenate(vs_to_merge))

new_encoding[k] = v_out
elif isinstance(v[0], list):
v_out = []
vs_to_merge = []
for i in range(len(batch)):
if not needs_merge[i]:
if len(vs_to_merge) > 0:
v_out.append(list(chain(*vs_to_merge)))
vs_to_merge = []
vs_to_merge.append(v[i])

if len(vs_to_merge) > 0:
v_out.append(list(chain(*vs_to_merge)))
new_encoding[k] = v_out
else:
raise ValueError(f"Unknown type {type(v[0])}")
return new_encoding

# TODO remove this when it's resolved https://github.com/huggingface/tokenizers/issues/1449
@cached_property
def _needs_long_sequence_workaround(self):
if isinstance(self.tokenizer, PreTrainedTokenizerFast):
normalizer = self.tokenizer.backend_tokenizer.normalizer
if normalizer is None:
return False
# if there's a "Replace" normalizer, then we need to do the workaround
# inexplicably there's no way to see inside a Sequence so we also have to assume it needs it
return isinstance(normalizer, (normalizers.Replace, normalizers.Sequence))
else:
return False

@property
def num_cpus(self) -> int:
if self.override_resources is not None:
Expand Down
26 changes: 0 additions & 26 deletions tests/test_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,29 +41,3 @@ def test_lm_example_handles_ignore_id():
no_ignore_loss = next_token_loss(Pos, Vocab, distr, tokens, loss_mask=ex_no_ignore.loss_mask)

assert no_ignore_loss.item() >= ignored_loss.item() + 100 / Pos.size


def test_merge_split_encodings():
tokenizer = AutoTokenizer.from_pretrained("gpt2")
# make this very short for testing

lorem = """Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum."""

short_batch_tokenizer = BatchTokenizer(tokenizer, _workaround_len=len(lorem) // 3)
# force this
short_batch_tokenizer._needs_long_sequence_workaround = True

batch_tokenizer = BatchTokenizer(tokenizer, _workaround_len=50000)
batch = [lorem]

short_out = short_batch_tokenizer(batch)
reg_out = batch_tokenizer(batch)

assert short_out == reg_out


@skip_if_hf_model_not_accessible("meta-llama/Llama-2-7b-hf")
def test_llama_tokenizer_needs_long_sequence_workaround():
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
batch_tokenizer = BatchTokenizer(tokenizer)
assert batch_tokenizer._needs_long_sequence_workaround