diff --git a/pyproject.toml b/pyproject.toml index 8aa8db4ad..01e7617b9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,7 @@ dependencies = [ "haliax>=1.4.dev291", "equinox>=0.11.4", "jaxtyping>=0.2.20", + "tokenizers>=0.15.2", "transformers>=4.39.3", "optax>=0.1.9", "wandb~=0.16.6", diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index c635a98ea..f9845ef0d 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -312,9 +312,6 @@ def _maybe_force_tokenizer_parallelism(tokenizer: PreTrainedTokenizerBase): os.environ["TOKENIZERS_PARALLELISM"] = "true" -LONG_STRING_WORKAROUND = 100_000 - - ws = regex.compile(r"\s") @@ -332,7 +329,6 @@ def __init__( *, batch_size=128, override_resources=None, - _workaround_len=LONG_STRING_WORKAROUND, return_attention_mask=False, padding=False, max_length=None, @@ -368,7 +364,6 @@ def __init__( self._need_to_add_eos = should_append_eos self._need_to_add_bos = should_append_bos - self._workaround_len = _workaround_len def __call__(self, batch: Sequence[str]) -> BatchEncoding: orig_lengths = [len(d) for d in batch] @@ -378,97 +373,13 @@ def __call__(self, batch: Sequence[str]) -> BatchEncoding: if self._need_to_add_eos: batch = [d + " " + self.tokenizer.eos_token for d in batch] - if self._needs_long_sequence_workaround: - # break any strings that are longer than 50K characters into smaller chunks - orig_batch = batch - batch = [] - needs_merge = [] - for i, d in enumerate(orig_batch): - needs_merge.append(False) - orig_len = orig_lengths[i] - while len(d) > self._workaround_len: - # we'd rather break strings at whitespace, so find the first whitespace - match = ws.search(d, self._workaround_len) - # this is vanishingly unlikely, but if we can't find a whitespace, just break it at the limit - if match is None: - split = len(d) - else: - split = match.start() - - batch.append(d[:split]) - needs_merge.append(True) - - d = d[split:] - orig_len -= split - - batch.append(d) - else: - needs_merge = [] - if self.padding is not False: encoding = self.tokenizer(batch, return_attention_mask=self.return_attention_mask, verbose=False, padding=self.padding, max_length=self.max_length, truncation=True) # type: ignore else: encoding = self.tokenizer(batch, return_attention_mask=self.return_attention_mask, verbose=False) # type: ignore - if needs_merge: - new_encoding = self._merge_split_encodings(batch, encoding, needs_merge) - encoding = BatchEncoding(new_encoding) - return encoding - @staticmethod - def _merge_split_encodings(batch, encoding, needs_merge): - # merge the encodings back together - # we might need to merge multiple encodings together - # needs merge marks the first n-1 encodings that need to be merged for each document - new_encoding = {} - for k, v in encoding.items(): - if len(v) == 0: - continue - if isinstance(v[0], np.ndarray): - assert len(v) == len(batch) - v_out = [] - vs_to_merge = [] - for i in range(len(batch)): - if not needs_merge[i]: - v_out.append(np.concatenate(vs_to_merge)) - vs_to_merge = [] - vs_to_merge.append(v[i]) - - if len(vs_to_merge) > 0: - v_out.append(np.concatenate(vs_to_merge)) - - new_encoding[k] = v_out - elif isinstance(v[0], list): - v_out = [] - vs_to_merge = [] - for i in range(len(batch)): - if not needs_merge[i]: - if len(vs_to_merge) > 0: - v_out.append(list(chain(*vs_to_merge))) - vs_to_merge = [] - vs_to_merge.append(v[i]) - - if len(vs_to_merge) > 0: - v_out.append(list(chain(*vs_to_merge))) - new_encoding[k] = v_out - else: - raise ValueError(f"Unknown type {type(v[0])}") - return new_encoding - - # TODO remove this when it's resolved https://github.com/huggingface/tokenizers/issues/1449 - @cached_property - def _needs_long_sequence_workaround(self): - if isinstance(self.tokenizer, PreTrainedTokenizerFast): - normalizer = self.tokenizer.backend_tokenizer.normalizer - if normalizer is None: - return False - # if there's a "Replace" normalizer, then we need to do the workaround - # inexplicably there's no way to see inside a Sequence so we also have to assume it needs it - return isinstance(normalizer, (normalizers.Replace, normalizers.Sequence)) - else: - return False - @property def num_cpus(self) -> int: if self.override_resources is not None: diff --git a/tests/test_text.py b/tests/test_text.py index 70b2d26a7..d30c7c6b8 100644 --- a/tests/test_text.py +++ b/tests/test_text.py @@ -41,29 +41,3 @@ def test_lm_example_handles_ignore_id(): no_ignore_loss = next_token_loss(Pos, Vocab, distr, tokens, loss_mask=ex_no_ignore.loss_mask) assert no_ignore_loss.item() >= ignored_loss.item() + 100 / Pos.size - - -def test_merge_split_encodings(): - tokenizer = AutoTokenizer.from_pretrained("gpt2") - # make this very short for testing - - lorem = """Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.""" - - short_batch_tokenizer = BatchTokenizer(tokenizer, _workaround_len=len(lorem) // 3) - # force this - short_batch_tokenizer._needs_long_sequence_workaround = True - - batch_tokenizer = BatchTokenizer(tokenizer, _workaround_len=50000) - batch = [lorem] - - short_out = short_batch_tokenizer(batch) - reg_out = batch_tokenizer(batch) - - assert short_out == reg_out - - -@skip_if_hf_model_not_accessible("meta-llama/Llama-2-7b-hf") -def test_llama_tokenizer_needs_long_sequence_workaround(): - tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") - batch_tokenizer = BatchTokenizer(tokenizer) - assert batch_tokenizer._needs_long_sequence_workaround