diff --git a/python/src/aiconfig/editor/example_aiconfig_model_registry.py b/python/src/aiconfig/editor/example_aiconfig_model_registry.py new file mode 100644 index 000000000..6b9b4a01f --- /dev/null +++ b/python/src/aiconfig/editor/example_aiconfig_model_registry.py @@ -0,0 +1,21 @@ +""" + Use this file to register model parsers that don't ship with aiconfig. + - Make sure your package is installed in the same environment as aiconfig. + - You must define a function `register_model_parsers() -> None` in this file. + - You should call `AIConfigRuntime.register_model_parser` in that function. + + See example below. +""" + + +# from aiconfig import AIConfigRuntime +# from llama import LlamaModelParser + + +def register_model_parsers() -> None: + # Example: + # model_path = "/path/to/my/local/llama/model" + # llama_model_parser = LlamaModelParser(model_path) + # AIConfigRuntime.register_model_parser(llama_model_parser, "llama-2-7b-chat") + # You can remove this `pass` once your function is implemented (see above). + pass diff --git a/python/src/aiconfig/editor/server/server.py b/python/src/aiconfig/editor/server/server.py index 87e84dbb2..d62a9b7da 100644 --- a/python/src/aiconfig/editor/server/server.py +++ b/python/src/aiconfig/editor/server/server.py @@ -34,16 +34,16 @@ class EditServerConfig(core_utils.Record): aiconfig_path: Optional[str] = None log_level: str | int = "INFO" server_mode: str - parsers_module_path: Optional[str] = None + parsers_module_path: str = "aiconfig_model_registry.py" @dataclass class ServerState: - aiconfig_runtime: AIConfigRuntime | None = None + aiconfig: AIConfigRuntime | None = None @dataclass(frozen=True) -class HttpPOSTResponse: +class HttpPostResponse: message: str output: str | None = None code: int = 200 @@ -103,7 +103,7 @@ def _register_user_model_parsers(user_register_fn: Callable[[], None]) -> Result return core_utils.ErrWithTraceback(e) -def _load_user_module_from_path_and_register_model_parsers(path_to_module: str) -> HttpPOSTResponse: +def _load_user_parser_module(path_to_module: str) -> Result[None, str]: LOGGER.info(f"Importing parsers module from {path_to_module}") res_user_module = _import_module_from_path(path_to_module) register_result = ( @@ -111,17 +111,22 @@ def _load_user_module_from_path_and_register_model_parsers(path_to_module: str) # .and_then(_register_user_model_parsers) ) + return register_result + + +def _get_http_response_load_user_parser_module(path_to_module: str) -> HttpPostResponse: + register_result = _load_user_parser_module(path_to_module) match register_result: case Ok(_): msg = f"Successfully registered model parsers from {path_to_module}" LOGGER.info(msg) - return HttpPOSTResponse( + return HttpPostResponse( message=msg, ) case Err(e): msg = f"Failed to register model parsers from {path_to_module}: {e}" LOGGER.error(msg) - return HttpPOSTResponse( + return HttpPostResponse( message=msg, code=400, ) @@ -137,74 +142,84 @@ def home(): @app.route("/api/load_model_parser_module", methods=["POST"]) def load_model_parser_module(): - def _run_with_path(path: str) -> HttpPOSTResponse: - return _load_user_module_from_path_and_register_model_parsers(path) + def _run_with_path(path: str) -> HttpPostResponse: + return _get_http_response_load_user_parser_module(path) - return _http_response_with_path(_run_with_path).to_flask_format() + result = _http_response_with_path(_run_with_path) + return result.to_flask_format() -def _http_response_with_path(path_fn: Callable[[str], HttpPOSTResponse]) -> HttpPOSTResponse: +def _get_validated_request_path(raw_path: str) -> Result[str, str]: + if not raw_path: + return Err("No path provided") + resolved = _resolve_path(raw_path) + if not os.path.isfile(resolved): + return Err(f"File does not exist: {resolved}") + return Ok(resolved) + + +def _http_response_with_path(path_fn: Callable[[str], HttpPostResponse]) -> HttpPostResponse: request_json = request.get_json() - path = request_json["path"] - if not path: - return HttpPOSTResponse(message="No path provided", code=400) + path = request_json.get("path", None) - resolved = _resolve_path(path) - if not os.path.isfile(resolved): - return HttpPOSTResponse(message=f"File does not exist: {path}", code=400) - return path_fn(resolved) + validated_path = _get_validated_request_path(path) + match validated_path: + case Ok(path): + return path_fn(path) + case Err(e): + return HttpPostResponse(message=e, code=400) @app.route("/api/load", methods=["POST"]) def load(): - def _run_with_path(path: str) -> HttpPOSTResponse: + def _run_with_path(path: str) -> HttpPostResponse: LOGGER.info(f"Loading AIConfig from {path}") - ss = _get_server_state(app) + state = _get_server_state(app) try: - ss.aiconfig_runtime = AIConfigRuntime.load(path) # type: ignore - return HttpPOSTResponse(message="Done") + state.aiconfig = AIConfigRuntime.load(path) # type: ignore + return HttpPostResponse(message="Done") except Exception as e: - return HttpPOSTResponse(message=f"

Failed to load AIConfig from {path}: {e}", code=400) + return HttpPostResponse(message=f"

Failed to load AIConfig from {path}: {e}", code=400) return _http_response_with_path(_run_with_path).to_flask_format() @app.route("/api/save", methods=["POST"]) def save(): - def _run_with_path(path: str) -> HttpPOSTResponse: + def _run_with_path(path: str) -> HttpPostResponse: LOGGER.info(f"Saving AIConfig to {path}") - ss = _get_server_state(app) + state = _get_server_state(app) try: - ss.aiconfig_runtime.save(path) # type: ignore - return HttpPOSTResponse(message="Done") + state.aiconfig.save(path) # type: ignore + return HttpPostResponse(message="Done") except Exception as e: err: Err[str] = core_utils.ErrWithTraceback(e) LOGGER.error(f"Failed to save AIConfig to {path}: {err}") - return HttpPOSTResponse(message=f"

Failed to save AIConfig to {path}: {err}", code=400) + return HttpPostResponse(message=f"

Failed to save AIConfig to {path}: {err}", code=400) return _http_response_with_path(_run_with_path).to_flask_format() @app.route("/api/create", methods=["POST"]) def create(): - ss = _get_server_state(app) - ss.aiconfig_runtime = AIConfigRuntime.create() # type: ignore + state = _get_server_state(app) + state.aiconfig = AIConfigRuntime.create() # type: ignore return {"message": "Done"}, 200 @app.route("/api/run", methods=["POST"]) async def run(): - ss = _get_server_state(app) + state = _get_server_state(app) request_json = request.get_json() prompt_name = request_json.get("prompt_name", None) stream = request_json.get("stream", True) LOGGER.info(f"Running prompt: {prompt_name}, {stream=}") inference_options = InferenceOptions(stream=stream) try: - result = await ss.aiconfig_runtime.run(prompt_name, options=inference_options) # type: ignore + result = await state.aiconfig.run(prompt_name, options=inference_options) # type: ignore LOGGER.debug(f"Result: {result=}") result_text = str( - ss.aiconfig_runtime.get_output_text(prompt_name) # type: ignore + state.aiconfig.get_output_text(prompt_name) # type: ignore # if isinstance(result, list) # @@ -219,11 +234,11 @@ async def run(): @app.route("/api/add_prompt", methods=["POST"]) def add_prompt(): - ss = _get_server_state(app) + state = _get_server_state(app) request_json = request.get_json() try: LOGGER.info(f"Adding prompt: {request_json}") - ss.aiconfig_runtime.add_prompt(**request_json) # type: ignore + state.aiconfig.add_prompt(**request_json) # type: ignore return {"message": "Done"}, 200 except Exception as e: err: Err[str] = core_utils.ErrWithTraceback(e) @@ -248,21 +263,26 @@ def run_backend_server(edit_config: EditServerConfig) -> Result[int, str]: return Ok(0) -def _init_server_state(app: Flask, edit_config: EditServerConfig) -> None: - if edit_config.parsers_module_path is not None: - _load_user_module_from_path_and_register_model_parsers(edit_config.parsers_module_path) +def _load_user_parser_module_if_exists(parsers_module_path: str) -> None: + _get_validated_request_path(parsers_module_path).and_then(_load_user_parser_module).map_or_else( + lambda e: LOGGER.warning(f"Failed to load parsers module: {e}"), # type: ignore + lambda _: LOGGER.info(f"Loaded parsers module from {edit_config.parsers_module_path}"), # type: ignore + ) - LOGGER.info("Initializing server state") + +def _init_server_state(app: Flask, edit_config: EditServerConfig) -> None: assert edit_config.server_mode in {"debug_servers", "debug_backend", "prod"} - ss = _get_server_state(app) + LOGGER.info("Initializing server state") + _load_user_parser_module_if_exists(edit_config.parsers_module_path) + state = _get_server_state(app) - assert ss.aiconfig_runtime is None + assert state.aiconfig is None if edit_config.aiconfig_path: LOGGER.info(f"Loading AIConfig from {edit_config.aiconfig_path}") aiconfig_runtime = AIConfigRuntime.load(edit_config.aiconfig_path) # type: ignore - ss.aiconfig_runtime = aiconfig_runtime + state.aiconfig = aiconfig_runtime LOGGER.info(f"Loaded AIConfig from {edit_config.aiconfig_path}") else: aiconfig_runtime = AIConfigRuntime.create() # type: ignore - ss.aiconfig_runtime = aiconfig_runtime + state.aiconfig = aiconfig_runtime LOGGER.info("Created new AIConfig")