diff --git a/R/provider-gemini.R b/R/provider-gemini.R index 293661e..b84a241 100644 --- a/R/provider-gemini.R +++ b/R/provider-gemini.R @@ -107,7 +107,7 @@ method(chat_request, ProviderGemini) <- function(provider, contents = contents, tools = tools, systemInstruction = system, - generation_config = generation_config, + generationConfig = generation_config, !!!extra_args )) req <- req_body_json(req, body) @@ -131,7 +131,7 @@ method(stream_merge_chunks, ProviderGemini) <- function(provider, result, chunk) if (is.null(result)) { chunk } else { - merge_dicts(result, chunk) + merge_gemini_chunks(result, chunk) } } method(value_turn, ProviderGemini) <- function(provider, result, has_type = FALSE) { @@ -158,6 +158,7 @@ method(value_turn, ProviderGemini) <- function(provider, result, has_type = FALS ) } }) + contents <- compact(contents) usage <- result$usageMetadata tokens <- c( usage$promptTokenCount %||% NA_integer_, @@ -193,7 +194,13 @@ method(as_json, list(ProviderGemini, ToolDef)) <- function(provider, x) { } method(as_json, list(ProviderGemini, ContentText)) <- function(provider, x) { - list(text = x@text) + if (identical(x@text, "")) { + # Gemini tool call requests can include a Content with empty text, + # but it doesn't like it if you send this back + NULL + } else { + list(text = x@text) + } } # https://ai.google.dev/api/caching#FileData @@ -249,3 +256,133 @@ method(as_json, list(ProviderGemini, TypeObject)) <- function(provider, x) { required = as.list(names2(x@properties)[required]) )) } + +# Gemini-specific merge logic -------------------------------------------------- + +merge_last <- function() { + function(left, right, path = NULL) { + right + } +} + +merge_identical <- function() { + function(left, right, path = NULL) { + if (!identical(left, right)) { + stop("Expected identical values, but got ", deparse(left), " and ", deparse(right)) + } + left + } +} + +merge_any_or_empty <- function() { + function(left, right, path = NULL) { + if (!is.null(left) && nzchar(left)) { + left + } else if (!is.null(right) && nzchar(right)) { + right + } else { + "" + } + } +} + +merge_optional <- function(merge_func) { + function(left, right, path = NULL) { + if (is.null(left) && is.null(right)) { + NULL + } else { + merge_func(left, right, path) + } + } +} + +merge_objects <- function(...) { + spec <- list(...) + function(left, right, path = NULL) { + # cat(paste(collapse = "", path), "\n") + stopifnot(is.list(left), is.list(right), all(nzchar(names(spec)))) + mapply(names(spec), spec, FUN = function(key, value) { + value(left[[key]], right[[key]], c(path, ".", key)) + }, USE.NAMES = TRUE, SIMPLIFY = FALSE) + } +} + +merge_candidate_lists <- function(...) { + merge_unindexed <- merge_objects(...) + merge_indexed <- merge_objects(index = merge_identical(), ...) + + function(left, right, path = NULL) { + if (length(left) == 1 && length(right) == 1) { + list(merge_unindexed(left[[1]], right[[1]], c(path, "[]"))) + } else { + # left and right are lists of objects with [["index"]] + # We need to find the elements that have matching indices and merge them + left_indices <- vapply(left, `[[`, integer(1), "index") + right_indices <- vapply(right, `[[`, integer(1), "index") + # I know this seems weird, but according to Google's Go SDK, we should + # only retain indices on the right that *already* appear on the left. + # Citations: + # https://github.com/google/generative-ai-go/blob/3d14f4039eaef321b15bcbf70839389d7f000233/genai/client_test.go#L655 + # https://github.com/google/generative-ai-go/blob/3d14f4039eaef321b15bcbf70839389d7f000233/genai/client.go#L396 + lapply(left_indices, function(index) { + left_item <- left[[which(left_indices == index)]] + right_item <- right[[which(right_indices == index)]] + if (is.null(right_item)) { + left_item + } else { + merge_indexed(left_item, right_item, c(path, "[", index, "]")) + } + }) + } + } +} + +merge_append <- function() { + function(left, right, path = NULL) { + c(left, right) + } +} + +merge_parts <- function() { + function(left, right, path = NULL) { + joined <- c(left, right) + + # Identify text parts + is_text <- map_lgl(joined, ~is.list(.x) && identical(names(.x), "text")) + + # Create groups for contiguous sections + groups <- cumsum(c(TRUE, diff(is_text) != 0)) + + # Split into groups and process each + split_parts <- split(joined, groups) + merged_split_parts <- map2(split_parts, split(is_text, groups), function(parts, is_text_group) { + if (!is_text_group[[1]]) { + # Non-text group: return parts unchanged + return(parts) + } else { + # Text group: merge text values + text_values <- map_chr(parts, ~.x[["text"]]) + list(list(text = paste0(text_values, collapse = ""))) + } + }) + unlist(merged_split_parts, recursive = FALSE, use.names = FALSE) + } +} + +# Put it all together... +merge_gemini_chunks <- merge_objects( + candidates = merge_candidate_lists( + content = merge_objects( + role = merge_any_or_empty(), + parts = merge_parts() + ), + finishReason = merge_last(), + safetyRatings = merge_last(), + citationMetadata = merge_optional( + merge_objects(citationSources = merge_append()) + ), + tokenCount = merge_last() + ), + promptFeedback = merge_last(), + usageMetadata = merge_last() +) diff --git a/R/utils.R b/R/utils.R index a612087..d497733 100644 --- a/R/utils.R +++ b/R/utils.R @@ -47,6 +47,7 @@ last_response_json <- function() { } print_json <- function(x) { cat(pretty_json(x)) + cat("\n") } pretty_json <- function(x) { jsonlite::toJSON(x, pretty = TRUE, auto_unbox = TRUE) diff --git a/tests/testthat/helper-provider.R b/tests/testthat/helper-provider.R index b0dde20..55edd43 100644 --- a/tests/testthat/helper-provider.R +++ b/tests/testthat/helper-provider.R @@ -52,7 +52,10 @@ test_tools_simple <- function(chat_fun) { chat <- chat_fun(system_prompt = "Be very terse, not even punctuation.") chat$register_tool(tool(function() "2024-01-01", "Return the current date")) - result <- chat$chat("What's the current date in YMD format?", echo = TRUE) + expect_output( + result <- chat$chat("What's the current date in YMD format?", echo = TRUE), + "2024-01-01" + ) expect_match(result, "2024-01-01") result <- chat$chat("What month is it? Provide the full name") @@ -114,7 +117,7 @@ test_tools_sequential <- function(chat_fun, total_calls) { test_data_extraction <- function(chat_fun) { article_summary <- type_object( - "Summary of the article.", + "Summary of the article. Preserve existing case.", title = type_string("Content title"), author = type_string("Name of the author") ) diff --git a/tests/testthat/test-provider-gemini.R b/tests/testthat/test-provider-gemini.R index a2546de..90d675d 100644 --- a/tests/testthat/test-provider-gemini.R +++ b/tests/testthat/test-provider-gemini.R @@ -48,3 +48,26 @@ test_that("can use images", { test_images_inline(chat_fun) test_images_remote_error(chat_fun) }) + +# chunk merging ---------------------------------------------------------- + +test_that("can merge text output", { + # output from "tell me a joke" with text changed + messages <- c( + '{"candidates": [{"content": {"parts": [{"text": "a"}],"role": "model"}}],"usageMetadata": {"promptTokenCount": 5,"totalTokenCount": 5},"modelVersion": "gemini-1.5-flash"}', + '{"candidates": [{"content": {"parts": [{"text": "b"}],"role": "model"}}],"usageMetadata": {"promptTokenCount": 5,"totalTokenCount": 5},"modelVersion": "gemini-1.5-flash"}', + '{"candidates": [{"content": {"parts": [{"text": "c"}],"role": "model"},"finishReason": "STOP"}],"usageMetadata": {"promptTokenCount": 5,"candidatesTokenCount": 17,"totalTokenCount": 22},"modelVersion": "gemini-1.5-flash"}' + ) + chunks <- lapply(messages, jsonlite::parse_json) + + out <- merge_gemini_chunks(chunks[[1]], chunks[[2]]) + out <- merge_gemini_chunks(out, chunks[[3]]) + + expect_equal(out$candidates[[1]]$content$parts[[1]]$text, "abc") + expect_equal(out$usageMetadata, list( + promptTokenCount = 5, + candidatesTokenCount = 17, + totalTokenCount = 22 + )) + expect_equal(out$candidates[[1]]$finishReason, "STOP") +})