Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactors #27

Open
wants to merge 60 commits into
base: main
Choose a base branch
from
Open

Refactors #27

wants to merge 60 commits into from

Conversation

doomslide
Copy link
Contributor

Added full attention scores + some memory optimization for kvcache + other minor things

@doomslide
Copy link
Contributor Author

did a bit more cleaning, most major is factoring generate out of main and factoring initialize out of generate. they have different logic so easier to debug this way.

@xjdr-alt xjdr-alt mentioned this pull request Oct 7, 2024
@qdbp
Copy link
Contributor

qdbp commented Oct 8, 2024

@hallucinomeny since this PR looks like it's ahead in the queue of my #40 and you also "lift" the generator function out of main, would you be open to moving to the dataclass-based impl I have there?

I think that provides better encapsulation and will also play more nicely with the "sampler interface standardization" I have in mind. Specifically:

# Create the batch of tokens
@dataclass(kw_only=True)
class TokenGenerator(Generic[Cfg_contra, ST]):
  weights: XfmrWeights
  model_params: ModelParams
  tokenizer: Tokenizer
  sampler: EntropySampler[Cfg_contra, ST]
  sampler_cfg: Cfg_contra

  def generate_from_prompt(self, init_tokens) -> Generator[str, None, None]:
    gen_tokens = None
    cur_pos = 0
    tokens = jnp.array([init_tokens], jnp.int32)
    bsz, seqlen = tokens.shape
    attn_mask = build_attn_mask(seqlen, cur_pos)
    mp = self.model_params
    freqs_cis = precompute_freqs_cis(mp.head_dim, mp.max_seq_len, mp.rope_theta, mp.use_scaled_rope)
    kvcache = KVCache.new(mp.n_layers, bsz, mp.max_seq_len, mp.n_local_kv_heads, mp.head_dim)
    logits, kvcache, _, _ = xfmr(self.weights, mp, tokens, cur_pos, freqs_cis[:seqlen], kvcache, attn_mask=attn_mask)
    next_token = jnp.argmax(logits[:, -1], axis=-1, keepdims=True).astype(jnp.int32)
    gen_tokens = next_token

    yield self.tokenizer.decode([next_token.item()])

    cur_pos = seqlen
    stop = jnp.array([128001, 128008, 128009])
    state: ST | None = None
    while cur_pos < 8192:
      cur_pos += 1
      logits, kvcache, scores, _ = xfmr(
        self.weights, mp, next_token, cur_pos, freqs_cis[cur_pos : cur_pos + 1], kvcache
      )
      next_token, state = self.sampler(gen_tokens, logits, scores, cfg=self.sampler_cfg, state=state)
      gen_tokens = jnp.concatenate((gen_tokens, next_token))
      yield self.tokenizer.decode(next_token.tolist()[0])
      if jnp.isin(next_token, stop).any():
        break

obviously this includes changes (such as e.g. sampler and state being arguments) that are only part of my PR but if this structure is used those can be deconflicted easily later

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants