-
Notifications
You must be signed in to change notification settings - Fork 10.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
sampling : fix off-by-one in tail-free sampling #9604
Conversation
This will change the behavior compared to the pre-refactor implementation, so it would be good to check that the implementation is actually correct. I have been trying to compare the results of the tests with the reference implementation linked in the paper, and I get very different results, but the code required some changes so maybe I broke something. I get these results:
import tensorflow as tf
import math
def tail_free(logits_ar, z, temperature=1.0):
"""
Inputs:
* logits (tensorflow tensor, Batch size x number of tokens) - takes in the neural network output
* z (float) - hyperparameter for where to draw the tail. Recommend a value of 0.9 - 0.95. The lower
the value the fewer tokens are kept (tighter the tail is).
* temperature (float) - optional temperature parameter.
Outputs:
* samples - tensor (Batch size x 1) - randomly sampled tokens post pruning
"""
logits = tf.convert_to_tensor(logits_ar)
logits = tf.log(logits)
logits = tf.expand_dims(logits, 0)
logits = logits / tf.to_float(temperature)
sps = tf.sort(tf.nn.softmax(logits, axis=1), direction='DESCENDING',axis=1)
grad = sps[:,1:]-sps[:,:-1] # first derivative
grad = grad[:,1:]-grad[:,:-1] #this is the 2nd derivative
only_pos = tf.math.abs(grad)
sec_indices = tf.range(grad.shape[1].value)
sec_weights = only_pos/ tf.math.reduce_sum( only_pos, axis=1, keepdims=True )
tail_ids = tf.cast(tf.argmax(tf.cast(tf.cumsum(sec_weights, axis=1)>z, tf.int8), axis=1), tf.int32)+1
# adding one to put it in the center of the tail.
logit_inds = tf.stack([tf.range(0,logits.shape[0].value), tail_ids], axis=1)
tail_min_vals = tf.expand_dims(tf.gather_nd(logits, logit_inds),1)
# removes any tokens below the tail location by setting their values to be very very small.
pruned = tf.where(
logits < tail_min_vals,
tf.ones_like(logits, dtype=logits.dtype) * -1e10,
logits,
)
# do not need to convert to softmax again (renormalize) before passing to tf.multinomial
samples = tf.multinomial(pruned, num_samples=1, output_dtype=tf.int32)
with tf.Session() as sess:
tail_min_vals = sess.run(tail_min_vals).flatten()[0]
tail_min_vals = math.exp(tail_min_vals)
l = [l for l in logits_ar if l >= tail_min_vals]
l.sort(reverse=True)
print(f"{logits_ar}, {z} -> {l} (cutoff = {tail_min_vals})")
return samples
# test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f}, 0.25f);
tail_free([0.1, 0.15, 0.2, 0.25, 0.3], 0.25)
# test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f, 0.25f}, 0.50f);
tail_free([0.1, 0.15, 0.2, 0.25, 0.3], 0.50)
# test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f, 0.25f, 0.20f}, 0.80f);
tail_free([0.1, 0.15, 0.2, 0.25, 0.3], 0.80)
# test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f, 0.25f}, 0.99f);
tail_free([0.1, 0.15, 0.2, 0.25, 0.3], 0.99) |
Btw the last test with Also, the reference implementation has the comment:
I haven't been able to run the python script yet (think I have tensorflow version mismatch), but from the results you posted, they seem to go the other way around - lower value keeps more tokens. |
It needs tensorflow 1.14. I definitely got something wrong, but I am not sure what. I tried looking at the with tf.Session() as sess:
tail_min_vals = sess.run(tail_min_vals).flatten()[0]
p = sess.run(pruned).flatten().tolist()
p = [math.exp(l) for l in p if l > -1e10]
f = [f"{l:.2f}" for l in p]
s = ', '.join(f)
print(f"{logits_ar}, {z} -> [{s}] (cutoff = {math.exp(tail_min_vals):.2f})") |
Yup, I got the script running by using Qwen 2.5 to rewrite it for TF v2, but haven't been able to figure out what is causing the discrepancy. TF v2 compatible script (unverified)import tensorflow as tf
import numpy as np
def tail_free(logits_ar, z, temperature=1.0):
"""
Inputs:
* logits (tensorflow tensor, Batch size x number of tokens) - takes in the neural network output
* z (float) - hyperparameter for where to draw the tail. Recommend a value of 0.9 - 0.95. The lower
the value the fewer tokens are kept (tighter the tail is).
* temperature (float) - optional temperature parameter.
Outputs:
* samples - tensor (Batch size x 1) - randomly sampled tokens post pruning
"""
logits = tf.convert_to_tensor(logits_ar)
logits = tf.math.log(logits) # gg: is this needed?
logits = tf.expand_dims(logits, 0)
logits = logits / temperature
sps = tf.sort(tf.nn.softmax(logits, axis=1), direction='DESCENDING',axis=1)
grad = sps[:,1:]-sps[:,:-1] # first derivative
grad = grad[:,1:]-grad[:,:-1] #this is the 2nd derivative
only_pos = tf.math.abs(grad)
sec_indices = tf.range(grad.shape[1])
sec_weights = only_pos / tf.reduce_sum(only_pos, axis=1, keepdims=True)
tail_ids = tf.cast(tf.argmax(tf.cast(tf.cumsum(sec_weights, axis=1) > z, tf.int8), axis=1), tf.int32)+1
# adding one to put it in the center of the tail.
logit_inds = tf.stack([tf.range(0, logits.shape[0]), tail_ids], axis=1)
tail_min_vals = tf.expand_dims(tf.gather_nd(logits, logit_inds), 1)
# removes any tokens below the tail location by setting their values to be very very small.
pruned = tf.where(
logits < tail_min_vals,
tf.ones_like(logits, dtype=logits.dtype) * -1e10,
logits,
)
# do not need to convert to softmax again (renormalize) before passing to tf.multinomial
samples = tf.random.categorical(pruned, num_samples=1, dtype=tf.int32)
tail_min_vals_np = tail_min_vals.numpy().flatten()[0]
tail_min_vals_exp = np.exp(tail_min_vals_np)
l = [l for l in logits_ar if l >= tail_min_vals_exp]
l.sort(reverse=True)
print(f"{logits_ar}, {z} -> {l} (cutoff = {tail_min_vals_exp})")
return samples
# test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f}, 0.25f);
print(tail_free([0.1, 0.15, 0.2, 0.25, 0.3], 0.25).numpy())
# test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f, 0.25f}, 0.50f);
print(tail_free([0.1, 0.15, 0.2, 0.25, 0.3], 0.50).numpy())
# test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f, 0.25f, 0.20f}, 0.80f);
print(tail_free([0.1, 0.15, 0.2, 0.25, 0.3], 0.80).numpy())
# test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f, 0.25f}, 0.99f);
print(tail_free([0.1, 0.15, 0.2, 0.25, 0.3], 0.99).numpy()) |
I wonder if we should remove this sampler to simplify things. |
Yes, if we can't verify that it works as intended, it would be better to remove it. If somebody is interested in this sampler, they can review this. |
The tail-free sampler was incorrectly using the index to set the size of the candidates resulting in the size being 1 less than it has to be.