Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sampling : fix off-by-one in tail-free sampling #9604

Closed
wants to merge 1 commit into from
Closed

Conversation

ggerganov
Copy link
Owner

The tail-free sampler was incorrectly using the index to set the size of the candidates resulting in the size being 1 less than it has to be.

@github-actions github-actions bot added the testing Everything test related label Sep 23, 2024
@slaren
Copy link
Collaborator

slaren commented Sep 23, 2024

This will change the behavior compared to the pre-refactor implementation, so it would be good to check that the implementation is actually correct. I have been trying to compare the results of the tests with the reference implementation linked in the paper, and I get very different results, but the code required some changes so maybe I broke something. I get these results:

[0.1, 0.15, 0.2, 0.25, 0.3], 0.25 -> [0.3, 0.25, 0.2, 0.15] (cutoff = 0.1499999978930995)
[0.1, 0.15, 0.2, 0.25, 0.3], 0.5 -> [0.3, 0.25, 0.2, 0.15] (cutoff = 0.1499999978930995)
[0.1, 0.15, 0.2, 0.25, 0.3], 0.8 -> [0.3, 0.25, 0.2] (cutoff = 0.19999999398584362)
[0.1, 0.15, 0.2, 0.25, 0.3], 0.99 -> [0.3, 0.25] (cutoff = 0.24999999904767284)
import tensorflow as tf
import math

def tail_free(logits_ar, z, temperature=1.0):
    """
    Inputs:
    * logits (tensorflow tensor, Batch size x number of tokens) - takes in the neural network output
    * z (float) - hyperparameter for where to draw the tail. Recommend a value of 0.9 - 0.95. The lower
    the value the fewer tokens are kept (tighter the tail is).
    * temperature (float) - optional temperature parameter.

    Outputs:
    * samples - tensor (Batch size x 1) - randomly sampled tokens post pruning
    """
    logits = tf.convert_to_tensor(logits_ar)
    logits = tf.log(logits)
    logits = tf.expand_dims(logits, 0)

    logits = logits / tf.to_float(temperature)
    sps = tf.sort(tf.nn.softmax(logits, axis=1), direction='DESCENDING',axis=1)
    grad = sps[:,1:]-sps[:,:-1] # first derivative
    grad = grad[:,1:]-grad[:,:-1] #this is the 2nd derivative

    only_pos = tf.math.abs(grad)
    sec_indices = tf.range(grad.shape[1].value)
    sec_weights = only_pos/ tf.math.reduce_sum( only_pos, axis=1, keepdims=True )

    tail_ids = tf.cast(tf.argmax(tf.cast(tf.cumsum(sec_weights, axis=1)>z, tf.int8), axis=1), tf.int32)+1
    # adding one to put it in the center of the tail.

    logit_inds = tf.stack([tf.range(0,logits.shape[0].value), tail_ids], axis=1)
    tail_min_vals = tf.expand_dims(tf.gather_nd(logits, logit_inds),1)

    # removes any tokens below the tail location by setting their values to be very very small.
    pruned = tf.where(
            logits < tail_min_vals,
            tf.ones_like(logits, dtype=logits.dtype) * -1e10,
            logits,
        )
    # do not need to convert to softmax again (renormalize) before passing to tf.multinomial
    samples = tf.multinomial(pruned, num_samples=1, output_dtype=tf.int32)

    with tf.Session() as sess:
        tail_min_vals = sess.run(tail_min_vals).flatten()[0]
        tail_min_vals = math.exp(tail_min_vals)
        l = [l for l in logits_ar if l >= tail_min_vals]
        l.sort(reverse=True)
        print(f"{logits_ar}, {z} -> {l} (cutoff = {tail_min_vals})")

    return samples


# test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f},               0.25f);
tail_free([0.1, 0.15, 0.2, 0.25, 0.3], 0.25)

# test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f, 0.25f},        0.50f);
tail_free([0.1, 0.15, 0.2, 0.25, 0.3], 0.50)

# test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f, 0.25f, 0.20f}, 0.80f);
tail_free([0.1, 0.15, 0.2, 0.25, 0.3], 0.80)

# test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f, 0.25f},        0.99f);
tail_free([0.1, 0.15, 0.2, 0.25, 0.3], 0.99)

@ggerganov
Copy link
Owner Author

Btw the last test with z=0.99f, with the C++ implementation, I get 3 candidates in the result: {0.3f, 0.25f, 0.20f}.

Also, the reference implementation has the comment:

z (float) - hyperparameter for where to draw the tail. Recommend a value of 0.9 - 0.95. The lower the value the fewer tokens are kept (tighter the tail is).

I haven't been able to run the python script yet (think I have tensorflow version mismatch), but from the results you posted, they seem to go the other way around - lower value keeps more tokens.

@slaren
Copy link
Collaborator

slaren commented Sep 24, 2024

I haven't been able to run the python script yet (think I have tensorflow version mismatch), but from the results you posted, they seem to go the other way around - lower value keeps more tokens.

It needs tensorflow 1.14. I definitely got something wrong, but I am not sure what. I tried looking at the pruned list instead, but the results are the same.

    with tf.Session() as sess:
        tail_min_vals = sess.run(tail_min_vals).flatten()[0]
        p = sess.run(pruned).flatten().tolist()
        p = [math.exp(l) for l in p if l > -1e10]
        f = [f"{l:.2f}" for l in p]
        s = ', '.join(f)
        print(f"{logits_ar}, {z} -> [{s}] (cutoff = {math.exp(tail_min_vals):.2f})")

@ggerganov
Copy link
Owner Author

Yup, I got the script running by using Qwen 2.5 to rewrite it for TF v2, but haven't been able to figure out what is causing the discrepancy.

TF v2 compatible script (unverified)
import tensorflow as tf
import numpy as np

def tail_free(logits_ar, z, temperature=1.0):
    """
    Inputs:
    * logits (tensorflow tensor, Batch size x number of tokens) - takes in the neural network output
    * z (float) - hyperparameter for where to draw the tail. Recommend a value of 0.9 - 0.95. The lower
    the value the fewer tokens are kept (tighter the tail is).
    * temperature (float) - optional temperature parameter.

    Outputs:
    * samples - tensor (Batch size x 1) - randomly sampled tokens post pruning
    """
    logits = tf.convert_to_tensor(logits_ar)
    logits = tf.math.log(logits) # gg: is this needed?
    logits = tf.expand_dims(logits, 0)

    logits = logits / temperature
    sps = tf.sort(tf.nn.softmax(logits, axis=1), direction='DESCENDING',axis=1)
    grad = sps[:,1:]-sps[:,:-1] # first derivative
    grad = grad[:,1:]-grad[:,:-1] #this is the 2nd derivative

    only_pos = tf.math.abs(grad)
    sec_indices = tf.range(grad.shape[1])
    sec_weights = only_pos / tf.reduce_sum(only_pos, axis=1, keepdims=True)

    tail_ids = tf.cast(tf.argmax(tf.cast(tf.cumsum(sec_weights, axis=1) > z, tf.int8), axis=1), tf.int32)+1
    # adding one to put it in the center of the tail.

    logit_inds = tf.stack([tf.range(0, logits.shape[0]), tail_ids], axis=1)
    tail_min_vals = tf.expand_dims(tf.gather_nd(logits, logit_inds), 1)

    # removes any tokens below the tail location by setting their values to be very very small.
    pruned = tf.where(
        logits < tail_min_vals,
        tf.ones_like(logits, dtype=logits.dtype) * -1e10,
        logits,
    )
    # do not need to convert to softmax again (renormalize) before passing to tf.multinomial
    samples = tf.random.categorical(pruned, num_samples=1, dtype=tf.int32)

    tail_min_vals_np = tail_min_vals.numpy().flatten()[0]
    tail_min_vals_exp = np.exp(tail_min_vals_np)
    l = [l for l in logits_ar if l >= tail_min_vals_exp]
    l.sort(reverse=True)
    print(f"{logits_ar}, {z} -> {l} (cutoff = {tail_min_vals_exp})")

    return samples

# test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f},               0.25f);
print(tail_free([0.1, 0.15, 0.2, 0.25, 0.3], 0.25).numpy())

# test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f, 0.25f},        0.50f);
print(tail_free([0.1, 0.15, 0.2, 0.25, 0.3], 0.50).numpy())

# test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f, 0.25f, 0.20f}, 0.80f);
print(tail_free([0.1, 0.15, 0.2, 0.25, 0.3], 0.80).numpy())

# test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f, 0.25f},        0.99f);
print(tail_free([0.1, 0.15, 0.2, 0.25, 0.3], 0.99).numpy())

@ggerganov
Copy link
Owner Author

I wonder if we should remove this sampler to simplify things.

@slaren
Copy link
Collaborator

slaren commented Oct 27, 2024

Yes, if we can't verify that it works as intended, it would be better to remove it. If somebody is interested in this sampler, they can review this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
testing Everything test related
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants