diff --git a/mediapipe/tasks/web/genai_experimental/BUILD b/mediapipe/tasks/web/genai_experimental/BUILD new file mode 100644 index 0000000000..e568be4d3f --- /dev/null +++ b/mediapipe/tasks/web/genai_experimental/BUILD @@ -0,0 +1,115 @@ +# This contains the MediaPipe GenAI Tasks. + +load("@build_bazel_rules_nodejs//:index.bzl", "pkg_npm") +load("@npm//@bazel/rollup:index.bzl", "rollup_bundle") +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") +load( + "//mediapipe/framework/tool:mediapipe_files.bzl", + "mediapipe_files", +) + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +mediapipe_files(srcs = [ + "wasm/genai_experimental_wasm_internal.js", + "wasm/genai_experimental_wasm_internal.wasm", + "wasm/genai_experimental_wasm_nosimd_internal.js", + "wasm/genai_experimental_wasm_nosimd_internal.wasm", +]) + +GENAI_EXPERIMENTAL_LIBS = [ + "//mediapipe/tasks/web/core:fileset_resolver", + "//mediapipe/tasks/web/genai_experimental/rag_pipeline", +] + +mediapipe_ts_library( + name = "genai_experimental_lib", + srcs = ["index.ts"], + visibility = ["//visibility:public"], + deps = GENAI_EXPERIMENTAL_LIBS, +) + +mediapipe_ts_library( + name = "genai_experimental_types", + srcs = ["types.ts"], + visibility = ["//visibility:public"], + deps = GENAI_EXPERIMENTAL_LIBS, +) + +rollup_bundle( + name = "genai_experimental_bundle_mjs", + config_file = "//mediapipe/tasks/web:rollup.config.mjs", + entry_point = "index.ts", + format = "esm", + output_dir = False, + sourcemap = "true", + deps = [ + ":genai_experimental_lib", + "@npm//@rollup/plugin-commonjs", + "@npm//@rollup/plugin-node-resolve", + "@npm//@rollup/plugin-terser", + "@npm//google-protobuf", + ], +) + +rollup_bundle( + name = "genai_experimental_bundle_cjs", + config_file = "//mediapipe/tasks/web:rollup.config.mjs", + entry_point = "index.ts", + format = "cjs", + output_dir = False, + sourcemap = "true", + deps = [ + ":genai_experimental_lib", + "@npm//@rollup/plugin-commonjs", + "@npm//@rollup/plugin-node-resolve", + "@npm//@rollup/plugin-terser", + "@npm//google-protobuf", + ], +) + +genrule( + name = "genai_experimental_sources", + srcs = [ + ":genai_experimental_bundle_cjs", + ":genai_experimental_bundle_mjs", + ], + outs = [ + "genai_experimental_bundle.cjs", + "genai_experimental_bundle.cjs.map", + "genai_experimental_bundle.mjs", + "genai_experimental_bundle.mjs.map", + ], + cmd = ( + "for FILE in $(SRCS); do " + + " OUT_FILE=$(GENDIR)/mediapipe/tasks/web/genai_experimental/$$(" + + " basename $$FILE | sed -E 's/_([cm])js\\.js/.\\1js/'" + + " ); " + + " echo $$FILE ; echo $$OUT_FILE ; " + + " cp $$FILE $$OUT_FILE ; " + + "done;" + ), +) + +pkg_npm( + name = "genai_experimental_pkg", + package_name = "@mediapipe/tasks-__NAME__", + srcs = [ + "README.md", + "package.json", + ], + substitutions = { + "__NAME__": "genai_experimental", + "__DESCRIPTION__": "MediaPipe GenAI Experimental Tasks", + "__TYPES__": "genai_experimental.d.ts", + }, + tgz = "genai_experimental.tgz", + deps = [ + "package.json", + "wasm/genai_experimental_wasm_internal.js", + "wasm/genai_experimental_wasm_internal.wasm", + "wasm/genai_experimental_wasm_nosimd_internal.js", + "wasm/genai_experimental_wasm_nosimd_internal.wasm", + ":genai_experimental_sources", + ], +) diff --git a/mediapipe/tasks/web/genai_experimental/README.md b/mediapipe/tasks/web/genai_experimental/README.md new file mode 100644 index 0000000000..c0931add94 --- /dev/null +++ b/mediapipe/tasks/web/genai_experimental/README.md @@ -0,0 +1,28 @@ +# MediaPipe Tasks GenAI Package + +This package contains experimental GenAI tasks for MediaPipe. + +## RAG Pipeline Inference + +You can use the RAG Pipeline to augment an LLM Inference Task with existing +knowledge. + +``` +const genaiFileset = await FilesetResolver.forGenAiTasks(); +const genaiExperimentalFileset = + await FilesetResolver.forGenAiExperimentalTasks(); +const llmInference = await LlmInference.createFromModelPath(genaiFileset, ...); +const ragPipeline = await RagPipeline.createWithEmbeddingModel( + genaiExperimentalFileset, + llmInference, + EMBEDDING_MODEL_URL, +); +await ragPipeline.recordBatchedMemory([ + 'Paris is the capital of France.', + 'Berlin is the capital of Germany', +]); +const result = await ragPipeline.generateResponse( + 'What is the capital of France?', +); +console.log(result); +``` diff --git a/mediapipe/tasks/web/genai_experimental/index.ts b/mediapipe/tasks/web/genai_experimental/index.ts new file mode 100644 index 0000000000..7909242f49 --- /dev/null +++ b/mediapipe/tasks/web/genai_experimental/index.ts @@ -0,0 +1,25 @@ +/** + * Copyright 2024 The MediaPipe Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {FilesetResolver as FilesetResolverImpl} from '../../../tasks/web/core/fileset_resolver'; +import {RagPipeline as RagPipelineImpl} from '../../../tasks/web/genai_experimental/rag_pipeline/rag_pipeline'; + +// Declare the variables locally so that Rollup in OSS includes them explicitly +// as exports. +const FilesetResolver = FilesetResolverImpl; +const RagPipeline = RagPipelineImpl; + +export {FilesetResolver, RagPipeline}; diff --git a/mediapipe/tasks/web/genai_experimental/package.json b/mediapipe/tasks/web/genai_experimental/package.json new file mode 100644 index 0000000000..0a7311d014 --- /dev/null +++ b/mediapipe/tasks/web/genai_experimental/package.json @@ -0,0 +1,23 @@ +{ + "name": "@mediapipe/tasks-__NAME__", + "version": "__VERSION__", + "description": "__DESCRIPTION__", + "main": "__NAME___bundle.cjs", + "browser": "__NAME___bundle.mjs", + "module": "__NAME___bundle.mjs", + "exports": { + "import": "./__NAME___bundle.mjs", + "require": "./__NAME___bundle.cjs", + "default": "./__NAME___bundle.mjs", + "types": "./__TYPES__" + }, + "dependencies": { + "@mediapipe/tasks-genai": "__VERSION__" + }, + "author": "mediapipe@google.com", + "license": "Apache-2.0", + "type": "module", + "types": "__TYPES__", + "homepage": "http://mediapipe.dev", + "keywords": [ "AR", "ML", "Augmented", "MediaPipe", "MediaPipe Tasks" ] +} diff --git a/mediapipe/tasks/web/genai_experimental/rag_pipeline/BUILD b/mediapipe/tasks/web/genai_experimental/rag_pipeline/BUILD new file mode 100644 index 0000000000..247da3b1e2 --- /dev/null +++ b/mediapipe/tasks/web/genai_experimental/rag_pipeline/BUILD @@ -0,0 +1,19 @@ +# This contains the MediaPipe RAG Pipeline + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +mediapipe_ts_library( + name = "rag_pipeline", + srcs = ["rag_pipeline.ts"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/genai/llm_inference", + "//mediapipe/web/graph_runner:graph_runner_ts", + "//mediapipe/web/graph_runner:graph_runner_wasm_file_reference_ts", + ], +) diff --git a/mediapipe/tasks/web/genai_experimental/rag_pipeline/rag_pipeline.ts b/mediapipe/tasks/web/genai_experimental/rag_pipeline/rag_pipeline.ts new file mode 100644 index 0000000000..4ab128af44 --- /dev/null +++ b/mediapipe/tasks/web/genai_experimental/rag_pipeline/rag_pipeline.ts @@ -0,0 +1,270 @@ +/** + * Copyright 2024 The MediaPipe Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; +import {LlmInference} from '../../../../tasks/web/genai/llm_inference/llm_inference'; +import { + FileLocator, + WasmMediaPipeConstructor, + WasmModule, + createMediaPipeLib, +} from '../../../../web/graph_runner/graph_runner'; +import {WasmFileReference} from '../../../../web/graph_runner/graph_runner_wasm_file_reference'; +// Placeholder for internal dependency on trusted resource url + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +declare interface RagWasmModule extends WasmModule { + HEAPF32: Float32Array; + HEAPU8: Uint8Array; + + _addStringVectorEntry(vecPtr: number, strPtr: number): void; + _allocateStringVector(size: number): number; + _release(pointer: number): void; + ccall( + ident: string, + returnType: string, + argTypes: string[], + args: unknown[], + opts?: {async?: boolean}, + ): number; + UTF8ToString(encoded: number): string; + stringToNewUTF8(decoded: string): number; +} + +const PROMPT_TEMPLATE = `system +You are an assistant for question-answering tasks. You are given facts and you need to answer a question only using the facts provided. + +context +Here are the facts: +{memory} + +user +Use the facts to answer questions from the User. +User query:{query} + +model +`; + +type ProgressListener = (partial: string, done: boolean) => void; + +/** + * RAG (Retrieval-Augmented Generation) Pipeline API for MediaPipe. + * + * This API is highly experimental and will change. + */ +export class RagPipeline { + /** + * Initializes the Wasm runtime and creates a new `RagPipeline` using the + * provided `LLMInference` task. + * + * @export + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. + * @param llmInference The LLM Inference Task to use with this RAG pipeline. + * @param embeddingModel Either the buffer or url to the embedding model that + * will be used in the RAG pipeline to embed texts. + */ + static async createWithEmbeddingModel( + wasmFileset: WasmFileset, + llmInference: LlmInference, + embeddingModel: string | Uint8Array, + ): Promise { + const fileLocator: FileLocator = { + locateFile(file): string { + // We currently only use a single .wasm file and a single .data file (for + // the tasks that have to load assets). We need to revisit how we + // initialize the file locator if we ever need to differentiate between + // diffferent files. + if (file.endsWith('.wasm')) { + return wasmFileset.wasmBinaryPath.toString(); + } else if (wasmFileset.assetBinaryPath && file.endsWith('.data')) { + return wasmFileset.assetBinaryPath.toString(); + } + return file; + }, + }; + + const ragPipeline = await createMediaPipeLib( + RagPipeline.bind( + null, + llmInference, + ) as unknown as WasmMediaPipeConstructor, + wasmFileset.wasmLoaderPath, + wasmFileset.assetLoaderPath, + /* glCanvas= */ null, + fileLocator, + ); + + let wasmFileRef: WasmFileReference; + if (embeddingModel instanceof Uint8Array) { + wasmFileRef = WasmFileReference.loadFromArray( + ragPipeline.ragModule, + embeddingModel, + ); + } else { + wasmFileRef = await WasmFileReference.loadFromUrl( + ragPipeline.ragModule, + embeddingModel, + ); + } + await ragPipeline.wrapStringPtr(PROMPT_TEMPLATE, (promptStrPtr) => + ragPipeline.ragModule.ccall( + 'initializeChain', + 'void', + ['number', 'number', 'number'], + [wasmFileRef.offset, wasmFileRef.size, promptStrPtr], + {async: true}, + ), + ); + + wasmFileRef.free(); + return ragPipeline; + } + + /** @hideconstructor */ + constructor( + private readonly llmInference: LlmInference, + private readonly ragModule: RagWasmModule, + ) {} + + /** + * Instructs the RAG pipeline to memorize the records in the array. + * + * @export + * @param data The array of records to be remembered by RAG pipeline. + */ + recordBatchedMemory(data: string[]) { + const vecPtr = this.ragModule._allocateStringVector(data.length); + if (!vecPtr) { + throw new Error('Unable to allocate new string vector on heap.'); + } + for (const entry of data) { + this.wrapStringPtr(entry, (entryStringPtr) => { + this.ragModule._addStringVectorEntry(vecPtr, entryStringPtr); + }); + } + return this.ragModule.ccall( + 'recordBatchedMemory', + 'void', + ['number'], + [vecPtr], + {async: true}, + ); + } + + /** + * Uses the RAG pipeline to augment the query. + * + * @param query The users' query. + * @param topK The number of top related entries to be accounted in. + * @return RAG's augmented query. + */ + private async buildPrompt(query: string, topK = 2): Promise { + const result = await this.wrapStringPtr(query, (queryStrPtr) => + this.ragModule.ccall( + 'invoke', + 'number', + ['number', 'number'], + [queryStrPtr, topK], + {async: true}, + ), + ); + return this.ragModule.UTF8ToString(result); + } + + /** + * Uses RAG to augment the query and run LLM Inference. `topK` defaults to 2. + * + * @export + * @param query The users' query. + * @return The generated text result. + */ + generateResponse(query: string): Promise; + /** + * Uses RAG to augment the query and run LLM Inference. + * + * @export + * @param query The users' query. + * @param topK The number of top related entries to be accounted in. + * @return The generated text result. + */ + generateResponse(query: string, topK: number): Promise; + /** + * Uses RAG to augment the query and run LLM Inference. + * + * @export + * @param query The users' query. + * @param progressListener A listener that will be triggered when the task has + * new partial response generated. + * @return The generated text result. + */ + generateResponse( + query: string, + progressListener: ProgressListener, + ): Promise; + /** + * Uses RAG to augment the query and run LLM Inference. `topK` defaults to 2. + * + * @export + * @param query The users' query. + * @param topK The number of top related entries to be accounted in. + * @param progressListener A listener that will be triggered when the task has + * new partial response generated. + * @return The generated text result. + */ + generateResponse( + query: string, + topK: number, + progressListener: ProgressListener, + ): Promise; + /** @export */ + generateResponse( + query: string, + topKOrProgressListener?: number | ProgressListener, + progressListener?: ProgressListener, + ): Promise { + const topK = + typeof topKOrProgressListener === 'number' ? topKOrProgressListener : 2; + progressListener = + typeof topKOrProgressListener === 'function' + ? topKOrProgressListener + : progressListener; + return this.buildPrompt(query, topK).then((prompt) => { + if (progressListener) { + return this.llmInference.generateResponse(prompt, progressListener); + } else { + return this.llmInference.generateResponse(prompt); + } + }); + } + + private wrapStringPtr( + stringData: string, + userFunction: (strPtr: number) => T, + ): T { + const stringDataPtr = this.ragModule.stringToNewUTF8(stringData); + const res = userFunction(stringDataPtr); + this.ragModule._release(stringDataPtr); + return res; + } + + /** @export */ + close() {} +} + + diff --git a/mediapipe/tasks/web/genai_experimental/types.ts b/mediapipe/tasks/web/genai_experimental/types.ts new file mode 100644 index 0000000000..0338bac83b --- /dev/null +++ b/mediapipe/tasks/web/genai_experimental/types.ts @@ -0,0 +1,18 @@ +/** + * Copyright 2024 The MediaPipe Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +export * from '../../../tasks/web/core/fileset_resolver'; +export * from '../../../tasks/web/genai_experimental/rag_pipeline/rag_pipeline';