From 9d8270b499177044d9948b1d19fbc7edb4dbe1b3 Mon Sep 17 00:00:00 2001 From: Devin Robison Date: Mon, 6 Nov 2023 14:53:30 -0700 Subject: [PATCH 1/2] Docs updates and added ability to use Nemo or OpenAI services in RAG pipeline --- examples/llm/common/utils.py | 12 ++++-- examples/llm/rag/README.md | 8 ++-- examples/llm/rag/persistant_pipeline.py | 39 ++++++++++++-------- examples/llm/rag/run.py | 29 +++++++++++---- examples/llm/rag/standalone_pipeline.py | 36 ++++++++++++------ morpheus/llm/nodes/extracter_node.py | 2 +- morpheus/llm/services/openai_chat_service.py | 2 +- 7 files changed, 83 insertions(+), 45 deletions(-) diff --git a/examples/llm/common/utils.py b/examples/llm/common/utils.py index 158e97535d..7121f3abf4 100644 --- a/examples/llm/common/utils.py +++ b/examples/llm/common/utils.py @@ -17,6 +17,7 @@ from langchain.embeddings import HuggingFaceEmbeddings from morpheus.llm.services.nemo_llm_service import NeMoLLMService +from morpheus.llm.services.openai_chat_service import OpenAIChatService from morpheus.service.vdb.milvus_vector_db_service import MilvusVectorDBService from morpheus.service.vdb.utils import VectorDBServiceFactory @@ -30,12 +31,17 @@ def build_huggingface_embeddings(model_name: str, model_kwargs: dict = None, enc def build_llm_service(model_name: str, model_type, **model_kwargs): - if (model_type.lower() in ('nemo', )): - llm_service = NeMoLLMService() + if (model_type.lower() in ('nemollm',)): + llm_service_cls = NeMoLLMService + elif (model_type.lower() in ('openai',)): + llm_service_cls = OpenAIChatService + if ("tokens_to_generate" in model_kwargs): + model_kwargs.pop("tokens_to_generate") else: - # TODO(Devin) : Add additional options raise RuntimeError(f"Unsupported LLM model type: {model_type}") + llm_service = llm_service_cls() + return llm_service.get_client(model_name, **model_kwargs) diff --git a/examples/llm/rag/README.md b/examples/llm/rag/README.md index 3af143c02e..972838c0e5 100644 --- a/examples/llm/rag/README.md +++ b/examples/llm/rag/README.md @@ -214,14 +214,14 @@ pipeline option of `rag`: ```bash export NGC_API_KEY=[YOUR_KEY_HERE] -NGC_API_KEY=${NGC_API_KEY} python examples/llm/main.py rag pipeline +python examples/llm/main.py rag persistent --model_name gpt-43b-002 --model_type NemoLLM ``` **Using OpenAI LLM models** ```bash export OPENAI_API_KEY=[YOUR_KEY_HERE] -OPENAI_API_KEY=${OPENAI_API_KEY} python examples/llm/main.py rag pipeline +python examples/llm/main.py rag persistent --model_name gpt-3.5-turbo --model_type OpenAI ``` ### Run example (Persistent Pipeline): @@ -232,14 +232,14 @@ OPENAI_API_KEY=${OPENAI_API_KEY} python examples/llm/main.py rag pipeline ```bash export NGC_API_KEY=[YOUR_KEY_HERE] -python examples/llm/main.py rag persistent +python examples/llm/main.py rag persistent --model_name gpt-43b-002 --model_type NemoLLM ``` **Using OpenAI LLM models** ```bash export OPENAI_API_KEY=[YOUR_KEY_HERE] -python examples/llm/main.py rag persistent +python examples/llm/main.py rag persistent --model_name gpt-3.5-turbo --model_type OpenAI ``` ### Options: diff --git a/examples/llm/rag/persistant_pipeline.py b/examples/llm/rag/persistant_pipeline.py index ed3018b902..f38405a621 100644 --- a/examples/llm/rag/persistant_pipeline.py +++ b/examples/llm/rag/persistant_pipeline.py @@ -59,7 +59,6 @@ def supports_cpp_node(self): return False def _build(self, builder: mrc.Builder, in_ports_streams: typing.List[StreamPair]) -> typing.List[StreamPair]: - assert len(in_ports_streams) == 1, "Only 1 input supported" # Create a broadcast node @@ -84,7 +83,6 @@ def filter_lower_fn(data: MessageMeta): def _build_engine(model_name: str, vdb_service: VectorDBResourceService): - engine = LLMEngine() engine.add_node("extracter", node=ExtracterNode()) @@ -110,47 +108,54 @@ def _build_engine(model_name: str, vdb_service: VectorDBResourceService): def pipeline( - num_threads, - pipeline_batch_size, - model_max_batch_size, - embedding_size, - model_name, + num_threads, + pipeline_batch_size, + model_max_batch_size, + embedding_size, + model_name, ): - + # Initialize the configuration object for the pipeline config = Config() config.mode = PipelineModes.OTHER - # Below properties are specified by the command line config.num_threads = num_threads config.pipeline_batch_size = pipeline_batch_size config.model_max_batch_size = model_max_batch_size config.mode = PipelineModes.NLP + + # Set a buffer size for stages to pass data between each other config.edge_buffer_size = 128 + # Build a vector database service with a specified embedding size vdb_service = build_milvus_service(embedding_size=embedding_size) + # Define tasks for upload and retrieval operations upload_task = {"task_type": "upload", "task_dict": {"input_keys": ["questions"], }} retrieve_task = {"task_type": "retrieve", "task_dict": {"input_keys": ["questions", "embedding"], }} pipe = Pipeline(config) - # Source of the retrieval queries + # Add a Kafka source stage to ingest retrieval queries retrieve_source = pipe.add_stage(KafkaSourceStage(config, bootstrap_servers="auto", input_topic=["retrieve_input"])) + # Deserialize the messages for the retrieve queries retrieve_deserialize = pipe.add_stage( DeserializeStage(config, message_type=ControlMessage, task_type="llm_engine", task_payload=retrieve_task)) + # Connect the Kafka source to the deserialize stage for retrieve queries pipe.add_edge(retrieve_source, retrieve_deserialize) - # Source of continually uploading documents + # Add a Kafka source stage to ingest documents for uploading upload_source = pipe.add_stage(KafkaSourceStage(config, bootstrap_servers="auto", input_topic=["upload"])) + # Deserialize the messages for the upload documents upload_deserialize = pipe.add_stage( DeserializeStage(config, message_type=ControlMessage, task_type="llm_engine", task_payload=upload_task)) + # Connect the Kafka source to the deserialize stage for upload documents pipe.add_edge(upload_source, upload_deserialize) - # Join the sources into one for tokenization + # Preprocess stage for NLP tasks that joins both upload and retrieve sources preprocess = pipe.add_stage( PreprocessNLPStage(config, vocab_hash_file="data/bert-base-uncased-hash.txt", @@ -159,9 +164,11 @@ def pipeline( add_special_tokens=False, column='content')) + # Connect deserialize stages to the preprocess stage pipe.add_edge(upload_deserialize, preprocess) pipe.add_edge(retrieve_deserialize, preprocess) + # Inference stage configured to use a Triton server inference = pipe.add_stage( TritonInferenceStage(config, model_name=model_name, @@ -170,21 +177,22 @@ def pipeline( use_shared_memory=True)) pipe.add_edge(preprocess, inference) - # Split the results based on the task + # Split the results based on the task type split = pipe.add_stage(SplitStage(config)) pipe.add_edge(inference, split) - # If it's a retrieve task, branch to the LLM engine for RAG + # For retrieve tasks, connect to an LLM engine stage configured for RAG retrieve_llm_engine = pipe.add_stage( LLMEngineStage(config, engine=_build_engine(model_name=model_name, vdb_service=vdb_service.load_resource("RSS")))) pipe.add_edge(split.output_ports[0], retrieve_llm_engine) + # Write retrieve results to a Kafka topic retrieve_results = pipe.add_stage( WriteToKafkaStage(config, bootstrap_servers="auto", output_topic="retrieve_output")) pipe.add_edge(retrieve_llm_engine, retrieve_results) - # If its an upload task, then send it to the database + # For upload tasks, send the data to the vector database upload_vdb = pipe.add_stage( WriteToVectorDBStage(config, resource_name="RSS", @@ -195,7 +203,6 @@ def pipeline( start_time = time.time() - # Run the pipeline pipe.run() return start_time diff --git a/examples/llm/rag/run.py b/examples/llm/rag/run.py index 54ddb6b734..7aa30a0197 100644 --- a/examples/llm/rag/run.py +++ b/examples/llm/rag/run.py @@ -44,12 +44,19 @@ def run(): type=click.IntRange(min=1), help="Max batch size to use for the model", ) +@click.option( + "--model_type", + type=click.Choice(['OpenAI', 'NemoLLM'], case_sensitive=False), + default='NemoLLM', + help="Type of the large language model to use", +) @click.option( "--model_name", - required=True, type=str, - default='gpt-43b-002', - help="The name of the large language model that is deployed on Triton server", + default=None, # Set default to None to detect if the user provided a value + help="The name of the model that is deployed on Triton server", + callback=lambda ctx, param, value: (value if value is not None else + ('gpt-3.5-turbo' if ctx.params['model_type'].lower() == 'openai' else 'gpt-43b-002')) ) @click.option( "--vdb_resource_name", @@ -65,7 +72,6 @@ def run(): help="Number of times to repeat the input query. Useful for testing performance.", ) def pipeline(**kwargs): - from .standalone_pipeline import standalone return standalone(**kwargs) @@ -97,15 +103,22 @@ def pipeline(**kwargs): type=click.IntRange(min=1), help="The output size of the embedding calculation. Depends on the model supplied by --model_name", ) +@click.option( + "--model_type", + type=click.Choice(['OpenAI', 'NemoLLM'], case_sensitive=False), + default='NemoLLM', + help="Type of the large language model to use", +) @click.option( "--model_name", - required=True, type=str, - default='gpt-43b-002', + show_default=True, + default=None, # Set default to None, it will be dynamically determined by the callback help="The name of the model that is deployed on Triton server", + callback=lambda ctx, param, value: (value if value is not None else + ('gpt-3.5-turbo' if ctx.params['model_type'].lower() == 'openai' else 'gpt-43b-002')) ) -def persistant(**kwargs): - +def persistent(**kwargs): from .persistant_pipeline import pipeline as _pipeline return _pipeline(**kwargs) diff --git a/examples/llm/rag/standalone_pipeline.py b/examples/llm/rag/standalone_pipeline.py index ae28189d31..13c0000585 100644 --- a/examples/llm/rag/standalone_pipeline.py +++ b/examples/llm/rag/standalone_pipeline.py @@ -37,7 +37,7 @@ logger = logging.getLogger(__name__) -def _build_engine(model_name: str, vdb_resource_name: str): +def _build_engine(model_name: str, model_type: str, vdb_resource_name: str): engine = LLMEngine() @@ -56,7 +56,7 @@ def _build_engine(model_name: str, vdb_resource_name: str): embeddings = build_huggingface_embeddings("sentence-transformers/all-MiniLM-L6-v2", model_kwargs={'device': 'cuda'}, encode_kwargs={'batch_size': 100}) - llm_service = build_llm_service(model_name, 'nemo', temperature=0.5, tokens_to_generate=200) + llm_service = build_llm_service(model_name, model_type=model_type, temperature=0.5, tokens_to_generate=200) # Async wrapper around embeddings async def calc_embeddings(texts: list[str]) -> list[list[float]]: @@ -75,49 +75,61 @@ async def calc_embeddings(texts: list[str]) -> list[list[float]]: def standalone( - num_threads, - pipeline_batch_size, - model_max_batch_size, - model_name, - vdb_resource_name, - repeat_count, + num_threads, + pipeline_batch_size, + model_max_batch_size, + model_name, + model_type, + vdb_resource_name, + repeat_count, ): + # Configuration setup for the pipeline config = Config() - config.mode = PipelineModes.OTHER + config.mode = PipelineModes.OTHER # Initial mode set to OTHER, will be overridden below - # Below properties are specified by the command line config.num_threads = num_threads config.pipeline_batch_size = pipeline_batch_size config.model_max_batch_size = model_max_batch_size config.mode = PipelineModes.NLP - config.edge_buffer_size = 128 + config.edge_buffer_size = 128 # Set edge buffer size for the pipeline stages + # Create a DataFrame as the data source for the pipeline source_dfs = [ cudf.DataFrame({"questions": ["What are some new attacks discovered in the cyber security industry?."] * 5}) ] + # Define a task to be used by the pipeline stages completion_task = {"task_type": "completion", "task_dict": {"input_keys": ["questions"], }} + # Initialize the pipeline with the configuration pipe = LinearPipeline(config) + # Set the source stage of the pipeline with the DataFrame and repeat count pipe.set_source(InMemorySourceStage(config, dataframes=source_dfs, repeat=repeat_count)) + # Add deserialization stage to convert messages for processing pipe.add_stage( DeserializeStage(config, message_type=ControlMessage, task_type="llm_engine", task_payload=completion_task)) + # Add a monitoring stage to observe the source data rate pipe.add_stage(MonitorStage(config, description="Source rate", unit='questions')) + # Add the main LLM engine stage to the pipeline with the model and vector database pipe.add_stage( - LLMEngineStage(config, engine=_build_engine(model_name=model_name, vdb_resource_name=vdb_resource_name))) + LLMEngineStage(config, engine=_build_engine(model_name=model_name, model_type=model_type, vdb_resource_name=vdb_resource_name))) + # Add a sink stage to collect the output from the pipeline sink = pipe.add_stage(InMemorySinkStage(config)) + # Add another monitoring stage to observe the response rate with a delayed start pipe.add_stage(MonitorStage(config, description="Response rate", unit="responses", delayed_start=True)) start_time = time.time() pipe.run() + # Log the total number of responses received after pipeline completion logger.info("Pipeline complete. Received %s responses", len(sink.get_messages())) + # Return the start time for performance measurement or further processing return start_time diff --git a/morpheus/llm/nodes/extracter_node.py b/morpheus/llm/nodes/extracter_node.py index cb5baf6179..374753bcc4 100644 --- a/morpheus/llm/nodes/extracter_node.py +++ b/morpheus/llm/nodes/extracter_node.py @@ -30,7 +30,7 @@ class ExtracterNode(LLMNodeBase): """ def get_input_names(self) -> list[str]: - # This node does not receive it's inputs from upstream nodes, but rather from the task itself + # This node does not receive its inputs from upstream nodes, but rather from the task itself return [] async def execute(self, context: LLMContext) -> LLMContext: diff --git a/morpheus/llm/services/openai_chat_service.py b/morpheus/llm/services/openai_chat_service.py index dda00d032d..bd0709d86b 100644 --- a/morpheus/llm/services/openai_chat_service.py +++ b/morpheus/llm/services/openai_chat_service.py @@ -206,7 +206,7 @@ def get_client(self, The name of the model to create a client for. set_assistant: bool, optional default=False - When `True`, a second input field named `assistant` will be used to proide additional context to the model. + When `True`, a second input field named `assistant` will be used to provide additional context to the model. model_kwargs : dict[str, typing.Any] Additional keyword arguments to pass to the model when generating text. From e7b601aff46f75e68ff870975ab3c3365bab0a71 Mon Sep 17 00:00:00 2001 From: Devin Robison Date: Mon, 6 Nov 2023 15:41:45 -0700 Subject: [PATCH 2/2] Update RAG.split_stage to use typing changes --- examples/llm/rag/persistant_pipeline.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/examples/llm/rag/persistant_pipeline.py b/examples/llm/rag/persistant_pipeline.py index f38405a621..24dd3c65b4 100644 --- a/examples/llm/rag/persistant_pipeline.py +++ b/examples/llm/rag/persistant_pipeline.py @@ -29,7 +29,7 @@ from morpheus.messages import MessageMeta from morpheus.pipeline.pipeline import Pipeline from morpheus.pipeline.stage import Stage -from morpheus.pipeline.stream_pair import StreamPair +from morpheus.pipeline.stage_schema import StageSchema from morpheus.service.vdb.vector_db_service import VectorDBResourceService from morpheus.stages.inference.triton_inference_stage import TritonInferenceStage from morpheus.stages.input.kafka_source_stage import KafkaSourceStage @@ -58,12 +58,15 @@ def name(self) -> str: def supports_cpp_node(self): return False - def _build(self, builder: mrc.Builder, in_ports_streams: typing.List[StreamPair]) -> typing.List[StreamPair]: - assert len(in_ports_streams) == 1, "Only 1 input supported" + def compute_schema(self, schema: StageSchema): + assert len(schema.output_schemas) == 2, "Expected two output schemas" + + def _build(self, builder: mrc.Builder, input_nodes: list[mrc.SegmentObject]) -> list[mrc.SegmentObject]: + assert len(input_nodes) == 1, "Only 1 input supported" # Create a broadcast node broadcast = Broadcast(builder, "broadcast") - builder.make_edge(in_ports_streams[0][0], broadcast) + builder.make_edge(input_nodes[0], broadcast) def filter_higher_fn(data: MessageMeta): return MessageMeta(data.df[data.df["v2"] >= 0.5]) @@ -79,10 +82,10 @@ def filter_lower_fn(data: MessageMeta): filter_lower = builder.make_node("filter_lower", ops.map(filter_lower_fn)) builder.make_edge(broadcast, filter_lower) - return [(filter_higher, in_ports_streams[0][1]), (filter_lower, in_ports_streams[0][1])] + return [filter_higher, filter_lower] -def _build_engine(model_name: str, vdb_service: VectorDBResourceService): +def _build_engine(model_name: str, model_type: str, vdb_service: VectorDBResourceService): engine = LLMEngine() engine.add_node("extracter", node=ExtracterNode()) @@ -96,7 +99,7 @@ def _build_engine(model_name: str, vdb_service: VectorDBResourceService): Please answer the following question: \n{{ query }}""" - llm_service = build_llm_service(model_name, model_type="nemo", temperature=0.5, tokens_to_generate=200) + llm_service = build_llm_service(model_name, model_type=model_type, temperature=0.5, tokens_to_generate=200) engine.add_node("rag", inputs=[("/extracter/*", "*")], @@ -113,6 +116,7 @@ def pipeline( model_max_batch_size, embedding_size, model_name, + model_type, ): # Initialize the configuration object for the pipeline config = Config() @@ -184,7 +188,8 @@ def pipeline( # For retrieve tasks, connect to an LLM engine stage configured for RAG retrieve_llm_engine = pipe.add_stage( LLMEngineStage(config, - engine=_build_engine(model_name=model_name, vdb_service=vdb_service.load_resource("RSS")))) + engine=_build_engine(model_name=model_name, model_type=model_type, + vdb_service=vdb_service.load_resource("RSS")))) pipe.add_edge(split.output_ports[0], retrieve_llm_engine) # Write retrieve results to a Kafka topic