Skip to content

Commit

Permalink
Merge releases/2024/3 into master (#640)
Browse files Browse the repository at this point in the history
Co-authored-by: Alina Kladieva <[email protected]>
Co-authored-by: Anastasiia Pnevskaia <[email protected]>
Co-authored-by: Nikita Malinin <[email protected]>
  • Loading branch information
4 people authored Jul 23, 2024
1 parent d24a683 commit 5d21486
Show file tree
Hide file tree
Showing 12 changed files with 83 additions and 37 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/causal_lm_cpp.yml
Original file line number Diff line number Diff line change
Expand Up @@ -652,7 +652,7 @@ jobs:
python -m pip install --upgrade-strategy eager -r ./samples/requirements.txt --pre --extra-index-url https://storage.openvinotoolkit.org/simple/wheels/nightly
python -m pip install ./thirdparty/openvino_tokenizers/[transformers] --pre --extra-index-url https://storage.openvinotoolkit.org/simple/wheels/nightly
optimum-cli export openvino --trust-remote-code --weight-format fp16 --model TinyLlama/TinyLlama-1.1B-Chat-v1.0 TinyLlama-1.1B-Chat-v1.0
cmake -DCMAKE_BUILD_TYPE=Releas -S ./ -B ./build/
cmake -DCMAKE_BUILD_TYPE=Release -S ./ -B ./build/
cmake --build ./build/ --config Release -j
- name: Run gtests
run: |
Expand Down
26 changes: 24 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ project(OpenVINOGenAI
HOMEPAGE_URL "https://github.com/openvinotoolkit/openvino.genai"
LANGUAGES CXX)

option(INSTALL_GTEST "Enable installation of googletest. (Projects embedding googletest may want to turn this OFF.)" OFF)
option(RAPIDJSON_BUILD_DOC "Build rapidjson documentation." OFF)

# Find OpenVINODeveloperPackage first to compile with SDL flags
find_package(OpenVINODeveloperPackage QUIET
PATHS "${OpenVINO_DIR}")
Expand All @@ -40,13 +43,32 @@ find_file(spda_to_pa_header sdpa_to_paged_attention.hpp

include(cmake/features.cmake)

if(ENABLE_PYTHON)
# the following two calls are required for cross-compilation
if(OpenVINODeveloperPackage_DIR)
ov_find_python3(REQUIRED)
ov_detect_python_module_extension()
else()
if(CMAKE_VERSION VERSION_GREATER_EQUAL 3.18)
find_package(Python3 REQUIRED COMPONENTS Interpreter Development.Module)
else()
find_package(Python3 REQUIRED COMPONENTS Interpreter Development)
endif()
endif()
endif()

add_subdirectory(thirdparty)
add_subdirectory(src)
add_subdirectory(samples)
add_subdirectory(tests/cpp)

install(FILES LICENSE DESTINATION licensing COMPONENT licensing_genai RENAME LICENSE-GENAI)
install(FILES third-party-programs.txt DESTINATION licensing COMPONENT licensing_genai RENAME third-party-programs-genai.txt)
install(FILES LICENSE DESTINATION docs/licensing COMPONENT licensing_genai RENAME LICENSE-GENAI)
install(FILES third-party-programs.txt DESTINATION docs/licensing COMPONENT licensing_genai RENAME third-party-programs-genai.txt)
set(CPACK_ARCHIVE_COMPONENT_INSTALL ON)
set(CPACK_INCLUDE_TOPLEVEL_DIRECTORY OFF)
# Workaround https://gitlab.kitware.com/cmake/cmake/-/issues/2614
set(CPACK_COMPONENTS_ALL core_genai core_genai_dev cpp_samples_genai licensing_genai openvino_tokenizers openvino_tokenizers_licenses)
if(ENABLE_PYTHON)
list(APPEND CPACK_COMPONENTS_ALL pygenai_${Python3_VERSION_MAJOR}_${Python3_VERSION_MINOR})
endif()
include(CPack)
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,9 @@ int main(int argc, char* argv[]) try {
// vLLM specific params
scheduler_config.max_num_seqs = 2;

ov::genai::ContinuousBatchingPipeline pipe(models_path, scheduler_config);
// It's possible to construct a Tokenizer from a different path.
// If the Tokenizer isn't specified, it's loaded from the same folder.
ov::genai::ContinuousBatchingPipeline pipe(models_path, ov::genai::Tokenizer{models_path}, scheduler_config);
std::vector<ov::genai::GenerationResult> generation_results = pipe.generate(prompts, sampling_params);

for (size_t request_id = 0; request_id < generation_results.size(); ++request_id) {
Expand Down
19 changes: 18 additions & 1 deletion src/cpp/include/openvino/genai/continuous_batching_pipeline.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,24 @@ class OPENVINO_GENAI_EXPORTS ContinuousBatchingPipeline {
const std::string& device = "CPU",
const ov::AnyMap& plugin_config = {});

std::shared_ptr<ov::genai::Tokenizer> get_tokenizer();
/**
* @brief Constructs a ContinuousBatchingPipeline when ov::genai::Tokenizer is initialized manually using file from the different dirs.
*
* @param model_path Path to the dir with model, tokenizer .xml/.bin files, and generation_configs.json
* @param scheduler_config
* @param tokenizer manually initialized ov::genai::Tokenizer
* @param device optional device
* @param plugin_config optional plugin_config
*/
ContinuousBatchingPipeline(
const std::string& model_path,
const ov::genai::Tokenizer& tokenizer,
const SchedulerConfig& scheduler_config,
const std::string& device="CPU",
const ov::AnyMap& plugin_config={}
);

ov::genai::Tokenizer get_tokenizer();

ov::genai::GenerationConfig get_config() const;

Expand Down
4 changes: 2 additions & 2 deletions src/cpp/include/openvino/genai/llm_pipeline.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,10 @@ class OPENVINO_GENAI_EXPORTS LLMPipeline {
);

/**
* @brief Constructs a LLMPipeline when ov::Tokenizer is initialized manually using file from the different dirs.
* @brief Constructs a LLMPipeline when ov::genai::Tokenizer is initialized manually using file from the different dirs.
*
* @param model_path Path to the dir with model, tokenizer .xml/.bin files, and generation_configs.json
* @param tokenizer manually initialized ov::Tokenizer
* @param tokenizer manually initialized ov::genai::Tokenizer
* @param device optional device
* @param plugin_config optional plugin_config
*/
Expand Down
2 changes: 1 addition & 1 deletion src/cpp/include/openvino/genai/tokenizer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ struct TokenizedInputs {
class OPENVINO_GENAI_EXPORTS Tokenizer {
public:
/**
* @brief ov::Tokenizer constructor.
* @brief ov::genai::Tokenizer constructor.
* @param tokenizer_path openvino_tokenizer.xml and openvino_detokenizer.xml should be located in the tokenizer_path
*/
Tokenizer(const std::string& tokenizer_path);
Expand Down
27 changes: 19 additions & 8 deletions src/cpp/src/continuous_batching_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ using namespace ov::genai;
void apply_paged_attention_transformations(std::shared_ptr<ov::Model> model, DeviceConfig& device_config);

class ContinuousBatchingPipeline::Impl {
std::shared_ptr<ov::genai::Tokenizer> m_tokenizer;
ov::genai::Tokenizer m_tokenizer;
std::shared_ptr<Scheduler> m_scheduler;
std::shared_ptr<CacheManager> m_cache_manager;
std::shared_ptr<ModelRunner> m_model_runner;
Expand Down Expand Up @@ -70,9 +70,9 @@ class ContinuousBatchingPipeline::Impl {
}

public:
Impl(const std::string& models_path, const SchedulerConfig& scheduler_config, const std::string device, const ov::AnyMap& plugin_config) {
Impl(const std::string& models_path, const Tokenizer& tokenizer, const SchedulerConfig& scheduler_config, const std::string& device, const ov::AnyMap& plugin_config) :
m_tokenizer{tokenizer} {
ov::Core core;
m_tokenizer = std::make_shared<ov::genai::Tokenizer>(models_path);

// The model can be compiled for GPU as well
std::shared_ptr<ov::Model> model = core.read_model(models_path + "/openvino_model.xml");
Expand Down Expand Up @@ -105,6 +105,9 @@ class ContinuousBatchingPipeline::Impl {
// read default generation config
}

Impl(const std::string& models_path, const SchedulerConfig& scheduler_config, const std::string& device, const ov::AnyMap& plugin_config)
: Impl{models_path, Tokenizer(models_path), scheduler_config, device, plugin_config} {}

ov::genai::GenerationConfig get_config() const {
return m_generation_config;
}
Expand All @@ -113,19 +116,19 @@ class ContinuousBatchingPipeline::Impl {
return m_pipeline_metrics;
}

std::shared_ptr<ov::genai::Tokenizer> get_tokenizer() {
ov::genai::Tokenizer get_tokenizer() {
return m_tokenizer;
}

GenerationHandle add_request(uint64_t request_id, std::string prompt, ov::genai::GenerationConfig sampling_params) {
sampling_params.set_eos_token_id(m_tokenizer->get_eos_token_id());
sampling_params.set_eos_token_id(m_tokenizer.get_eos_token_id());
sampling_params.validate();

ov::Tensor input_ids;
{
static ManualTimer timer("tokenize");
timer.start();
input_ids = m_tokenizer->encode(prompt).input_ids;
input_ids = m_tokenizer.encode(prompt).input_ids;
timer.end();
}

Expand Down Expand Up @@ -263,7 +266,7 @@ class ContinuousBatchingPipeline::Impl {
auto num_outputs = std::min(sampling_params[generation_idx].num_return_sequences, generation_outputs.size());
for (size_t generation_output_idx = 0; generation_output_idx < num_outputs; ++generation_output_idx) {
const auto& generation_output = generation_outputs[generation_output_idx];
std::string output_text = m_tokenizer->decode(generation_output.generated_token_ids);
std::string output_text = m_tokenizer.decode(generation_output.generated_token_ids);
result.m_generation_ids.push_back(output_text);
result.m_scores.push_back(generation_output.score);
}
Expand All @@ -283,7 +286,15 @@ ContinuousBatchingPipeline::ContinuousBatchingPipeline( const std::string& model
m_impl = std::make_shared<Impl>(models_path, scheduler_config, device, plugin_config);
}

std::shared_ptr<ov::genai::Tokenizer> ContinuousBatchingPipeline::get_tokenizer() {
ContinuousBatchingPipeline::ContinuousBatchingPipeline(
const std::string& model_path,
const Tokenizer& tokenizer,
const SchedulerConfig& scheduler_config,
const std::string& device,
const ov::AnyMap& plugin_config
) : m_impl{std::make_shared<Impl>(model_path, tokenizer, scheduler_config, device, plugin_config)} {}

ov::genai::Tokenizer ContinuousBatchingPipeline::get_tokenizer() {
return m_impl->get_tokenizer();
}

Expand Down
16 changes: 3 additions & 13 deletions src/python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,7 @@ FetchContent_Declare(
FetchContent_GetProperties(pybind11)
# search for FindPython3.cmake instead of legacy modules
set(PYBIND11_FINDPYTHON ON)
# the following two calls are required for cross-compilation
if(OpenVINODeveloperPackage_DIR)
ov_find_python3(REQUIRED)
ov_detect_python_module_extension()
else()
if(CMAKE_VERSION VERSION_GREATER_EQUAL 3.18)
find_package(Python3 REQUIRED COMPONENTS Interpreter Development.Module)
else()
find_package(Python3 REQUIRED COMPONENTS Interpreter Development)
endif()
endif()

if(NOT pybind11_POPULATED)
FetchContent_Populate(pybind11)
add_subdirectory(${pybind11_SOURCE_DIR} ${pybind11_BINARY_DIR})
Expand Down Expand Up @@ -65,10 +55,10 @@ endif()
install(FILES "${CMAKE_CURRENT_SOURCE_DIR}/openvino_genai/__init__.py"
"${CMAKE_BINARY_DIR}/openvino_genai/__version__.py"
DESTINATION python/openvino_genai
COMPONENT pygenai_${Python_VERSION_MAJOR}_${Python_VERSION_MINOR})
COMPONENT pygenai_${Python3_VERSION_MAJOR}_${Python3_VERSION_MINOR})
install(TARGETS py_generate_pipeline
LIBRARY DESTINATION python/openvino_genai
COMPONENT pygenai_${Python_VERSION_MAJOR}_${Python_VERSION_MINOR})
COMPONENT pygenai_${Python3_VERSION_MAJOR}_${Python3_VERSION_MINOR})

install(FILES "${CMAKE_BINARY_DIR}/openvino_genai/__version__.py"
DESTINATION openvino_genai
Expand Down
10 changes: 7 additions & 3 deletions src/python/py_generate_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -596,10 +596,14 @@ PYBIND11_MODULE(py_generate_pipeline, m) {
.def_readwrite("max_num_seqs", &SchedulerConfig::max_num_seqs);

py::class_<ContinuousBatchingPipeline>(m, "ContinuousBatchingPipeline")
.def(py::init([](const std::string& model_path, const SchedulerConfig& config) {
.def(py::init([](const std::string& model_path, const SchedulerConfig& scheduler_config, const std::string& device, const std::map<std::string, py::object>& plugin_config) {
ScopedVar env_manager(ov_tokenizers_module_path());
return std::make_unique<ContinuousBatchingPipeline>(model_path, config);
}))
return std::make_unique<ContinuousBatchingPipeline>(model_path, scheduler_config, device, properties_to_any_map(plugin_config));
}), py::arg("model_path"), py::arg("scheduler_config"), py::arg("device") = "CPU", py::arg("plugin_config") = ov::AnyMap({}))
.def(py::init([](const std::string& model_path, const ov::genai::Tokenizer& tokenizer, const SchedulerConfig& scheduler_config, const std::string& device, const std::map<std::string, py::object>& plugin_config) {
ScopedVar env_manager(ov_tokenizers_module_path());
return std::make_unique<ContinuousBatchingPipeline>(model_path, tokenizer, scheduler_config, device, properties_to_any_map(plugin_config));
}), py::arg("model_path"), py::arg("tokenizer"), py::arg("scheduler_config"), py::arg("device") = "CPU", py::arg("plugin_config") = ov::AnyMap({}))
.def("get_tokenizer", &ContinuousBatchingPipeline::get_tokenizer)
.def("get_config", &ContinuousBatchingPipeline::get_config)
.def("add_request", &ContinuousBatchingPipeline::add_request)
Expand Down
2 changes: 1 addition & 1 deletion tests/python_tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def run_continuous_batching(
prompts: List[str],
generation_configs : List[GenerationConfig]
) -> List[GenerationResult]:
pipe = ContinuousBatchingPipeline(model_path.absolute().as_posix(), scheduler_config)
pipe = ContinuousBatchingPipeline(model_path.absolute().as_posix(), scheduler_config, "CPU", {})
output = pipe.generate(prompts, generation_configs)
del pipe
shutil.rmtree(model_path)
Expand Down
6 changes: 3 additions & 3 deletions tests/python_tests/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
import sys
from dataclasses import dataclass
from pathlib import Path
from openvino_genai import ContinuousBatchingPipeline, GenerationConfig
from typing import List, Optional, TypedDict
from openvino_genai import ContinuousBatchingPipeline, GenerationConfig, Tokenizer
from typing import List, TypedDict

from common import run_test_pipeline, get_models_list, get_model_and_tokenizer, save_ov_model_from_optimum, \
generate_and_compare_with_reference_text, get_greedy, get_beam_search, get_multinomial_temperature, \
Expand Down Expand Up @@ -306,7 +306,7 @@ def test_post_oom_health(tmp_path):
model_path : Path = tmp_path / model_id
save_ov_model_from_optimum(model, hf_tokenizer, model_path)

pipe = ContinuousBatchingPipeline(model_path.absolute().as_posix(), scheduler_config)
pipe = ContinuousBatchingPipeline(model_path.absolute().as_posix(), Tokenizer(model_path.absolute().as_posix()), scheduler_config)
# First run should return incomplete response
output = pipe.generate(["What is OpenVINO?"], generation_configs)
assert(len(output))
Expand Down

0 comments on commit 5d21486

Please sign in to comment.