Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(headers): use as supplied + support multiple headers #260

Merged
merged 1 commit into from
Oct 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion lua/kulala/cmd/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,12 @@ M.run_parser = function(req, callback)
end
end
INT_PROCESSING.redirect_response_body_to_file(result.redirect_response_body_to_files)
PARSER.scripts.javascript.run("post_request", result.scripts.post_request)

local has_post_request_scripts = #result.scripts.post_request.inline > 0
or #result.scripts.pre_request.files > 0
if has_post_request_scripts then
PARSER.scripts.javascript.run("post_request", result.scripts.post_request)
end
Api.trigger("after_request")
end
Fs.delete_request_scripts_files()
Expand Down
2 changes: 1 addition & 1 deletion lua/kulala/globals/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ local M = {}

local plugin_tmp_dir = FS.get_plugin_tmp_dir()

M.VERSION = "4.0.2"
M.VERSION = "4.0.3"
M.UI_ID = "kulala://ui"
M.SCRATCHPAD_ID = "kulala://scratchpad"
M.HEADERS_FILE = plugin_tmp_dir .. "/headers.txt"
Expand Down
41 changes: 31 additions & 10 deletions lua/kulala/internal_processing/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,39 @@ local function get_nested_value(t, key)
return value
end

local get_headers_as_table = function()
---Function to get the last headers as a table
---@description Reads the headers file and returns the headers as a table.
---In some cases the headers file might contain multiple header sections,
---e.g. if you have follow-redirections enabled.
---This function will return the headers of the last response.
---@return table
local get_last_headers_as_table = function()
local headers_file = FS.read_file(GLOBALS.HEADERS_FILE):gsub("\r\n", "\n")
local lines = vim.split(headers_file, "\n")
local headers_table = {}
-- INFO:
-- We only want the headers of the last response
-- so we reset the headers_table only each time the previous line was empty
-- and we also have new headers data
local previously_empty = false
for _, header in ipairs(lines) do
if header:find(":") ~= nil then
local kv = vim.split(header, ":")
local key = kv[1]
-- the value should be everything after the first colon
-- but we can't use slice and join because the value might contain colons
local value = header:sub(#key + 2)
headers_table[key] = vim.trim(value)
local empty_line = header == ""
if empty_line then
previously_empty = true
else
if previously_empty then
headers_table = {}
end
previously_empty = false
if header:find(":") ~= nil then
local kv = vim.split(header, ":")
local key = kv[1]
-- INFO:
-- the value should be everything after the first colon
-- but we can't use slice and join because the value might contain colons
local value = header:sub(#key + 2)
headers_table[key] = vim.trim(value)
end
end
end
return headers_table
Expand Down Expand Up @@ -78,7 +99,7 @@ local get_cookies_as_table = function()
end

local get_lower_headers_as_table = function()
local headers = get_headers_as_table()
local headers = get_last_headers_as_table()
local headers_table = {}
for key, value in pairs(headers) do
headers_table[key:lower()] = value
Expand All @@ -101,7 +122,7 @@ end
M.set_env_for_named_request = function(name, body)
local named_request = {
response = {
headers = get_headers_as_table(),
headers = get_last_headers_as_table(),
body = body,
cookies = get_cookies_as_table(),
},
Expand Down
46 changes: 22 additions & 24 deletions lua/kulala/parser/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ local REQUEST_VARIABLES = require("kulala.parser.request_variables")
local STRING_UTILS = require("kulala.utils.string")
local PARSER_UTILS = require("kulala.parser.utils")
local TS = require("kulala.parser.treesitter")
local PLUGIN_TMP_DIR = FS.get_plugin_tmp_dir()
local CURL_FORMAT_FILE = FS.get_plugin_path({ "parser", "curl-format.json" })
local Logger = require("kulala.logger")

Expand Down Expand Up @@ -342,8 +341,8 @@ M.get_document = function()
-- dynamic variables are defined as `{{$variable_name}}`
local key, value = line:match("^([^:]+):%s*(.*)$")
if key and value then
request.headers[key:lower()] = value
request.headers_raw[key:lower()] = value
request.headers[key] = value
request.headers_raw[key] = value
end
elseif is_request_line == true then
-- Request line (e.g., GET http://example.com HTTP/1.1)
Expand Down Expand Up @@ -455,13 +454,9 @@ end
---@field file string -- The file path to write the response body to
---@field overwrite boolean -- Whether to overwrite the file if it already exists

---@class ScriptsItems
---@field inline table -- Inline post-request handler scripts - each element is a line of the script
---@field files table -- File post-request handler scripts - each element is a file path
---
---@class Scripts
---@field pre_request ScriptsItems -- Pre-request handler scripts
---@field post_request ScriptsItems -- Post-request handler scripts
---@field pre_request ScriptData -- Pre-request handler scripts
---@field post_request ScriptData -- Post-request handler scripts

---@class Request
---@field metadata table[] -- Metadata of the request
Expand Down Expand Up @@ -575,12 +570,12 @@ M.parse = function(start_request_linenr)
res.url, res.headers, res.body =
replace_variables_in_url_headers_body(res, document_variables, env, has_pre_request_scripts)

-- Merge headers from the $shared environment if it exists
-- Merge headers from the $shared environment if it does not exist in the request
-- this ensures that you can always override the headers in the request
if DB.find_unique("http_client_env_shared") then
local default_headers = DB.find_unique("http_client_env_shared")["$default_headers"]
if default_headers then
for key, value in pairs(default_headers) do
key = key:lower()
if res.headers[key] == nil then
res.headers[key] = value
end
Expand Down Expand Up @@ -617,13 +612,15 @@ M.parse = function(start_request_linenr)
table.insert(res.cmd, res.method)

local is_graphql = PARSER_UTILS.contains_meta_tag(res, "graphql")
or PARSER_UTILS.contains_header(res.headers, "x-request-type", "GraphQL")
or PARSER_UTILS.contains_header(res.headers, "x-request-type", "graphql")
if CONFIG.get().treesitter then
-- treesitter parser handles graphql requests before this point
is_graphql = false
end

if res.headers["content-type"] ~= nil and res.body ~= nil then
local content_type_header_name, content_type_header_value = PARSER_UTILS.get_header(res.headers, "content-type")

if content_type_header_name and content_type_header_value and res.body ~= nil then
-- check if we are a graphql query
-- we need this here, because the user could have defined the content-type
-- as application/json, but the body is a graphql query
Expand All @@ -633,9 +630,9 @@ M.parse = function(start_request_linenr)
if gql_json then
table.insert(res.cmd, "--data")
table.insert(res.cmd, gql_json)
res.headers["content-type"] = "application/json"
res.headers[content_type_header_name] = "application/json"
end
elseif res.headers["content-type"]:find("^multipart/form%-data") then
elseif content_type_header_value:find("^multipart/form%-data") then
local tmp_file = FS.get_binary_temp_file(res.body)
if tmp_file ~= nil then
table.insert(res.cmd, "--data-binary")
Expand All @@ -654,31 +651,32 @@ M.parse = function(start_request_linenr)
if gql_json then
table.insert(res.cmd, "--data")
table.insert(res.cmd, gql_json)
res.headers["content-type"] = "application/json"
res.headers[content_type_header_name] = "application/json"
end
end
end

if res.headers["authorization"] then
local auth_header = res.headers["authorization"]
local authtype = auth_header:match("^(%w+)%s+.*")
local auth_header_name, auth_header_value = PARSER_UTILS.get_header(res.headers, "authorization")

if auth_header_name and auth_header_value then
local authtype = auth_header_value:match("^(%w+)%s+.*")
if authtype == nil then
authtype = auth_header:match("^(%w+)%s*$")
authtype = auth_header_value:match("^(%w+)%s*$")
end

if authtype ~= nil then
authtype = authtype:lower()

if authtype == "ntlm" or authtype == "negotiate" or authtype == "digest" or authtype == "basic" then
local match, authuser, authpw = auth_header:match("^(%w+)%s+([^%s:]+)%s*[:%s]%s*([^%s]+)%s*$")
local match, authuser, authpw = auth_header_value:match("^(%w+)%s+([^%s:]+)%s*[:%s]%s*([^%s]+)%s*$")
if match ~= nil or (authtype == "ntlm" or authtype == "negotiate") then
table.insert(res.cmd, "--" .. authtype)
table.insert(res.cmd, "-u")
table.insert(res.cmd, (authuser or "") .. ":" .. (authpw or ""))
res.headers["authorization"] = nil
res.headers[auth_header_name] = nil
end
elseif authtype == "aws" then
local key, secret, optional = auth_header:match("^%w+%s([^%s]+)%s*([^%s]+)[%s$]+(.*)$")
local key, secret, optional = auth_header_value:match("^%w+%s([^%s]+)%s*([^%s]+)[%s$]+(.*)$")
local token = optional:match("token:([^%s]+)")
local region = optional:match("region:([^%s]+)")
local service = optional:match("service:([^%s]+)")
Expand All @@ -697,7 +695,7 @@ M.parse = function(start_request_linenr)
table.insert(res.cmd, "-H")
table.insert(res.cmd, "x-amz-security-token:" .. token)
end
res.headers["authorization"] = nil
res.headers[auth_header_name] = nil
end
end
end
Expand Down
104 changes: 100 additions & 4 deletions lua/kulala/parser/utils.lua
Original file line number Diff line number Diff line change
@@ -1,21 +1,117 @@
local M = {}

-- PERF: we do a lot of if else blocks with repeating loops
-- we could "optimize" this by using a single loop and if else blocks
-- that would make the code more readable and easier to maintain
-- but it would also make it slower

---Check if a request has a specific meta tag
---@param request table The request to check
---@param tag string The meta tag to check for
M.contains_meta_tag = function(request, tag)
tag = tag:lower()
for _, meta in ipairs(request.metadata) do
if meta.name == tag then
if meta.name:lower() == tag then
return true
end
end
return false
end

---Check if a header is present in the request
---@param headers table The headers to check
---@param header string The header name to check
---@param value string|nil The value to check for or nil if only the header name should be checked
---@return boolean
M.contains_header = function(headers, header, value)
for k, v in pairs(headers) do
if k == header and v == value then
return true
header = header:lower()
value = value and value:lower() or nil
vim.print("header: " .. header .. " value: " .. value)
if value == nil then
for k, _ in pairs(headers) do
if k:lower() == header then
return true
end
end
else
for k, v in pairs(headers) do
if k:lower() == header and v:lower() == value then
return true
end
end
end
return false
end

---Get the value of a header from the request
---@param headers table The headers to check
---@param header string The header name to check
---@param dont_ignore_case boolean|nil If true, the header name will be case sensitive
---@return string|nil
M.get_header_value = function(headers, header, dont_ignore_case)
header = dont_ignore_case and header or header:lower()
for k, v in pairs(headers) do
if k == header then
return v
end
end
return nil
end

---Get the name of a header from the request
---@param headers table The headers to check
---@param header string The header name to check
---@param dont_ignore_case boolean|nil If true, the header name will be case sensitive
---@return string|nil
M.get_header_name = function(headers, header, dont_ignore_case)
header = dont_ignore_case and header or header:lower()
for k, _ in pairs(headers) do
if k:lower() == header then
return k
end
end
return nil
end

---Get a header from the request
---@param headers table The headers to check
---@param header string The header name to check
---@param value string|nil The value to check for or nil if only the header name should be checked
---@param dont_ignore_case boolean|nil If true, the header name will be case sensitive
---@return (string|nil), (string|nil) The header name and value or nil if not found
M.get_header = function(headers, header, value, dont_ignore_case)
header = dont_ignore_case and header or header:lower()
value = value and (dont_ignore_case and value or value:lower()) or nil
if dont_ignore_case then
if value == nil then
for k, _ in pairs(headers) do
if k == header then
return k, headers[k]
end
end
else
for k, v in pairs(headers) do
if k == header and v == value then
return k, v
end
end
end
else
if value == nil then
for k, _ in pairs(headers) do
if k:lower() == header then
return k, headers[k]
end
end
else
for k, v in pairs(headers) do
if k:lower() == header and v:lower() == value then
return k, v
end
end
end
end
return nil, nil
end

return M