Skip to content

Commit

Permalink
[AIC-py][editor] add default model parsers file (#567)
Browse files Browse the repository at this point in the history
[AIC-py][editor] add default model parsers file

---
Stack created with [Sapling](https://sapling-scm.com). Best reviewed
with
[ReviewStack](https://reviewstack.dev/lastmile-ai/aiconfig/pull/567).
* #573
* __->__ #567
  • Loading branch information
jonathanlastmileai authored Dec 21, 2023
2 parents 1460cfd + 992999c commit 17cab10
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 42 deletions.
21 changes: 21 additions & 0 deletions python/src/aiconfig/editor/example_aiconfig_model_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
"""
Use this file to register model parsers that don't ship with aiconfig.
- Make sure your package is installed in the same environment as aiconfig.
- You must define a function `register_model_parsers() -> None` in this file.
- You should call `AIConfigRuntime.register_model_parser` in that function.
See example below.
"""


# from aiconfig import AIConfigRuntime
# from llama import LlamaModelParser


def register_model_parsers() -> None:
# Example:
# model_path = "/path/to/my/local/llama/model"
# llama_model_parser = LlamaModelParser(model_path)
# AIConfigRuntime.register_model_parser(llama_model_parser, "llama-2-7b-chat")
# You can remove this `pass` once your function is implemented (see above).
pass
104 changes: 62 additions & 42 deletions python/src/aiconfig/editor/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,16 @@ class EditServerConfig(core_utils.Record):
aiconfig_path: Optional[str] = None
log_level: str | int = "INFO"
server_mode: str
parsers_module_path: Optional[str] = None
parsers_module_path: str = "aiconfig_model_registry.py"


@dataclass
class ServerState:
aiconfig_runtime: AIConfigRuntime | None = None
aiconfig: AIConfigRuntime | None = None


@dataclass(frozen=True)
class HttpPOSTResponse:
class HttpPostResponse:
message: str
output: str | None = None
code: int = 200
Expand Down Expand Up @@ -103,25 +103,30 @@ def _register_user_model_parsers(user_register_fn: Callable[[], None]) -> Result
return core_utils.ErrWithTraceback(e)


def _load_user_module_from_path_and_register_model_parsers(path_to_module: str) -> HttpPOSTResponse:
def _load_user_parser_module(path_to_module: str) -> Result[None, str]:
LOGGER.info(f"Importing parsers module from {path_to_module}")
res_user_module = _import_module_from_path(path_to_module)
register_result = (
res_user_module.and_then(_load_register_fn_from_user_module) #
#
.and_then(_register_user_model_parsers)
)
return register_result


def _get_http_response_load_user_parser_module(path_to_module: str) -> HttpPostResponse:
register_result = _load_user_parser_module(path_to_module)
match register_result:
case Ok(_):
msg = f"Successfully registered model parsers from {path_to_module}"
LOGGER.info(msg)
return HttpPOSTResponse(
return HttpPostResponse(
message=msg,
)
case Err(e):
msg = f"Failed to register model parsers from {path_to_module}: {e}"
LOGGER.error(msg)
return HttpPOSTResponse(
return HttpPostResponse(
message=msg,
code=400,
)
Expand All @@ -137,74 +142,84 @@ def home():

@app.route("/api/load_model_parser_module", methods=["POST"])
def load_model_parser_module():
def _run_with_path(path: str) -> HttpPOSTResponse:
return _load_user_module_from_path_and_register_model_parsers(path)
def _run_with_path(path: str) -> HttpPostResponse:
return _get_http_response_load_user_parser_module(path)

return _http_response_with_path(_run_with_path).to_flask_format()
result = _http_response_with_path(_run_with_path)
return result.to_flask_format()


def _http_response_with_path(path_fn: Callable[[str], HttpPOSTResponse]) -> HttpPOSTResponse:
def _get_validated_request_path(raw_path: str) -> Result[str, str]:
if not raw_path:
return Err("No path provided")
resolved = _resolve_path(raw_path)
if not os.path.isfile(resolved):
return Err(f"File does not exist: {resolved}")
return Ok(resolved)


def _http_response_with_path(path_fn: Callable[[str], HttpPostResponse]) -> HttpPostResponse:
request_json = request.get_json()
path = request_json["path"]
if not path:
return HttpPOSTResponse(message="No path provided", code=400)
path = request_json.get("path", None)

resolved = _resolve_path(path)
if not os.path.isfile(resolved):
return HttpPOSTResponse(message=f"File does not exist: {path}", code=400)
return path_fn(resolved)
validated_path = _get_validated_request_path(path)
match validated_path:
case Ok(path):
return path_fn(path)
case Err(e):
return HttpPostResponse(message=e, code=400)


@app.route("/api/load", methods=["POST"])
def load():
def _run_with_path(path: str) -> HttpPOSTResponse:
def _run_with_path(path: str) -> HttpPostResponse:
LOGGER.info(f"Loading AIConfig from {path}")
ss = _get_server_state(app)
state = _get_server_state(app)
try:
ss.aiconfig_runtime = AIConfigRuntime.load(path) # type: ignore
return HttpPOSTResponse(message="Done")
state.aiconfig = AIConfigRuntime.load(path) # type: ignore
return HttpPostResponse(message="Done")
except Exception as e:
return HttpPOSTResponse(message=f"<p>Failed to load AIConfig from {path}: {e}", code=400)
return HttpPostResponse(message=f"<p>Failed to load AIConfig from {path}: {e}", code=400)

return _http_response_with_path(_run_with_path).to_flask_format()


@app.route("/api/save", methods=["POST"])
def save():
def _run_with_path(path: str) -> HttpPOSTResponse:
def _run_with_path(path: str) -> HttpPostResponse:
LOGGER.info(f"Saving AIConfig to {path}")
ss = _get_server_state(app)
state = _get_server_state(app)
try:
ss.aiconfig_runtime.save(path) # type: ignore
return HttpPOSTResponse(message="Done")
state.aiconfig.save(path) # type: ignore
return HttpPostResponse(message="Done")
except Exception as e:
err: Err[str] = core_utils.ErrWithTraceback(e)
LOGGER.error(f"Failed to save AIConfig to {path}: {err}")
return HttpPOSTResponse(message=f"<p>Failed to save AIConfig to {path}: {err}", code=400)
return HttpPostResponse(message=f"<p>Failed to save AIConfig to {path}: {err}", code=400)

return _http_response_with_path(_run_with_path).to_flask_format()


@app.route("/api/create", methods=["POST"])
def create():
ss = _get_server_state(app)
ss.aiconfig_runtime = AIConfigRuntime.create() # type: ignore
state = _get_server_state(app)
state.aiconfig = AIConfigRuntime.create() # type: ignore
return {"message": "Done"}, 200


@app.route("/api/run", methods=["POST"])
async def run():
ss = _get_server_state(app)
state = _get_server_state(app)
request_json = request.get_json()
prompt_name = request_json.get("prompt_name", None)
stream = request_json.get("stream", True)
LOGGER.info(f"Running prompt: {prompt_name}, {stream=}")
inference_options = InferenceOptions(stream=stream)
try:
result = await ss.aiconfig_runtime.run(prompt_name, options=inference_options) # type: ignore
result = await state.aiconfig.run(prompt_name, options=inference_options) # type: ignore
LOGGER.debug(f"Result: {result=}")
result_text = str(
ss.aiconfig_runtime.get_output_text(prompt_name) # type: ignore
state.aiconfig.get_output_text(prompt_name) # type: ignore
#
if isinstance(result, list)
#
Expand All @@ -219,11 +234,11 @@ async def run():

@app.route("/api/add_prompt", methods=["POST"])
def add_prompt():
ss = _get_server_state(app)
state = _get_server_state(app)
request_json = request.get_json()
try:
LOGGER.info(f"Adding prompt: {request_json}")
ss.aiconfig_runtime.add_prompt(**request_json) # type: ignore
state.aiconfig.add_prompt(**request_json) # type: ignore
return {"message": "Done"}, 200
except Exception as e:
err: Err[str] = core_utils.ErrWithTraceback(e)
Expand All @@ -248,21 +263,26 @@ def run_backend_server(edit_config: EditServerConfig) -> Result[int, str]:
return Ok(0)


def _init_server_state(app: Flask, edit_config: EditServerConfig) -> None:
if edit_config.parsers_module_path is not None:
_load_user_module_from_path_and_register_model_parsers(edit_config.parsers_module_path)
def _load_user_parser_module_if_exists(parsers_module_path: str) -> None:
_get_validated_request_path(parsers_module_path).and_then(_load_user_parser_module).map_or_else(
lambda e: LOGGER.warning(f"Failed to load parsers module: {e}"), # type: ignore
lambda _: LOGGER.info(f"Loaded parsers module from {edit_config.parsers_module_path}"), # type: ignore
)

LOGGER.info("Initializing server state")

def _init_server_state(app: Flask, edit_config: EditServerConfig) -> None:
assert edit_config.server_mode in {"debug_servers", "debug_backend", "prod"}
ss = _get_server_state(app)
LOGGER.info("Initializing server state")
_load_user_parser_module_if_exists(edit_config.parsers_module_path)
state = _get_server_state(app)

assert ss.aiconfig_runtime is None
assert state.aiconfig is None
if edit_config.aiconfig_path:
LOGGER.info(f"Loading AIConfig from {edit_config.aiconfig_path}")
aiconfig_runtime = AIConfigRuntime.load(edit_config.aiconfig_path) # type: ignore
ss.aiconfig_runtime = aiconfig_runtime
state.aiconfig = aiconfig_runtime
LOGGER.info(f"Loaded AIConfig from {edit_config.aiconfig_path}")
else:
aiconfig_runtime = AIConfigRuntime.create() # type: ignore
ss.aiconfig_runtime = aiconfig_runtime
state.aiconfig = aiconfig_runtime
LOGGER.info("Created new AIConfig")

0 comments on commit 17cab10

Please sign in to comment.