Skip to content

Commit

Permalink
Applied requests from PR
Browse files Browse the repository at this point in the history
  • Loading branch information
movchan74 committed Nov 6, 2023
1 parent 8c351b1 commit 1bb6df4
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 25 deletions.
21 changes: 9 additions & 12 deletions aana/api/api_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ class Endpoint:

def generate_model_name(self, suffix: str) -> str:
"""
Generate a Pydantic model name based oon a given suffix.
Generate a Pydantic model name based on a given suffix.
Parameters:
suffix (str): Suffix for the model name (e.g. "Request", "Response").
Expand Down Expand Up @@ -237,7 +237,6 @@ def create_endpoint_func(
self,
pipeline: Pipeline,
RequestModel: Type[BaseModel],
ResponseModel: Type[BaseModel],
file_upload_field: Optional[FileUploadField] = None,
):
async def route_func_body(body: str, files: Optional[List[UploadFile]] = None):
Expand Down Expand Up @@ -293,13 +292,16 @@ async def route_func(body: str = Form(...), files=files):

return route_func

def register(self, app: FastAPI, pipeline: Pipeline):
def register(
self, app: FastAPI, pipeline: Pipeline, custom_schemas: Dict[str, Dict]
):
"""
Register an endpoint to the FastAPI app.
Register an endpoint to the FastAPI app and add schemas to the custom schemas dictionary.
Parameters:
app (FastAPI): FastAPI app to register the endpoint to.
pipeline (Pipeline): Pipeline to register the endpoint to.
custom_schemas (Dict[str, Dict]): Dictionary of custom schemas.
"""
input_sockets, output_sockets = pipeline.get_sockets(self.outputs)
RequestModel = self.get_request_model(input_sockets)
Expand All @@ -308,7 +310,6 @@ def register(self, app: FastAPI, pipeline: Pipeline):
route_func = self.create_endpoint_func(
pipeline=pipeline,
RequestModel=RequestModel,
ResponseModel=ResponseModel,
file_upload_field=file_upload_field,
)
app.post(
Expand All @@ -321,11 +322,7 @@ def register(self, app: FastAPI, pipeline: Pipeline):
400: {"model": ExceptionResponseModel},
},
)(route_func)

def get_request_schema(self, pipeline: Pipeline):
input_sockets, _ = pipeline.get_sockets(self.outputs)
RequestModel = self.get_request_model(input_sockets)
return RequestModel.schema()
custom_schemas[self.name] = RequestModel.schema()


def add_custom_schemas_to_openapi_schema(
Expand Down Expand Up @@ -354,11 +351,11 @@ def add_custom_schemas_to_openapi_schema(
dict: The openapi schema with the custom schemas added.
"""

if "definitions" not in openapi_schema:
openapi_schema["definitions"] = {}
for schema_name, schema in custom_schemas.items():
# if we have a definitions then we need to move them out to the top level of the schema
if "definitions" in schema:
if "definitions" not in openapi_schema:
openapi_schema["definitions"] = {}
openapi_schema["definitions"].update(schema["definitions"])
del schema["definitions"]
openapi_schema["components"]["schemas"][f"Body_{schema_name}"]["properties"][
Expand Down
12 changes: 6 additions & 6 deletions aana/api/request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,19 @@ def __init__(
"""
Args:
deployments (Dict): The dictionary of deployments.
It is passed to the context to the pipeline so the pipeline can access the deployments handles.
It is passed to the context to the pipeline
so the pipeline can access the deployments handles.
"""

self.context = context
self.endpoints = endpoints
self.pipeline = Pipeline(pipeline_nodes, context)

self.custom_schemas = {}
self.custom_schemas: Dict[str, Dict] = {}
for endpoint in self.endpoints:
endpoint.register(app=app, pipeline=self.pipeline)
# get schema for endpoint to add to openapi schema
schema = endpoint.get_request_schema(self.pipeline)
self.custom_schemas[endpoint.name] = schema
endpoint.register(
app=app, pipeline=self.pipeline, custom_schemas=self.custom_schemas
)

app.openapi = self.custom_openapi
self.ready = True
Expand Down
9 changes: 4 additions & 5 deletions aana/configs/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,9 @@

from mobius_pipeline.node.node_definition import NodeDefinition
from mobius_pipeline.pipeline.output_graph import OutputGraph
from aana.configs.endpoints import endpoints
from aana.configs.pipeline import nodes
from aana.configs.deployments import deployments


def get_configuration(target: str, endpoints=endpoints, nodes=nodes, deployments=deployments) -> Dict:
def get_configuration(target: str, endpoints, nodes, deployments) -> Dict:
"""
Returns the configuration for the specified target.
Expand Down Expand Up @@ -36,7 +33,9 @@ def get_configuration(target: str, endpoints=endpoints, nodes=nodes, deployments

# Check if target is valid
if target not in endpoints:
raise ValueError(f"Invalid target: {target}. Valid targets: {', '.join(endpoints.keys())}")
raise ValueError(
f"Invalid target: {target}. Valid targets: {', '.join(endpoints.keys())}"
)

# Find the endpoints that are to be deployed
target_endpoints = endpoints[target]
Expand Down
13 changes: 11 additions & 2 deletions aana/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,16 @@ def run():
from ray import serve
from aana.api.request_handler import RequestHandler
from aana.configs.build import get_configuration
from aana.configs.endpoints import endpoints as all_endpoints
from aana.configs.pipeline import nodes as all_nodes
from aana.configs.deployments import deployments as all_deployments

configuration = get_configuration(args.target)
configuration = get_configuration(
args.target,
endpoints=all_endpoints,
nodes=all_nodes,
deployments=all_deployments,
)
endpoints = configuration["endpoints"]
pipeline_nodes = configuration["nodes"]
deployments = configuration["deployments"]
Expand All @@ -36,7 +44,8 @@ def run():
}
try:
server = RequestHandler.bind(endpoints, pipeline_nodes, context)
handle = serve.run(server, port=args.port, host=args.host)
serve.run(server, port=args.port, host=args.host)
# TODO: add logging
print("Deployed Serve app successfully.")
while True:
time.sleep(10)
Expand Down

0 comments on commit 1bb6df4

Please sign in to comment.