Skip to content

Commit

Permalink
Port from 24.6 release to master (#1356)
Browse files Browse the repository at this point in the history
- ~#1302 (didn't
port this PR because of the issue CVS-159227)
- #1262
- #1336
- #1331

---------

Co-authored-by: Andrei Kochin <[email protected]>
Co-authored-by: Vladimir Zlobin <[email protected]>
Co-authored-by: Ilya Lavrenov <[email protected]>
  • Loading branch information
4 people authored Dec 11, 2024
1 parent 15176a8 commit 8a74d24
Show file tree
Hide file tree
Showing 28 changed files with 691 additions and 170 deletions.
28 changes: 14 additions & 14 deletions .github/workflows/causal_lm_cpp.yml
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,13 @@ jobs:
PYTHONPATH: "./build"
- run: >
. ./ov/setupvars.sh
&& timeout 25s ./build/samples/cpp/greedy_causal_lm/greedy_causal_lm ./open_llama_3b_v2/ "return 0"
| diff <(timeout 25s samples/python/greedy_causal_lm/greedy_causal_lm.py ./open_llama_3b_v2/ "return 0") -
&& timeout 25s ./build/samples/cpp/text_generation/greedy_causal_lm ./open_llama_3b_v2/ "return 0"
| diff <(timeout 25s samples/python/text_generation/greedy_causal_lm.py ./open_llama_3b_v2/ "return 0") -
env:
PYTHONPATH: "./build"
- run: >
. ./ov/setupvars.sh
&& samples/python/greedy_causal_lm/lora.py ./TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T/ adapter_model.safetensors "How to create a table with two columns, one of them has type float, another one has type int?"
&& samples/python/text_generation/lora.py ./TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T/ adapter_model.safetensors "How to create a table with two columns, one of them has type float, another one has type int?"
env:
PYTHONPATH: "./build"
Expand Down Expand Up @@ -249,7 +249,7 @@ jobs:
- run: >
set PATH=.\build\openvino_genai\;%PATH%
&& call .\ov\setupvars.bat
&& .\build\samples\cpp\greedy_causal_lm\Release\greedy_causal_lm.exe .\TinyLlama-1.1B-Chat-v1.0\ 69 > .\cpp.txt
&& .\build\samples\cpp\text_generation\Release\greedy_causal_lm.exe .\TinyLlama-1.1B-Chat-v1.0\ 69 > .\cpp.txt
- run: |
echo import transformers > ref.py
echo predictions = open('cpp.txt', 'r').read() >> ref.py
Expand All @@ -266,13 +266,13 @@ jobs:
set PATH=.\build\openvino_genai\;%PATH%
&& set "PYTHONPATH=./build/"
&& call .\ov\setupvars.bat
&& python samples\python\greedy_causal_lm\greedy_causal_lm.py .\TinyLlama-1.1B-Chat-v1.0\ 69 > .\py.txt
&& python samples\python\text_generation\greedy_causal_lm.py .\TinyLlama-1.1B-Chat-v1.0\ 69 > .\py.txt
- run: fc .\cpp.txt .\py.txt
- run: >
set PATH=.\build\openvino_genai\;%PATH%
&& set "PYTHONPATH=./build/"
&& call .\ov\setupvars.bat
&& python samples\python\greedy_causal_lm\lora.py .\TinyLlama\TinyLlama-1.1B-intermediate-step-1431k-3T\ adapter_model.safetensors "How to create a table with two columns, one of them has type float, another one has type int?"
&& python samples\python\text_generation\lora.py .\TinyLlama\TinyLlama-1.1B-intermediate-step-1431k-3T\ adapter_model.safetensors "How to create a table with two columns, one of them has type float, another one has type int?"
cpp-greedy_causal_lm-Qwen-7B-Chat:
runs-on: ubuntu-20.04-16-cores
Expand Down Expand Up @@ -304,7 +304,7 @@ jobs:
optimum-cli export openvino --trust-remote-code --weight-format fp16 --model Qwen/Qwen-7B-Chat Qwen-7B-Chat
- run: >
. ./ov/setupvars.sh
&& timeout 2m ./build/samples/cpp/greedy_causal_lm/greedy_causal_lm ./Qwen-7B-Chat/ 69 | diff <(timeout 2m samples/python/greedy_causal_lm/greedy_causal_lm.py ./Qwen-7B-Chat/ 69) -
&& timeout 2m ./build/samples/cpp/text_generation/greedy_causal_lm ./Qwen-7B-Chat/ 69 | diff <(timeout 2m samples/python/text_generation/greedy_causal_lm.py ./Qwen-7B-Chat/ 69) -
env:
PYTHONPATH: "./build"
Expand Down Expand Up @@ -446,7 +446,7 @@ jobs:
run: |
source ./ov/setupvars.sh
./build/samples/cpp/speculative_decoding_lm/speculative_decoding_lm ./dolly-v2-7b/ ./dolly-v2-3b/ "Alan Turing was a" > predictions_speculative.txt
./build/samples/cpp/greedy_causal_lm/greedy_causal_lm ./dolly-v2-7b/ "Alan Turing was a" > predictions_greedy.txt
./build/samples/cpp/text_generation/greedy_causal_lm ./dolly-v2-7b/ "Alan Turing was a" > predictions_greedy.txt
python ./samples/python/speculative_decoding_lm/speculative_decoding_lm.py ./dolly-v2-7b/ ./dolly-v2-3b/ "Alan Turing was a" > predictions_py.txt
python -c "
with open('predictions_greedy.txt', 'r') as f:
Expand Down Expand Up @@ -504,7 +504,7 @@ jobs:
A:' > ./prompt.txt
./build/samples/cpp/prompt_lookup_decoding_lm/prompt_lookup_decoding_lm ./TinyLlama-1.1B-Chat-v1.0/ "$(<prompt.txt)" > predictions_prompt_lookup.txt
./build/samples/cpp/greedy_causal_lm/greedy_causal_lm ./TinyLlama-1.1B-Chat-v1.0/ "$(<prompt.txt)" > predictions_greedy.txt
./build/samples/cpp/text_generation/greedy_causal_lm ./TinyLlama-1.1B-Chat-v1.0/ "$(<prompt.txt)" > predictions_greedy.txt
python -c "
with open('predictions_greedy.txt', 'r') as f:
predicted_greedy = f.readline()
Expand All @@ -525,7 +525,7 @@ jobs:
A:' > ./prompt.txt
./build/samples/cpp/prompt_lookup_decoding_lm/prompt_lookup_decoding_lm ./Qwen-7B-Chat/ "$(<prompt.txt)" > predictions_prompt_lookup.txt
./build/samples/cpp/greedy_causal_lm/greedy_causal_lm ./Qwen-7B-Chat/ "$(<prompt.txt)" > predictions_greedy.txt
./build/samples/cpp/text_generation/greedy_causal_lm ./Qwen-7B-Chat/ "$(<prompt.txt)" > predictions_greedy.txt
python -c "
with open('predictions_greedy.txt', 'r') as f:
predicted_greedy = f.readline()
Expand Down Expand Up @@ -566,7 +566,7 @@ jobs:
- name: Run Generation
run: |
source ./ov/setupvars.sh
timeout 50s ./build/samples/cpp/greedy_causal_lm/greedy_causal_lm ./phi-1_5/ "Alan Turing was a" > ./pred_greedy.txt
timeout 50s ./build/samples/cpp/text_generation/greedy_causal_lm ./phi-1_5/ "Alan Turing was a" > ./pred_greedy.txt
- name: Compare
run: |
python -c "
Expand All @@ -585,7 +585,7 @@ jobs:
echo Phi-1_5 passed
- run: >
. ./ov/setupvars.sh
&& timeout 50s samples/python/greedy_causal_lm/greedy_causal_lm.py ./phi-1_5/ "Alan Turing was a"
&& timeout 50s samples/python/text_generation/greedy_causal_lm.py ./phi-1_5/ "Alan Turing was a"
| diff ./pred_greedy.txt -
env:
PYTHONPATH: "./build"
Expand Down Expand Up @@ -621,7 +621,7 @@ jobs:
- name: Run Generation
run: |
source ./ov/setupvars.sh
timeout 50s ./build/samples/cpp/greedy_causal_lm/greedy_causal_lm ./redpajama-3b-chat/ "Alan Turing was a" > ./pred_greedy.txt
timeout 50s ./build/samples/cpp/text_generation/greedy_causal_lm ./redpajama-3b-chat/ "Alan Turing was a" > ./pred_greedy.txt
- name: Compare
run: |
python -c "
Expand All @@ -640,7 +640,7 @@ jobs:
echo "Alan Turing was a" passed
- run: >
. ./ov/setupvars.sh
&& timeout 50s samples/python/greedy_causal_lm/greedy_causal_lm.py ./redpajama-3b-chat/ "Alan Turing was a"
&& timeout 50s samples/python/text_generation/greedy_causal_lm.py ./redpajama-3b-chat/ "Alan Turing was a"
| diff ./pred_greedy.txt -
env:
PYTHONPATH: "./build"
Expand Down
6 changes: 3 additions & 3 deletions samples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
add_subdirectory(cpp/beam_search_causal_lm)
add_subdirectory(cpp/benchmark_genai)
add_subdirectory(cpp/chat_sample)
add_subdirectory(cpp/greedy_causal_lm)
add_subdirectory(cpp/text_generation)
add_subdirectory(cpp/lora_greedy_causal_lm)
add_subdirectory(cpp/multinomial_causal_lm)
add_subdirectory(cpp/prompt_lookup_decoding_lm)
Expand All @@ -25,7 +25,7 @@ install(DIRECTORY
cpp/beam_search_causal_lm
cpp/benchmark_genai
cpp/chat_sample
cpp/greedy_causal_lm
cpp/text_generation
cpp/image_generation
cpp/lora_greedy_causal_lm
cpp/multinomial_causal_lm
Expand All @@ -39,7 +39,7 @@ install(DIRECTORY
python/beam_search_causal_lm
python/benchmark_genai
python/chat_sample
python/greedy_causal_lm
python/text_generation
python/image_generation
python/multinomial_causal_lm
python/speculative_decoding_lm
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,16 @@ install(TARGETS greedy_causal_lm
RUNTIME DESTINATION samples_bin/
COMPONENT samples_bin
EXCLUDE_FROM_ALL)

add_executable(encrypted_model_causal_lm encrypted_model_causal_lm.cpp)
target_link_libraries(encrypted_model_causal_lm PRIVATE openvino::genai)
set_target_properties(encrypted_model_causal_lm PROPERTIES
COMPILE_PDB_NAME encrypted_model_causal_lm
# Ensure out of box LC_RPATH on macOS with SIP
INSTALL_RPATH_USE_LINK_PATH ON)
target_compile_features(encrypted_model_causal_lm PRIVATE cxx_std_11)

install(TARGETS encrypted_model_causal_lm
RUNTIME DESTINATION samples_bin/
COMPONENT samples_bin
EXCLUDE_FROM_ALL)
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,18 @@ Discrete GPUs (dGPUs) usually provide better performance compared to CPUs. It is

See https://github.com/openvinotoolkit/openvino.genai/blob/master/src/README.md#supported-models for the list of supported models.

## Using encrypted models

LLMPipeline and Tokenizer objects can be initialized directly from the memory buffer, e.g. when user stores only encrypted files and decrypts them on-the-fly.
The following code snippet demonstrates how to load the model from the memory buffer:

```cpp
auto [model_str, weights_tensor] = decrypt_model(models_path + "/openvino_model.xml", models_path + "/openvino_model.bin");
ov::genai::Tokenizer tokenizer(models_path);
ov::genai::LLMPipeline pipe(model_str, weights_tensor, tokenizer, device);
```
For the sake of brevity the code above does not include Tokenizer decryption. For more details look to encrypted_model_causal_lm sample.
### Troubleshooting
#### Unicode characters encoding error on Windows
Expand Down
59 changes: 59 additions & 0 deletions samples/cpp/text_generation/encrypted_model_causal_lm.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
// Copyright (C) 2023-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

#include "openvino/genai/llm_pipeline.hpp"
#include <fstream>

std::pair<std::string, ov::Tensor> decrypt_model(const std::string& model_path, const std::string& weights_path) {
std::ifstream model_file(model_path);
std::ifstream weights_file(weights_path, std::ios::binary);
if (!model_file.is_open() || !weights_file.is_open()) {
throw std::runtime_error("Cannot open model or weights file");
}

// User can add file decryption of model_file and weights_file in memory here.

std::string model_str((std::istreambuf_iterator<char>(model_file)), std::istreambuf_iterator<char>());
std::vector<char> weights_buffer((std::istreambuf_iterator<char>(weights_file)), std::istreambuf_iterator<char>());
auto weights_tensor = ov::Tensor(ov::element::u8, {weights_buffer.size()}, weights_buffer.data());
return {model_str, weights_tensor};
}

ov::genai::Tokenizer decrypt_tokenizer(const std::string& models_path) {
std::string tok_model_path = models_path + "/openvino_tokenizer.xml";
std::string tok_weights_path = models_path + "/openvino_tokenizer.bin";
auto [tok_model_str, tok_weights_tensor] = decrypt_model(tok_model_path, tok_weights_path);

std::string detok_model_path = models_path + "/openvino_detokenizer.xml";
std::string detok_weights_path = models_path + "/openvino_detokenizer.bin";
auto [detok_model_str, detok_weights_tensor] = decrypt_model(tok_model_path, tok_weights_path);

return ov::genai::Tokenizer(tok_model_str, tok_weights_tensor, detok_model_str, detok_weights_tensor);
}

int main(int argc, char* argv[]) try {
if (3 > argc)
throw std::runtime_error(std::string{"Usage: "} + argv[0] + " <MODEL_DIR> \"<PROMPT>\"");

std::string device = "CPU"; // GPU, NPU can be used as well
std::string models_path = argv[1];
std::string prompt = argv[2];

auto [model_str, model_weights] = decrypt_model(models_path + "/openvino_model.xml", models_path + "/openvino_model.bin");
ov::genai::Tokenizer tokenizer = decrypt_tokenizer(models_path);

ov::genai::LLMPipeline pipe(model_str, model_weights, tokenizer, device);

std::string result = pipe.generate(prompt, ov::genai::max_new_tokens(100));
std::cout << result << std::endl;
} catch (const std::exception& error) {
try {
std::cerr << error.what() << '\n';
} catch (const std::ios_base::failure&) {}
return EXIT_FAILURE;
} catch (...) {
try {
std::cerr << "Non-exception object thrown\n";
} catch (const std::ios_base::failure&) {}
return EXIT_FAILURE;
}
2 changes: 1 addition & 1 deletion samples/cpp/visual_language_chat/visual_language_chat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ int main(int argc, char* argv[]) try {

std::string device = "CPU"; // GPU can be used as well
ov::AnyMap enable_compile_cache;
if ("GPU" == device) {
if (device == "GPU") {
// Cache compiled models on disk for GPU to save time on the
// next run. It's not beneficial for CPU.
enable_compile_cache.insert({ov::cache_dir("vlm_cache")});
Expand Down
File renamed without changes.
File renamed without changes.
26 changes: 26 additions & 0 deletions src/cpp/include/openvino/genai/continuous_batching_pipeline.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,32 @@ class OPENVINO_GENAI_EXPORTS ContinuousBatchingPipeline {
const ov::AnyMap& properties = {}
);

/**
* @brief Constructs a ContinuousBatchingPipeline from already existing model and tokenizer.
*
* This constructor allows for the creation of a ContinuousBatchingPipeline using an existing model
* represented as a string and a weights tensor, along with a manually initialized tokenizer.
* This is useful when the model and tokenizer are already loaded or created in memory and do not
* need to be loaded from files.
*
* @param model_str A string representation of the model.
* @param weights_tensor A tensor containing the weights of the model.
* @param tokenizer A manually initialized ov::genai::Tokenizer.
* @param scheduler_config Configuration for the scheduler.
* @param device The device to run the pipeline on (e.g., CPU, GPU).
* @param properties Optional properties for the pipeline.
* @param generation_config Optional generation configuration for the pipeline.
*/
ContinuousBatchingPipeline(
const std::string& model_str,
const ov::Tensor& weights_tensor,
const ov::genai::Tokenizer& tokenizer,
const SchedulerConfig& scheduler_config,
const std::string& device,
const ov::AnyMap& properties = {},
const ov::genai::GenerationConfig& generation_config = {}
);

ov::genai::Tokenizer get_tokenizer();

ov::genai::GenerationConfig get_config() const;
Expand Down
17 changes: 17 additions & 0 deletions src/cpp/include/openvino/genai/llm_pipeline.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,15 @@ class OPENVINO_GENAI_EXPORTS LLMPipeline {
const ov::AnyMap& properties = {}
);

LLMPipeline(
const std::string& model_str,
const ov::Tensor& weights_tensor,
const ov::genai::Tokenizer& tokenizer,
const std::string& device,
const ov::AnyMap& properties = {},
const ov::genai::GenerationConfig& generation_config = {}
);

OPENVINO_DEPRECATED("Please, specify device explicitly when create LLMPipeline. This overload will be removed in 2025.0.0 release")
explicit LLMPipeline(const std::filesystem::path& path) :
LLMPipeline(path, "CPU") { }
Expand Down Expand Up @@ -274,6 +283,14 @@ class OPENVINO_GENAI_EXPORTS LLMPipeline {
OPENVINO_GENAI_EXPORTS std::pair<std::string, Any> streamer(StreamerVariant func);
OPENVINO_GENAI_EXPORTS std::pair<std::string, Any> generation_config(const GenerationConfig& config);

OPENVINO_GENAI_EXPORTS std::pair<std::string, Any> draft_model(
std::string& model_str,
ov::Tensor& weights_tensor,
const ov::genai::Tokenizer& tokenizer,
const std::string& device = {},
const ov::AnyMap& properties = {},
const ov::genai::GenerationConfig& generation_config = {});

OPENVINO_GENAI_EXPORTS std::pair<std::string, Any> draft_model(
const std::filesystem::path& models_path,
const std::string& device = {},
Expand Down
70 changes: 65 additions & 5 deletions src/cpp/include/openvino/genai/tokenizer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,72 @@ struct TokenizedInputs {
class OPENVINO_GENAI_EXPORTS Tokenizer {
public:
/**
* @brief ov::genai::Tokenizer constructor.
* @param tokenizer_path openvino_tokenizer.xml and openvino_detokenizer.xml should be located in the tokenizer_path
* @param properties Properties passed to ov::Core::compile_model
*/
Tokenizer(const std::filesystem::path& tokenizer_path, const ov::AnyMap& properties = {});
* @brief ov::genai::Tokenizer constructor.
* @param tokenizer_path openvino_tokenizer.xml and openvino_detokenizer.xml should be located in the tokenizer_path
* @param properties Properties passed to ov::Core::compile_model
*/
explicit Tokenizer(const std::filesystem::path& tokenizer_path, const ov::AnyMap& properties = {});

/**
* @brief ov::genai::Tokenizer constructor to initialize directly from model and weights
*
* This constructor is used when tokenizer and detokenizer are separate models already loaded into memory.
* When this constructor is used bos, eos, pad token ids are expected to be in IR.
* If an IR is older (< 2024.3) then this tokens are default initialized to be ignored.
* @param tokenizer_model_str tokenizer model string
* @param tokenizer_weights_tensor ov::Tensor with tokenizer weights
* @param detokenizer_model_str detokenizer model string
* @param detokenizer_weights_tensor ov::Tensor with detokenizer weights
* @param properties Properties passed to ov::Core::compile_model
*/
Tokenizer(
const std::string& tokenizer_model_str,
ov::Tensor& tokenizer_weights_tensor,
std::string& detokenizer_model_str,
ov::Tensor& detokenizer_weights_tensor,
const ov::AnyMap& properties = {}
);

/**
* @brief ov::genai::Tokenizer constructor to initialize directly from model and weights.
*
* This constructor is used when tokenizer (or detokenizer) already loaded into memory. Whether it's
* tokenizer or detokenizer is defined from model input signature. When this constructor is used bos, eos, pad token ids
* are expected to be in IR. If an IR is older (< 2024.3) then this tokens are default initialized to be ignored.
* @param model_str model string
* @param weights_tensor ov::Tensor with model weights
* @param properties Properties passed to ov::Core::compile_model
*/
Tokenizer(const std::string& model_str, ov::Tensor& weights_tensor, const ov::AnyMap& properties = {});

/**
* @brief ov::genai::Tokenizer constructor with variable number of properties
* @param tokenizer_model_str tokenizer model string
* @param tokenizer_weights_tensor ov::Tensor with tokenizer weights
* @param detokenizer_model_str detokenizer model string
* @param detokenizer_weights_tensor ov::Tensor with detokenizer weights
* @param properties optional properties
*/
template <typename... Properties, typename std::enable_if<ov::util::StringAny<Properties...>::value, bool>::type = true>
Tokenizer(
const std::string& tokenizer_model_str,
ov::Tensor& tokenizer_weights_tensor,
std::string& detokenizer_model_str,
ov::Tensor& detokenizer_weights_tensor,
Properties&&... properties
) : Tokenizer(tokenizer_model_str, tokenizer_weights_tensor, detokenizer_model_str, detokenizer_weights_tensor, ov::AnyMap{std::forward<Properties>(properties)...}) { }

/**
* @brief ov::genai::Tokenizer constructor with variable number of properties
* @param model_str model string
* @param weights_tensor ov::Tensor with model weights
* @param properties optional properties
*/
template <typename... Properties, typename std::enable_if<ov::util::StringAny<Properties...>::value, bool>::type = true>
Tokenizer(const std::string& model_str, ov::Tensor& weights_tensor,
Properties&&... properties)
: Tokenizer(model_str, weights_tensor, ov::AnyMap{std::forward<Properties>(properties)...}) { }

/**
* @brief ov::genai::Tokenizer constructor with variable number of properties
* @param tokenizer_path openvino_tokenizer.xml and openvino_detokenizer.xml should be located in the tokenizer_path
Expand Down
Loading

0 comments on commit 8a74d24

Please sign in to comment.