Skip to content

Commit

Permalink
Implement Gemini-specific chunk merging logic (#201)
Browse files Browse the repository at this point in the history
Fixes #199.

---------

Co-authored-by: Hadley Wickham <[email protected]>
  • Loading branch information
jcheng5 and hadley authored Dec 18, 2024
1 parent ddadc17 commit 7526fb0
Show file tree
Hide file tree
Showing 4 changed files with 169 additions and 5 deletions.
143 changes: 140 additions & 3 deletions R/provider-gemini.R
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ method(chat_request, ProviderGemini) <- function(provider,
contents = contents,
tools = tools,
systemInstruction = system,
generation_config = generation_config,
generationConfig = generation_config,
!!!extra_args
))
req <- req_body_json(req, body)
Expand All @@ -131,7 +131,7 @@ method(stream_merge_chunks, ProviderGemini) <- function(provider, result, chunk)
if (is.null(result)) {
chunk
} else {
merge_dicts(result, chunk)
merge_gemini_chunks(result, chunk)
}
}
method(value_turn, ProviderGemini) <- function(provider, result, has_type = FALSE) {
Expand All @@ -158,6 +158,7 @@ method(value_turn, ProviderGemini) <- function(provider, result, has_type = FALS
)
}
})
contents <- compact(contents)
usage <- result$usageMetadata
tokens <- c(
usage$promptTokenCount %||% NA_integer_,
Expand Down Expand Up @@ -193,7 +194,13 @@ method(as_json, list(ProviderGemini, ToolDef)) <- function(provider, x) {
}

method(as_json, list(ProviderGemini, ContentText)) <- function(provider, x) {
list(text = x@text)
if (identical(x@text, "")) {
# Gemini tool call requests can include a Content with empty text,
# but it doesn't like it if you send this back
NULL
} else {
list(text = x@text)
}
}

# https://ai.google.dev/api/caching#FileData
Expand Down Expand Up @@ -249,3 +256,133 @@ method(as_json, list(ProviderGemini, TypeObject)) <- function(provider, x) {
required = as.list(names2(x@properties)[required])
))
}

# Gemini-specific merge logic --------------------------------------------------

merge_last <- function() {
function(left, right, path = NULL) {
right
}
}

merge_identical <- function() {
function(left, right, path = NULL) {
if (!identical(left, right)) {
stop("Expected identical values, but got ", deparse(left), " and ", deparse(right))
}
left
}
}

merge_any_or_empty <- function() {
function(left, right, path = NULL) {
if (!is.null(left) && nzchar(left)) {
left
} else if (!is.null(right) && nzchar(right)) {
right
} else {
""
}
}
}

merge_optional <- function(merge_func) {
function(left, right, path = NULL) {
if (is.null(left) && is.null(right)) {
NULL
} else {
merge_func(left, right, path)
}
}
}

merge_objects <- function(...) {
spec <- list(...)
function(left, right, path = NULL) {
# cat(paste(collapse = "", path), "\n")
stopifnot(is.list(left), is.list(right), all(nzchar(names(spec))))
mapply(names(spec), spec, FUN = function(key, value) {
value(left[[key]], right[[key]], c(path, ".", key))
}, USE.NAMES = TRUE, SIMPLIFY = FALSE)
}
}

merge_candidate_lists <- function(...) {
merge_unindexed <- merge_objects(...)
merge_indexed <- merge_objects(index = merge_identical(), ...)

function(left, right, path = NULL) {
if (length(left) == 1 && length(right) == 1) {
list(merge_unindexed(left[[1]], right[[1]], c(path, "[]")))
} else {
# left and right are lists of objects with [["index"]]
# We need to find the elements that have matching indices and merge them
left_indices <- vapply(left, `[[`, integer(1), "index")
right_indices <- vapply(right, `[[`, integer(1), "index")
# I know this seems weird, but according to Google's Go SDK, we should
# only retain indices on the right that *already* appear on the left.
# Citations:
# https://github.com/google/generative-ai-go/blob/3d14f4039eaef321b15bcbf70839389d7f000233/genai/client_test.go#L655
# https://github.com/google/generative-ai-go/blob/3d14f4039eaef321b15bcbf70839389d7f000233/genai/client.go#L396
lapply(left_indices, function(index) {
left_item <- left[[which(left_indices == index)]]
right_item <- right[[which(right_indices == index)]]
if (is.null(right_item)) {
left_item
} else {
merge_indexed(left_item, right_item, c(path, "[", index, "]"))
}
})
}
}
}

merge_append <- function() {
function(left, right, path = NULL) {
c(left, right)
}
}

merge_parts <- function() {
function(left, right, path = NULL) {
joined <- c(left, right)

# Identify text parts
is_text <- map_lgl(joined, ~is.list(.x) && identical(names(.x), "text"))

# Create groups for contiguous sections
groups <- cumsum(c(TRUE, diff(is_text) != 0))

# Split into groups and process each
split_parts <- split(joined, groups)
merged_split_parts <- map2(split_parts, split(is_text, groups), function(parts, is_text_group) {
if (!is_text_group[[1]]) {
# Non-text group: return parts unchanged
return(parts)
} else {
# Text group: merge text values
text_values <- map_chr(parts, ~.x[["text"]])
list(list(text = paste0(text_values, collapse = "")))
}
})
unlist(merged_split_parts, recursive = FALSE, use.names = FALSE)
}
}

# Put it all together...
merge_gemini_chunks <- merge_objects(
candidates = merge_candidate_lists(
content = merge_objects(
role = merge_any_or_empty(),
parts = merge_parts()
),
finishReason = merge_last(),
safetyRatings = merge_last(),
citationMetadata = merge_optional(
merge_objects(citationSources = merge_append())
),
tokenCount = merge_last()
),
promptFeedback = merge_last(),
usageMetadata = merge_last()
)
1 change: 1 addition & 0 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ last_response_json <- function() {
}
print_json <- function(x) {
cat(pretty_json(x))
cat("\n")
}
pretty_json <- function(x) {
jsonlite::toJSON(x, pretty = TRUE, auto_unbox = TRUE)
Expand Down
7 changes: 5 additions & 2 deletions tests/testthat/helper-provider.R
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,10 @@ test_tools_simple <- function(chat_fun) {
chat <- chat_fun(system_prompt = "Be very terse, not even punctuation.")
chat$register_tool(tool(function() "2024-01-01", "Return the current date"))

result <- chat$chat("What's the current date in YMD format?", echo = TRUE)
expect_output(
result <- chat$chat("What's the current date in YMD format?", echo = TRUE),
"2024-01-01"
)
expect_match(result, "2024-01-01")

result <- chat$chat("What month is it? Provide the full name")
Expand Down Expand Up @@ -114,7 +117,7 @@ test_tools_sequential <- function(chat_fun, total_calls) {

test_data_extraction <- function(chat_fun) {
article_summary <- type_object(
"Summary of the article.",
"Summary of the article. Preserve existing case.",
title = type_string("Content title"),
author = type_string("Name of the author")
)
Expand Down
23 changes: 23 additions & 0 deletions tests/testthat/test-provider-gemini.R
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,26 @@ test_that("can use images", {
test_images_inline(chat_fun)
test_images_remote_error(chat_fun)
})

# chunk merging ----------------------------------------------------------

test_that("can merge text output", {
# output from "tell me a joke" with text changed
messages <- c(
'{"candidates": [{"content": {"parts": [{"text": "a"}],"role": "model"}}],"usageMetadata": {"promptTokenCount": 5,"totalTokenCount": 5},"modelVersion": "gemini-1.5-flash"}',
'{"candidates": [{"content": {"parts": [{"text": "b"}],"role": "model"}}],"usageMetadata": {"promptTokenCount": 5,"totalTokenCount": 5},"modelVersion": "gemini-1.5-flash"}',
'{"candidates": [{"content": {"parts": [{"text": "c"}],"role": "model"},"finishReason": "STOP"}],"usageMetadata": {"promptTokenCount": 5,"candidatesTokenCount": 17,"totalTokenCount": 22},"modelVersion": "gemini-1.5-flash"}'
)
chunks <- lapply(messages, jsonlite::parse_json)

out <- merge_gemini_chunks(chunks[[1]], chunks[[2]])
out <- merge_gemini_chunks(out, chunks[[3]])

expect_equal(out$candidates[[1]]$content$parts[[1]]$text, "abc")
expect_equal(out$usageMetadata, list(
promptTokenCount = 5,
candidatesTokenCount = 17,
totalTokenCount = 22
))
expect_equal(out$candidates[[1]]$finishReason, "STOP")
})

0 comments on commit 7526fb0

Please sign in to comment.