Skip to content

Commit

Permalink
Reorganize implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
Wovchena committed Dec 18, 2023
1 parent a5f4cd2 commit 9bf75c3
Show file tree
Hide file tree
Showing 7 changed files with 110 additions and 82 deletions.
35 changes: 26 additions & 9 deletions .github/workflows/causal_lm_cpp.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,25 @@ jobs:
- name: Compare
run: |
source ./ov/setupvars.sh
python ./text_generation/causal_lm/cpp/convert_tokenizers.py ./TinyLlama-1.1B-Chat-v0.6/
python ./text_generation/causal_lm/cpp/convert_tokenizers.py ./TinyLlama-1.1B-Chat-v0.6/pytorch/dldt/FP16/
timeout 25s ./build/beam_search_causal_lm ./TinyLlama-1.1B-Chat-v0.6/pytorch/dldt/FP16/openvino_model.xml ./tokenizer.xml ./detokenizer.xml 69 > ./pred.txt
timeout 25s ./build/beam_search_causal_lm ./TinyLlama-1.1B-Chat-v0.6/pytorch/dldt/FP16/ "Why is the Sun yellow?" > ./pred.txt
python -c "
import transformers
with open('pred.txt', 'r') as file:
predictions = file.read()
tokenizer = transformers.LlamaTokenizer.from_pretrained('./TinyLlama-1.1B-Chat-v0.6/')
tokenized = tokenizer('Why is the Sun yellow?', return_tensors='pt')
for beam in transformers.LlamaForCausalLM.from_pretrained('./TinyLlama-1.1B-Chat-v0.6/').generate(**tokenized, num_beam_groups=3, num_beams=15, num_return_sequences=15, diversity_penalty=1.0, max_new_tokens=20, early_stopping=False, length_penalty=1.0, no_repeat_ngram_size=9**9, do_sample=False):
ref = ': ' + tokenizer.decode(beam[tokenized['input_ids'].numel():], skip_special_tokens=True) + '\n'
idx = predictions.find(ref)
if -1 == idx:
raise RuntimeError(f'Missing "{ref=}" from predictions')
predictions = predictions[:idx] + predictions[idx + len(ref):]
"
echo Why is the Sun yellow? passed
timeout 25s ./build/beam_search_causal_lm ./TinyLlama-1.1B-Chat-v0.6/pytorch/dldt/FP16/ 69 > ./pred.txt
python -c "
import transformers
with open('pred.txt', 'r') as file:
Expand All @@ -69,7 +85,7 @@ jobs:
"
echo 69 passed
timeout 25s ./build/beam_search_causal_lm ./TinyLlama-1.1B-Chat-v0.6/pytorch/dldt/FP16/openvino_model.xml ./tokenizer.xml ./detokenizer.xml Hi > ./pred.txt
timeout 25s ./build/beam_search_causal_lm ./TinyLlama-1.1B-Chat-v0.6/pytorch/dldt/FP16/ Hi > ./pred.txt
python -c "
import transformers
with open('pred.txt', 'r') as file:
Expand All @@ -85,7 +101,7 @@ jobs:
"
echo Hi passed
timeout 25s ./build/beam_search_causal_lm ./TinyLlama-1.1B-Chat-v0.6/pytorch/dldt/FP16/openvino_model.xml ./tokenizer.xml ./detokenizer.xml "return 0" > ./pred.txt
timeout 25s ./build/beam_search_causal_lm ./TinyLlama-1.1B-Chat-v0.6/pytorch/dldt/FP16/ "return 0" > ./pred.txt
python -c "
import transformers
with open('pred.txt', 'r') as file:
Expand All @@ -101,7 +117,7 @@ jobs:
"
echo return 0 passed
./build/beam_search_causal_lm ./TinyLlama-1.1B-Chat-v0.6/pytorch/dldt/FP16/openvino_model.xml ./tokenizer.xml ./detokenizer.xml "" > ./pred.txt
./build/beam_search_causal_lm ./TinyLlama-1.1B-Chat-v0.6/pytorch/dldt/FP16/ "" > ./pred.txt
python -c "
import transformers
with open('pred.txt', 'r') as file:
Expand All @@ -117,7 +133,7 @@ jobs:
"
echo '""' passed
./build/beam_search_causal_lm ./TinyLlama-1.1B-Chat-v0.6/pytorch/dldt/FP16/openvino_model.xml ./tokenizer.xml ./detokenizer.xml "你好! 你好嗎?" > ./pred.txt
./build/beam_search_causal_lm ./TinyLlama-1.1B-Chat-v0.6/pytorch/dldt/FP16/ "你好! 你好嗎?" > ./pred.txt
python -c "
import transformers
with open('pred.txt', 'r') as file:
Expand All @@ -133,6 +149,7 @@ jobs:
"
echo 你好! 你好嗎? passed
cpp-beam_search_causal_lm-windows:
if: false # TODO: enable after openvino package with fix is published
runs-on: windows-latest
steps:
- uses: actions/checkout@v4
Expand Down Expand Up @@ -168,13 +185,13 @@ jobs:
shell: cmd
run: |
call w_openvino_toolkit_windows_2023.3.0.dev20231214_x86_64\setupvars.bat
python .\text_generation\causal_lm\cpp\convert_tokenizers.py .\TinyLlama-1.1B-Chat-v0.6\
python .\text_generation\causal_lm\cpp\convert_tokenizers.py .\TinyLlama-1.1B-Chat-v0.6\pytorch\dldt\FP16\
.\build\Release\beam_search_causal_lm.exe .\TinyLlama-1.1B-Chat-v0.6\pytorch\dldt\FP16\openvino_model.xml .\tokenizer.xml .\detokenizer.xml 69 > .\pred.txt
.\build\Release\beam_search_causal_lm.exe .\TinyLlama-1.1B-Chat-v0.6\pytorch\dldt\FP16\ "Why is the Sun yellow?" > .\pred.txt
echo import transformers > ref.py
echo predictions = open('pred.txt', 'r').read() >> ref.py
echo tokenizer = transformers.LlamaTokenizer.from_pretrained(r'.\TinyLlama-1.1B-Chat-v0.6') >> ref.py
echo tokenized = tokenizer('69', return_tensors='pt') >> ref.py
echo tokenized = tokenizer('Why is the Sun yellow?', return_tensors='pt') >> ref.py
echo for beam in transformers.LlamaForCausalLM.from_pretrained(r'.\TinyLlama-1.1B-Chat-v0.6').generate(**tokenized, num_beam_groups=3, num_beams=15, num_return_sequences=15, diversity_penalty=1.0, max_new_tokens=20, early_stopping=False, length_penalty=1.0, no_repeat_ngram_size=9**9, do_sample=False): >> ref.py
echo ref = ': ' + tokenizer.decode(beam[tokenized['input_ids'].numel():], skip_special_tokens=True) + '\n' >> ref.py
echo idx = predictions.find(ref) >> ref.py
Expand Down
22 changes: 0 additions & 22 deletions text_generation/causal_lm/cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,6 @@ find_package(OpenVINO REQUIRED COMPONENTS Runtime)
target_link_libraries(causal_lm PRIVATE openvino::runtime user_ov_extensions)
set_target_properties(causal_lm PROPERTIES CXX_STANDARD 17)
set_target_properties(causal_lm PROPERTIES CXX_STANDARD_REQUIRED ON)
if(MSVC)
target_compile_options(
causal_lm PRIVATE
/Wall # Display all warnings
/wd4710 /wd4711 # Disable the inline warnings
/EHsc # Enable standard C++ stack unwinding, assume functions with extern "C" never throw
)
else()
target_compile_options(causal_lm PRIVATE -Wall) # Display all warnings
endif()

add_executable(beam_search_causal_lm beam_search_causal_lm.cpp)
target_compile_definitions(beam_search_causal_lm PRIVATE USER_OV_EXTENSIONS_PATH=\"$<TARGET_FILE:user_ov_extensions>\")
Expand All @@ -32,15 +22,3 @@ find_package(OpenVINO REQUIRED COMPONENTS Runtime)
target_link_libraries(beam_search_causal_lm PRIVATE openvino::runtime user_ov_extensions)
set_target_properties(beam_search_causal_lm PROPERTIES CXX_STANDARD 17)
set_target_properties(beam_search_causal_lm PROPERTIES CXX_STANDARD_REQUIRED ON)
if(MSVC)
target_compile_options(
beam_search_causal_lm PRIVATE
/Wall # Display all warnings
/wd4626 /wd5027 # Disable the implicit definition of assignment operator as deleted warings
/wd4710 /wd4711 # Disable the inline warnings
/wd4820 # Disable the padding addition warning after data members
/EHsc # Enable standard C++ stack unwinding, assume functions with extern "C" never throw
)
else()
target_compile_options(beam_search_causal_lm PRIVATE -Wall) # Display all warnings
endif()
10 changes: 5 additions & 5 deletions text_generation/causal_lm/cpp/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,17 +55,17 @@ The `--upgrade-strategy eager` option is needed to ensure `optimum-intel` is upg
source <INSTALL_DIR>/setupvars.sh
python -m pip install --upgrade-strategy eager "optimum[openvino]>=1.14" -r ../../../llm_bench/python/requirements.txt ../../../thirdparty/openvino_contrib/modules/custom_operations/[transformers] --extra-index-url https://download.pytorch.org/whl/cpu
python ../../../llm_bench/python/convert.py --model_id meta-llama/Llama-2-7b-hf --output_dir ./Llama-2-7b-hf/ --precision FP16 --stateful
python ./convert_tokenizers.py --streaming-detokenizer ./Llama-2-7b-hf/
python ./convert_tokenizers.py --streaming-detokenizer ./Llama-2-7b-hf/pytorch/dldt/FP16/
```

## Run

Usage:
1. `causal_lm <openvino_model.xml> <tokenizer.xml> <detokenizer.xml> "<prompt>"`
2. `beam_search_causal_lm <openvino_model.xml> <tokenizer.xml> <detokenizer.xml> "<prompt>"`
1. `causal_lm <MODEL_DIR> "<PROMPT>"`
2. `beam_search_causal_lm <MODEL_DIR> "<PROMPT>"`

Examples:
1. `./build/causal_lm ./Llama-2-7b-hf/pytorch/dldt/FP16/openvino_model.xml ./tokenizer.xml ./detokenizer.xml "Why is the Sun yellow?"`
2. `./build/beam_search_causal_lm ./Llama-2-7b-hf/pytorch/dldt/FP16/openvino_model.xml ./tokenizer.xml ./detokenizer.xml "Why is the Sun yellow?"`
1. `./build/causal_lm ./Llama-2-7b-hf/pytorch/dldt/FP16/ "Why is the Sun yellow?"`
2. `./build/beam_search_causal_lm ./Llama-2-7b-hf/pytorch/dldt/FP16/ "Why is the Sun yellow?"`

To enable Unicode characters for Windows cmd open `Region` settings from `Control panel`. `Administrative`->`Change system locale`->`Beta: Use Unicode UTF-8 for worldwide language support`->`OK`. Reboot.
86 changes: 50 additions & 36 deletions text_generation/causal_lm/cpp/beam_search_causal_lm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,17 @@
#include <openvino_extensions/strings.hpp>

namespace {
std::pair<ov::Tensor, ov::Tensor> tokenize(ov::InferRequest&& tokenizer, std::string_view prompt) {
std::tuple<ov::InferRequest, ov::InferRequest, ov::InferRequest> compile_models(const std::string model_dir) {
ov::Core core;
core.add_extension(USER_OV_EXTENSIONS_PATH); // USER_OV_EXTENSIONS_PATH is defined in CMakeLists.txt
return {
core.compile_model(model_dir + "/openvino_model.xml", "CPU").create_infer_request(),
core.compile_model(model_dir + "/openvino_tokenizer.xml", "CPU").create_infer_request(),
core.compile_model(model_dir + "/openvino_detokenizer.xml", "CPU").create_infer_request()
};
}

std::pair<ov::Tensor, ov::Tensor> tokenize(ov::InferRequest& tokenizer, std::string_view prompt) {
ov::Tensor destination = tokenizer.get_input_tensor();
openvino_extensions::pack_strings(std::array{prompt}, destination);
tokenizer.infer();
Expand All @@ -23,55 +33,59 @@ std::string detokenize(ov::InferRequest& detokenizer, const std::vector<int64_t>
detokenizer.infer();
return openvino_extensions::unpack_strings(detokenizer.get_output_tensor()).front();
}

void initialize_inputs(ov::InferRequest& lm, const ov::Tensor& input_ids, const ov::Tensor& attention_mask) {
lm.set_tensor("input_ids", input_ids);
lm.set_tensor("attention_mask", attention_mask);
ov::Tensor position_ids = lm.get_tensor("position_ids");
position_ids.set_shape(input_ids.get_shape());
std::iota(position_ids.data<int64_t>(), position_ids.data<int64_t>() + position_ids.get_size(), 0);
lm.get_tensor("beam_idx").set_shape({1});
lm.get_tensor("beam_idx").data<int32_t>()[0] = 0;
}

void set_pointers(
ov::InferRequest& lm, std::vector<int64_t>& next_tokens, std::vector<int32_t>& next_beams) {
size_t batch_size = next_tokens.size();
lm.set_tensor("input_ids", ov::Tensor{ov::element::i64, {batch_size, 1}, next_tokens.data()});
lm.set_tensor("beam_idx", ov::Tensor{ov::element::i32, {batch_size}, next_beams.data()});
}

void set_auxiliary_inputs(ov::InferRequest& lm) {
size_t batch_size = lm.get_tensor("input_ids").get_shape().front();
ov::Tensor attention_mask = lm.get_tensor("attention_mask");
ov::Shape mask_shape{batch_size, attention_mask.get_shape().at(1) + 1};
attention_mask.set_shape(mask_shape);
std::fill_n(attention_mask.data<int64_t>(), ov::shape_size(mask_shape), 1);
lm.get_tensor("position_ids").set_shape({batch_size, 1});
std::fill_n(lm.get_tensor("position_ids").data<int64_t>(), batch_size, mask_shape.at(1) - 1);
}
}

int main(int argc, char* argv[]) try {
if (argc != 5) {
throw std::runtime_error(std::string{"Usage: "} + argv[0]
+ " <openvino_model.xml> <tokenizer.xml> <detokenizer.xml> '<prompt>'");
if (argc != 3) {
throw std::runtime_error(std::string{"Usage: "} + argv[0] + " <MODEL_DIR> '<PROMPT>'");
}
ov::Core core;
core.add_extension(USER_OV_EXTENSIONS_PATH); // USER_OV_EXTENSIONS_PATH is defined in root CMakeLists.txt
auto [input_ids, mask] = tokenize(core.compile_model(argv[2], "CPU").create_infer_request(), argv[4]);
ov::InferRequest detokenizer = core.compile_model(argv[3], "CPU").create_infer_request();
ov::InferRequest ireq = core.compile_model(argv[1], "CPU").create_infer_request();
ireq.set_tensor("input_ids", input_ids);
ireq.set_tensor("attention_mask", mask);
ov::Tensor position_ids = ireq.get_tensor("position_ids");
position_ids.set_shape(input_ids.get_shape());
std::iota(position_ids.data<int64_t>(), position_ids.data<int64_t>() + position_ids.get_size(), 0);
ireq.get_tensor("beam_idx").set_shape({1});
ireq.get_tensor("beam_idx").data<int32_t>()[0] = 0;
Parameters parameters;
auto [lm, tokenizer, detokenizer] = compile_models(argv[1]);
auto [input_ids, attention_mask] = tokenize(tokenizer, argv[4]);
initialize_inputs(lm, input_ids, attention_mask);
const int64_t* prompt_data = input_ids.data<const int64_t>();
parameters.prompt = std::vector<int64_t>{prompt_data, prompt_data + input_ids.get_size()};
Parameters parameters{std::vector<int64_t>{prompt_data, prompt_data + input_ids.get_size()}};
GroupBeamSearcher group_beam_searcher{parameters};
std::vector<int64_t> next_tokens;
std::vector<int32_t> next_beams;
for (size_t length_count = 0; length_count < parameters.max_new_tokens; ++length_count) {
ireq.infer();
std::tie(next_tokens, next_beams) = group_beam_searcher.process(ireq.get_tensor("logits"));
lm.infer();
std::tie(next_tokens, next_beams) = group_beam_searcher.process(lm.get_tensor("logits"));
if (next_tokens.empty()) {
break;
}
size_t batch_size = next_tokens.size();
ireq.set_tensor("input_ids", ov::Tensor{ov::element::i64, {batch_size, 1}, next_tokens.data()});
ov::Tensor attention_mask = ireq.get_tensor("attention_mask");
ov::Shape mask_shape{batch_size, attention_mask.get_shape().at(1) + 1};
attention_mask.set_shape(mask_shape);
std::fill_n(attention_mask.data<int64_t>(), shape_size(mask_shape), 1);
position_ids.set_shape({batch_size, 1});
std::fill_n(position_ids.data<int64_t>(), batch_size, mask_shape.at(1) - 1);
ireq.set_tensor("beam_idx", ov::Tensor{ov::element::i32, {batch_size}, next_beams.data()});
set_pointers(lm, next_tokens, next_beams);
set_auxiliary_inputs(lm);
}
for (Group& group : group_beam_searcher.groups) {
if (!group.done) {
for (Beam& beam : group.ongoing) {
group.finish(std::move(beam), parameters);
}
}
for (const std::vector<Beam>& group : finalize(std::move(group_beam_searcher))) {
std::cout << "Group:\n";
for (const Beam& beam : group.min_heap) {
for (const Beam& beam : group) {
std::cout << beam.score << ": " << detokenize(detokenizer, beam.tokens) << '\n';
}
}
Expand Down
17 changes: 10 additions & 7 deletions text_generation/causal_lm/cpp/causal_lm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#include <openvino_extensions/strings.hpp>

namespace {
std::pair<ov::Tensor, ov::Tensor> tokenize(ov::InferRequest&& tokenizer, std::string_view prompt) {
std::pair<ov::Tensor, ov::Tensor> tokenize(ov::InferRequest& tokenizer, std::string_view prompt) {
constexpr size_t BATCH_SIZE = 1;
ov::Tensor destination = tokenizer.get_input_tensor();
openvino_extensions::pack_strings(std::array<std::string_view, BATCH_SIZE>{prompt}, destination);
Expand All @@ -24,14 +24,17 @@ void print_token(ov::InferRequest& detokenizer, int64_t out_token) {
}

int main(int argc, char* argv[]) try {
if (argc != 5) {
throw std::runtime_error(std::string{"Usage: "} + argv[0] + " <openvino_model.xml> <tokenizer.xml> <detokenizer.xml> '<prompt>'");
if (argc != 3) {
throw std::runtime_error(std::string{"Usage: "} + argv[0] + " <MODEL_DIR> '<PROMPT>'");
}
ov::Core core;
core.add_extension(USER_OV_EXTENSIONS_PATH); // USER_OV_EXTENSIONS_PATH is defined in CMakeLists.txt
auto [input_ids, attention_mask] = tokenize(core.compile_model(argv[2], "CPU").create_infer_request(), argv[4]);
ov::InferRequest detokenizer = core.compile_model(argv[3], "CPU").create_infer_request();
std::shared_ptr<ov::Model> model = core.read_model(argv[1]);
ov::InferRequest tokenizer = core.compile_model(
std::string{argv[2]} + "/openvino_tokenizer.xml", "CPU").create_infer_request();
auto [input_ids, attention_mask] = tokenize(tokenizer, argv[4]);
ov::InferRequest detokenizer = core.compile_model(
std::string{argv[3]} + "/openvino_detokenizer.xml", "CPU").create_infer_request();
std::shared_ptr<ov::Model> model = core.read_model(std::string{argv[1]} + "/openvino_model.xml");
constexpr size_t BATCH_SIZE = 1;
std::map<size_t, ov::PartialShape> shapes = {
{0, ov::PartialShape{
Expand Down Expand Up @@ -73,7 +76,7 @@ int main(int argc, char* argv[]) try {
ireq.get_tensor("input_ids").data<int64_t>()[0] = out_token;
ireq.get_tensor("attention_mask").set_shape({BATCH_SIZE, ireq.get_tensor("attention_mask").get_shape()[1] + 1});
std::fill_n(ireq.get_tensor("attention_mask").data<int64_t>(), ireq.get_tensor("attention_mask").get_size(), 1);
ireq.get_tensor("position_ids").data<int64_t>()[0] = ireq.get_tensor("attention_mask").get_size() - 2;
ireq.get_tensor("position_ids").data<int64_t>()[0] = int64_t(ireq.get_tensor("attention_mask").get_size() - 2);
for (size_t idx = 3; idx < inputs.size(); ++idx) {
ireq.set_input_tensor(idx, ireq.get_output_tensor(idx - 2));
}
Expand Down
7 changes: 4 additions & 3 deletions text_generation/causal_lm/cpp/convert_tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# SPDX-License-Identifier: Apache-2.0

import argparse
import pathlib

import openvino
import openvino_tokenizers
Expand All @@ -12,13 +13,13 @@
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--streaming-detokenizer', action='store_true')
parser.add_argument('pretrained_model_name_or_path')
parser.add_argument('pretrained_model_name_or_path', type=pathlib.Path)
args = parser.parse_args()
tokenizer, detokenizer = openvino_tokenizers.convert_tokenizer(
transformers.AutoTokenizer.from_pretrained(args.pretrained_model_name_or_path),
with_detokenizer=True, streaming_detokenizer=args.streaming_detokenizer)
openvino.save_model(tokenizer, "tokenizer.xml")
openvino.save_model(detokenizer, "detokenizer.xml")
openvino.save_model(tokenizer, args.pretrained_model_name_or_path / "openvino_tokenizer.xml")
openvino.save_model(detokenizer, args.pretrained_model_name_or_path / "openvino_detokenizer.xml")


if __name__ == '__main__':
Expand Down
15 changes: 15 additions & 0 deletions text_generation/causal_lm/cpp/group_beam_searcher.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -240,3 +240,18 @@ struct GroupBeamSearcher {
return {next_tokens, next_beams};
}
};

// Consume group_beam_searcher because beams are consumed
std::vector<std::vector<Beam>> finalize(GroupBeamSearcher&& group_beam_searcher) {
std::vector<std::vector<Beam>> finalized;
finalized.reserve(group_beam_searcher.groups.size());
for (Group& group : group_beam_searcher.groups) {
if (!group.done) {
for (Beam& beam : group.ongoing) {
group.finish(std::move(beam), group_beam_searcher.parameters);
}
}
finalized.push_back(std::move(group.min_heap));
}
return finalized;
}

0 comments on commit 9bf75c3

Please sign in to comment.