Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use get_max_new_tokens() insted of max_new_tokens field when stopping… #1417

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions src/cpp/src/sampler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ void Sampler::GroupBeamSearcher::select_next_tokens(const ov::Tensor& logits, Sa
}

// check whether group has finished
group.is_done(m_parameters);
group.is_done(m_parameters, this->m_sequence_group->get_prompt_len());

// group cannot continue if there are no valid child beams
if (child_beams_per_group[group_id].size() == 0) {
Expand Down Expand Up @@ -560,14 +560,14 @@ std::vector<int64_t> Sampler::_try_finish_generation(SequenceGroup::Ptr & sequen
std::vector<int64_t> dropped_seq_ids;
for (auto& running_sequence : sequence_group->get_running_sequences()) {
const auto generated_len = running_sequence->get_generated_len();
if (sampling_params.max_new_tokens <= generated_len ||
if (sampling_params.get_max_new_tokens(sequence_group->get_prompt_len()) <= generated_len ||
is_stop_token_id_hit(running_sequence->get_generated_ids().back(), sampling_params.stop_token_ids) && !sampling_params.ignore_eos) {
// stop sequence by max_new_tokens or stop token (eos included)
running_sequence->set_status(SequenceStatus::FINISHED);

if (is_stop_token_id_hit(running_sequence->get_generated_ids().back(), sampling_params.stop_token_ids) && !sampling_params.ignore_eos) {
running_sequence->set_finish_reason(GenerationFinishReason::STOP);
} else if (sampling_params.max_new_tokens == generated_len) {
} else if (sampling_params.get_max_new_tokens(sequence_group->get_prompt_len()) == generated_len) {
running_sequence->set_finish_reason(GenerationFinishReason::LENGTH);
}

Expand Down Expand Up @@ -786,8 +786,8 @@ SamplerOutput Sampler::sample(std::vector<SequenceGroup::Ptr> & sequence_groups,
// max counter of needed to be sampled tokens
OPENVINO_ASSERT(running_sequence->get_generated_len() >= token_offset);
size_t generated_and_verified_len = running_sequence->get_generated_len() - token_offset;
OPENVINO_ASSERT(sampling_params.max_new_tokens >= generated_and_verified_len);
size_t max_num_sampled_token = sampling_params.max_new_tokens - generated_and_verified_len;
OPENVINO_ASSERT(sampling_params.get_max_new_tokens(sequence_group->get_prompt_len()) >= generated_and_verified_len);
size_t max_num_sampled_token = sampling_params.get_max_new_tokens(sequence_group->get_prompt_len()) - generated_and_verified_len;
if (max_num_sampled_token == 0) {
stop_sample_tokens(running_sequence, token_offset, max_num_sampled_token, max_removed_tokens_per_request);
break;
Expand Down Expand Up @@ -873,7 +873,7 @@ SamplerOutput Sampler::sample(std::vector<SequenceGroup::Ptr> & sequence_groups,
// check max length stop criteria
std::vector<Sequence::Ptr> running_sequences = sequence_group->get_running_sequences();
if (!sequence_group->has_finished() &&
running_sequences[0]->get_generated_len() == sampling_params.max_new_tokens) {
running_sequences[0]->get_generated_len() == sampling_params.get_max_new_tokens(sequence_group->get_prompt_len())) {
// stop sequence by max_new_tokens
m_beam_search_info.at(request_id).finalize(sampler_output);
}
Expand Down Expand Up @@ -939,7 +939,7 @@ int64_t Sampler::GroupBeamSearcher::Group::finish(Beam beam, const ov::genai::Ge
return preeempted_sequence_id;
}

void Sampler::GroupBeamSearcher::Group::is_done(const ov::genai::GenerationConfig& sampling_params) {
void Sampler::GroupBeamSearcher::Group::is_done(const ov::genai::GenerationConfig& sampling_params, size_t prompt_len) {
assert(sampling_params.num_beams % sampling_params.num_beam_groups == 0 &&
"number of beams should be divisible by number of groups");
size_t group_size = sampling_params.num_beams / sampling_params.num_beam_groups;
Expand All @@ -960,7 +960,7 @@ void Sampler::GroupBeamSearcher::Group::is_done(const ov::genai::GenerationConfi
return;
}
case ov::genai::StopCriteria::NEVER: {
size_t length = sampling_params.length_penalty > 0.0 ? sampling_params.max_new_tokens : cur_len;
size_t length = sampling_params.length_penalty > 0.0 ? sampling_params.get_max_new_tokens(prompt_len) : cur_len;
float highest_attainable_score = best_sum_logprobs / std::pow(float(length), sampling_params.length_penalty);
done = worst_score >= highest_attainable_score;
return;
Expand Down
2 changes: 1 addition & 1 deletion src/cpp/src/sampler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ class Sampler::GroupBeamSearcher {
bool done = false;

int64_t finish(Beam beam, const ov::genai::GenerationConfig& sampling_params);
void is_done(const ov::genai::GenerationConfig& sampling_params);
void is_done(const ov::genai::GenerationConfig& sampling_params, size_t prompt_len);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can include prompt_len information in Group members during object construction and avoid passing it as a parameter in this method.

};

SequenceGroup::Ptr m_sequence_group;
Expand Down
Loading