Skip to content

Commit

Permalink
Removing unnecessary chunking and merging of long texts (#575)
Browse files Browse the repository at this point in the history
  • Loading branch information
versae authored May 13, 2024
1 parent 2888a35 commit c2bc833
Show file tree
Hide file tree
Showing 3 changed files with 1 addition and 115 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ dependencies = [
"haliax>=1.4.dev291",
"equinox>=0.11.4",
"jaxtyping>=0.2.20",
"tokenizers>=0.15.2",
"transformers>=4.39.3",
"optax>=0.1.9",
"wandb~=0.16.6",
Expand Down
89 changes: 0 additions & 89 deletions src/levanter/data/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,9 +312,6 @@ def _maybe_force_tokenizer_parallelism(tokenizer: PreTrainedTokenizerBase):
os.environ["TOKENIZERS_PARALLELISM"] = "true"


LONG_STRING_WORKAROUND = 100_000


ws = regex.compile(r"\s")


Expand All @@ -332,7 +329,6 @@ def __init__(
*,
batch_size=128,
override_resources=None,
_workaround_len=LONG_STRING_WORKAROUND,
return_attention_mask=False,
padding=False,
max_length=None,
Expand Down Expand Up @@ -368,7 +364,6 @@ def __init__(

self._need_to_add_eos = should_append_eos
self._need_to_add_bos = should_append_bos
self._workaround_len = _workaround_len

def __call__(self, batch: Sequence[str]) -> BatchEncoding:
orig_lengths = [len(d) for d in batch]
Expand All @@ -378,97 +373,13 @@ def __call__(self, batch: Sequence[str]) -> BatchEncoding:
if self._need_to_add_eos:
batch = [d + " " + self.tokenizer.eos_token for d in batch]

if self._needs_long_sequence_workaround:
# break any strings that are longer than 50K characters into smaller chunks
orig_batch = batch
batch = []
needs_merge = []
for i, d in enumerate(orig_batch):
needs_merge.append(False)
orig_len = orig_lengths[i]
while len(d) > self._workaround_len:
# we'd rather break strings at whitespace, so find the first whitespace
match = ws.search(d, self._workaround_len)
# this is vanishingly unlikely, but if we can't find a whitespace, just break it at the limit
if match is None:
split = len(d)
else:
split = match.start()

batch.append(d[:split])
needs_merge.append(True)

d = d[split:]
orig_len -= split

batch.append(d)
else:
needs_merge = []

if self.padding is not False:
encoding = self.tokenizer(batch, return_attention_mask=self.return_attention_mask, verbose=False, padding=self.padding, max_length=self.max_length, truncation=True) # type: ignore
else:
encoding = self.tokenizer(batch, return_attention_mask=self.return_attention_mask, verbose=False) # type: ignore

if needs_merge:
new_encoding = self._merge_split_encodings(batch, encoding, needs_merge)
encoding = BatchEncoding(new_encoding)

return encoding

@staticmethod
def _merge_split_encodings(batch, encoding, needs_merge):
# merge the encodings back together
# we might need to merge multiple encodings together
# needs merge marks the first n-1 encodings that need to be merged for each document
new_encoding = {}
for k, v in encoding.items():
if len(v) == 0:
continue
if isinstance(v[0], np.ndarray):
assert len(v) == len(batch)
v_out = []
vs_to_merge = []
for i in range(len(batch)):
if not needs_merge[i]:
v_out.append(np.concatenate(vs_to_merge))
vs_to_merge = []
vs_to_merge.append(v[i])

if len(vs_to_merge) > 0:
v_out.append(np.concatenate(vs_to_merge))

new_encoding[k] = v_out
elif isinstance(v[0], list):
v_out = []
vs_to_merge = []
for i in range(len(batch)):
if not needs_merge[i]:
if len(vs_to_merge) > 0:
v_out.append(list(chain(*vs_to_merge)))
vs_to_merge = []
vs_to_merge.append(v[i])

if len(vs_to_merge) > 0:
v_out.append(list(chain(*vs_to_merge)))
new_encoding[k] = v_out
else:
raise ValueError(f"Unknown type {type(v[0])}")
return new_encoding

# TODO remove this when it's resolved https://github.com/huggingface/tokenizers/issues/1449
@cached_property
def _needs_long_sequence_workaround(self):
if isinstance(self.tokenizer, PreTrainedTokenizerFast):
normalizer = self.tokenizer.backend_tokenizer.normalizer
if normalizer is None:
return False
# if there's a "Replace" normalizer, then we need to do the workaround
# inexplicably there's no way to see inside a Sequence so we also have to assume it needs it
return isinstance(normalizer, (normalizers.Replace, normalizers.Sequence))
else:
return False

@property
def num_cpus(self) -> int:
if self.override_resources is not None:
Expand Down
26 changes: 0 additions & 26 deletions tests/test_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,29 +41,3 @@ def test_lm_example_handles_ignore_id():
no_ignore_loss = next_token_loss(Pos, Vocab, distr, tokens, loss_mask=ex_no_ignore.loss_mask)

assert no_ignore_loss.item() >= ignored_loss.item() + 100 / Pos.size


def test_merge_split_encodings():
tokenizer = AutoTokenizer.from_pretrained("gpt2")
# make this very short for testing

lorem = """Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum."""

short_batch_tokenizer = BatchTokenizer(tokenizer, _workaround_len=len(lorem) // 3)
# force this
short_batch_tokenizer._needs_long_sequence_workaround = True

batch_tokenizer = BatchTokenizer(tokenizer, _workaround_len=50000)
batch = [lorem]

short_out = short_batch_tokenizer(batch)
reg_out = batch_tokenizer(batch)

assert short_out == reg_out


@skip_if_hf_model_not_accessible("meta-llama/Llama-2-7b-hf")
def test_llama_tokenizer_needs_long_sequence_workaround():
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
batch_tokenizer = BatchTokenizer(tokenizer)
assert batch_tokenizer._needs_long_sequence_workaround

0 comments on commit c2bc833

Please sign in to comment.