diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a7ee04a103..306b2005e5 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -34,7 +34,7 @@ repos: - id: black additional_dependencies: ['click==8.0.4'] - repo: https://github.com/PyCQA/isort - rev: 5.10.1 + rev: 5.12.0 hooks: - id: isort args: ["--profile", "black"] diff --git a/cpp/src/examples/CMakeLists.txt b/cpp/src/examples/CMakeLists.txt index 6f8441d190..b7c97691df 100644 --- a/cpp/src/examples/CMakeLists.txt +++ b/cpp/src/examples/CMakeLists.txt @@ -4,7 +4,7 @@ set(MNIST_SOURCE_FILES "") list(APPEND MNIST_SOURCE_FILES ${MNIST_SRC_DIR}/mnist_handler.cc) add_library(mnist_handler SHARED ${MNIST_SOURCE_FILES}) target_include_directories(mnist_handler PUBLIC ${MNIST_SRC_DIR}) -target_link_libraries(mnist_handler PRIVATE ts_backends_torch_scripted ts_utils ${TORCH_LIBRARIES}) +target_link_libraries(mnist_handler PRIVATE ts_backends_torch_scripted ts_utils ${TORCH_LIBRARIES}) set(LLM_SRC_DIR "${torchserve_cpp_SOURCE_DIR}/src/examples/llamacpp") set(LLAMACPP_SRC_DIR "/home/ubuntu/llama.cpp") @@ -20,10 +20,12 @@ set(MY_OBJECT_FILES ${LLAMACPP_SRC_DIR}/ggml.o ${LLAMACPP_SRC_DIR}/llama.o ${LLAMACPP_SRC_DIR}/common.o - ${LLAMACPP_SRC_DIR}/k_quants.o + ${LLAMACPP_SRC_DIR}/ggml-quants.o ${LLAMACPP_SRC_DIR}/ggml-alloc.o ${LLAMACPP_SRC_DIR}/grammar-parser.o ${LLAMACPP_SRC_DIR}/console.o + ${LLAMACPP_SRC_DIR}/build-info.o + ${LLAMACPP_SRC_DIR}/ggml-backend.o ) diff --git a/cpp/src/examples/llamacpp/llamacpp_handler.cc b/cpp/src/examples/llamacpp/llamacpp_handler.cc index 1ad72e05cc..fc01858bec 100644 --- a/cpp/src/examples/llamacpp/llamacpp_handler.cc +++ b/cpp/src/examples/llamacpp/llamacpp_handler.cc @@ -53,7 +53,8 @@ LlamacppHandler::LoadModel( llama_backend_init(params.numa); ctx_params = llama_context_default_params(); - llamamodel = llama_load_model_from_file(params.model.c_str(), ctx_params); + model_params = llama_model_default_params(); + llamamodel = llama_load_model_from_file(params.model.c_str(), model_params); return std::make_pair(module, device); } catch (const c10::Error& e) { @@ -74,7 +75,6 @@ std::vector LlamacppHandler::Preprocess( std::pair&>& idx_to_req_id, std::shared_ptr& request_batch, std::shared_ptr& response_batch) { - initialize_context(); std::vector batch_ivalue; @@ -181,8 +181,7 @@ torch::Tensor LlamacppHandler::Inference( // evaluate the transformer if (llama_eval(llama_ctx, tokens_list.data(), int(tokens_list.size()), - llama_get_kv_cache_token_count(llama_ctx), - params.n_threads)) { + llama_get_kv_cache_token_count(llama_ctx))) { std::cout << "Failed to eval\n" << __func__ << std::endl; break; } @@ -194,7 +193,7 @@ torch::Tensor LlamacppHandler::Inference( llama_token new_token_id = 0; auto logits = llama_get_logits(llama_ctx); - auto n_vocab = llama_n_vocab(llama_ctx); + auto n_vocab = llama_n_vocab(llamamodel); std::vector candidates; candidates.reserve(n_vocab); @@ -210,7 +209,7 @@ torch::Tensor LlamacppHandler::Inference( new_token_id = llama_sample_token_greedy(llama_ctx, &candidates_p); // is it an end of stream ? - if (new_token_id == llama_token_eos(llama_ctx)) { + if (new_token_id == llama_token_eos(llamamodel)) { std::cout << "Reached [end of text]\n"; break; } diff --git a/cpp/src/examples/llamacpp/llamacpp_handler.hh b/cpp/src/examples/llamacpp/llamacpp_handler.hh index 54de782fad..520099f2d6 100644 --- a/cpp/src/examples/llamacpp/llamacpp_handler.hh +++ b/cpp/src/examples/llamacpp/llamacpp_handler.hh @@ -13,6 +13,7 @@ namespace llm { class LlamacppHandler : public torchserve::torchscripted::BaseHandler { private: gpt_params params; + llama_model_params model_params; llama_model* llamamodel; llama_context_params ctx_params; llama_context* llama_ctx; @@ -52,4 +53,4 @@ class LlamacppHandler : public torchserve::torchscripted::BaseHandler { override; }; } // namespace llm -#endif // LLAMACPP_HANDLER_HH_ \ No newline at end of file +#endif // LLAMACPP_HANDLER_HH_ diff --git a/cpp/test/backends/torch_scripted/torch_scripted_backend_test.cc b/cpp/test/backends/torch_scripted/torch_scripted_backend_test.cc index e841c57ea1..16fedc660a 100644 --- a/cpp/test/backends/torch_scripted/torch_scripted_backend_test.cc +++ b/cpp/test/backends/torch_scripted/torch_scripted_backend_test.cc @@ -84,8 +84,7 @@ TEST_F(TorchScriptedBackendTest, TestLoadPredictLlmHandler) { "test/resources/torchscript_model/llamacpp/llamacpp_handler", "llm", -1, "", "", 1, false), "test/resources/torchscript_model/llamacpp/llamacpp_handler", - "test/resources/torchscript_model/llamacpp/prompt.txt", "llm_ts", - 200); + "test/resources/torchscript_model/llamacpp/prompt.txt", "llm_ts", 200); } TEST_F(TorchScriptedBackendTest, TestBackendInitWrongModelDir) {