From cbc6b33d88bd5b7484230fe367885b8c50a1f92b Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Tue, 18 Jun 2024 17:02:00 -0600 Subject: [PATCH 1/4] RouteInfoFromBackend: Forward get_route_info and ROUTE_INFO_HEADER_KEY from backend Signed-off-by: Gabe Goodhart --- .../text_generation/peft_tgis_remote.py | 13 +------- .../text_generation/text_generation_tgis.py | 11 +------ .../toolkit/text_generation/tgis_utils.py | 33 +++---------------- .../text_generation/test_tgis_utils.py | 5 ++- 4 files changed, 10 insertions(+), 52 deletions(-) diff --git a/caikit_nlp/modules/text_generation/peft_tgis_remote.py b/caikit_nlp/modules/text_generation/peft_tgis_remote.py index 67921bd3..c6067e15 100644 --- a/caikit_nlp/modules/text_generation/peft_tgis_remote.py +++ b/caikit_nlp/modules/text_generation/peft_tgis_remote.py @@ -43,7 +43,6 @@ from ...toolkit.text_generation.tgis_utils import ( GENERATE_FUNCTION_TGIS_ARGS, TGISGenerationClient, - get_route_info, ) from ...toolkit.verbalizer_utils import render_verbalizer from . import PeftPromptTuning @@ -354,15 +353,5 @@ def _register_model_connection_with_context( a context override provided. """ if self._tgis_backend: - if route_info := get_route_info(context): - log.debug( - " Registering remote model connection with context " - "override: 'hostname: %s'", - route_info, - ) - self._tgis_backend.register_model_connection( - self.base_model_name, - {"hostname": route_info}, - fill_with_defaults=True, - ) + self._tgis_backend.handle_runtime_context(self.base_model_name, context) self._model_loaded = True diff --git a/caikit_nlp/modules/text_generation/text_generation_tgis.py b/caikit_nlp/modules/text_generation/text_generation_tgis.py index b1c344ba..2ee32a7b 100644 --- a/caikit_nlp/modules/text_generation/text_generation_tgis.py +++ b/caikit_nlp/modules/text_generation/text_generation_tgis.py @@ -45,7 +45,6 @@ from ...toolkit.text_generation.tgis_utils import ( GENERATE_FUNCTION_TGIS_ARGS, TGISGenerationClient, - get_route_info, ) from .text_generation_local import TextGeneration @@ -362,13 +361,5 @@ def _register_model_connection_with_context( a context override provided. """ if self._tgis_backend: - if route_info := get_route_info(context): - log.debug( - " Registering remote model connection with context " - "override: 'hostname: %s'", - route_info, - ) - self._tgis_backend.register_model_connection( - self.model_name, {"hostname": route_info}, fill_with_defaults=True - ) + self._tgis_backend.handle_runtime_context(self.model_name, context) self._model_loaded = True diff --git a/caikit_nlp/toolkit/text_generation/tgis_utils.py b/caikit_nlp/toolkit/text_generation/tgis_utils.py index 627bea1e..a8e55c2b 100644 --- a/caikit_nlp/toolkit/text_generation/tgis_utils.py +++ b/caikit_nlp/toolkit/text_generation/tgis_utils.py @@ -35,6 +35,7 @@ TokenStreamDetails, ) from caikit.interfaces.runtime.data_model import RuntimeServerContextType +from caikit_tgis_backend import TGISBackend from caikit_tgis_backend.protobufs import generation_pb2 import alog @@ -87,7 +88,9 @@ } # HTTP Header / gRPC Metadata key used to identify a route override -ROUTE_INFO_HEADER_KEY = "x-route-info" +# (forwarded for API compatibility) +ROUTE_INFO_HEADER_KEY = TGISBackend.ROUTE_INFO_HEADER_KEY +get_route_info = TGISBackend.get_route_info def raise_caikit_core_exception(rpc_error: grpc.RpcError): @@ -688,31 +691,3 @@ def unary_tokenize( return TokenizationResults( token_count=response.token_count, ) - - -def get_route_info( - context: Optional[RuntimeServerContextType], -) -> Optional[str]: - """ - Returns a tuple `(True, x-route-info)` from context if "x-route-info" was found in - the headers/metadata. - - Otherwise returns a tuple `(False, None)` if "x-route-info" was not found in the - context or if context is None. - """ - if context is None: - return None - - if isinstance(context, grpc.ServicerContext): - route_info = dict(context.invocation_metadata()).get(ROUTE_INFO_HEADER_KEY) - if route_info: - return route_info - elif isinstance(context, fastapi.Request): - route_info = context.headers.get(ROUTE_INFO_HEADER_KEY) - if route_info: - return route_info - else: - error.log_raise( - "", - ValueError(f"context is of an unsupported type: {type(context)}"), - ) diff --git a/tests/toolkit/text_generation/test_tgis_utils.py b/tests/toolkit/text_generation/test_tgis_utils.py index 2e97a5e8..696a9acf 100644 --- a/tests/toolkit/text_generation/test_tgis_utils.py +++ b/tests/toolkit/text_generation/test_tgis_utils.py @@ -133,6 +133,9 @@ def test_TGISGenerationClient_rpc_errors(status_code, method): assert isinstance(rpc_err, grpc.RpcError) +# NOTE: This test is preserved in caikit-nlp despite being duplicated in +# caikit-tgis-backend so that we guarantee that the functionality is accessible +# in a version-compatible way here. @pytest.mark.parametrize( argnames=["context", "route_info"], argvalues=[ @@ -168,7 +171,7 @@ def test_TGISGenerationClient_rpc_errors(status_code, method): ) def test_get_route_info(context: RuntimeServerContextType, route_info: Optional[str]): if not isinstance(context, (fastapi.Request, grpc.ServicerContext, type(None))): - with pytest.raises(ValueError): + with pytest.raises(TypeError): tgis_utils.get_route_info(context) else: actual_route_info = tgis_utils.get_route_info(context) From 1818a306a28d7561a3da47af53d156a25f1a8c32 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 19 Jun 2024 15:29:38 -0600 Subject: [PATCH 2/4] RouteInfoFromBackend: Bump caikit-tgis-backend Signed-off-by: Gabe Goodhart --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 4d0d7e44..872d4c3e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,7 @@ classifiers=[ ] dependencies = [ "caikit[runtime-grpc,runtime-http]>=0.26.27,<0.27.0", - "caikit-tgis-backend>=0.1.33,<0.2.0", + "caikit-tgis-backend>=0.1.34,<0.2.0", # TODO: loosen dependencies "grpcio>=1.62.2", # explicitly pin grpc dependencies to a recent version to avoid pip backtracking "grpcio-reflection>=1.62.2", From 0e23e10bfbbf3b4e3ed99b89459ba15b4f881e06 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 19 Jun 2024 15:37:05 -0600 Subject: [PATCH 3/4] RouteInfoFromBackend: Bump caikit for context registration in backend Signed-off-by: Gabe Goodhart --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 872d4c3e..14be273f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,7 @@ classifiers=[ "License :: OSI Approved :: Apache Software License" ] dependencies = [ - "caikit[runtime-grpc,runtime-http]>=0.26.27,<0.27.0", + "caikit[runtime-grpc,runtime-http]>=0.26.34,<0.27.0", "caikit-tgis-backend>=0.1.34,<0.2.0", # TODO: loosen dependencies "grpcio>=1.62.2", # explicitly pin grpc dependencies to a recent version to avoid pip backtracking From ff7f05682e56db424c048138d3d876e7ea7bbb30 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 19 Jun 2024 15:39:44 -0600 Subject: [PATCH 4/4] RouteInfoFromBackend: Remove unused imports Signed-off-by: Gabe Goodhart --- caikit_nlp/toolkit/text_generation/tgis_utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/caikit_nlp/toolkit/text_generation/tgis_utils.py b/caikit_nlp/toolkit/text_generation/tgis_utils.py index a8e55c2b..392455c0 100644 --- a/caikit_nlp/toolkit/text_generation/tgis_utils.py +++ b/caikit_nlp/toolkit/text_generation/tgis_utils.py @@ -14,10 +14,9 @@ """This file is for helper functions related to TGIS.""" # Standard -from typing import Iterable, Optional +from typing import Iterable # Third Party -import fastapi import grpc # First Party @@ -34,7 +33,6 @@ TokenizationResults, TokenStreamDetails, ) -from caikit.interfaces.runtime.data_model import RuntimeServerContextType from caikit_tgis_backend import TGISBackend from caikit_tgis_backend.protobufs import generation_pb2 import alog