From 926035f2d7c28eeb9b45a5b52b2eb5c1e2913643 Mon Sep 17 00:00:00 2001 From: movchan74 Date: Mon, 23 Oct 2023 15:41:56 +0000 Subject: [PATCH 1/7] Aana basic functionality --- .devcontainer/Dockerfile | 2 + .devcontainer/devcontainer.json | 36 +++++++ .gitmodules | 3 + .vscode/settings.json | 6 ++ Dockerfile | 33 ++++++ README.md | 85 ++++++++++++++- aana/__init__.py | 0 aana/api/__init__.py | 0 aana/api/app.py | 136 ++++++++++++++++++++++++ aana/api/request_handler.py | 77 ++++++++++++++ aana/api/responses.py | 26 +++++ aana/configs/__init__.py | 0 aana/configs/deployments.py | 22 ++++ aana/configs/pipeline.py | 84 +++++++++++++++ aana/deployments/__init__.py | 0 aana/deployments/base_deployment.py | 31 ++++++ aana/deployments/vllm_deployment.py | 131 +++++++++++++++++++++++ aana/exceptions/general.py | 83 +++++++++++++++ aana/main.py | 5 + aana/models/pydantic/__init__.py | 0 aana/models/pydantic/llm_request.py | 24 +++++ aana/models/pydantic/prompt.py | 13 +++ aana/models/pydantic/sampling_params.py | 25 +++++ aana/utils/general.py | 73 +++++++++++++ install.sh | 2 + mobius-pipeline | 1 + notebooks/demo.ipynb | 83 +++++++++++++++ pyproject.toml | 25 +++++ 28 files changed, 1005 insertions(+), 1 deletion(-) create mode 100644 .devcontainer/Dockerfile create mode 100644 .devcontainer/devcontainer.json create mode 100644 .gitmodules create mode 100644 .vscode/settings.json create mode 100644 Dockerfile create mode 100644 aana/__init__.py create mode 100644 aana/api/__init__.py create mode 100644 aana/api/app.py create mode 100644 aana/api/request_handler.py create mode 100644 aana/api/responses.py create mode 100644 aana/configs/__init__.py create mode 100644 aana/configs/deployments.py create mode 100644 aana/configs/pipeline.py create mode 100644 aana/deployments/__init__.py create mode 100644 aana/deployments/base_deployment.py create mode 100644 aana/deployments/vllm_deployment.py create mode 100644 aana/exceptions/general.py create mode 100644 aana/main.py create mode 100644 aana/models/pydantic/__init__.py create mode 100644 aana/models/pydantic/llm_request.py create mode 100644 aana/models/pydantic/prompt.py create mode 100644 aana/models/pydantic/sampling_params.py create mode 100644 aana/utils/general.py create mode 100644 install.sh create mode 160000 mobius-pipeline create mode 100644 notebooks/demo.ipynb create mode 100644 pyproject.toml diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile new file mode 100644 index 00000000..9439a8e6 --- /dev/null +++ b/.devcontainer/Dockerfile @@ -0,0 +1,2 @@ +FROM nvidia/cuda:11.8.0-cudnn8-devel-ubuntu22.04 +RUN apt-get update && apt-get install -y libgl1 libglib2.0-0 diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json new file mode 100644 index 00000000..19a1975a --- /dev/null +++ b/.devcontainer/devcontainer.json @@ -0,0 +1,36 @@ +{ + "name": "Ubuntu", + "build": { + "dockerfile": "Dockerfile" + }, + "features": { + "ghcr.io/devcontainers/features/python:1": { + "installTools": true, + "version": "3.10" + }, + "ghcr.io/devcontainers-contrib/features/poetry:2": { + "version": "latest" + } + }, + "hostRequirements": { + "gpu": "optional" + }, + "mounts": [ + "source=/nas,target=/nas,type=bind", + "source=/nas2,target=/nas2,type=bind" + ], + + "postCreateCommand": "sh ${containerWorkspaceFolder}/install.sh", + "postStartCommand": "git config --global --add safe.directory ${containerWorkspaceFolder}", + "customizations": { + "vscode": { + "extensions": [ + "ms-python.black-formatter", + "ms-python.python", + "ms-python.mypy-type-checker", + "ms-toolsai.jupyter" + ] + } + } + +} diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 00000000..75cdb57e --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "mobius-pipeline"] + path = mobius-pipeline + url = ../mobius-pipeline.git diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 00000000..d99f2f30 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,6 @@ +{ + "[python]": { + "editor.defaultFormatter": "ms-python.black-formatter" + }, + "python.formatting.provider": "none" +} \ No newline at end of file diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 00000000..a95f4d81 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,33 @@ +# Use NVIDIA CUDA as base image +FROM nvidia/cuda:11.8.0-cudnn8-devel-ubuntu22.04 + +# Set working directory +WORKDIR /app + +# Set environment variables to non-interactive (this prevents some prompts) +ENV DEBIAN_FRONTEND=non-interactive + +# Install required libraries, tools, and Python3 +RUN apt-get update && apt-get install -y libgl1 libglib2.0-0 curl git python3.10 python3.10-dev python3-pip python3.10-venv + +# Install poetry +RUN curl -sSL https://install.python-poetry.org | python3 - + +# Update PATH +RUN echo 'export PATH="/root/.local/bin:$PATH"' >> /root/.bashrc +ENV PATH="/root/.local/bin:$PATH" + +# Copy project files into the container +COPY . /app + +# Install the package with poetry +RUN sh install.sh + +# Disable buffering for stdout and stderr to get the logs in real time +ENV PYTHONUNBUFFERED=1 + +# Expose the desired port +EXPOSE 8000 + +# Set the command to run the SDK when the container starts +CMD ["poetry", "run", "serve", "run", "--port", "8000", "--host", "0.0.0.0", "aana.main:server"] diff --git a/README.md b/README.md index b7d0d1ba..eb57eec4 100644 --- a/README.md +++ b/README.md @@ -1 +1,84 @@ -# aana_sdk \ No newline at end of file +# Aana + +Aana is a multi-model SDK for deploying and serving machine learning models. + +## Installation + +1. Clone this repository. +2. Update submodules. + +```bash +git submodule update --init --recursive +``` + +3. Install additional libraries. + +```bash +apt update && apt install -y libgl1 +``` + +4. Install the package with poetry. + +It will install the package and all dependencies in a virtual environment. + +```bash +sh install.sh +``` + +5. Run the SDK. + +```bash +CUDA_VISIBLE_DEVICES=0 poetry run serve run --port 8000 --host 0.0.0.0 aana.main:server +``` + +The first run might take a while because the models will be downloaded from Google Drive and cached. + +Once you see `Deployed Serve app successfully.` in the logs, the server is ready to accept requests. + +You can change the port and CUDA_VISIBLE_DEVICES environment variable to your needs. + +The server will be available at http://localhost:8000. + +The documentation will be available at http://localhost:8000/docs and http://localhost:8000/redoc. + +For HuggingFace Transformers, you need to specify HF_AUTH environment variable with your HuggingFace API token. + +6. Send a request to the server. + +You can find examples in the [demo notebook](notebooks/demo.ipynb). + +## Run with Docker + +1. Clone this repository. + +2. Update submodules. + +```bash +git submodule update --init --recursive +``` + +3. Build the Docker image. + +```bash +docker build -t aana:0.1.0 . +``` + +4. Run the Docker container. + +```bash +docker run --rm --init -p 8000:8000 --gpus all -e CUDA_VISIBLE_DEVICES=0 -v aana_cache:/root/.aana -v aana_hf_cache:/root/.cache/huggingface --name aana_instance aana:0.1.0 +``` + +The first run might take a while because the models will be downloaded from Google Drive and cached. The models will be stored in the `aana_cache` volume. The HuggingFace models will be stored in the `aana_hf_cache` volume. If you want to remove the cached models, remove the volume. + +Once you see `Deployed Serve app successfully.` in the logs, the server is ready to accept requests. + +You can change the port and gpus parameters to your needs. + +The server will be available at http://localhost:8000. + +The documentation will be available at http://localhost:8000/docs and http://localhost:8000/redoc. + +5. Send a request to the server. + +You can find examples in the [demo notebook](notebooks/demo.ipynb). diff --git a/aana/__init__.py b/aana/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/aana/api/__init__.py b/aana/api/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/aana/api/app.py b/aana/api/app.py new file mode 100644 index 00000000..f249a289 --- /dev/null +++ b/aana/api/app.py @@ -0,0 +1,136 @@ +import traceback +from typing import Union +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse +from mobius_pipeline.exceptions import PipelineException +from pydantic import ValidationError +from ray.exceptions import RayTaskError + +from aana.exceptions.general import AanaException + +app = FastAPI() + + +@app.exception_handler(ValidationError) +async def validation_exception_handler(request: Request, exc: ValidationError): + """ + This handler is used to handle pydantic validation errors + + Args: + request (Request): The request object + exc (ValidationError): The validation error + + Returns: + JSONResponse: JSON response with the error details + """ + # TODO: Structure the error response so that it is consistent with the other error responses + return JSONResponse( + status_code=422, + content={"detail": exc.errors()}, + ) + + +def custom_exception_handler( + request: Request, exc: Union[PipelineException, AanaException, RayTaskError] +): + """ + This handler is used to handle custom exceptions raised in the application. + PipelineException is the exception raised by the Mobius Pipeline. + AanaException is the exception raised by the Aana application. + Sometimes custom exception are wrapped into RayTaskError so we need to handle that as well. + + Args: + request (Request): The request object + exc (Union[PipelineException, AanaException, RayTaskError]): The exception raised + + Returns: + JSONResponse: JSON response with the error details. The response contains the following fields: + error: The name of the exception class. + message: The message of the exception. + data: The additional data returned by the exception that can be used to identify the error (e.g. image path, url, model name etc.) + stacktrace: The stacktrace of the exception. + """ + # a PipelineException or AanaException can be wrapped into a RayTaskError + if isinstance(exc, RayTaskError): + # str(e) returns whole stack trace + # if exception is a RayTaskError + # let's use it to get the stack trace + stacktrace = str(exc) + # get the original exception + exc = exc.cause + else: + # if it is not a RayTaskError + # then we need to get the stack trace + stacktrace = traceback.format_exc() + # get the data from the exception + # can be used to return additional info + # like image path, url, model name etc. + data = exc.get_data() + # get the name of the class of the exception + # can be used to identify the type of the error + error = exc.__class__.__name__ + # get the message of the exception + message = str(exc) + return JSONResponse( + status_code=400, + content={ + "error": error, + "message": message, + "data": data, + "stacktrace": stacktrace, + }, + ) + + +@app.exception_handler(PipelineException) +async def pipeline_exception_handler(request: Request, exc: PipelineException): + """ + This handler is used to handle exceptions raised by the Mobius Pipeline. + + Args: + request (Request): The request object + exc (PipelineException): The exception raised + + Returns: + JSONResponse: JSON response with the error details + """ + return custom_exception_handler(request, exc) + + +@app.exception_handler(AanaException) +async def aana_exception_handler(request: Request, exc: AanaException): + """ + This handler is used to handle exceptions raised by the Aana application. + + Args: + request (Request): The request object + exc (AanaException): The exception raised + + Returns: + JSONResponse: JSON response with the error details + """ + return custom_exception_handler(request, exc) + + +@app.exception_handler(RayTaskError) +async def ray_task_error_handler(request: Request, exc: RayTaskError): + """ + This handler is used to handle RayTaskError exceptions. + + Args: + request (Request): The request object + exc (RayTaskError): The exception raised + + Returns: + JSONResponse: JSON response with the error details. The response contains the following fields: + error: The name of the exception class. + message: The message of the exception. + stacktrace: The stacktrace of the exception. + """ + error = exc.__class__.__name__ + stacktrace = traceback.format_exc() + + return JSONResponse( + status_code=400, + content={"error": error, "message": str(exc), "stacktrace": stacktrace}, + ) diff --git a/aana/api/request_handler.py b/aana/api/request_handler.py new file mode 100644 index 00000000..8e61989d --- /dev/null +++ b/aana/api/request_handler.py @@ -0,0 +1,77 @@ +from typing import Dict, List, Tuple +import ray +from ray import serve + +from mobius_pipeline.pipeline import Pipeline + +from aana.api.app import app +from aana.api.responses import AanaJSONResponse +from aana.configs.pipeline import nodes +from aana.models.pydantic.llm_request import LLMRequest + + +async def run_pipeline( + pipeline: Pipeline, data: Dict, required_outputs: List[str] +) -> Tuple[Dict, Dict[str, float]]: + """ + This function is used to run a Mobius Pipeline. + It creates a container from the data, runs the pipeline and returns the output. + + Args: + pipeline (Pipeline): The pipeline to run. + data (dict): The data to create the container from. + required_outputs (List[str]): The required outputs of the pipeline. + + Returns: + tuple[dict, dict[str, float]]: The output of the pipeline and the execution time of the pipeline. + """ + + # create a container from the data + container = pipeline.parse_dict(data) + + # run the pipeline + output, execution_time = await pipeline.run( + container, required_outputs, return_execution_time=True + ) + return output, execution_time + + +@serve.deployment(route_prefix="/", num_replicas=1, ray_actor_options={"num_cpus": 0.1}) +@serve.ingress(app) +class RequestHandler: + """This class is used to handle requests to the Aana application.""" + + def __init__(self, deployments: Dict): + """ + Args: + deployments (Dict): The dictionary of deployments. + It is passed to the context to the pipeline so the pipeline can access the deployments handles. + """ + self.context = { + "deployments": deployments, + } + self.pipeline = Pipeline(nodes, self.context) + + @app.post("/llm/generate") + async def generate_llm(self, llm_request: LLMRequest) -> AanaJSONResponse: + """ + The endpoint for running the LLM. + It is running the pipeline with the given prompt and sampling parameters. + This is here as an example and will be replace with automatic endpoint generation. + + Args: + llm_request (LLMRequest): The LLM request. It contains the prompt and sampling parameters. + + Returns: + AanaJSONResponse: The response containing the output of the pipeline and the execution time. + """ + prompt = llm_request.prompt + sampling_params = llm_request.sampling_params + + output, execution_time = await run_pipeline( + self.pipeline, + {"prompt": prompt, "sampling_params": sampling_params}, + ["vllm_llama2_7b_chat_output"], + ) + output["execution_time"] = execution_time + return AanaJSONResponse(content=output) diff --git a/aana/api/responses.py b/aana/api/responses.py new file mode 100644 index 00000000..e16d8274 --- /dev/null +++ b/aana/api/responses.py @@ -0,0 +1,26 @@ +from typing import Any, Optional +from fastapi.responses import JSONResponse +import orjson + + +class AanaJSONResponse(JSONResponse): + """ + A JSON response class that uses orjson to serialize data. + It has additional support for numpy arrays. + """ + + media_type = "application/json" + option = None + + def __init__(self, option: Optional[int] = orjson.OPT_SERIALIZE_NUMPY, **kwargs): + """ + Initialize the response class with the orjson option. + """ + self.option = option + super().__init__(**kwargs) + + def render(self, content: Any) -> bytes: + """ + Override the render method to use orjson.dumps instead of json.dumps. + """ + return orjson.dumps(content, option=self.option) diff --git a/aana/configs/__init__.py b/aana/configs/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/aana/configs/deployments.py b/aana/configs/deployments.py new file mode 100644 index 00000000..14a52961 --- /dev/null +++ b/aana/configs/deployments.py @@ -0,0 +1,22 @@ +from aana.deployments.vllm_deployment import VLLMDeployment +from aana.models.pydantic.sampling_params import SamplingParams +from aana.utils.general import encode_options + +#TODO: add build system to only serve the deployment if it's needed + +deployments = { + "vllm_deployment_llama2_7b_chat": VLLMDeployment.options( + num_replicas=1, + max_concurrent_queries=1000, + ray_actor_options={"num_gpus": 0.5}, + user_config={ + "model": "TheBloke/Llama-2-7b-Chat-AWQ", + "dtype": "auto", + "quantization": "awq", + "gpu_memory_utilization": 0.7, + "default_sampling_params": encode_options( + SamplingParams(temperature=1.0, top_p=1.0, top_k=-1, max_tokens=256) + ), + }, + ).bind(), +} diff --git a/aana/configs/pipeline.py b/aana/configs/pipeline.py new file mode 100644 index 00000000..cd3e79a4 --- /dev/null +++ b/aana/configs/pipeline.py @@ -0,0 +1,84 @@ +""" +This file contains the pipeline configuration for the aana application. +It is used to generate the pipeline and the API endpoints. +""" + +from aana.models.pydantic.prompt import Prompt +from aana.models.pydantic.sampling_params import SamplingParams + +# container data model +# we don't enforce this data model for now but it's a good reference for writing paths and flatten_by +# class Container: +# prompt: Prompt +# sampling_params: SamplingParams +# vllm_llama2_7b_chat_output_stream: str +# vllm_llama2_7b_chat_output: str + + +nodes = [ + { + "name": "prompt", + "type": "input", + "inputs": [], + "outputs": [ + {"name": "prompt", "key": "prompt", "path": "prompt", "data_model": Prompt} + ], + }, + { + "name": "sampling_params", + "type": "input", + "inputs": [], + "outputs": [ + { + "name": "sampling_params", + "key": "sampling_params", + "path": "sampling_params", + "data_model": SamplingParams, + } + ], + }, + { + "name": "vllm_stream_llama2_7b_chat", + "type": "ray_deployment", + "deployment_name": "vllm_deployment_llama2_7b_chat", + "data_type": "generator", + "generator_path": "prompt", + "method": "generate_stream", + "inputs": [ + {"name": "prompt", "key": "prompt", "path": "prompt"}, + { + "name": "sampling_params", + "key": "sampling_params", + "path": "sampling_params", + }, + ], + "outputs": [ + { + "name": "vllm_llama2_7b_chat_output_stream", + "key": "text", + "path": "vllm_llama2_7b_chat_output_stream", + } + ], + }, + { + "name": "vllm_llama2_7b_chat", + "type": "ray_deployment", + "deployment_name": "vllm_deployment_llama2_7b_chat", + "method": "generate", + "inputs": [ + {"name": "prompt", "key": "prompt", "path": "prompt"}, + { + "name": "sampling_params", + "key": "sampling_params", + "path": "sampling_params", + }, + ], + "outputs": [ + { + "name": "vllm_llama2_7b_chat_output", + "key": "text", + "path": "vllm_llama2_7b_chat_output", + } + ], + }, +] diff --git a/aana/deployments/__init__.py b/aana/deployments/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/aana/deployments/base_deployment.py b/aana/deployments/base_deployment.py new file mode 100644 index 00000000..51d2c174 --- /dev/null +++ b/aana/deployments/base_deployment.py @@ -0,0 +1,31 @@ +from aana.utils.general import load_options + + +class BaseDeployment: + """ + Base class for all deployments. + We can use this class to define common methods for all deployments. + For example, we can connect to the database here or download artifacts. + """ + + def __init__(self): + self.config = None + self.configured = False + + async def reconfigure(self, config): + """ + Reconfigure the deployment. + The method is called when the deployment is updated. + """ + self.config = config + # go through the config and try to load options + for key in self.config: + self.config[key] = load_options(self.config[key], ignore_errors=True) + await self.apply_config() + self.configured = True + + async def apply_config(self): + """ + Apply the configuration. + """ + pass diff --git a/aana/deployments/vllm_deployment.py b/aana/deployments/vllm_deployment.py new file mode 100644 index 00000000..a9489d54 --- /dev/null +++ b/aana/deployments/vllm_deployment.py @@ -0,0 +1,131 @@ +from typing import List +from ray import serve +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.sampling_params import SamplingParams as VLLMSamplingParams +from vllm.utils import random_uuid + +from aana.deployments.base_deployment import BaseDeployment +from aana.exceptions.general import InferenceException +from aana.models.pydantic.sampling_params import SamplingParams +from aana.utils.general import merged_options + + +@serve.deployment +class VLLMDeployment(BaseDeployment): + """ + Deployment to serve large language models using vLLM. + """ + + async def apply_config(self): + """ + Apply the configuration. + + The method is called when the deployment is created or updated. + + It loads the model and creates the engine. + + The configuration should contain the following keys: + - model: the model name + - dtype: the data type (optional, default: "auto") + - quantization: the quantization method (optional, default: None) + - gpu_memory_utilization: the GPU memory utilization. + - default_sampling_params: the default sampling parameters. + """ + await super().apply_config() + + # parse the config + model: str = self.config["model"] + dtype: str = self.config.get("dtype", "auto") + quantization: str = self.config.get("quantization", None) + gpu_memory_utilization: float = self.config["gpu_memory_utilization"] + self.default_sampling_params: SamplingParams = self.config[ + "default_sampling_params" + ] + args = AsyncEngineArgs( + model=model, + dtype=dtype, + quantization=quantization, + gpu_memory_utilization=gpu_memory_utilization, + ) + + # TODO: check if the model is already loaded. + # If it is and none of the model parameters changed, we don't need to reload the model. + + # create the engine + self.engine = AsyncLLMEngine.from_engine_args(args) + + async def generate_stream(self, prompt: str, sampling_params: SamplingParams): + """ + Generate completion for the given prompt and stream the results. + + Args: + prompt (str): the prompt + sampling_params (SamplingParams): the sampling parameters + + Yields: + dict: the generated text + """ + prompt = str(prompt) + sampling_params = merged_options(self.default_sampling_params, sampling_params) + request_id = None + try: + # convert SamplingParams to VLLMSamplingParams + sampling_params_vllm = VLLMSamplingParams( + temperature=sampling_params.temperature, + top_p=sampling_params.top_p, + top_k=sampling_params.top_k, + max_tokens=sampling_params.max_tokens, + ) + # start the request + request_id = random_uuid() + results_generator = self.engine.generate( + prompt, sampling_params_vllm, request_id + ) + + num_returned = 0 + async for request_output in results_generator: + text_output = request_output.outputs[0].text[num_returned:] + yield {"text": text_output} + num_returned += len(text_output) + except GeneratorExit as e: + # If the generator is cancelled, we need to cancel the request + if request_id is not None: + await self.engine.abort(request_id) + raise e + except Exception as e: + raise InferenceException() from e + + async def generate(self, prompt: str, sampling_params: SamplingParams): + """ + Generate completion for the given prompt. + + Args: + prompt (str): the prompt + sampling_params (SamplingParams): the sampling parameters + + Returns: + dict: the generated text + """ + generated_text = "" + async for chunk in self.generate_stream(prompt, sampling_params): + generated_text += chunk["text"] + return {"text": generated_text} + + async def generate_batch(self, prompts: List[str], sampling_params: SamplingParams): + """ + Generate completion for the batch of prompts. + + Args: + prompts (List[str]): the prompts + sampling_params (SamplingParams): the sampling parameters + + Returns: + dict: the generated texts + """ + texts = [] + for prompt in prompts: + text = await self.generate(prompt, sampling_params) + texts.append(text["text"]) + + return {"texts": texts} diff --git a/aana/exceptions/general.py b/aana/exceptions/general.py new file mode 100644 index 00000000..740ddd9e --- /dev/null +++ b/aana/exceptions/general.py @@ -0,0 +1,83 @@ +from typing import Any, Dict, Type + + +class AanaException(Exception): + """ + Base class for SDK exceptions. + """ + + extra = {} + + def __str__(self) -> str: + """ + Return a string representation of the exception. + + String is defined as follows: + ``` + (extra_key1=extra_value1, extra_key2=extra_value2, ...) + ``` + """ + class_name = self.__class__.__name__ + extra_str = "" + for key, value in self.extra.items(): + extra_str += f", {key}={value}" + return f"{class_name}({extra_str})" + + def get_data(self) -> Dict[str, Any]: + """ + Get the data to be returned to the client. + + Returns: + Dict[str, Any]: data to be returned to the client + """ + data = self.extra.copy() + return data + + def add_extra(self, key: str, value: Any): + """ + Add extra data to the exception. + + This data will be returned to the user as part of the response. + + How to use: in the exception handler, add the extra data to the exception and raise it again. + + Example: + ``` + try: + ... + except AanaException as e: + e.add_extra('extra_key', 'extra_value') + raise e + ``` + + Args: + key (str): key of the extra data + value (Any): value of the extra data + """ + self.extra[key] = value + + +class InferenceException(AanaException): + """Exception raised when there is an error during inference. + + Attributes: + model_name -- name of the model + """ + + def __init__(self, model_name=""): + """ + Initialize the exception. + + Args: + model_name (str): name of the model that caused the exception + """ + super().__init__() + self.model_name = model_name + self.extra["model_name"] = model_name + + def __reduce__(self): + # This method is called when the exception is pickled + # We need to do this if exception has one or more arguments + # See https://bugs.python.org/issue32696#msg310963 for more info + # TODO: check if there is a better way to do this + return (self.__class__, (self.model_name,)) diff --git a/aana/main.py b/aana/main.py new file mode 100644 index 00000000..f422d34c --- /dev/null +++ b/aana/main.py @@ -0,0 +1,5 @@ +from aana.api.request_handler import RequestHandler +from aana.configs.deployments import deployments + + +server = RequestHandler.bind(deployments) diff --git a/aana/models/pydantic/__init__.py b/aana/models/pydantic/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/aana/models/pydantic/llm_request.py b/aana/models/pydantic/llm_request.py new file mode 100644 index 00000000..910c63d4 --- /dev/null +++ b/aana/models/pydantic/llm_request.py @@ -0,0 +1,24 @@ +from typing import Dict, List, Optional +from pydantic import BaseModel, Extra, Field + +from aana.models.pydantic.prompt import Prompt +from aana.models.pydantic.sampling_params import SamplingParams + + +class LLMRequest(BaseModel): + """ + This class is used to represent a request to LLM. + + Attributes: + prompt (Prompt): A prompt to LLM. + sampling_params (SamplingParams): Sampling parameters for generating text. + """ + + prompt: Prompt = Field(..., description="A prompt to LLM.") + sampling_params: Optional[SamplingParams] = Field( + None, description="Sampling parameters for generating text." + ) + + class Config: + extra = Extra.forbid + schema_extra = {"description": "A request to LLM."} diff --git a/aana/models/pydantic/prompt.py b/aana/models/pydantic/prompt.py new file mode 100644 index 00000000..ff04b368 --- /dev/null +++ b/aana/models/pydantic/prompt.py @@ -0,0 +1,13 @@ +from pydantic import BaseModel + + +class Prompt(BaseModel): + """A model for a user prompt to LLM.""" + + __root__: str + + def __str__(self): + return self.__root__ + + class Config: + schema_extra = {"description": "A prompt to LLM."} diff --git a/aana/models/pydantic/sampling_params.py b/aana/models/pydantic/sampling_params.py new file mode 100644 index 00000000..7e595ac6 --- /dev/null +++ b/aana/models/pydantic/sampling_params.py @@ -0,0 +1,25 @@ +from typing import Optional + +from pydantic import BaseModel, Field + + +class SamplingParams(BaseModel): + """ + A model for sampling parameters of LLM. + + Attributes: + temperature (float): The temperature. + top_p (float): Top-p. + top_k (int): Top-k. + max_tokens (int): The maximum number of tokens to generate. + """ + + temperature: Optional[float] = Field(default=None, description="The temperature.") + top_p: Optional[float] = Field(default=None, description="Top-p.") + top_k: Optional[int] = Field(default=None, description="Top-k.") + max_tokens: Optional[int] = Field( + default=None, description="The maximum number of tokens to generate." + ) + + class Config: + schema_extra = {"description": "Sampling parameters for generating text."} diff --git a/aana/utils/general.py b/aana/utils/general.py new file mode 100644 index 00000000..dd080eee --- /dev/null +++ b/aana/utils/general.py @@ -0,0 +1,73 @@ +import pickle +from typing import Any, TypeVar + +OptionType = TypeVar("OptionType") + + +def load_options(s: str, ignore_errors: bool = True) -> Any: + """ + Load options from a string. + + The string is assumed to be a pickled object encoded in latin1. + If the string cannot be unpickled, return the string itself if ignore_errors is True, + otherwise raise an exception. + + The function is used to pass options using Ray's user_config. + user_config accepts only JSON serializable objects, so we need to encode the options. + + Args: + s (str): string to be unpickled + ignore_errors (bool): if True, return the string itself if it cannot be unpickled, otherwise raise an exception + + Returns:z + unpickled object or the string itself if ignore_errors is True + """ + try: + b = s.encode("latin1") + return pickle.loads(b) + except Exception as e: + if ignore_errors: + return s + raise e + + +def encode_options(options: Any) -> str: + """ + Encode options as a string. + + The string is a pickled object encoded in latin1. + + The function is used to pass options using Ray's user_config. + user_config accepts only JSON serializable objects, so we need to encode the options. + + Args: + options (Any): options to be encoded + + Returns: + str: options encoded as a string + """ + b = pickle.dumps(options) + return b.decode("latin1") + + +def merged_options(default_options: OptionType, options: OptionType) -> OptionType: + """ + Merge options into default_options. + + Args: + default_options (OptionType): default options + options (OptionType): options to be merged into default_options + + Returns: + OptionType: merged options + """ + # if options is None, return default_options + if options is None: + return default_options + # options and default_options have to be of the same type + assert type(default_options) == type(options) + default_options_dict = default_options.dict() + for k, v in options.dict().items(): + if v is not None: + default_options_dict[k] = v + return options.__class__(**default_options_dict) diff --git a/install.sh b/install.sh new file mode 100644 index 00000000..14339833 --- /dev/null +++ b/install.sh @@ -0,0 +1,2 @@ +#!/bin/sh +poetry install diff --git a/mobius-pipeline b/mobius-pipeline new file mode 160000 index 00000000..7600ba01 --- /dev/null +++ b/mobius-pipeline @@ -0,0 +1 @@ +Subproject commit 7600ba015f4d2a47db30f51f63939aaadfda6949 diff --git a/notebooks/demo.ipynb b/notebooks/demo.ipynb new file mode 100644 index 00000000..9f3c4878 --- /dev/null +++ b/notebooks/demo.ipynb @@ -0,0 +1,83 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import requests\n", + "import json" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "data = {\n", + " 'prompt': '[INST] Who is Elon Musk? [/INST]',\n", + " 'sampling_params' : {\n", + " 'temperature': 0.9,\n", + " }\n", + "}\n", + "\n", + "url = 'http://127.0.0.1:8000/llm/generate'" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'vllm_llama2_7b_chat_output': ' Elon Musk is a South African-born entrepreneur, inventor, and business magnate who is best known for being the CEO of SpaceX and Tesla, Inc. He is one of the most successful and influential entrepreneurs of the 21st century, known for his innovative ideas, visionary leadership, and his ability to bring those ideas to life.\\nMusk was born on June 28, 1971, in Pretoria, South Africa. He developed an interest in computing and programming at an early age and taught himself computer programming. He moved to Canada in 1992 to attend college, and later transferred to the University of Pennsylvania, where he graduated with a degree in economics and physics.\\nAfter college, Musk moved to California to pursue a career in technology and entrepreneurship. He co-founded his first company, Zip2, which provided online content publishing software for news organizations. In 1999, he co-founded X.com, which later became PayPal, an online payment system that was acquired by eBay for $1.5 billion in 2002.\\nIn 2',\n", + " 'execution_time': {'prompt': 0,\n", + " 'sampling_params': 0,\n", + " 'vllm_stream_llama2_7b_chat': 0,\n", + " 'vllm_llama2_7b_chat': 4.421537399291992}}" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "response = requests.post(url, json=data)\n", + "response.json()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "aana-vIr3-B0u-py3.10", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..81e9f965 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,25 @@ +[tool.poetry] +name = "aana" +version = "0.1.0" +description = "Multimodal SDK" +authors = ["Mobius Labs GmbH"] +license = "License agreement" +readme = "README.md" + +[tool.poetry.dependencies] +python = "~3.10" +mobius-pipeline = { path = "./mobius-pipeline", develop = true } +pydantic = "<2" +fastapi = "^0.104.0" +ray = {extras = ["serve"], version = "^2.7.1"} +python-multipart = "^0.0.6" +torch = { url = "https://download.pytorch.org/whl/cu118/torch-2.0.1%2Bcu118-cp310-cp310-linux_x86_64.whl" } +torchvision = { url = "https://download.pytorch.org/whl/cu118/torchvision-0.15.2%2Bcu118-cp310-cp310-linux_x86_64.whl" } +vllm = "^0.2.1.post1" + +[tool.poetry.group.dev.dependencies] +ipykernel = "^6.25.2" + +[build-system] +requires = ["poetry-core"] +build-backend = "poetry.core.masonry.api" From c48bb5b4f4f0027d8446016bea26d9231f993f4e Mon Sep 17 00:00:00 2001 From: movchan74 Date: Tue, 24 Oct 2023 11:11:45 +0000 Subject: [PATCH 2/7] Code improvements and test for vllm --- aana/api/app.py | 12 +-- aana/api/request_handler.py | 1 - aana/configs/deployments.py | 3 +- aana/deployments/base_deployment.py | 10 ++- aana/deployments/vllm_deployment.py | 29 +++---- aana/exceptions/general.py | 2 +- aana/main.py | 6 +- aana/models/pydantic/llm_request.py | 2 +- .../tests/deployments/test_vllm_deployment.py | 75 +++++++++++++++++++ aana/utils/general.py | 5 +- pyproject.toml | 6 ++ 11 files changed, 122 insertions(+), 29 deletions(-) create mode 100644 aana/tests/deployments/test_vllm_deployment.py diff --git a/aana/api/app.py b/aana/api/app.py index f249a289..867be7a5 100644 --- a/aana/api/app.py +++ b/aana/api/app.py @@ -31,7 +31,7 @@ async def validation_exception_handler(request: Request, exc: ValidationError): def custom_exception_handler( - request: Request, exc: Union[PipelineException, AanaException, RayTaskError] + request: Request, exc_raw: Union[PipelineException, AanaException, RayTaskError] ): """ This handler is used to handle custom exceptions raised in the application. @@ -41,7 +41,7 @@ def custom_exception_handler( Args: request (Request): The request object - exc (Union[PipelineException, AanaException, RayTaskError]): The exception raised + exc_raw (Union[PipelineException, AanaException, RayTaskError]): The exception raised Returns: JSONResponse: JSON response with the error details. The response contains the following fields: @@ -51,17 +51,19 @@ def custom_exception_handler( stacktrace: The stacktrace of the exception. """ # a PipelineException or AanaException can be wrapped into a RayTaskError - if isinstance(exc, RayTaskError): + if isinstance(exc_raw, RayTaskError): # str(e) returns whole stack trace # if exception is a RayTaskError # let's use it to get the stack trace - stacktrace = str(exc) + stacktrace = str(exc_raw) # get the original exception - exc = exc.cause + exc: Union[PipelineException, AanaException] = exc_raw.cause + assert isinstance(exc, (PipelineException, AanaException)) else: # if it is not a RayTaskError # then we need to get the stack trace stacktrace = traceback.format_exc() + exc = exc_raw # get the data from the exception # can be used to return additional info # like image path, url, model name etc. diff --git a/aana/api/request_handler.py b/aana/api/request_handler.py index 8e61989d..f62314ce 100644 --- a/aana/api/request_handler.py +++ b/aana/api/request_handler.py @@ -1,5 +1,4 @@ from typing import Dict, List, Tuple -import ray from ray import serve from mobius_pipeline.pipeline import Pipeline diff --git a/aana/configs/deployments.py b/aana/configs/deployments.py index 14a52961..63a3e219 100644 --- a/aana/configs/deployments.py +++ b/aana/configs/deployments.py @@ -2,7 +2,6 @@ from aana.models.pydantic.sampling_params import SamplingParams from aana.utils.general import encode_options -#TODO: add build system to only serve the deployment if it's needed deployments = { "vllm_deployment_llama2_7b_chat": VLLMDeployment.options( @@ -18,5 +17,5 @@ SamplingParams(temperature=1.0, top_p=1.0, top_k=-1, max_tokens=256) ), }, - ).bind(), + ), } diff --git a/aana/deployments/base_deployment.py b/aana/deployments/base_deployment.py index 51d2c174..2bdecc0c 100644 --- a/aana/deployments/base_deployment.py +++ b/aana/deployments/base_deployment.py @@ -1,3 +1,4 @@ +from typing import Any, Dict from aana.utils.general import load_options @@ -12,7 +13,7 @@ def __init__(self): self.config = None self.configured = False - async def reconfigure(self, config): + async def reconfigure(self, config: Dict[str, Any]): """ Reconfigure the deployment. The method is called when the deployment is updated. @@ -21,11 +22,14 @@ async def reconfigure(self, config): # go through the config and try to load options for key in self.config: self.config[key] = load_options(self.config[key], ignore_errors=True) - await self.apply_config() + await self.apply_config(config) self.configured = True - async def apply_config(self): + async def apply_config(self, config: Dict[str, Any]): """ Apply the configuration. + + Args: + config (dict): the configuration """ pass diff --git a/aana/deployments/vllm_deployment.py b/aana/deployments/vllm_deployment.py index a9489d54..7611bd80 100644 --- a/aana/deployments/vllm_deployment.py +++ b/aana/deployments/vllm_deployment.py @@ -1,9 +1,10 @@ -from typing import List +from typing import Any, Dict, List from ray import serve from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.sampling_params import SamplingParams as VLLMSamplingParams from vllm.utils import random_uuid +from vllm.model_executor.utils import set_random_seed from aana.deployments.base_deployment import BaseDeployment from aana.exceptions.general import InferenceException @@ -17,7 +18,7 @@ class VLLMDeployment(BaseDeployment): Deployment to serve large language models using vLLM. """ - async def apply_config(self): + async def apply_config(self, config: Dict[str, Any]): """ Apply the configuration. @@ -31,17 +32,18 @@ async def apply_config(self): - quantization: the quantization method (optional, default: None) - gpu_memory_utilization: the GPU memory utilization. - default_sampling_params: the default sampling parameters. + + Args: + config (dict): the configuration of the deployment """ - await super().apply_config() + await super().apply_config(config) # parse the config - model: str = self.config["model"] - dtype: str = self.config.get("dtype", "auto") - quantization: str = self.config.get("quantization", None) - gpu_memory_utilization: float = self.config["gpu_memory_utilization"] - self.default_sampling_params: SamplingParams = self.config[ - "default_sampling_params" - ] + model: str = config["model"] + dtype: str = config.get("dtype", "auto") + quantization: str = config.get("quantization", None) + gpu_memory_utilization: float = config["gpu_memory_utilization"] + self.default_sampling_params: SamplingParams = config["default_sampling_params"] args = AsyncEngineArgs( model=model, dtype=dtype, @@ -72,13 +74,12 @@ async def generate_stream(self, prompt: str, sampling_params: SamplingParams): try: # convert SamplingParams to VLLMSamplingParams sampling_params_vllm = VLLMSamplingParams( - temperature=sampling_params.temperature, - top_p=sampling_params.top_p, - top_k=sampling_params.top_k, - max_tokens=sampling_params.max_tokens, + **sampling_params.dict(exclude_unset=True) ) # start the request request_id = random_uuid() + # set the random seed for reproducibility + set_random_seed(42) results_generator = self.engine.generate( prompt, sampling_params_vllm, request_id ) diff --git a/aana/exceptions/general.py b/aana/exceptions/general.py index 740ddd9e..019af797 100644 --- a/aana/exceptions/general.py +++ b/aana/exceptions/general.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Type +from typing import Any, Dict class AanaException(Exception): diff --git a/aana/main.py b/aana/main.py index f422d34c..648ca499 100644 --- a/aana/main.py +++ b/aana/main.py @@ -1,5 +1,9 @@ from aana.api.request_handler import RequestHandler from aana.configs.deployments import deployments +# TODO: add build system to only serve the deployment if it's needed +binded_deployments = {} +for name, deployment in deployments.items(): + binded_deployments[name] = deployment.bind() -server = RequestHandler.bind(deployments) +server = RequestHandler.bind(binded_deployments) diff --git a/aana/models/pydantic/llm_request.py b/aana/models/pydantic/llm_request.py index 910c63d4..6b0c784c 100644 --- a/aana/models/pydantic/llm_request.py +++ b/aana/models/pydantic/llm_request.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional +from typing import Optional from pydantic import BaseModel, Extra, Field from aana.models.pydantic.prompt import Prompt diff --git a/aana/tests/deployments/test_vllm_deployment.py b/aana/tests/deployments/test_vllm_deployment.py new file mode 100644 index 00000000..0b344f67 --- /dev/null +++ b/aana/tests/deployments/test_vllm_deployment.py @@ -0,0 +1,75 @@ +import random +import pytest +import rapidfuzz +import ray +from ray import serve + +from aana.configs.deployments import deployments +from aana.models.pydantic.sampling_params import SamplingParams + + +def expected_output(name): + if name == "vllm_deployment_llama2_7b_chat": + return ( + " Elon Musk is a South African-born entrepreneur, inventor, and business magnate. " + "He is best known for his revolutionary ideas" + ) + else: + raise ValueError(f"Unknown deployment name: {name}") + + +def ray_setup(deployment): + # Setup ray environment and serve + ray.init(ignore_reinit_error=True) + app = deployment.bind() + # random port from 30000 to 40000 + port = random.randint(30000, 40000) + handle = serve.run(app, port=port) + return handle + + +@pytest.mark.asyncio +async def test_vllm_deployments(): + for name, deployment in deployments.items(): + handle = ray_setup(deployment) + + # test generate method + output = await handle.generate.remote( + prompt="[INST] Who is Elon Musk? [/INST]", + sampling_params=SamplingParams(temperature=1.0, max_tokens=32), + ) + text = output["text"] + expected_text = expected_output(name) + dist = rapidfuzz.distance.Levenshtein.distance(text, expected_text) + assert ( + dist <= len(expected_text) * 0.1 + ) # Allow 10% difference in case of randomness + + # test generate_stream method + stream = handle.options(stream=True).generate_stream.remote( + prompt="[INST] Who is Elon Musk? [/INST]", + sampling_params=SamplingParams(temperature=1.0, max_tokens=32), + ) + text = "" + async for chunk in stream: + chunk = await chunk + text += chunk["text"] + expected_text = expected_output(name) + dist = rapidfuzz.distance.Levenshtein.distance(text, expected_text) + assert dist <= len(expected_text) * 0.1 + + # test generate_batch method + output = await handle.generate_batch.remote( + prompts=[ + "[INST] Who is Elon Musk? [/INST]", + "[INST] Who is Elon Musk? [/INST]", + ], + sampling_params=SamplingParams(temperature=1.0, max_tokens=32), + ) + texts = output["texts"] + expected_text = expected_output(name) + print(texts) + + for text in texts: + dist = rapidfuzz.distance.Levenshtein.distance(text, expected_text) + assert dist <= len(expected_text) * 0.1 diff --git a/aana/utils/general.py b/aana/utils/general.py index dd080eee..c18a4a4a 100644 --- a/aana/utils/general.py +++ b/aana/utils/general.py @@ -1,7 +1,7 @@ import pickle from typing import Any, TypeVar -OptionType = TypeVar("OptionType") +from pydantic import BaseModel def load_options(s: str, ignore_errors: bool = True) -> Any: @@ -50,6 +50,9 @@ def encode_options(options: Any) -> str: return b.decode("latin1") +OptionType = TypeVar("OptionType", bound=BaseModel) + + def merged_options(default_options: OptionType, options: OptionType) -> OptionType: """ Merge options into default_options. diff --git a/pyproject.toml b/pyproject.toml index 81e9f965..446b0d59 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,10 +16,16 @@ python-multipart = "^0.0.6" torch = { url = "https://download.pytorch.org/whl/cu118/torch-2.0.1%2Bcu118-cp310-cp310-linux_x86_64.whl" } torchvision = { url = "https://download.pytorch.org/whl/cu118/torchvision-0.15.2%2Bcu118-cp310-cp310-linux_x86_64.whl" } vllm = "^0.2.1.post1" +scipy = "^1.11.3" +rapidfuzz = "^3.4.0" [tool.poetry.group.dev.dependencies] ipykernel = "^6.25.2" +mypy = "^1.6.1" [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" + +[tool.pytest.ini_options] +norecursedirs = "mobius-pipeline" From 3d176f23de8dc32b21bc7f18afe87d405fc6861a Mon Sep 17 00:00:00 2001 From: movchan74 Date: Thu, 26 Oct 2023 19:00:24 +0000 Subject: [PATCH 3/7] Applied PR suggestions --- aana/configs/deployments.py | 19 ++++++----- aana/deployments/base_deployment.py | 6 +--- aana/deployments/vllm_deployment.py | 45 +++++++++++++++++--------- aana/exceptions/general.py | 2 +- aana/utils/general.py | 49 +---------------------------- 5 files changed, 43 insertions(+), 78 deletions(-) diff --git a/aana/configs/deployments.py b/aana/configs/deployments.py index 63a3e219..bb8e1535 100644 --- a/aana/configs/deployments.py +++ b/aana/configs/deployments.py @@ -1,6 +1,5 @@ -from aana.deployments.vllm_deployment import VLLMDeployment +from aana.deployments.vllm_deployment import VLLMConfig, VLLMDeployment from aana.models.pydantic.sampling_params import SamplingParams -from aana.utils.general import encode_options deployments = { @@ -8,14 +7,14 @@ num_replicas=1, max_concurrent_queries=1000, ray_actor_options={"num_gpus": 0.5}, - user_config={ - "model": "TheBloke/Llama-2-7b-Chat-AWQ", - "dtype": "auto", - "quantization": "awq", - "gpu_memory_utilization": 0.7, - "default_sampling_params": encode_options( - SamplingParams(temperature=1.0, top_p=1.0, top_k=-1, max_tokens=256) + user_config=VLLMConfig( + model="TheBloke/Llama-2-7b-Chat-AWQ", + dtype="auto", + quantization="awq", + gpu_memory_utilization=0.7, + default_sampling_params=SamplingParams( + temperature=1.0, top_p=1.0, top_k=-1, max_tokens=256 ), - }, + ).dict(), ), } diff --git a/aana/deployments/base_deployment.py b/aana/deployments/base_deployment.py index 2bdecc0c..eda82474 100644 --- a/aana/deployments/base_deployment.py +++ b/aana/deployments/base_deployment.py @@ -1,5 +1,4 @@ from typing import Any, Dict -from aana.utils.general import load_options class BaseDeployment: @@ -19,9 +18,6 @@ async def reconfigure(self, config: Dict[str, Any]): The method is called when the deployment is updated. """ self.config = config - # go through the config and try to load options - for key in self.config: - self.config[key] = load_options(self.config[key], ignore_errors=True) await self.apply_config(config) self.configured = True @@ -32,4 +28,4 @@ async def apply_config(self, config: Dict[str, Any]): Args: config (dict): the configuration """ - pass + raise NotImplementedError diff --git a/aana/deployments/vllm_deployment.py b/aana/deployments/vllm_deployment.py index 7611bd80..2f279f03 100644 --- a/aana/deployments/vllm_deployment.py +++ b/aana/deployments/vllm_deployment.py @@ -1,4 +1,5 @@ -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional +from pydantic import BaseModel, Field from ray import serve from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine @@ -12,6 +13,25 @@ from aana.utils.general import merged_options +class VLLMConfig(BaseModel): + """ + The configuration of the vLLM deployment. + + Attributes: + model (str): the model name + dtype (str): the data type (optional, default: "auto") + quantization (str): the quantization method (optional, default: None) + gpu_memory_utilization (float): the GPU memory utilization. + default_sampling_params (SamplingParams): the default sampling parameters. + """ + + model: str + dtype: Optional[str] = Field(default="auto") + quantization: Optional[str] = Field(default=None) + gpu_memory_utilization: float + default_sampling_params: SamplingParams + + @serve.deployment class VLLMDeployment(BaseDeployment): """ @@ -36,19 +56,16 @@ async def apply_config(self, config: Dict[str, Any]): Args: config (dict): the configuration of the deployment """ - await super().apply_config(config) - - # parse the config - model: str = config["model"] - dtype: str = config.get("dtype", "auto") - quantization: str = config.get("quantization", None) - gpu_memory_utilization: float = config["gpu_memory_utilization"] - self.default_sampling_params: SamplingParams = config["default_sampling_params"] + config_obj = VLLMConfig(**config) + self.model = config_obj.model + self.default_sampling_params: SamplingParams = ( + config_obj.default_sampling_params + ) args = AsyncEngineArgs( - model=model, - dtype=dtype, - quantization=quantization, - gpu_memory_utilization=gpu_memory_utilization, + model=config_obj.model, + dtype=config_obj.dtype, + quantization=config_obj.quantization, + gpu_memory_utilization=config_obj.gpu_memory_utilization, ) # TODO: check if the model is already loaded. @@ -95,7 +112,7 @@ async def generate_stream(self, prompt: str, sampling_params: SamplingParams): await self.engine.abort(request_id) raise e except Exception as e: - raise InferenceException() from e + raise InferenceException(model_name=self.model) from e async def generate(self, prompt: str, sampling_params: SamplingParams): """ diff --git a/aana/exceptions/general.py b/aana/exceptions/general.py index 019af797..3902122d 100644 --- a/aana/exceptions/general.py +++ b/aana/exceptions/general.py @@ -64,7 +64,7 @@ class InferenceException(AanaException): model_name -- name of the model """ - def __init__(self, model_name=""): + def __init__(self, model_name): """ Initialize the exception. diff --git a/aana/utils/general.py b/aana/utils/general.py index c18a4a4a..5390e519 100644 --- a/aana/utils/general.py +++ b/aana/utils/general.py @@ -1,55 +1,8 @@ -import pickle -from typing import Any, TypeVar +from typing import TypeVar from pydantic import BaseModel -def load_options(s: str, ignore_errors: bool = True) -> Any: - """ - Load options from a string. - - The string is assumed to be a pickled object encoded in latin1. - If the string cannot be unpickled, return the string itself if ignore_errors is True, - otherwise raise an exception. - - The function is used to pass options using Ray's user_config. - user_config accepts only JSON serializable objects, so we need to encode the options. - - Args: - s (str): string to be unpickled - ignore_errors (bool): if True, return the string itself if it cannot be unpickled, otherwise raise an exception - - Returns:z - unpickled object or the string itself if ignore_errors is True - """ - try: - b = s.encode("latin1") - return pickle.loads(b) - except Exception as e: - if ignore_errors: - return s - raise e - - -def encode_options(options: Any) -> str: - """ - Encode options as a string. - - The string is a pickled object encoded in latin1. - - The function is used to pass options using Ray's user_config. - user_config accepts only JSON serializable objects, so we need to encode the options. - - Args: - options (Any): options to be encoded - - Returns: - str: options encoded as a string - """ - b = pickle.dumps(options) - return b.decode("latin1") - - OptionType = TypeVar("OptionType", bound=BaseModel) From 517472b7f3cbdc0a5086f9dc10363830fdccd5cb Mon Sep 17 00:00:00 2001 From: movchan74 Date: Fri, 27 Oct 2023 16:22:13 +0000 Subject: [PATCH 4/7] Applied suggestion to merge_options and test it --- aana/tests/test_merge_options.py | 51 ++++++++++++++++++++++++++++++++ aana/utils/general.py | 4 +-- 2 files changed, 53 insertions(+), 2 deletions(-) create mode 100644 aana/tests/test_merge_options.py diff --git a/aana/tests/test_merge_options.py b/aana/tests/test_merge_options.py new file mode 100644 index 00000000..929a0c90 --- /dev/null +++ b/aana/tests/test_merge_options.py @@ -0,0 +1,51 @@ +from typing import Optional +import pytest +from pydantic import BaseModel + +from aana.utils.general import merged_options + + +class MyOptions(BaseModel): + field1: str + field2: Optional[int] = None + field3: bool + + +def test_merged_options_same_type(): + """ + Test merged_options with options of the same type as default_options + """ + default = MyOptions(field1="default1", field2=2, field3=True) + to_merge = MyOptions(field1="merge1", field2=None, field3=False) + merged = merged_options(default, to_merge) + + assert merged.field1 == "merge1" + assert ( + merged.field2 == 2 + ) # Should retain value from default_options as it's None in options + assert merged.field3 == False + + +def test_merged_options_none(): + """ + Test merged_options with options=None + """ + default = MyOptions(field1="default1", field2=2, field3=True) + merged = merged_options(default, None) + + assert merged.dict() == default.dict() + + +def test_merged_options_type_mismatch(): + """ + Test merged_options with options of a different type from default_options + """ + + class AnotherOptions(BaseModel): + another_field: str + + default = MyOptions(field1="default1", field2=2, field3=True) + to_merge = AnotherOptions(another_field="test") + + with pytest.raises(AssertionError): + merged_options(default, to_merge) diff --git a/aana/utils/general.py b/aana/utils/general.py index 5390e519..7b0c4d28 100644 --- a/aana/utils/general.py +++ b/aana/utils/general.py @@ -19,11 +19,11 @@ def merged_options(default_options: OptionType, options: OptionType) -> OptionTy """ # if options is None, return default_options if options is None: - return default_options + return default_options.copy() # options and default_options have to be of the same type assert type(default_options) == type(options) default_options_dict = default_options.dict() for k, v in options.dict().items(): if v is not None: default_options_dict[k] = v - return options.__class__(**default_options_dict) + return options.__class__.parse_obj(default_options_dict) From c4a267b81b53d4e090be3c7783922aa113b1521a Mon Sep 17 00:00:00 2001 From: movchan74 Date: Mon, 30 Oct 2023 16:38:23 +0000 Subject: [PATCH 5/7] Better description and validation for sampling params. --- aana/models/pydantic/sampling_params.py | 51 ++++++++++++++++++++---- aana/tests/test_sampling_params.py | 52 +++++++++++++++++++++++++ 2 files changed, 95 insertions(+), 8 deletions(-) create mode 100644 aana/tests/test_sampling_params.py diff --git a/aana/models/pydantic/sampling_params.py b/aana/models/pydantic/sampling_params.py index 7e595ac6..dd3265c6 100644 --- a/aana/models/pydantic/sampling_params.py +++ b/aana/models/pydantic/sampling_params.py @@ -1,6 +1,6 @@ from typing import Optional -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, validator class SamplingParams(BaseModel): @@ -8,18 +8,53 @@ class SamplingParams(BaseModel): A model for sampling parameters of LLM. Attributes: - temperature (float): The temperature. - top_p (float): Top-p. - top_k (int): Top-k. + temperature (float): Float that controls the randomness of the sampling. Lower + values make the model more deterministic, while higher values make + the model more random. Zero means greedy sampling. + top_p (float): Float that controls the cumulative probability of the top tokens + to consider. Must be in (0, 1]. Set to 1 to consider all tokens. + top_k (int): Integer that controls the number of top tokens to consider. Set + to -1 to consider all tokens. max_tokens (int): The maximum number of tokens to generate. """ - temperature: Optional[float] = Field(default=None, description="The temperature.") - top_p: Optional[float] = Field(default=None, description="Top-p.") - top_k: Optional[int] = Field(default=None, description="Top-k.") + temperature: Optional[float] = Field( + default=None, + ge=0.0, + description=( + "Float that controls the randomness of the sampling. " + "Lower values make the model more deterministic, " + "while higher values make the model more random. " + "Zero means greedy sampling." + ), + ) + top_p: Optional[float] = Field( + default=None, + gt=0.0, + le=1.0, + description=( + "Float that controls the cumulative probability of the top tokens to consider. " + "Must be in (0, 1]. Set to 1 to consider all tokens." + ), + ) + top_k: Optional[int] = Field( + default=None, + description=( + "Integer that controls the number of top tokens to consider. " + "Set to -1 to consider all tokens." + ), + ) max_tokens: Optional[int] = Field( - default=None, description="The maximum number of tokens to generate." + default=None, ge=1, description="The maximum number of tokens to generate." ) + @validator("top_k", always=True, pre=True) + def check_top_k(cls, v): + if v is None: + return v + if v < -1 or v == 0: + raise ValueError(f"top_k must be -1 (disable), or at least 1, got {v}.") + return v + class Config: schema_extra = {"description": "Sampling parameters for generating text."} diff --git a/aana/tests/test_sampling_params.py b/aana/tests/test_sampling_params.py new file mode 100644 index 00000000..f26c03a3 --- /dev/null +++ b/aana/tests/test_sampling_params.py @@ -0,0 +1,52 @@ +import pytest + +from aana.models.pydantic.sampling_params import SamplingParams + +def test_valid_sampling_params(): + """ + Test valid sampling parameters. + """ + params = SamplingParams(temperature=0.5, top_p=0.9, top_k=10, max_tokens=50) + assert params.temperature == 0.5 + assert params.top_p == 0.9 + assert params.top_k == 10 + assert params.max_tokens == 50 + + # Test valid params with default values (None) + params = SamplingParams() + assert params.temperature is None + assert params.top_p is None + assert params.top_k is None + assert params.max_tokens is None + +def test_invalid_temperature(): + """ + Test invalid temperature values. + """ + with pytest.raises(ValueError): + SamplingParams(temperature=-1.0) + +def test_invalid_top_p(): + """ + Test invalid top_p values. + """ + with pytest.raises(ValueError): + SamplingParams(top_p=0.0) + with pytest.raises(ValueError): + SamplingParams(top_p=1.1) + +def test_invalid_top_k(): + """ + Test invalid top_k values. + """ + with pytest.raises(ValueError): + SamplingParams(top_k=0) + with pytest.raises(ValueError): + SamplingParams(top_k=-2) + +def test_invalid_max_tokens(): + """ + Test invalid max_tokens values. + """ + with pytest.raises(ValueError): + SamplingParams(max_tokens=0) From 981f05bcb63878555cb9f7503f5efae786821ac6 Mon Sep 17 00:00:00 2001 From: movchan74 Date: Mon, 30 Oct 2023 16:39:39 +0000 Subject: [PATCH 6/7] Improved test for vllm --- .../tests/deployments/test_vllm_deployment.py | 33 ++++++++++++++----- 1 file changed, 24 insertions(+), 9 deletions(-) diff --git a/aana/tests/deployments/test_vllm_deployment.py b/aana/tests/deployments/test_vllm_deployment.py index 0b344f67..009a31e2 100644 --- a/aana/tests/deployments/test_vllm_deployment.py +++ b/aana/tests/deployments/test_vllm_deployment.py @@ -7,6 +7,8 @@ from aana.configs.deployments import deployments from aana.models.pydantic.sampling_params import SamplingParams +ALLOWED_LEVENSTEIN_ERROR_RATE = 0.1 + def expected_output(name): if name == "vllm_deployment_llama2_7b_chat": @@ -28,6 +30,25 @@ def ray_setup(deployment): return handle +def compare_texts(expected_text: str, text: str): + """ + Compare two texts using Levenshtein distance. + The error rate is allowed to be less than ALLOWED_LEVENSTEIN_ERROR_RATE. + + Args: + expected_text (str): the expected text + text (str): the actual text + + Raises: + AssertionError: if the error rate is too high + """ + dist = rapidfuzz.distance.Levenshtein.distance(text, expected_text) + assert dist < len(expected_text) * ALLOWED_LEVENSTEIN_ERROR_RATE, ( + expected_text, + text, + ) + + @pytest.mark.asyncio async def test_vllm_deployments(): for name, deployment in deployments.items(): @@ -40,10 +61,7 @@ async def test_vllm_deployments(): ) text = output["text"] expected_text = expected_output(name) - dist = rapidfuzz.distance.Levenshtein.distance(text, expected_text) - assert ( - dist <= len(expected_text) * 0.1 - ) # Allow 10% difference in case of randomness + compare_texts(expected_text, text) # test generate_stream method stream = handle.options(stream=True).generate_stream.remote( @@ -55,8 +73,7 @@ async def test_vllm_deployments(): chunk = await chunk text += chunk["text"] expected_text = expected_output(name) - dist = rapidfuzz.distance.Levenshtein.distance(text, expected_text) - assert dist <= len(expected_text) * 0.1 + compare_texts(expected_text, text) # test generate_batch method output = await handle.generate_batch.remote( @@ -68,8 +85,6 @@ async def test_vllm_deployments(): ) texts = output["texts"] expected_text = expected_output(name) - print(texts) for text in texts: - dist = rapidfuzz.distance.Levenshtein.distance(text, expected_text) - assert dist <= len(expected_text) * 0.1 + compare_texts(expected_text, text) From 656eec4bc9be6966f3e82c10e0a2ba6ee5ed1f1a Mon Sep 17 00:00:00 2001 From: movchan74 Date: Wed, 1 Nov 2023 10:26:03 +0000 Subject: [PATCH 7/7] Use BaseException from Mobius Pipeline as base for all Aana exceptions --- aana/api/app.py | 40 ++++++++------------------ aana/exceptions/general.py | 59 ++------------------------------------ mobius-pipeline | 2 +- 3 files changed, 15 insertions(+), 86 deletions(-) diff --git a/aana/api/app.py b/aana/api/app.py index 867be7a5..48f9946a 100644 --- a/aana/api/app.py +++ b/aana/api/app.py @@ -2,11 +2,10 @@ from typing import Union from fastapi import FastAPI, Request from fastapi.responses import JSONResponse -from mobius_pipeline.exceptions import PipelineException +from mobius_pipeline.exceptions import BaseException from pydantic import ValidationError from ray.exceptions import RayTaskError -from aana.exceptions.general import AanaException app = FastAPI() @@ -31,17 +30,17 @@ async def validation_exception_handler(request: Request, exc: ValidationError): def custom_exception_handler( - request: Request, exc_raw: Union[PipelineException, AanaException, RayTaskError] + request: Request, exc_raw: Union[BaseException, RayTaskError] ): """ This handler is used to handle custom exceptions raised in the application. - PipelineException is the exception raised by the Mobius Pipeline. - AanaException is the exception raised by the Aana application. + BaseException is the base exception for all the exceptions + from the Mobius Pipeline and Aana application. Sometimes custom exception are wrapped into RayTaskError so we need to handle that as well. Args: request (Request): The request object - exc_raw (Union[PipelineException, AanaException, RayTaskError]): The exception raised + exc_raw (Union[BaseException, RayTaskError]): The exception raised Returns: JSONResponse: JSON response with the error details. The response contains the following fields: @@ -50,15 +49,15 @@ def custom_exception_handler( data: The additional data returned by the exception that can be used to identify the error (e.g. image path, url, model name etc.) stacktrace: The stacktrace of the exception. """ - # a PipelineException or AanaException can be wrapped into a RayTaskError + # a BaseException can be wrapped into a RayTaskError if isinstance(exc_raw, RayTaskError): # str(e) returns whole stack trace # if exception is a RayTaskError # let's use it to get the stack trace stacktrace = str(exc_raw) # get the original exception - exc: Union[PipelineException, AanaException] = exc_raw.cause - assert isinstance(exc, (PipelineException, AanaException)) + exc: BaseException = exc_raw.cause + assert isinstance(exc, BaseException) else: # if it is not a RayTaskError # then we need to get the stack trace @@ -84,29 +83,14 @@ def custom_exception_handler( ) -@app.exception_handler(PipelineException) -async def pipeline_exception_handler(request: Request, exc: PipelineException): +@app.exception_handler(BaseException) +async def pipeline_exception_handler(request: Request, exc: BaseException): """ - This handler is used to handle exceptions raised by the Mobius Pipeline. + This handler is used to handle exceptions raised by the Mobius Pipeline and Aana application. Args: request (Request): The request object - exc (PipelineException): The exception raised - - Returns: - JSONResponse: JSON response with the error details - """ - return custom_exception_handler(request, exc) - - -@app.exception_handler(AanaException) -async def aana_exception_handler(request: Request, exc: AanaException): - """ - This handler is used to handle exceptions raised by the Aana application. - - Args: - request (Request): The request object - exc (AanaException): The exception raised + exc (BaseException): The exception raised Returns: JSONResponse: JSON response with the error details diff --git a/aana/exceptions/general.py b/aana/exceptions/general.py index 3902122d..558ac51e 100644 --- a/aana/exceptions/general.py +++ b/aana/exceptions/general.py @@ -1,63 +1,8 @@ from typing import Any, Dict +from mobius_pipeline.exceptions import BaseException -class AanaException(Exception): - """ - Base class for SDK exceptions. - """ - - extra = {} - - def __str__(self) -> str: - """ - Return a string representation of the exception. - - String is defined as follows: - ``` - (extra_key1=extra_value1, extra_key2=extra_value2, ...) - ``` - """ - class_name = self.__class__.__name__ - extra_str = "" - for key, value in self.extra.items(): - extra_str += f", {key}={value}" - return f"{class_name}({extra_str})" - - def get_data(self) -> Dict[str, Any]: - """ - Get the data to be returned to the client. - - Returns: - Dict[str, Any]: data to be returned to the client - """ - data = self.extra.copy() - return data - - def add_extra(self, key: str, value: Any): - """ - Add extra data to the exception. - - This data will be returned to the user as part of the response. - - How to use: in the exception handler, add the extra data to the exception and raise it again. - - Example: - ``` - try: - ... - except AanaException as e: - e.add_extra('extra_key', 'extra_value') - raise e - ``` - - Args: - key (str): key of the extra data - value (Any): value of the extra data - """ - self.extra[key] = value - - -class InferenceException(AanaException): +class InferenceException(BaseException): """Exception raised when there is an error during inference. Attributes: diff --git a/mobius-pipeline b/mobius-pipeline index 7600ba01..8bdf633a 160000 --- a/mobius-pipeline +++ b/mobius-pipeline @@ -1 +1 @@ -Subproject commit 7600ba015f4d2a47db30f51f63939aaadfda6949 +Subproject commit 8bdf633aaa9227b732b56a096ae04c6ebe4e8060