Skip to content

Commit

Permalink
Use get_max_new_tokens() insted of max_new_tokens field when stopping…
Browse files Browse the repository at this point in the history
… generation
  • Loading branch information
michalkulakowski committed Dec 20, 2024
1 parent 67f2d26 commit 402bba1
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
16 changes: 8 additions & 8 deletions src/cpp/src/sampler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ void Sampler::GroupBeamSearcher::select_next_tokens(const ov::Tensor& logits, Sa
}

// check whether group has finished
group.is_done(m_parameters);
group.is_done(m_parameters, this->m_sequence_group->get_prompt_len());

// group cannot continue if there are no valid child beams
if (child_beams_per_group[group_id].size() == 0) {
Expand Down Expand Up @@ -560,14 +560,14 @@ std::vector<int64_t> Sampler::_try_finish_generation(SequenceGroup::Ptr & sequen
std::vector<int64_t> dropped_seq_ids;
for (auto& running_sequence : sequence_group->get_running_sequences()) {
const auto generated_len = running_sequence->get_generated_len();
if (sampling_params.max_new_tokens <= generated_len ||
if (sampling_params.get_max_new_tokens(sequence_group->get_prompt_len()) <= generated_len ||
is_stop_token_id_hit(running_sequence->get_generated_ids().back(), sampling_params.stop_token_ids) && !sampling_params.ignore_eos) {
// stop sequence by max_new_tokens or stop token (eos included)
running_sequence->set_status(SequenceStatus::FINISHED);

if (is_stop_token_id_hit(running_sequence->get_generated_ids().back(), sampling_params.stop_token_ids) && !sampling_params.ignore_eos) {
running_sequence->set_finish_reason(GenerationFinishReason::STOP);
} else if (sampling_params.max_new_tokens == generated_len) {
} else if (sampling_params.get_max_new_tokens(sequence_group->get_prompt_len()) == generated_len) {
running_sequence->set_finish_reason(GenerationFinishReason::LENGTH);
}

Expand Down Expand Up @@ -786,8 +786,8 @@ SamplerOutput Sampler::sample(std::vector<SequenceGroup::Ptr> & sequence_groups,
// max counter of needed to be sampled tokens
OPENVINO_ASSERT(running_sequence->get_generated_len() >= token_offset);
size_t generated_and_verified_len = running_sequence->get_generated_len() - token_offset;
OPENVINO_ASSERT(sampling_params.max_new_tokens >= generated_and_verified_len);
size_t max_num_sampled_token = sampling_params.max_new_tokens - generated_and_verified_len;
OPENVINO_ASSERT(sampling_params.get_max_new_tokens(sequence_group->get_prompt_len()) >= generated_and_verified_len);
size_t max_num_sampled_token = sampling_params.get_max_new_tokens(sequence_group->get_prompt_len()) - generated_and_verified_len;
if (max_num_sampled_token == 0) {
stop_sample_tokens(running_sequence, token_offset, max_num_sampled_token, max_removed_tokens_per_request);
break;
Expand Down Expand Up @@ -873,7 +873,7 @@ SamplerOutput Sampler::sample(std::vector<SequenceGroup::Ptr> & sequence_groups,
// check max length stop criteria
std::vector<Sequence::Ptr> running_sequences = sequence_group->get_running_sequences();
if (!sequence_group->has_finished() &&
running_sequences[0]->get_generated_len() == sampling_params.max_new_tokens) {
running_sequences[0]->get_generated_len() == sampling_params.get_max_new_tokens(sequence_group->get_prompt_len())) {
// stop sequence by max_new_tokens
m_beam_search_info.at(request_id).finalize(sampler_output);
}
Expand Down Expand Up @@ -939,7 +939,7 @@ int64_t Sampler::GroupBeamSearcher::Group::finish(Beam beam, const ov::genai::Ge
return preeempted_sequence_id;
}

void Sampler::GroupBeamSearcher::Group::is_done(const ov::genai::GenerationConfig& sampling_params) {
void Sampler::GroupBeamSearcher::Group::is_done(const ov::genai::GenerationConfig& sampling_params, size_t prompt_len) {
assert(sampling_params.num_beams % sampling_params.num_beam_groups == 0 &&
"number of beams should be divisible by number of groups");
size_t group_size = sampling_params.num_beams / sampling_params.num_beam_groups;
Expand All @@ -960,7 +960,7 @@ void Sampler::GroupBeamSearcher::Group::is_done(const ov::genai::GenerationConfi
return;
}
case ov::genai::StopCriteria::NEVER: {
size_t length = sampling_params.length_penalty > 0.0 ? sampling_params.max_new_tokens : cur_len;
size_t length = sampling_params.length_penalty > 0.0 ? sampling_params.get_max_new_tokens() : cur_len;
float highest_attainable_score = best_sum_logprobs / std::pow(float(length), sampling_params.length_penalty);
done = worst_score >= highest_attainable_score;
return;
Expand Down
2 changes: 1 addition & 1 deletion src/cpp/src/sampler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ class Sampler::GroupBeamSearcher {
bool done = false;

int64_t finish(Beam beam, const ov::genai::GenerationConfig& sampling_params);
void is_done(const ov::genai::GenerationConfig& sampling_params);
void is_done(const ov::genai::GenerationConfig& sampling_params, size_t prompt_len);
};

SequenceGroup::Ptr m_sequence_group;
Expand Down

0 comments on commit 402bba1

Please sign in to comment.