-
Notifications
You must be signed in to change notification settings - Fork 864
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
BabyLlama with CPP backend #2544
Closed
shrinath-suresh
wants to merge
25
commits into
pytorch:cpp_backend
from
shrinath-suresh:baby_llama_integration
Closed
Changes from 23 commits
Commits
Show all changes
25 commits
Select commit
Hold shift + click to select a range
641a708
Baby Llama - Porting run.c for integration and fixed clang type conve…
shrinath-suresh 016e4f1
Custom preprocess implementation
shrinath-suresh 38d3e93
Free memory only after the inference is done
shrinath-suresh 52a7927
Implement Postprocess
shrinath-suresh c675664
Setting Fast compiler option
shrinath-suresh 374a2e8
Reading checkpoint path and tokenizer path from config file using folly
shrinath-suresh 48f522c
Removing run.c from cmake
shrinath-suresh 49a3015
Replace auto with appropriate data type
shrinath-suresh aeb1bb0
Using smartpointers and initializing the vector with appropriate size…
shrinath-suresh ee20424
Using smartpointers
shrinath-suresh f5d9799
Directly converting the tensor values to prompt token ids
shrinath-suresh 9b3de26
Moving run.c and common variables to .cc file
shrinath-suresh 3e0e2c3
Moving run.c to a separate folder
shrinath-suresh 5c0495e
Uncommenting the original run.c main method
shrinath-suresh e75a5ae
Implemented destructor to free up resources
shrinath-suresh 9afce52
Supporting files for unit test
shrinath-suresh 0d12619
Processing all the batch inputs
shrinath-suresh bd03fd8
Setting InferenceMode guard
shrinath-suresh d2dc632
Updating InferenceMode to use torch::InferenceMode
shrinath-suresh 67b46aa
Updating class name to BabyLlamaHandler
shrinath-suresh f30aab2
Renaming llm_handler target to babyllama_handler
shrinath-suresh 7174cde
Adding dummy pt file
shrinath-suresh 6dc025b
Typo Fix
shrinath-suresh 450b85d
Calculate tokens/per second for batch input
shrinath-suresh 8d279be
Adding README.md for babyllama example
shrinath-suresh File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,297 @@ | ||
#include "src/examples/babyllama/baby_llama_handler.hh" | ||
|
||
#include <typeinfo> | ||
|
||
#include "src/examples/babyllama/llama2.c/run.c" | ||
|
||
namespace llm { | ||
|
||
Transformer transformer; | ||
Tokenizer tokenizer; | ||
Sampler sampler; | ||
int steps = 256; | ||
|
||
std::pair<std::shared_ptr<torch::jit::script::Module>, | ||
std::shared_ptr<torch::Device>> | ||
BabyLlamaHandler::LoadModel( | ||
std::shared_ptr<torchserve::LoadModelRequest>& load_model_request) { | ||
try { | ||
auto device = GetTorchDevice(load_model_request); | ||
// Load dummy model | ||
auto module = std::make_shared<torch::jit::script::Module>( | ||
torch::jit::load(fmt::format("{}/{}", load_model_request->model_dir, | ||
manifest_->GetModel().serialized_file), | ||
*device)); | ||
|
||
const std::string configFilePath = | ||
fmt::format("{}/{}", load_model_request->model_dir, "config.json"); | ||
std::string jsonContent; | ||
if (!folly::readFile(configFilePath.c_str(), jsonContent)) { | ||
std::cerr << "config.json not found at: " << configFilePath << std::endl; | ||
throw; | ||
} | ||
folly::dynamic json; | ||
json = folly::parseJson(jsonContent); | ||
std::string checkpoint_path; | ||
std::string tokenizer_path; | ||
if (json.find("checkpoint_path") != json.items().end() && | ||
json.find("tokenizer_path") != json.items().end()) { | ||
checkpoint_path = json["checkpoint_path"].asString(); | ||
tokenizer_path = json["tokenizer_path"].asString(); | ||
} else { | ||
std::cerr | ||
<< "Required fields 'model_name' and 'model_path' not found in JSON." | ||
<< std::endl; | ||
throw; | ||
} | ||
|
||
build_transformer(&transformer, const_cast<char*>(checkpoint_path.c_str())); | ||
|
||
build_tokenizer(&tokenizer, const_cast<char*>(tokenizer_path.c_str()), | ||
transformer.config.vocab_size); | ||
|
||
float temperature = | ||
1.0f; // 0.0 = greedy deterministic. 1.0 = original. don't set higher | ||
float topp = 0.9f; // top-p in nucleus sampling. 1.0 = off. 0.9 works well, | ||
// but slower | ||
unsigned long long rng_seed; | ||
// build the Sampler | ||
build_sampler(&sampler, transformer.config.vocab_size, temperature, topp, | ||
rng_seed); | ||
|
||
return std::make_pair(module, device); | ||
} catch (const c10::Error& e) { | ||
TS_LOGF(ERROR, "loading the model: {}, device id: {}, error: {}", | ||
load_model_request->model_name, load_model_request->gpu_id, | ||
e.msg()); | ||
throw e; | ||
} catch (const std::runtime_error& e) { | ||
TS_LOGF(ERROR, "loading the model: {}, device id: {}, error: {}", | ||
load_model_request->model_name, load_model_request->gpu_id, | ||
e.what()); | ||
throw e; | ||
} | ||
} | ||
|
||
std::vector<torch::jit::IValue> BabyLlamaHandler::Preprocess( | ||
std::shared_ptr<torch::Device>& device, | ||
std::pair<std::string&, std::map<uint8_t, std::string>&>& idx_to_req_id, | ||
std::shared_ptr<torchserve::InferenceRequestBatch>& request_batch, | ||
std::shared_ptr<torchserve::InferenceResponseBatch>& response_batch) { | ||
std::vector<torch::jit::IValue> batch_ivalue; | ||
std::vector<torch::Tensor> batch_tensors; | ||
uint8_t idx = 0; | ||
for (auto& request : *request_batch) { | ||
try { | ||
(*response_batch)[request.request_id] = | ||
std::make_shared<torchserve::InferenceResponse>(request.request_id); | ||
idx_to_req_id.first += idx_to_req_id.first.empty() | ||
? request.request_id | ||
: "," + request.request_id; | ||
|
||
auto data_it = request.parameters.find( | ||
torchserve::PayloadType::kPARAMETER_NAME_DATA); | ||
auto dtype_it = | ||
request.headers.find(torchserve::PayloadType::kHEADER_NAME_DATA_TYPE); | ||
if (data_it == request.parameters.end()) { | ||
data_it = request.parameters.find( | ||
torchserve::PayloadType::kPARAMETER_NAME_BODY); | ||
dtype_it = request.headers.find( | ||
torchserve::PayloadType::kHEADER_NAME_BODY_TYPE); | ||
} | ||
|
||
if (data_it == request.parameters.end() || | ||
dtype_it == request.headers.end()) { | ||
TS_LOGF(ERROR, "Empty payload for request id: {}", request.request_id); | ||
(*response_batch)[request.request_id]->SetResponse( | ||
500, "data_type", torchserve::PayloadType::kCONTENT_TYPE_TEXT, | ||
"Empty payload"); | ||
continue; | ||
} | ||
|
||
std::string msg = torchserve::Converter::VectorToStr(data_it->second); | ||
|
||
int num_prompt_tokens = 0; | ||
|
||
std::unique_ptr<char[], void (*)(char*)> msgCStr( | ||
new char[msg.size() + 1], [](char* ptr) { delete[] ptr; }); | ||
|
||
std::strcpy(msgCStr.get(), msg.c_str()); | ||
|
||
std::unique_ptr<int[]> prompt_tokens(new int[msg.length()]); | ||
|
||
encode(&tokenizer, msgCStr.get(), 1, 0, prompt_tokens.get(), | ||
&num_prompt_tokens); | ||
|
||
std::vector<torch::Tensor> tensor_vector; | ||
mreso marked this conversation as resolved.
Show resolved
Hide resolved
|
||
for (int64_t i = 0; i < num_prompt_tokens; ++i) { | ||
int token = prompt_tokens[i]; | ||
torch::Tensor tensor = torch::tensor(token, torch::kInt64); | ||
tensor_vector.push_back(tensor); | ||
} | ||
torch::Tensor stacked_tensor = torch::stack(tensor_vector); | ||
batch_ivalue.push_back(stacked_tensor); | ||
|
||
idx_to_req_id.second[idx++] = request.request_id; | ||
|
||
} catch (const std::runtime_error& e) { | ||
TS_LOGF(ERROR, "Failed to load tensor for request id: {}, error: {}", | ||
request.request_id, e.what()); | ||
auto response = (*response_batch)[request.request_id]; | ||
response->SetResponse(500, "data_type", | ||
torchserve::PayloadType::kDATA_TYPE_STRING, | ||
"runtime_error, failed to load tensor"); | ||
} catch (const c10::Error& e) { | ||
TS_LOGF(ERROR, "Failed to load tensor for request id: {}, c10 error:{}", | ||
request.request_id, e.msg()); | ||
auto response = (*response_batch)[request.request_id]; | ||
response->SetResponse(500, "data_type", | ||
torchserve::PayloadType::kDATA_TYPE_STRING, | ||
"c10 error, failed to load tensor"); | ||
} | ||
} | ||
|
||
return batch_ivalue; | ||
} | ||
|
||
torch::Tensor BabyLlamaHandler::Inference( | ||
std::shared_ptr<torch::jit::script::Module> model, | ||
std::vector<torch::jit::IValue>& inputs, | ||
std::shared_ptr<torch::Device>& device, | ||
std::pair<std::string&, std::map<uint8_t, std::string>&>& idx_to_req_id, | ||
std::shared_ptr<torchserve::InferenceResponseBatch>& response_batch) { | ||
torch::InferenceMode guard; | ||
std::vector<torch::Tensor> batch_output_vector; | ||
for (const torch::jit::IValue& input : inputs) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This for loop predict each inference request one by one. Can we optimize this section to either leverage C++ multithreading or GPU batching power? |
||
std::vector<torch::Tensor> tensor_vector; | ||
tensor_vector.reserve(steps); | ||
torch::Tensor tokens_list_tensor = input.toTensor(); | ||
|
||
int64_t num_elements = tokens_list_tensor.numel(); | ||
|
||
int64_t* data_ptr = tokens_list_tensor.data_ptr<int64_t>(); | ||
|
||
std::unique_ptr<int[]> prompt_tokens(new int[num_elements]); | ||
|
||
for (int64_t i = 0; i < num_elements; ++i) { | ||
prompt_tokens[i] = data_ptr[i]; | ||
} | ||
|
||
// start the main loop | ||
long start = | ||
0; // used to time our code, only initialized after first iteration | ||
int next; // will store the next token in the sequence | ||
int token = | ||
prompt_tokens[0]; // kick off with the first token in the prompt | ||
int pos = 0; // position in the sequence | ||
while (pos < steps) { | ||
// forward the transformer to get logits for the next token | ||
float* logits = forward(&transformer, token, pos); | ||
|
||
// advance the state state machine | ||
if (pos < num_elements - 1) { | ||
// if we are still processing the input prompt, force the next prompt | ||
// token | ||
next = prompt_tokens[pos + 1]; | ||
} else { | ||
// otherwise sample the next token from the logits | ||
next = sample(&sampler, logits); | ||
} | ||
pos++; | ||
|
||
torch::Tensor tensor = torch::tensor(next, torch::kLong); | ||
tensor_vector.push_back(tensor); | ||
|
||
// data-dependent terminating condition: the BOS (=1) token delimits | ||
// sequences | ||
if (next == 1) { | ||
break; | ||
} | ||
token = next; | ||
|
||
// init the timer here because the first iteration can be slower | ||
if (start == 0) { | ||
start = time_in_ms(); | ||
} | ||
} | ||
|
||
// report achieved tok/s (pos-1 because the timer starts after first | ||
// iteration) | ||
if (pos > 1) { | ||
long end = time_in_ms(); | ||
double token_per_sec = (pos - 1) / (double)(end - start) * 1000; | ||
std::cout << "Achieved tok per sec: " << token_per_sec << std::endl; | ||
} | ||
|
||
torch::Tensor stacked_tensor = torch::stack(tensor_vector); | ||
|
||
batch_output_vector.push_back(stacked_tensor); | ||
} | ||
|
||
return torch::stack(batch_output_vector); | ||
} | ||
|
||
void BabyLlamaHandler::Postprocess( | ||
const torch::Tensor& data, | ||
std::pair<std::string&, std::map<uint8_t, std::string>&>& idx_to_req_id, | ||
std::shared_ptr<torchserve::InferenceResponseBatch>& response_batch) { | ||
for (const auto& kv : idx_to_req_id.second) { | ||
try { | ||
int64_t num_elements = data.numel(); | ||
mreso marked this conversation as resolved.
Show resolved
Hide resolved
|
||
int64_t* data_ptr = data.data_ptr<int64_t>(); | ||
int64_t token = 1; | ||
std::string concatenated_string; | ||
for (int64_t i = 0; i < num_elements; ++i) { | ||
char* piece = decode(&tokenizer, token, data_ptr[i]); | ||
std::string piece_string(piece); | ||
token = data_ptr[i]; | ||
concatenated_string += piece_string; | ||
} | ||
|
||
std::cout << "Generated String: " << concatenated_string << std::endl; | ||
|
||
auto response = (*response_batch)[kv.second]; | ||
|
||
response->SetResponse(200, "data_type", | ||
torchserve::PayloadType::kDATA_TYPE_STRING, | ||
concatenated_string); | ||
} catch (const std::runtime_error& e) { | ||
TS_LOGF(ERROR, "Failed to load tensor for request id: {}, error: {}", | ||
kv.second, e.what()); | ||
auto response = (*response_batch)[kv.second]; | ||
response->SetResponse(500, "data_type", | ||
torchserve::PayloadType::kDATA_TYPE_STRING, | ||
"runtime_error, failed to postprocess tensor"); | ||
} catch (const c10::Error& e) { | ||
TS_LOGF(ERROR, | ||
"Failed to postprocess tensor for request id: {}, error: {}", | ||
kv.second, e.msg()); | ||
auto response = (*response_batch)[kv.second]; | ||
response->SetResponse(500, "data_type", | ||
torchserve::PayloadType::kDATA_TYPE_STRING, | ||
"c10 error, failed to postprocess tensor"); | ||
} | ||
} | ||
} | ||
|
||
BabyLlamaHandler::~BabyLlamaHandler() noexcept { | ||
free_sampler(&sampler); | ||
mreso marked this conversation as resolved.
Show resolved
Hide resolved
|
||
free_tokenizer(&tokenizer); | ||
free_transformer(&transformer); | ||
} | ||
|
||
} // namespace llm | ||
|
||
#if defined(__linux__) || defined(__APPLE__) | ||
extern "C" { | ||
torchserve::torchscripted::BaseHandler* allocatorBabyLlamaHandler() { | ||
return new llm::BabyLlamaHandler(); | ||
} | ||
|
||
void deleterBabyLlamaHandler(torchserve::torchscripted::BaseHandler* p) { | ||
if (p != nullptr) { | ||
delete static_cast<llm::BabyLlamaHandler*>(p); | ||
} | ||
} | ||
} | ||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
#ifndef BABYLLAMA_HANDLER_HH_ | ||
#define BABYLLAMA_HANDLER_HH_ | ||
|
||
#include <folly/FileUtil.h> | ||
#include <folly/json.h> | ||
|
||
#include <iostream> | ||
|
||
// #include "run.c" | ||
#include "src/backends/torch_scripted/handler/base_handler.hh" | ||
|
||
namespace llm { | ||
class BabyLlamaHandler : public torchserve::torchscripted::BaseHandler { | ||
public: | ||
// NOLINTBEGIN(bugprone-exception-escape) | ||
BabyLlamaHandler() = default; | ||
// NOLINTEND(bugprone-exception-escape) | ||
~BabyLlamaHandler() noexcept; | ||
|
||
void initialize_context(); | ||
|
||
virtual std::pair<std::shared_ptr<torch::jit::script::Module>, | ||
std::shared_ptr<torch::Device>> | ||
LoadModel(std::shared_ptr<torchserve::LoadModelRequest>& load_model_request); | ||
|
||
std::vector<torch::jit::IValue> Preprocess( | ||
std::shared_ptr<torch::Device>& device, | ||
std::pair<std::string&, std::map<uint8_t, std::string>&>& idx_to_req_id, | ||
std::shared_ptr<torchserve::InferenceRequestBatch>& request_batch, | ||
std::shared_ptr<torchserve::InferenceResponseBatch>& response_batch) | ||
override; | ||
|
||
torch::Tensor Inference( | ||
std::shared_ptr<torch::jit::script::Module> model, | ||
std::vector<torch::jit::IValue>& inputs, | ||
std::shared_ptr<torch::Device>& device, | ||
std::pair<std::string&, std::map<uint8_t, std::string>&>& idx_to_req_id, | ||
std::shared_ptr<torchserve::InferenceResponseBatch>& response_batch) | ||
override; | ||
|
||
void Postprocess( | ||
const torch::Tensor& data, | ||
std::pair<std::string&, std::map<uint8_t, std::string>&>& idx_to_req_id, | ||
std::shared_ptr<torchserve::InferenceResponseBatch>& response_batch) | ||
override; | ||
}; | ||
} // namespace llm | ||
#endif // BABYLLAMA_HANDLER_HH_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
MIT License | ||
|
||
Copyright (c) 2023 Andrej | ||
|
||
Permission is hereby granted, free of charge, to any person obtaining a copy | ||
of this software and associated documentation files (the "Software"), to deal | ||
in the Software without restriction, including without limitation the rights | ||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
copies of the Software, and to permit persons to whom the Software is | ||
furnished to do so, subject to the following conditions: | ||
|
||
The above copyright notice and this permission notice shall be included in all | ||
copies or substantial portions of the Software. | ||
|
||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
SOFTWARE. |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
current cpp backend only support one device id, which means there is no across gpu device partition.
i assume this example only work for single gpu.