Skip to content

Commit

Permalink
feat(llm): add huggingface provider (#13484)
Browse files Browse the repository at this point in the history
  • Loading branch information
srb3 authored Nov 25, 2024
1 parent 33407df commit c5199ff
Show file tree
Hide file tree
Showing 16 changed files with 877 additions and 2 deletions.
6 changes: 6 additions & 0 deletions changelog/unreleased/kong/feat-add-huggingface-llm-driver.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
message: |
Addded a new LLM driver for interfacing with the Hugging Face inference API.
The driver supports both serverless and dedicated LLM instances hosted by
Hugging Face for conversational and text generation tasks.
type: feature
scope: Core
1 change: 1 addition & 0 deletions kong-3.9.0-0.rockspec
Original file line number Diff line number Diff line change
Expand Up @@ -633,6 +633,7 @@ build = {
["kong.llm.drivers.llama2"] = "kong/llm/drivers/llama2.lua",
["kong.llm.drivers.gemini"] = "kong/llm/drivers/gemini.lua",
["kong.llm.drivers.bedrock"] = "kong/llm/drivers/bedrock.lua",
["kong.llm.drivers.huggingface"] = "kong/llm/drivers/huggingface.lua",


["kong.llm.plugin.base"] = "kong/llm/plugin/base.lua",
Expand Down
328 changes: 328 additions & 0 deletions kong/llm/drivers/huggingface.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,328 @@
local _M = {}

-- imports
local cjson = require("cjson.safe")
local fmt = string.format
local ai_shared = require("kong.llm.drivers.shared")
local socket_url = require("socket.url")
--

local DRIVER_NAME = "huggingface"

function _M.pre_request(conf, body)
return true, nil
end

local function from_huggingface(response_string, model_info, route_type)
local response_table, err = cjson.decode(response_string)
if not response_table then
ngx.log(ngx.ERR, "Failed to decode JSON response from HuggingFace API: ", err)
return nil, "Failed to decode response"
end

if response_table.error or response_table.message then
local error_msg = response_table.error or response_table.message
ngx.log(ngx.ERR, "Error from HuggingFace API: ", error_msg)
return nil, "API error: " .. error_msg
end

local transformed_response = {
model = model_info.name,
object = response_table.object or route_type,
choices = {},
usage = {},
}

-- Chat reports usage, generation does not
transformed_response.usage = response_table.usage or {}

response_table.generated_text = response_table[1] and response_table[1].generated_text or nil
if response_table.generated_text then
table.insert(transformed_response.choices, {
message = { content = response_table.generated_text },
index = 0,
finish_reason = "complete",
})
elseif response_table.choices then
for i, choice in ipairs(response_table.choices) do
local content = choice.message and choice.message.content or ""
table.insert(transformed_response.choices, {
message = { content = content },
index = i - 1,
finish_reason = "complete",
})
end
else
ngx.log(ngx.ERR, "Unexpected response format from Hugging Face API")
return nil, "Invalid response format"
end

local result_string, err = cjson.encode(transformed_response)
if not result_string then
ngx.log(ngx.ERR, "Failed to encode transformed response: ", err)
return nil, "Failed to encode response"
end
return result_string, nil
end

local function set_huggingface_options(model_info)
local use_cache = false
local wait_for_model = false

if model_info and model_info.options and model_info.options.huggingface then
use_cache = model_info.options.huggingface.use_cache or false
wait_for_model = model_info.options.huggingface.wait_for_model or false
end

return {
use_cache = use_cache,
wait_for_model = wait_for_model,
}
end

local function set_default_parameters(request_table)
local parameters = request_table.parameters or {}
if parameters.top_k == nil then
parameters.top_k = request_table.top_k
end
if parameters.top_p == nil then
parameters.top_p = request_table.top_p
end
if parameters.temperature == nil then
parameters.temperature = request_table.temperature
end
if parameters.max_tokens == nil then
if request_table.messages then
-- conversational model use the max_lenght param
-- https://huggingface.co/docs/api-inference/en/detailed_parameters?code=curl#conversational-task
parameters.max_lenght = request_table.max_tokens
else
parameters.max_new_tokens = request_table.max_tokens
end
end
request_table.top_k = nil
request_table.top_p = nil
request_table.temperature = nil
request_table.max_tokens = nil

return parameters
end

local function to_huggingface(task, request_table, model_info)
local parameters = set_default_parameters(request_table)
local options = set_huggingface_options(model_info)
if task == "llm/v1/completions" then
request_table.inputs = request_table.prompt
request_table.prompt = nil
end
request_table.options = options
request_table.parameters = parameters
request_table.model = model_info.name or request_table.model

return request_table, "application/json", nil
end

local function safe_access(tbl, ...)
local value = tbl
for _, key in ipairs({ ... }) do
value = value and value[key]
if not value then
return nil
end
end
return value
end

local function handle_huggingface_stream(event_t, model_info, route_type)
-- discard empty frames, it should either be a random new line, or comment
if (not event_t.data) or (#event_t.data < 1) then
return
end
local event, err = cjson.decode(event_t.data)

if err then
ngx.log(ngx.WARN, "failed to decode stream event frame from Hugging Face: " .. err)
return nil, "failed to decode stream event frame from Hugging Face", nil
end

local new_event
if route_type == "stream/llm/v1/chat" then
local content = safe_access(event, "choices", 1, "delta", "content") or ""
new_event = {
choices = {
[1] = {
delta = {
content = content,
role = "assistant",
},
index = 0,
},
},
model = event.model or model_info.name,
object = "chat.completion.chunk",
}
else
local text = safe_access(event, "token", "text") or ""
new_event = {
choices = {
[1] = {
text = text,
index = 0,
},
},
model = model_info.name,
object = "text_completion",
}
end
return cjson.encode(new_event), nil, nil
end

local transformers_from = {
["llm/v1/chat"] = from_huggingface,
["llm/v1/completions"] = from_huggingface,
["stream/llm/v1/chat"] = handle_huggingface_stream,
["stream/llm/v1/completions"] = handle_huggingface_stream,
}

function _M.from_format(response_string, model_info, route_type)
ngx.log(ngx.DEBUG, "converting from ", model_info.provider, "://", route_type, " type to kong")

-- MUST return a string, set as the response body
if not transformers_from[route_type] then
return nil, fmt("no transformer available from format %s://%s", model_info.provider, route_type)
end

local ok, response_string, err, metadata =
pcall(transformers_from[route_type], response_string, model_info, route_type)
if not ok or err then
return nil,
fmt("transformation failed from type %s://%s: %s", model_info.provider, route_type, err or "unexpected_error")
end

return response_string, nil, metadata
end

local transformers_to = {
["llm/v1/chat"] = to_huggingface,
["llm/v1/completions"] = to_huggingface,
}

function _M.to_format(request_table, model_info, route_type)
if not transformers_to[route_type] then
return nil, nil, fmt("no transformer for %s://%s", model_info.provider, route_type)
end

request_table = ai_shared.merge_config_defaults(request_table, model_info.options, model_info.route_type)

local ok, response_object, content_type, err =
pcall(transformers_to[route_type], route_type, request_table, model_info)
if err or not ok then
return nil, nil, fmt("error transforming to %s://%s", model_info.provider, route_type)
end

return response_object, content_type, nil
end

local function build_url(base_url, route_type)
return (route_type == "llm/v1/completions") and base_url or (base_url .. "/v1/chat/completions")
end

local function huggingface_endpoint(conf)
local parsed_url

local base_url
if conf.model.options and conf.model.options.upstream_url then
base_url = conf.model.options.upstream_url
elseif conf.model.name then
base_url = fmt(ai_shared.upstream_url_format[DRIVER_NAME], conf.model.name)
else
return nil
end

local url = build_url(base_url, conf.route_type)
parsed_url = socket_url.parse(url)

return parsed_url
end

function _M.configure_request(conf)
local parsed_url = huggingface_endpoint(conf)
if not parsed_url then
return kong.response.exit(400, "Could not parse the Hugging Face model endponit")
end
if parsed_url.path then
kong.service.request.set_path(parsed_url.path)
end
kong.service.request.set_scheme(parsed_url.scheme)
kong.service.set_target(parsed_url.host, tonumber(parsed_url.port) or 443)

local auth_header_name = conf.auth and conf.auth.header_name
local auth_header_value = conf.auth and conf.auth.header_value

if auth_header_name and auth_header_value then
kong.service.request.set_header(auth_header_name, auth_header_value)
end
return true, nil
end

function _M.post_request(conf)
-- Clear any response headers if needed
if ai_shared.clear_response_headers[DRIVER_NAME] then
for i, v in ipairs(ai_shared.clear_response_headers[DRIVER_NAME]) do
kong.response.clear_header(v)
end
end
end

function _M.subrequest(body, conf, http_opts, return_res_table)
-- Encode the request body as JSON
local body_string, err = cjson.encode(body)
if not body_string then
return nil, nil, "Failed to encode body to JSON: " .. (err or "unknown error")
end

-- Construct the Hugging Face API URL
local url = huggingface_endpoint(conf)
if not url then
return nil, nil, "Could not parse the Hugging Face model endpoint"
end
local url_string = url.scheme .. "://" .. url.host .. (url.path or "")

local headers = {
["Accept"] = "application/json",
["Content-Type"] = "application/json",
}

if conf.auth and conf.auth.header_name then
headers[conf.auth.header_name] = conf.auth.header_value
end

local method = "POST"

local res, err, httpc = ai_shared.http_request(url_string, body_string, method, headers, http_opts, return_res_table)

-- Handle the response
if not res then
return nil, nil, "Request to Hugging Face API failed: " .. (err or "unknown error")
end

-- Check if the response should be returned as a table
if return_res_table then
return {
status = res.status,
headers = res.headers,
body = res.body,
},
res.status,
nil,
httpc
else
if res.status >= 200 and res.status < 300 then
return res.body, res.status, nil
else
return res.body, res.status, "Hugging Face API returned status " .. res.status
end
end
end

return _M
12 changes: 11 additions & 1 deletion kong/llm/drivers/shared.lua
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ _M.upstream_url_format = {
gemini = "https://generativelanguage.googleapis.com",
gemini_vertex = "https://%s",
bedrock = "https://bedrock-runtime.%s.amazonaws.com",
mistral = "https://api.mistral.ai:443"
mistral = "https://api.mistral.ai:443",
huggingface = "https://api-inference.huggingface.co/models/%s",
}

_M.operation_map = {
Expand Down Expand Up @@ -147,6 +148,15 @@ _M.operation_map = {
gemini_vertex = {
["llm/v1/chat"] = {
path = "/v1/projects/%s/locations/%s/publishers/google/models/%s:%s",
},
},
huggingface = {
["llm/v1/completions"] = {
path = "/models/%s",
method = "POST",
},
["llm/v1/chat"] = {
path = "/models/%s",
method = "POST",
},
},
Expand Down
Loading

1 comment on commit c5199ff

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bazel Build

Docker image available kong/kong:c5199ff0c386bc8091273c8a99d6760eb382f102
Artifacts available https://github.com/Kong/kong/actions/runs/12004819836

Please sign in to comment.