Skip to content

Commit

Permalink
revert top_k fix
Browse files Browse the repository at this point in the history
  • Loading branch information
mzegla committed Jul 23, 2024
1 parent ff57c77 commit 73e9b1a
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 15 deletions.
16 changes: 13 additions & 3 deletions src/cpp/src/logit_processor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,15 +84,24 @@ class TopKFilter : public ILogitTransformer {

// If this transform is used along with top_p, it should be applied after it since top_p sorts entire vector and top_k does it only partially
void apply(Logits& logits) override {

/*
TODO: Uncommenting this section requires changes in reference texts in tests
if (m_top_k >= logits.m_size)
return;
*/

if (!logits.vector_initialized()) {
// Initialize and partially sort vector
logits.initialize_vector();
std::partial_sort(logits.m_vector.begin(), logits.m_vector.begin() + m_top_k, logits.m_vector.end(), [](const Token& lhs, const Token& rhs) {return lhs.m_log_prob > rhs.m_log_prob; });
// TODO: Uncommenting below requires uncommenting section above
// std::partial_sort(logits.m_vector.begin(), logits.m_vector.begin() + m_top_k, logits.m_vector.end(), [](const Token& lhs, const Token& rhs) {return lhs.m_log_prob > rhs.m_log_prob; });

std::sort(logits.m_vector.begin(), logits.m_vector.end(), [](const Token& lhs, const Token& rhs) {return lhs.m_log_prob > rhs.m_log_prob; });
}
logits.resize(m_top_k);
if (m_top_k < logits.m_size)
logits.resize(m_top_k);
}

protected:
Expand Down Expand Up @@ -320,7 +329,8 @@ class LogitProcessor {
if (sampling_params.top_p != 1.0f) {
m_logit_transformers.emplace_back(new LogitTransformers::TopPFilter(sampling_params.top_p));
}
if (sampling_params.top_k > 0 && sampling_params.top_k < std::numeric_limits<size_t>::max()) {
// TODO: Uncommenting below condition requires changes in reference texts in tests
if (sampling_params.top_k > 0 /* && sampling_params.top_k < std::numeric_limits<size_t>::max() */) {
m_logit_transformers.emplace_back(new LogitTransformers::TopKFilter(sampling_params.top_k));
}
}
Expand Down
4 changes: 4 additions & 0 deletions tests/cpp/logit_filtering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,9 @@ INSTANTIATE_TEST_SUITE_P(VariousInputs,
TopKFilteringTest,
testing::ValuesIn(TOP_K_TRANSFORM_TEST_CASES));

/*
TODO: Uncomment when top_k transform condition is fixed
TEST(TopKFilteringTest, FilterNotAppliedTopKGreaterThanInputSize) {
float input[]{0.090031, 0.244728, 0.665241};
float expected_output[]{0.090031, 0.244728, 0.665241}; // no change expected
Expand All @@ -126,6 +129,7 @@ TEST(TopKFilteringTest, FilterNotAppliedTopKGreaterThanInputSize) {
EXPECT_EQ(logits.m_data[i], expected_output[i]);
}
}
*/

struct RepetitionPenaltyTransformTestStruct {
static inline const size_t size = 3;
Expand Down
10 changes: 5 additions & 5 deletions tests/python_tests/test_preemption.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def test_preemption(tmp_path, params):
ref_texts=get_current_plarform_ref_texts({
"linux": [
[
"\n\nOpenVINO is a programming language with a lot of benefits. It has been designed in such a way that it is probably not suitable for"
"\n\nOpenVINO is a live platform that allows users to create and manage a new library for open source applications.\n\nOpenVINO is"
],
[
" You're getting much better results from doing this, than you are by not doing this. I have a BH and I was so far"
Expand Down Expand Up @@ -109,12 +109,12 @@ def test_preemption_with_multinomial(tmp_path, dynamic_split_fuse):
ref_texts=get_current_plarform_ref_texts({
"linux": [
[
" Buzzfeed ESPN CNBC MSNBC CBS\nFox News is on top of the list.\nIf a news station tries to afford real estate"
"\nI've seen this expression used too many times without making sense.\nAs an AI engineer, and as a scientist, we should make everything easier"
],
[
" condition of the leg?\nIt's been quite a while since I've seen it, so I didn't really know if it was good or bad",
' ratio of (-9)/(-12)*(-128)/(-32)?\n-1/2\nEvaluate (1*-9)/((',
' ratio of (-4 + (-5)/(-5))*-3?\n-6\nEvaluate ((-108)/(-32))/('
" position of the Z-shaped groove?\n0.41\nWhat is the current position of the Z-shaped groove?\n0.11\n",
" status of all of this? I can't stop thinking about it.\nIt's been a while since I've seen it. I found it a",
" status of your blog? Do you accept feedback?\nYes, I’m happy to accept feedback at this time (I’m a"
],
[
"\nIt's in the middle of nowhere if you haven’t seen one yet! It might be more convenient there than anywhere else.. maybe take",
Expand Down
14 changes: 7 additions & 7 deletions tests/python_tests/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ class RandomSamplingTestStruct:
prompts=["What is OpenVINO?"],
ref_texts=[
[
"\n\nOpenVINO is a new open source virtual cold storage solution for virtual machines under the Ubuntu Foundation. OpenVINO is a virtualized virtual"
"\n\nOpenVINO is a software development platform developed by OpenVINO, a set of technology companies and startups that enables developers to use the most"
]
],
),
Expand Down Expand Up @@ -174,7 +174,7 @@ class RandomSamplingTestStruct:
prompts=["What is OpenVINO?"],
ref_texts=[
[
'\nOpen Vino (OLIN) was launched on April 12, 2016 and has resulted in around 15% of all virtual meetings being hosted by companies'
"\nOpen Vino's are a new and improved way to find cheap, fast-investment frozen vegetables that have no waste or calories. They're"
]
],
),
Expand All @@ -183,9 +183,9 @@ class RandomSamplingTestStruct:
prompts=["What is location of"],
ref_texts=[
[
" the sensor?\nIt's a sensor on the back of the phone.\nGotcha, very cool.\nGood job man!",
' this website?\n\nTasty Big Fish, New York, NY\n\nFounded in 2018 by award-winning authors, Including the creative minds',
" this?\nIt's actually in this sub."
" the exact same image?\nI've tried multiple times to find it, but I'm still not sure. I am sure it's the exact same",
" your new house?\nAnywhere that has a GPS. It will be up to you.",
" your cat? He is more likely to be on the floor with him.\nTalduck"
]
],
),
Expand Down Expand Up @@ -216,7 +216,7 @@ class RandomSamplingTestStruct:
prompts=["What is OpenVINO?"],
ref_texts=[
[
'\n\nOpenVINO is a new open source virtual application that lets you create and modify all kinds of virtual machines in your environment. OpenVINO'
"\n\nOpenVINO is a software development platform developed by OpenVINO, Inc., which uses a RESTful API for server-side web applications"
]
],
),
Expand All @@ -225,7 +225,7 @@ class RandomSamplingTestStruct:
prompts=["What is OpenVINO?"],
ref_texts=[
[
'\nOpenVINO is a technology for low-power video streaming by building high-efficiency, decoupling and shrinking cards. This strategy is important'
"\n\nOpenVINO is a software development platform developed by OpenVINO, Inc., which offers the Linux-based platform. OpenVINO's"
]
],
),
Expand Down

0 comments on commit 73e9b1a

Please sign in to comment.