diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml new file mode 100644 index 00000000..d512876b --- /dev/null +++ b/.github/workflows/python-package.yml @@ -0,0 +1,38 @@ +# This workflow will install Python dependencies, run tests and lint with a variety of Python versions +# For more information see: https://docs.github.com/en/ actions/automating-builds-and-tests/building-and-testing-python + +name: Python package + +on: + push: + branches: + - '*' # Runs on push to any branch + pull_request: + branches: + - '*' # Runs on pull requests to any branch + workflow_dispatch: # Allows for manual triggering + +jobs: + build: + + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: ["3.10"] + + steps: + - name: Checkout code + uses: actions/checkout@v3 + with: + token: ${{ secrets.GH_MOBIUS_PIPELINE_TOKEN }} + submodules: recursive + - name: Bootstrap poetry + run: | + curl -sSL https://install.python-poetry.org | python - -y + - name: Update PATH + run: echo "$HOME/.local/bin" >> $GITHUB_PATH + - name: Install dependencies + run: poetry install + - name: Test with pytest + run: poetry run pytest diff --git a/.vscode/settings.json b/.vscode/settings.json index d99f2f30..c603b8d2 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -2,5 +2,10 @@ "[python]": { "editor.defaultFormatter": "ms-python.black-formatter" }, - "python.formatting.provider": "none" + "python.formatting.provider": "none", + "python.testing.pytestArgs": [ + "aana" + ], + "python.testing.unittestEnabled": false, + "python.testing.pytestEnabled": true } \ No newline at end of file diff --git a/Dockerfile b/Dockerfile index a95f4d81..1b237b74 100644 --- a/Dockerfile +++ b/Dockerfile @@ -23,11 +23,15 @@ COPY . /app # Install the package with poetry RUN sh install.sh +# Prepare the startup script +RUN chmod +x startup.sh + # Disable buffering for stdout and stderr to get the logs in real time ENV PYTHONUNBUFFERED=1 # Expose the desired port EXPOSE 8000 -# Set the command to run the SDK when the container starts -CMD ["poetry", "run", "serve", "run", "--port", "8000", "--host", "0.0.0.0", "aana.main:server"] +# Run the startup script, TARGET can be set with environment variables +ENV TARGET=llama2 +CMD ["/app/startup.sh"] diff --git a/README.md b/README.md index eb57eec4..9665ef26 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,5 @@ +[![Python package](https://github.com/mobiusml/aana_sdk/actions/workflows/python-package.yml/badge.svg)](https://github.com/mobiusml/aana_sdk/actions/workflows/python-package.yml) + # Aana Aana is a multi-model SDK for deploying and serving machine learning models. @@ -28,9 +30,11 @@ sh install.sh 5. Run the SDK. ```bash -CUDA_VISIBLE_DEVICES=0 poetry run serve run --port 8000 --host 0.0.0.0 aana.main:server +CUDA_VISIBLE_DEVICES=0 poetry run aana --port 8000 --host 0.0.0.0 --target llama2 ``` +The target parameter specifies the set of endpoints to deploy. + The first run might take a while because the models will be downloaded from Google Drive and cached. Once you see `Deployed Serve app successfully.` in the logs, the server is ready to accept requests. @@ -66,9 +70,11 @@ docker build -t aana:0.1.0 . 4. Run the Docker container. ```bash -docker run --rm --init -p 8000:8000 --gpus all -e CUDA_VISIBLE_DEVICES=0 -v aana_cache:/root/.aana -v aana_hf_cache:/root/.cache/huggingface --name aana_instance aana:0.1.0 +docker run --rm --init -p 8000:8000 --gpus all -e TARGET="llama2" -e CUDA_VISIBLE_DEVICES=0 -v aana_cache:/root/.aana -v aana_hf_cache:/root/.cache/huggingface --name aana_instance aana:0.1.0 ``` +Use the environment variable TARGET to specify the set of endpoints to deploy. + The first run might take a while because the models will be downloaded from Google Drive and cached. The models will be stored in the `aana_cache` volume. The HuggingFace models will be stored in the `aana_hf_cache` volume. If you want to remove the cached models, remove the volume. Once you see `Deployed Serve app successfully.` in the logs, the server is ready to accept requests. diff --git a/aana/api/api_generation.py b/aana/api/api_generation.py new file mode 100644 index 00000000..2ed87ef1 --- /dev/null +++ b/aana/api/api_generation.py @@ -0,0 +1,364 @@ +from dataclasses import dataclass +from enum import Enum +from typing import Dict, Tuple, Type, Any, List, Optional +from fastapi import FastAPI, File, Form, UploadFile +from mobius_pipeline.pipeline.pipeline import Pipeline +from mobius_pipeline.node.socket import Socket +from pydantic import Field, create_model, BaseModel, parse_raw_as +from aana.api.responses import AanaJSONResponse + +from aana.exceptions.general import MultipleFileUploadNotAllowed +from aana.models.pydantic.exception_response import ExceptionResponseModel + + +async def run_pipeline( + pipeline: Pipeline, data: Dict, required_outputs: List[str] +) -> Dict[str, float]: + """ + This function is used to run a Mobius Pipeline. + It creates a container from the data, runs the pipeline and returns the output. + + Args: + pipeline (Pipeline): The pipeline to run. + data (dict): The data to create the container from. + required_outputs (List[str]): The required outputs of the pipeline. + + Returns: + dict[str, float]: The output of the pipeline and the execution time of the pipeline. + """ + + # create a container from the data + container = pipeline.parse_dict(data) + + # run the pipeline + output, execution_time = await pipeline.run( + container, required_outputs, return_execution_time=True + ) + output["execution_time"] = execution_time + return output + + +@dataclass +class OutputFilter: + """ + Class used to represent an output filter. + + The output filter is a parameter that will be added to the request + and will allow to choose subset of `outputs` to return. + + Attributes: + name (str): Name of the output filter. + description (str): Description of the output filter. + """ + + name: str + description: str + + +@dataclass +class FileUploadField: + """ + Class used to represent a file upload field. + + Attributes: + name (str): Name of the field. + description (str): Description of the field. + """ + + name: str + description: str + + +@dataclass +class Endpoint: + """ + Class used to represent an endpoint. + + Attributes: + name (str): Name of the endpoint. + path (str): Path of the endpoint. + summary (str): Description of the endpoint that will be shown in the API documentation. + outputs (List[str]): List of required outputs from the pipeline that should be returned + by the endpoint. + output_filter (Optional[OutputFilter]): The parameter will be added to the request and + will allow to choose subset of `outputs` to return. + streaming (bool): Whether the endpoint outputs a stream of data. + """ + + name: str + path: str + summary: str + outputs: List[str] + output_filter: Optional[OutputFilter] = None + streaming: bool = False + + def generate_model_name(self, suffix: str) -> str: + """ + Generate a Pydantic model name based on a given suffix. + + Parameters: + suffix (str): Suffix for the model name (e.g. "Request", "Response"). + + Returns: + str: Generated model name. + """ + return "".join([word.capitalize() for word in self.name.split("_")]) + suffix + + def socket_to_field(self, socket: Socket) -> Tuple[Any, Any]: + """ + Convert a socket to a Pydantic field. + + Parameters: + socket (Socket): Socket to convert. + + Returns: + Tuple[Any, Field]: Tuple of the socket's data model and a Pydantic field. + """ + data_model = socket.data_model + + # if data model is None or Any, set it to Any + if data_model is None or data_model == Any: + data_model = Any + return (data_model, Field(None)) + + # check if any of the fields are required + if any(field.required for field in data_model.__fields__.values()): + return (data_model, ...) + + return (data_model, data_model()) + + def get_fields(self, sockets: List[Socket]) -> Dict[str, Tuple[Any, Any]]: + """ + Generate fields for the Pydantic model based on the provided sockets. + + Parameters: + sockets (List[Socket]): List of sockets. + + Returns: + Dict[str, Tuple[Any, Field]]: Dictionary of fields for the Pydantic model. + """ + fields = {} + for socket in sockets: + field = self.socket_to_field(socket) + fields[socket.name] = field + return fields + + def get_file_upload_field( + self, input_sockets: List[Socket] + ) -> Optional[FileUploadField]: + """ + Get the file upload field for the endpoint. + + Parameters: + input_sockets (List[Socket]): List of input sockets. + + Returns: + Optional[FileUploadField]: File upload field or None if not found. + + Raises: + MultipleFileUploadNotAllowed: If multiple inputs require file upload. + """ + + file_upload_field = None + for socket in input_sockets: + data_model = socket.data_model + + # skip sockets with no data model + if data_model is None or data_model == Any: + continue + + # check if pydantic model has file_upload field and it's set to True + file_upload_enabled = getattr(data_model.Config, "file_upload", False) + file_upload_description = getattr( + data_model.Config, "file_upload_description", "" + ) + + if file_upload_enabled and file_upload_field is None: + file_upload_field = FileUploadField( + name=socket.name, description=file_upload_description + ) + elif file_upload_enabled and file_upload_field is not None: + # raise an exception if multiple inputs require file upload + raise MultipleFileUploadNotAllowed(socket.name) + return file_upload_field + + def get_output_filter_field(self) -> Optional[Tuple[Any, Any]]: + """ + Get the output filter field for the endpoint. + + Returns: + Optional[Tuple[Any, Field]]: Output filter field or None if not found. + """ + if not self.output_filter: + return None + + description = self.output_filter.description + outputs_enum_name = self.generate_model_name("Outputs") + outputs_enum = Enum( # type: ignore + outputs_enum_name, [(output, output) for output in self.outputs], type=str + ) + field = (Optional[List[outputs_enum]], Field(None, description=description)) + return field + + def get_request_model(self, input_sockets: List[Socket]) -> Type[BaseModel]: + """ + Generate a Pydantic model for the request. + + Parameters: + input_sockets (List[Socket]): List of input sockets. + + Returns: + Type[BaseModel]: Pydantic model for the request. + """ + model_name = self.generate_model_name("Request") + input_fields = self.get_fields(input_sockets) + output_filter_field = self.get_output_filter_field() + if output_filter_field and self.output_filter: + input_fields[self.output_filter.name] = output_filter_field + RequestModel = create_model(model_name, **input_fields) + return RequestModel + + def get_response_model(self, output_sockets: List[Socket]) -> Type[BaseModel]: + """ + Generate a Pydantic model for the response. + + Parameters: + output_sockets (List[Socket]): List of output sockets. + + Returns: + Type[BaseModel]: Pydantic model for the response. + """ + model_name = self.generate_model_name("Response") + output_fields = self.get_fields(output_sockets) + ResponseModel = create_model(model_name, **output_fields) + return ResponseModel + + def create_endpoint_func( + self, + pipeline: Pipeline, + RequestModel: Type[BaseModel], + file_upload_field: Optional[FileUploadField] = None, + ): + async def route_func_body(body: str, files: Optional[List[UploadFile]] = None): + # parse form data as a pydantic model and validate it + data = parse_raw_as(RequestModel, body) + + # if the input requires file upload, add the files to the data + if file_upload_field and files: + files_as_bytes = [await file.read() for file in files] + getattr(data, file_upload_field.name).set_files(files_as_bytes) + + # We have to do this instead of data.dict() because + # data.dict() will convert all nested models to dicts + # and we want to keep them as pydantic models + data_dict = {} + for field_name in data.__fields__: + field_value = getattr(data, field_name) + # check if it has a method convert_to_entities + # if it does, call it to convert the model to an entity + if hasattr(field_value, "convert_to_entity"): + field_value = field_value.convert_to_entity() + data_dict[field_name] = field_value + + if self.output_filter: + requested_outputs = data_dict.get(self.output_filter.name, None) + else: + requested_outputs = None + + # if user requested specific outputs, use them + if requested_outputs: + # get values for requested outputs because it's a list of enums + requested_outputs = [output.value for output in requested_outputs] + outputs = requested_outputs + # otherwise use the required outputs from the config (all outputs endpoint provides) + else: + outputs = self.outputs + + # remove the output filter parameter from the data + if self.output_filter and self.output_filter.name in data_dict: + del data_dict[self.output_filter.name] + + # run the pipeline + output = await run_pipeline(pipeline, data_dict, outputs) + return AanaJSONResponse(content=output) + + if file_upload_field: + files = File(None, description=file_upload_field.description) + else: + files = None + + async def route_func(body: str = Form(...), files=files): + return await route_func_body(body=body, files=files) + + return route_func + + def register( + self, app: FastAPI, pipeline: Pipeline, custom_schemas: Dict[str, Dict] + ): + """ + Register an endpoint to the FastAPI app and add schemas to the custom schemas dictionary. + + Parameters: + app (FastAPI): FastAPI app to register the endpoint to. + pipeline (Pipeline): Pipeline to register the endpoint to. + custom_schemas (Dict[str, Dict]): Dictionary of custom schemas. + """ + input_sockets, output_sockets = pipeline.get_sockets(self.outputs) + RequestModel = self.get_request_model(input_sockets) + ResponseModel = self.get_response_model(output_sockets) + file_upload_field = self.get_file_upload_field(input_sockets) + route_func = self.create_endpoint_func( + pipeline=pipeline, + RequestModel=RequestModel, + file_upload_field=file_upload_field, + ) + app.post( + self.path, + summary=self.summary, + name=self.name, + operation_id=self.name, + response_model=ResponseModel, + responses={ + 400: {"model": ExceptionResponseModel}, + }, + )(route_func) + custom_schemas[self.name] = RequestModel.schema() + + +def add_custom_schemas_to_openapi_schema( + openapi_schema: Dict[str, Any], custom_schemas: Dict[str, Any] +) -> Dict[str, Any]: + """ + Add custom schemas to the openapi schema. + + File upload is that FastAPI doesn't support Pydantic models in multipart requests. + There is a discussion about it on FastAPI discussion forum. + See https://github.com/tiangolo/fastapi/discussions/8406 + The topic starter suggests a workaround. + The workaround is to use Forms instead of Pydantic models in the endpoint definition and + then convert the Forms to Pydantic models in the endpoint itself + using parse_raw_as function from Pydantic. + Since Pydantic model isn't used in the endpoint definition, + the API documentation will not be generated automatically. + So the workaround also suggests updating the API documentation manually + by overriding the openapi method of a FastAPI application. + + Args: + openapi_schema (dict): The openapi schema. + custom_schemas (dict): The custom schemas. + + Returns: + dict: The openapi schema with the custom schemas added. + """ + + if "definitions" not in openapi_schema: + openapi_schema["definitions"] = {} + for schema_name, schema in custom_schemas.items(): + # if we have a definitions then we need to move them out to the top level of the schema + if "definitions" in schema: + openapi_schema["definitions"].update(schema["definitions"]) + del schema["definitions"] + openapi_schema["components"]["schemas"][f"Body_{schema_name}"]["properties"][ + "body" + ] = schema + return openapi_schema diff --git a/aana/api/app.py b/aana/api/app.py index 48f9946a..69606db7 100644 --- a/aana/api/app.py +++ b/aana/api/app.py @@ -1,11 +1,11 @@ import traceback from typing import Union from fastapi import FastAPI, Request -from fastapi.responses import JSONResponse from mobius_pipeline.exceptions import BaseException from pydantic import ValidationError from ray.exceptions import RayTaskError - +from aana.api.responses import AanaJSONResponse +from aana.models.pydantic.exception_response import ExceptionResponseModel app = FastAPI() @@ -22,10 +22,14 @@ async def validation_exception_handler(request: Request, exc: ValidationError): Returns: JSONResponse: JSON response with the error details """ - # TODO: Structure the error response so that it is consistent with the other error responses - return JSONResponse( + + return AanaJSONResponse( status_code=422, - content={"detail": exc.errors()}, + content=ExceptionResponseModel( + error="ValidationError", + message="Validation error", + data=exc.errors(), + ).dict(), ) @@ -72,14 +76,11 @@ def custom_exception_handler( error = exc.__class__.__name__ # get the message of the exception message = str(exc) - return JSONResponse( + return AanaJSONResponse( status_code=400, - content={ - "error": error, - "message": message, - "data": data, - "stacktrace": stacktrace, - }, + content=ExceptionResponseModel( + error=error, message=message, data=data, stacktrace=stacktrace + ).dict(), ) @@ -116,7 +117,9 @@ async def ray_task_error_handler(request: Request, exc: RayTaskError): error = exc.__class__.__name__ stacktrace = traceback.format_exc() - return JSONResponse( + return AanaJSONResponse( status_code=400, - content={"error": error, "message": str(exc), "stacktrace": stacktrace}, + content=ExceptionResponseModel( + error=error, message=str(exc), stacktrace=stacktrace + ).dict(), ) diff --git a/aana/api/request_handler.py b/aana/api/request_handler.py index f62314ce..b93b13be 100644 --- a/aana/api/request_handler.py +++ b/aana/api/request_handler.py @@ -1,38 +1,12 @@ -from typing import Dict, List, Tuple +from typing import Any, Dict, List from ray import serve +from fastapi.openapi.utils import get_openapi from mobius_pipeline.pipeline import Pipeline +from aana.api.api_generation import Endpoint, add_custom_schemas_to_openapi_schema from aana.api.app import app from aana.api.responses import AanaJSONResponse -from aana.configs.pipeline import nodes -from aana.models.pydantic.llm_request import LLMRequest - - -async def run_pipeline( - pipeline: Pipeline, data: Dict, required_outputs: List[str] -) -> Tuple[Dict, Dict[str, float]]: - """ - This function is used to run a Mobius Pipeline. - It creates a container from the data, runs the pipeline and returns the output. - - Args: - pipeline (Pipeline): The pipeline to run. - data (dict): The data to create the container from. - required_outputs (List[str]): The required outputs of the pipeline. - - Returns: - tuple[dict, dict[str, float]]: The output of the pipeline and the execution time of the pipeline. - """ - - # create a container from the data - container = pipeline.parse_dict(data) - - # run the pipeline - output, execution_time = await pipeline.run( - container, required_outputs, return_execution_time=True - ) - return output, execution_time @serve.deployment(route_prefix="/", num_replicas=1, ray_actor_options={"num_cpus": 0.1}) @@ -40,37 +14,64 @@ async def run_pipeline( class RequestHandler: """This class is used to handle requests to the Aana application.""" - def __init__(self, deployments: Dict): + ready = False + + def __init__( + self, + endpoints: List[Endpoint], + pipeline_nodes: List[Dict[str, Any]], + context: Dict[str, Any], + ): """ Args: deployments (Dict): The dictionary of deployments. - It is passed to the context to the pipeline so the pipeline can access the deployments handles. + It is passed to the context to the pipeline + so the pipeline can access the deployments handles. + """ + + self.context = context + self.endpoints = endpoints + self.pipeline = Pipeline(pipeline_nodes, context) + + self.custom_schemas: Dict[str, Dict] = {} + for endpoint in self.endpoints: + endpoint.register( + app=app, pipeline=self.pipeline, custom_schemas=self.custom_schemas + ) + + app.openapi = self.custom_openapi + self.ready = True + + def custom_openapi(self) -> Dict[str, Any]: + if app.openapi_schema: + return app.openapi_schema + # TODO: populate title and version from package info + openapi_schema = get_openapi(title="Aana", version="0.1.0", routes=app.routes) + openapi_schema = add_custom_schemas_to_openapi_schema( + openapi_schema=openapi_schema, custom_schemas=self.custom_schemas + ) + app.openapi_schema = openapi_schema + return app.openapi_schema + + def get_context(self): """ - self.context = { - "deployments": deployments, - } - self.pipeline = Pipeline(nodes, self.context) + Returns: + dict: The context of the pipeline. + """ + return self.context - @app.post("/llm/generate") - async def generate_llm(self, llm_request: LLMRequest) -> AanaJSONResponse: + @app.get("/api/ready") + async def is_ready(self): """ - The endpoint for running the LLM. - It is running the pipeline with the given prompt and sampling parameters. - This is here as an example and will be replace with automatic endpoint generation. + The endpoint for checking if the application is ready. - Args: - llm_request (LLMRequest): The LLM request. It contains the prompt and sampling parameters. + Real reason for this endpoint is to make automatic endpoint generation work. + If RequestHandler doesn't have any endpoints defined manually, + then the automatic endpoint generation doesn't work. + #TODO: Find a better solution for this. Returns: - AanaJSONResponse: The response containing the output of the pipeline and the execution time. + AanaJSONResponse: The response containing the ready status. """ - prompt = llm_request.prompt - sampling_params = llm_request.sampling_params - output, execution_time = await run_pipeline( - self.pipeline, - {"prompt": prompt, "sampling_params": sampling_params}, - ["vllm_llama2_7b_chat_output"], - ) - output["execution_time"] = execution_time - return AanaJSONResponse(content=output) + return AanaJSONResponse(content={"ready": self.ready}) diff --git a/aana/configs/build.py b/aana/configs/build.py new file mode 100644 index 00000000..8cd724db --- /dev/null +++ b/aana/configs/build.py @@ -0,0 +1,83 @@ +from typing import Dict + +from mobius_pipeline.node.node_definition import NodeDefinition +from mobius_pipeline.pipeline.output_graph import OutputGraph + + +def get_configuration(target: str, endpoints, nodes, deployments) -> Dict: + """ + Returns the configuration for the specified target. + + A target is a set of endpoints that are to be deployed together. + + All the targets are defined in the endpoints.py file. + + The function finds: + - which endpoints are to be deployed + - which nodes are to be used in the pipeline + - which Ray Deployments need to be deployed + + Args: + target (str): The name of the target to be deployed. + endpoints (Dict): The dictionary of endpoints. + nodes (List): The list of nodes. + deployments (Dict): The dictionary of Ray Deployments. + + Returns: + Dict: The dictionary with the configuration. + The configuration contains 3 keys: + - endpoints + - nodes + - deployments + """ + + # Check if target is valid + if target not in endpoints: + raise ValueError( + f"Invalid target: {target}. Valid targets: {', '.join(endpoints.keys())}" + ) + + # Find the endpoints that are to be deployed + target_endpoints = endpoints[target] + + # Target endpoints require the following outputs + endpoint_outputs = [] + for endpoint in target_endpoints: + endpoint_outputs += endpoint.outputs + + # Build the output graph for the whole pipeline + node_definitions = [NodeDefinition.from_dict(node_dict) for node_dict in nodes] + outputs_graph = OutputGraph(node_definitions) + + # Find what inputs are required for the endpoint outputs + inputs = outputs_graph.find_input_nodes(endpoint_outputs) + # Target outputs are the inputs + subgraph of the pipeline + # that is required to generate the outputs for endpoints + target_outputs = inputs + outputs_graph.find_subgraph(inputs, endpoint_outputs) + + # Now we have the target outputs, we can find the nodes that generate them. + # Find the nodes that generate the target outputs + target_nodes = [] + for node in nodes: + node_output_names = [output["name"] for output in node["outputs"]] + if any([output in node_output_names for output in target_outputs]): + target_nodes.append(node) + + # Now we have the target nodes, we can find the Ray Deployments that they use. + # Find the Ray Deployments that are used by the target nodes + target_deployment_names = set() + for node in target_nodes: + if node["type"] == "ray_deployment": + target_deployment_names.add(node["deployment_name"]) + + target_deployments = {} + for deployment_name in target_deployment_names: + if deployment_name not in deployments: + raise ValueError(f"Deployment {deployment_name} is not defined.") + target_deployments[deployment_name] = deployments[deployment_name] + + return { + "endpoints": target_endpoints, + "nodes": target_nodes, + "deployments": target_deployments, + } diff --git a/aana/configs/deployments.py b/aana/configs/deployments.py index bb8e1535..7c66724e 100644 --- a/aana/configs/deployments.py +++ b/aana/configs/deployments.py @@ -1,7 +1,6 @@ from aana.deployments.vllm_deployment import VLLMConfig, VLLMDeployment from aana.models.pydantic.sampling_params import SamplingParams - deployments = { "vllm_deployment_llama2_7b_chat": VLLMDeployment.options( num_replicas=1, @@ -17,4 +16,19 @@ ), ).dict(), ), + "vllm_deployment_zephyr_7b_beta": VLLMDeployment.options( + num_replicas=1, + max_concurrent_queries=1000, + ray_actor_options={"num_gpus": 0.5}, + user_config=VLLMConfig( + model="TheBloke/zephyr-7B-beta-AWQ", + dtype="auto", + quantization="awq", + gpu_memory_utilization=0.9, + max_model_len=512, + default_sampling_params=SamplingParams( + temperature=1.0, top_p=1.0, top_k=-1, max_tokens=256 + ), + ).dict(), + ), } diff --git a/aana/configs/endpoints.py b/aana/configs/endpoints.py new file mode 100644 index 00000000..147544ec --- /dev/null +++ b/aana/configs/endpoints.py @@ -0,0 +1,21 @@ +from aana.api.api_generation import Endpoint + + +endpoints = { + "llama2": [ + Endpoint( + name="llm_generate", + path="/llm/generate", + summary="Generate text using LLaMa2 7B Chat", + outputs=["vllm_llama2_7b_chat_output"], + ) + ], + "zephyr": [ + Endpoint( + name="zephyr_generate", + path="/llm/generate", + summary="Generate text using Zephyr 7B Beta", + outputs=["vllm_zephyr_7b_beta_output"], + ) + ], +} diff --git a/aana/configs/pipeline.py b/aana/configs/pipeline.py index cd3e79a4..57bc31ae 100644 --- a/aana/configs/pipeline.py +++ b/aana/configs/pipeline.py @@ -13,6 +13,8 @@ # sampling_params: SamplingParams # vllm_llama2_7b_chat_output_stream: str # vllm_llama2_7b_chat_output: str +# vllm_zephyr_7b_beta_output_stream: str +# vllm_zephyr_7b_beta_output: str nodes = [ @@ -81,4 +83,48 @@ } ], }, + { + "name": "vllm_stream_zephyr_7b_beta", + "type": "ray_deployment", + "deployment_name": "vllm_deployment_zephyr_7b_beta", + "data_type": "generator", + "generator_path": "prompt", + "method": "generate_stream", + "inputs": [ + {"name": "prompt", "key": "prompt", "path": "prompt"}, + { + "name": "sampling_params", + "key": "sampling_params", + "path": "sampling_params", + }, + ], + "outputs": [ + { + "name": "vllm_zephyr_7b_beta_output_stream", + "key": "text", + "path": "vllm_zephyr_7b_beta_output_stream", + } + ], + }, + { + "name": "vllm_zephyr_7b_beta", + "type": "ray_deployment", + "deployment_name": "vllm_deployment_zephyr_7b_beta", + "method": "generate", + "inputs": [ + {"name": "prompt", "key": "prompt", "path": "prompt"}, + { + "name": "sampling_params", + "key": "sampling_params", + "path": "sampling_params", + }, + ], + "outputs": [ + { + "name": "vllm_zephyr_7b_beta_output", + "key": "text", + "path": "vllm_zephyr_7b_beta_output", + } + ], + }, ] diff --git a/aana/deployments/vllm_deployment.py b/aana/deployments/vllm_deployment.py index 2f279f03..1175c783 100644 --- a/aana/deployments/vllm_deployment.py +++ b/aana/deployments/vllm_deployment.py @@ -23,6 +23,7 @@ class VLLMConfig(BaseModel): quantization (str): the quantization method (optional, default: None) gpu_memory_utilization (float): the GPU memory utilization. default_sampling_params (SamplingParams): the default sampling parameters. + max_model_len (int): the maximum generated text length in tokens (optional, default: None) """ model: str @@ -30,6 +31,7 @@ class VLLMConfig(BaseModel): quantization: Optional[str] = Field(default=None) gpu_memory_utilization: float default_sampling_params: SamplingParams + max_model_len: Optional[int] = Field(default=None) @serve.deployment @@ -52,6 +54,7 @@ async def apply_config(self, config: Dict[str, Any]): - quantization: the quantization method (optional, default: None) - gpu_memory_utilization: the GPU memory utilization. - default_sampling_params: the default sampling parameters. + - max_model_len: the maximum generated text length in tokens (optional, default: None) Args: config (dict): the configuration of the deployment @@ -66,6 +69,7 @@ async def apply_config(self, config: Dict[str, Any]): dtype=config_obj.dtype, quantization=config_obj.quantization, gpu_memory_utilization=config_obj.gpu_memory_utilization, + max_model_len=config_obj.max_model_len, ) # TODO: check if the model is already loaded. diff --git a/aana/exceptions/general.py b/aana/exceptions/general.py index 558ac51e..ba0e4a3a 100644 --- a/aana/exceptions/general.py +++ b/aana/exceptions/general.py @@ -26,3 +26,25 @@ def __reduce__(self): # See https://bugs.python.org/issue32696#msg310963 for more info # TODO: check if there is a better way to do this return (self.__class__, (self.model_name,)) + + +class MultipleFileUploadNotAllowed(BaseException): + """ + Exception raised when multiple inputs require file upload. + + Attributes: + input_name -- name of the input + """ + + def __init__(self, input_name: str): + """ + Initialize the exception. + + Args: + input_name (str): name of the input that caused the exception + """ + self.input_name = input_name + super().__init__() + + def __reduce__(self): + return (self.__class__, (self.input_name,)) diff --git a/aana/main.py b/aana/main.py index 648ca499..65db8053 100644 --- a/aana/main.py +++ b/aana/main.py @@ -1,9 +1,65 @@ -from aana.api.request_handler import RequestHandler -from aana.configs.deployments import deployments +import argparse +import sys +import time +import traceback -# TODO: add build system to only serve the deployment if it's needed -binded_deployments = {} -for name, deployment in deployments.items(): - binded_deployments[name] = deployment.bind() -server = RequestHandler.bind(binded_deployments) +def run(): + arg_parser = argparse.ArgumentParser() + arg_parser.add_argument("--port", type=int, default=8000) + arg_parser.add_argument("--host", type=str, default="0.0.0.0") + arg_parser.add_argument( + "--target", + type=str, + required=True, + help="Specify the set of endpoints to be deployed.", + ) + args = arg_parser.parse_args() + + import ray + from ray import serve + from aana.api.request_handler import RequestHandler + from aana.configs.build import get_configuration + from aana.configs.endpoints import endpoints as all_endpoints + from aana.configs.pipeline import nodes as all_nodes + from aana.configs.deployments import deployments as all_deployments + + configuration = get_configuration( + args.target, + endpoints=all_endpoints, + nodes=all_nodes, + deployments=all_deployments, + ) + endpoints = configuration["endpoints"] + pipeline_nodes = configuration["nodes"] + deployments = configuration["deployments"] + + # Setup ray environment and serve + ray.init(ignore_reinit_error=True) + + context = { + "deployments": { + name: deployment.bind() for name, deployment in deployments.items() + } + } + try: + server = RequestHandler.bind(endpoints, pipeline_nodes, context) + serve.run(server, port=args.port, host=args.host) + # TODO: add logging + print("Deployed Serve app successfully.") + while True: + time.sleep(10) + + except KeyboardInterrupt: + print("Got KeyboardInterrupt, shutting down...") + serve.shutdown() + sys.exit() + + except Exception: + traceback.print_exc() + print( + "Received unexpected error, see console logs for more details. Shutting " + "down..." + ) + serve.shutdown() + sys.exit() diff --git a/aana/models/pydantic/exception_response.py b/aana/models/pydantic/exception_response.py new file mode 100644 index 00000000..dadcacd6 --- /dev/null +++ b/aana/models/pydantic/exception_response.py @@ -0,0 +1,22 @@ +from typing import Dict, Optional +from pydantic import BaseModel, Extra + + +class ExceptionResponseModel(BaseModel): + """ + This class is used to represent an exception response for 400 errors. + + Attributes: + error (str): The error that occurred. + message (str): The message of the error. + data (Optional[Dict]): The extra data that helps to debug the error. + stacktrace (Optional[str]): The stacktrace of the error. + """ + + error: str + message: str + data: Optional[Dict] = None + stacktrace: Optional[str] = None + + class Config: + extra = Extra.forbid diff --git a/aana/models/pydantic/llm_request.py b/aana/models/pydantic/llm_request.py deleted file mode 100644 index 6b0c784c..00000000 --- a/aana/models/pydantic/llm_request.py +++ /dev/null @@ -1,24 +0,0 @@ -from typing import Optional -from pydantic import BaseModel, Extra, Field - -from aana.models.pydantic.prompt import Prompt -from aana.models.pydantic.sampling_params import SamplingParams - - -class LLMRequest(BaseModel): - """ - This class is used to represent a request to LLM. - - Attributes: - prompt (Prompt): A prompt to LLM. - sampling_params (SamplingParams): Sampling parameters for generating text. - """ - - prompt: Prompt = Field(..., description="A prompt to LLM.") - sampling_params: Optional[SamplingParams] = Field( - None, description="Sampling parameters for generating text." - ) - - class Config: - extra = Extra.forbid - schema_extra = {"description": "A request to LLM."} diff --git a/aana/tests/__init__.py b/aana/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/aana/tests/deployments/test_vllm_deployment.py b/aana/tests/deployments/test_vllm_deployment.py index 009a31e2..231de644 100644 --- a/aana/tests/deployments/test_vllm_deployment.py +++ b/aana/tests/deployments/test_vllm_deployment.py @@ -6,6 +6,7 @@ from aana.configs.deployments import deployments from aana.models.pydantic.sampling_params import SamplingParams +from aana.tests.utils import is_gpu_available ALLOWED_LEVENSTEIN_ERROR_RATE = 0.1 @@ -16,6 +17,11 @@ def expected_output(name): " Elon Musk is a South African-born entrepreneur, inventor, and business magnate. " "He is best known for his revolutionary ideas" ) + if name == "vllm_deployment_zephyr_7b_beta": + return ( + "\n\nElon Musk is an entrepreneur, business magnate, and investor. " + "He is the founder, CEO, and Chief Designer of SpaceX" + ) else: raise ValueError(f"Unknown deployment name: {name}") @@ -49,6 +55,7 @@ def compare_texts(expected_text: str, text: str): ) +@pytest.mark.skipif(not is_gpu_available(), reason="GPU is not available") @pytest.mark.asyncio async def test_vllm_deployments(): for name, deployment in deployments.items(): diff --git a/aana/tests/test_api_generation.py b/aana/tests/test_api_generation.py new file mode 100644 index 00000000..ccfb2b03 --- /dev/null +++ b/aana/tests/test_api_generation.py @@ -0,0 +1,184 @@ +from typing import Any, Optional +from unittest.mock import Mock +from mobius_pipeline.node.socket import Socket +from mobius_pipeline.pipeline.pipeline import Pipeline +import pytest + +from pydantic import BaseModel, Field, Extra + +from aana.api.api_generation import Endpoint +from aana.exceptions.general import MultipleFileUploadNotAllowed + + +class InputModel(BaseModel): + input: str = Field(..., description="Input text") + + class Config: + extra = Extra.forbid + + +class FileUploadModel(BaseModel): + content: Optional[bytes] = Field( + None, + description="The content in bytes. Set this field to 'file' to upload files to the endpoint.", + ) + + def set_files(self, files): + if files: + if isinstance(files, list): + files = files[0] + self.content = files + + class Config: + extra = Extra.forbid + file_upload = True + file_upload_description = "Upload image files." + + +class OutputModel(BaseModel): + output: str = Field(..., description="Output text") + + class Config: + extra = Extra.forbid + + +def test_get_request_model(): + """Test the get_request_model function.""" + + endpoint = Endpoint( + name="test_endpoint", + summary="Test endpoint", + path="/test_endpoint", + outputs=["output"], + ) + + input_sockets = [ + Socket(name="input", path="input", key="input", data_model=InputModel), + Socket( + name="input_without_datamodel", + path="input_without_datamodel", + key="input_without_datamodel", + ), + ] + + RequestModel = endpoint.get_request_model(input_sockets) + + # Check that the request model named correctly + assert RequestModel.__name__ == "TestEndpointRequest" + + # Check that the request model has the correct fields + assert RequestModel.__fields__.keys() == {"input", "input_without_datamodel"} + + # Check that the request fields have the correct types + assert RequestModel.__fields__["input"].type_ == InputModel + assert RequestModel.__fields__["input_without_datamodel"].type_ == Any + + +def test_get_response_model(): + """Test the get_response_model function.""" + + endpoint = Endpoint( + name="test_endpoint", + summary="Test endpoint", + path="/test_endpoint", + outputs=["output", "output_without_datamodel"], + ) + + output_sockets = [ + Socket(name="output", path="output", key="output", data_model=OutputModel), + Socket( + name="output_without_datamodel", + path="output_without_datamodel", + key="output_without_datamodel", + ), + ] + + ResponseModel = endpoint.get_response_model(output_sockets) + + # Check that the response model named correctly + assert ResponseModel.__name__ == "TestEndpointResponse" + + # Check that the response model has the correct fields + assert ResponseModel.__fields__.keys() == {"output", "output_without_datamodel"} + + # Check that the response fields have the correct types + assert ResponseModel.__fields__["output"].type_ == OutputModel + assert ResponseModel.__fields__["output_without_datamodel"].type_ == Any + + endpoint_with_one_output = Endpoint( + name="test_endpoint", + summary="Test endpoint", + path="/test_endpoint", + outputs=["output"], + ) + + output_sockets = [ + Socket(name="output", path="output", key="output", data_model=OutputModel), + ] + + ResponseModel = endpoint_with_one_output.get_response_model(output_sockets) + + # Check that the response model named correctly + assert ResponseModel.__name__ == "TestEndpointResponse" + + # Check that the response model has the correct fields + assert ResponseModel.__fields__.keys() == {"output"} + + # Check that the response fields have the correct types + assert ResponseModel.__fields__["output"].type_ == OutputModel + + +def test_get_file_upload_field(): + """Test the get_file_upload_field function.""" + + endpoint = Endpoint( + name="test_endpoint", + summary="Test endpoint", + path="/test_endpoint", + outputs=["output"], + ) + + input_sockets = [ + Socket( + name="input", + path="input", + key="input", + data_model=FileUploadModel, + ), + ] + + file_upload_field = endpoint.get_file_upload_field(input_sockets) + + # Check that the file upload field named correctly + assert file_upload_field.name == "input" + + # Check that the file upload field has the correct description + assert file_upload_field.description == "Upload image files." + +def test_get_file_upload_field_multiple_file_uploads(): + """Test the get_file_upload_field function with multiple file uploads.""" + + endpoint = Endpoint( + name="test_endpoint", + summary="Test endpoint", + path="/test_endpoint", + outputs=["output"], + ) + + input_sockets = [ + Socket( + name="input", + path="input", + key="input", + data_model=FileUploadModel, + ), + Socket( + name="input2", + path="input2", + key="input2", + data_model=FileUploadModel, + ), + ] + + with pytest.raises(MultipleFileUploadNotAllowed): + endpoint.get_file_upload_field(input_sockets) diff --git a/aana/tests/test_app.py b/aana/tests/test_app.py new file mode 100644 index 00000000..d3175625 --- /dev/null +++ b/aana/tests/test_app.py @@ -0,0 +1,88 @@ +import json +import random +import pytest +import ray +from ray import serve +import requests + +from aana.api.api_generation import Endpoint +from aana.api.request_handler import RequestHandler + + +@serve.deployment +class Lowercase: + """Lowercase class is a Ray Serve deployment class that takes a text + and returns the lowercase version of it. + """ + + async def lower(self, text): + return {"text": [t.lower() for t in text]} + + +nodes = [ + { + "name": "text", + "type": "input", + "inputs": [], + "outputs": [{"name": "text", "key": "text", "path": "texts.[*].text"}], + }, + { + "name": "lowercase", + "type": "ray_deployment", + "deployment_name": "Lowercase", + "method": "lower", + "inputs": [{"name": "text", "key": "text", "path": "texts.[*].text"}], + "outputs": [ + { + "name": "lowercase_text", + "key": "text", + "path": "texts.[*].lowercase_text", + } + ], + }, +] + +context = { + "deployments": { + "Lowercase": Lowercase.bind(), + } +} + + +endpoints = [ + Endpoint( + name="lowercase", + path="/lowercase", + summary="Lowercase text", + outputs=["lowercase_text"], + ) +] + + +@pytest.fixture(scope="session") +def ray_setup(): + # Setup ray environment and serve + ray.init(ignore_reinit_error=True) + server = RequestHandler.bind(endpoints, nodes, context) + # random port from 30000 to 40000 + port = random.randint(30000, 40000) + handle = serve.run(server, port=port) + return handle, port + + +def test_app(ray_setup): + handle, port = ray_setup + + # Check that the server is ready + response = requests.get(f"http://localhost:{port}/api/ready") + assert response.status_code == 200 + assert response.json() == {"ready": True} + + # Test lowercase endpoint + data = {"text": ["Hello World!", "This is a test."]} + response = requests.post( + f"http://localhost:{port}/lowercase", data={"body": json.dumps(data)} + ) + assert response.status_code == 200 + lowercase_text = response.json().get("lowercase_text") + assert lowercase_text == ["hello world!", "this is a test."] diff --git a/aana/tests/test_build.py b/aana/tests/test_build.py new file mode 100644 index 00000000..0e069f3d --- /dev/null +++ b/aana/tests/test_build.py @@ -0,0 +1,169 @@ +from mobius_pipeline.exceptions import OutputNotFoundException +import pytest + +from aana.api.api_generation import Endpoint +from aana.configs.build import get_configuration + +nodes = [ + { + "name": "text", + "type": "input", + "inputs": [], + "outputs": [{"name": "text", "key": "text", "path": "texts.[*].text"}], + }, + { + "name": "number", + "type": "input", + "inputs": [], + "outputs": [{"name": "number", "key": "number", "path": "numbers.[*].number"}], + }, + { + "name": "lowercase", + "type": "ray_deployment", + "deployment_name": "Lowercase", + "method": "lower", + "inputs": [{"name": "text", "key": "text", "path": "texts.[*].text"}], + "outputs": [ + { + "name": "lowercase_text", + "key": "text", + "path": "texts.[*].lowercase_text", + } + ], + }, + { + "name": "uppercase", + "type": "ray_deployment", + "deployment_name": "Uppercase", + "method": "upper", + "inputs": [{"name": "text", "key": "text", "path": "texts.[*].text"}], + "outputs": [ + { + "name": "uppercase_text", + "key": "text", + "path": "texts.[*].uppercase_text", + } + ], + }, + { + "name": "capitalize", + "type": "ray_deployment", + "deployment_name": "Capitalize", + "method": "capitalize", + "inputs": [{"name": "text", "key": "text", "path": "texts.[*].text"}], + "outputs": [ + { + "name": "capitalize_text", + "key": "text", + "path": "texts.[*].capitalize_text", + } + ], + }, +] + +# don't define the deployment for "Capitalize" to test if get_configuration raises an error +deployments = {"Lowercase": "Lowercase", "Uppercase": "Uppercase"} + +endpoints = { + "lowercase": [ + Endpoint( + name="lowercase", + path="/lowercase", + summary="Lowercase text", + outputs=["lowercase_text"], + ) + ], + "uppercase": [ + Endpoint( + name="uppercase", + path="/uppercase", + summary="Uppercase text", + outputs=["uppercase_text"], + ) + ], + "both": [ + Endpoint( + name="lowercase", + path="/lowercase", + summary="Lowercase text", + outputs=["lowercase_text"], + ), + Endpoint( + name="uppercase", + path="/uppercase", + summary="Uppercase text", + outputs=["uppercase_text"], + ), + ], + "non_existent": [ + Endpoint( + name="non_existent", + path="/non_existent", + summary="Non existent endpoint", + outputs=["non_existent"], + ) + ], + "capitalize": [ + Endpoint( + name="capitalize", + path="/capitalize", + summary="Capitalize text", + outputs=["capitalize_text"], + ) + ], +} + + +@pytest.mark.parametrize( + "target, expected_nodes, expected_deployments", + [ + ("lowercase", ["text", "lowercase"], {"Lowercase": "Lowercase"}), + ("uppercase", ["text", "uppercase"], {"Uppercase": "Uppercase"}), + ( + "both", + ["text", "lowercase", "uppercase"], + {"Lowercase": "Lowercase", "Uppercase": "Uppercase"}, + ), + ], +) +def test_get_configuration_success(target, expected_nodes, expected_deployments): + """ + Test if get_configuration returns the correct configuration for various targets + """ + configuration = get_configuration(target, endpoints, nodes, deployments) + assert configuration["endpoints"] == endpoints[target] + + node_names = [node["name"] for node in configuration["nodes"]] + for expected_node in expected_nodes: + assert expected_node in node_names + + assert configuration["deployments"] == expected_deployments + + +def test_get_configuration_invalid_target(): + """ + Test if get_configuration raises an error if the target is invalid + """ + + with pytest.raises(ValueError): + get_configuration("invalid_target", endpoints, nodes, deployments) + + +def test_get_configuration_non_existent_output(): + """ + Test if get_configuration raises an error + if one of the target endpoints has a non-existent output. + """ + + with pytest.raises(OutputNotFoundException): + get_configuration("non_existent", endpoints, nodes, deployments) + + +def test_get_configuration_not_defined_deployment(): + """ + Test if get_configuration raises an error + if one of the target nodes uses a deployment that is not defined. + """ + + with pytest.raises(ValueError): + get_configuration("capitalize", endpoints, nodes, deployments) diff --git a/aana/tests/utils.py b/aana/tests/utils.py new file mode 100644 index 00000000..159e620e --- /dev/null +++ b/aana/tests/utils.py @@ -0,0 +1,11 @@ +def is_gpu_available() -> bool: + """ + Check if a GPU is available. + + Returns: + bool: True if a GPU is available, False otherwise. + """ + import torch + + # TODO: find the way to check if GPU is available without importing torch + return torch.cuda.is_available() diff --git a/notebooks/demo.ipynb b/notebooks/demo.ipynb index 9f3c4878..737a2436 100644 --- a/notebooks/demo.ipynb +++ b/notebooks/demo.ipynb @@ -12,45 +12,46 @@ }, { "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "data = {\n", - " 'prompt': '[INST] Who is Elon Musk? [/INST]',\n", - " 'sampling_params' : {\n", - " 'temperature': 0.9,\n", - " }\n", - "}\n", - "\n", - "url = 'http://127.0.0.1:8000/llm/generate'" - ] - }, - { - "cell_type": "code", - "execution_count": 7, + "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "{'vllm_llama2_7b_chat_output': ' Elon Musk is a South African-born entrepreneur, inventor, and business magnate who is best known for being the CEO of SpaceX and Tesla, Inc. He is one of the most successful and influential entrepreneurs of the 21st century, known for his innovative ideas, visionary leadership, and his ability to bring those ideas to life.\\nMusk was born on June 28, 1971, in Pretoria, South Africa. He developed an interest in computing and programming at an early age and taught himself computer programming. He moved to Canada in 1992 to attend college, and later transferred to the University of Pennsylvania, where he graduated with a degree in economics and physics.\\nAfter college, Musk moved to California to pursue a career in technology and entrepreneurship. He co-founded his first company, Zip2, which provided online content publishing software for news organizations. In 1999, he co-founded X.com, which later became PayPal, an online payment system that was acquired by eBay for $1.5 billion in 2002.\\nIn 2',\n", + "{'vllm_llama2_7b_chat_output': ' Elon Musk is a South African-born entrepreneur, inventor, and business magnate. He is best known for his involvement in revolutionizing multiple industries through his companies, including transportation, energy, and space exploration. Here are some key facts about Elon Musk:\\n\\n1. Early Life and Education: Musk was born on June 28, 1971, in Pretoria, South Africa. He developed an interest in computing and programming at an early age and taught himself computer programming. He moved to Canada in 1992 to attend college, and later transferred to the University of Pennsylvania, where he graduated with a degree in economics and physics.\\n2. Entrepreneurial Career: Musk co-founded his first company, Zip2, which provided online content publishing software for news organizations. In 1999, he co-founded X.com, which later became PayPal, an online payment system that was acquired by eBay for $1.5 billion in 2002.\\n3. SpaceX: In 2002, Musk founded SpaceX, a private aerospace manufacturer and',\n", " 'execution_time': {'prompt': 0,\n", " 'sampling_params': 0,\n", " 'vllm_stream_llama2_7b_chat': 0,\n", - " 'vllm_llama2_7b_chat': 4.421537399291992}}" + " 'vllm_llama2_7b_chat': 3.8096983432769775}}" ] }, - "execution_count": 7, + "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ + "data = {\n", + " 'prompt': '[INST] Who is Elon Musk? [/INST]',\n", + " 'sampling_params' : {\n", + " 'temperature': 0.9,\n", + " }\n", + "}\n", + "\n", + "url = 'http://127.0.0.1:8000/llm/generate'\n", + "\n", + "# response = requests.post(url, data={'body': json.dumps(data)})\n", "response = requests.post(url, json=data)\n", "response.json()" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, { "cell_type": "code", "execution_count": null, diff --git a/pyproject.toml b/pyproject.toml index 446b0d59..b4d6a4a3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,3 +29,6 @@ build-backend = "poetry.core.masonry.api" [tool.pytest.ini_options] norecursedirs = "mobius-pipeline" + +[tool.poetry.scripts] +aana = "aana.main:run" diff --git a/startup.sh b/startup.sh new file mode 100644 index 00000000..04c365cb --- /dev/null +++ b/startup.sh @@ -0,0 +1,3 @@ +#!/bin/bash +# TODO: pass arguments to the docker to set target instead of environment variable +poetry run aana --port 8000 --host 0.0.0.0 --target $TARGET