diff --git a/caikit_nlp/modules/text_generation/peft_tgis_remote.py b/caikit_nlp/modules/text_generation/peft_tgis_remote.py index 67921bd3..c6067e15 100644 --- a/caikit_nlp/modules/text_generation/peft_tgis_remote.py +++ b/caikit_nlp/modules/text_generation/peft_tgis_remote.py @@ -43,7 +43,6 @@ from ...toolkit.text_generation.tgis_utils import ( GENERATE_FUNCTION_TGIS_ARGS, TGISGenerationClient, - get_route_info, ) from ...toolkit.verbalizer_utils import render_verbalizer from . import PeftPromptTuning @@ -354,15 +353,5 @@ def _register_model_connection_with_context( a context override provided. """ if self._tgis_backend: - if route_info := get_route_info(context): - log.debug( - " Registering remote model connection with context " - "override: 'hostname: %s'", - route_info, - ) - self._tgis_backend.register_model_connection( - self.base_model_name, - {"hostname": route_info}, - fill_with_defaults=True, - ) + self._tgis_backend.handle_runtime_context(self.base_model_name, context) self._model_loaded = True diff --git a/caikit_nlp/modules/text_generation/text_generation_tgis.py b/caikit_nlp/modules/text_generation/text_generation_tgis.py index b1c344ba..2ee32a7b 100644 --- a/caikit_nlp/modules/text_generation/text_generation_tgis.py +++ b/caikit_nlp/modules/text_generation/text_generation_tgis.py @@ -45,7 +45,6 @@ from ...toolkit.text_generation.tgis_utils import ( GENERATE_FUNCTION_TGIS_ARGS, TGISGenerationClient, - get_route_info, ) from .text_generation_local import TextGeneration @@ -362,13 +361,5 @@ def _register_model_connection_with_context( a context override provided. """ if self._tgis_backend: - if route_info := get_route_info(context): - log.debug( - " Registering remote model connection with context " - "override: 'hostname: %s'", - route_info, - ) - self._tgis_backend.register_model_connection( - self.model_name, {"hostname": route_info}, fill_with_defaults=True - ) + self._tgis_backend.handle_runtime_context(self.model_name, context) self._model_loaded = True diff --git a/caikit_nlp/toolkit/text_generation/tgis_utils.py b/caikit_nlp/toolkit/text_generation/tgis_utils.py index 627bea1e..392455c0 100644 --- a/caikit_nlp/toolkit/text_generation/tgis_utils.py +++ b/caikit_nlp/toolkit/text_generation/tgis_utils.py @@ -14,10 +14,9 @@ """This file is for helper functions related to TGIS.""" # Standard -from typing import Iterable, Optional +from typing import Iterable # Third Party -import fastapi import grpc # First Party @@ -34,7 +33,7 @@ TokenizationResults, TokenStreamDetails, ) -from caikit.interfaces.runtime.data_model import RuntimeServerContextType +from caikit_tgis_backend import TGISBackend from caikit_tgis_backend.protobufs import generation_pb2 import alog @@ -87,7 +86,9 @@ } # HTTP Header / gRPC Metadata key used to identify a route override -ROUTE_INFO_HEADER_KEY = "x-route-info" +# (forwarded for API compatibility) +ROUTE_INFO_HEADER_KEY = TGISBackend.ROUTE_INFO_HEADER_KEY +get_route_info = TGISBackend.get_route_info def raise_caikit_core_exception(rpc_error: grpc.RpcError): @@ -688,31 +689,3 @@ def unary_tokenize( return TokenizationResults( token_count=response.token_count, ) - - -def get_route_info( - context: Optional[RuntimeServerContextType], -) -> Optional[str]: - """ - Returns a tuple `(True, x-route-info)` from context if "x-route-info" was found in - the headers/metadata. - - Otherwise returns a tuple `(False, None)` if "x-route-info" was not found in the - context or if context is None. - """ - if context is None: - return None - - if isinstance(context, grpc.ServicerContext): - route_info = dict(context.invocation_metadata()).get(ROUTE_INFO_HEADER_KEY) - if route_info: - return route_info - elif isinstance(context, fastapi.Request): - route_info = context.headers.get(ROUTE_INFO_HEADER_KEY) - if route_info: - return route_info - else: - error.log_raise( - "", - ValueError(f"context is of an unsupported type: {type(context)}"), - ) diff --git a/pyproject.toml b/pyproject.toml index 4d0d7e44..14be273f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,8 +14,8 @@ classifiers=[ "License :: OSI Approved :: Apache Software License" ] dependencies = [ - "caikit[runtime-grpc,runtime-http]>=0.26.27,<0.27.0", - "caikit-tgis-backend>=0.1.33,<0.2.0", + "caikit[runtime-grpc,runtime-http]>=0.26.34,<0.27.0", + "caikit-tgis-backend>=0.1.34,<0.2.0", # TODO: loosen dependencies "grpcio>=1.62.2", # explicitly pin grpc dependencies to a recent version to avoid pip backtracking "grpcio-reflection>=1.62.2", diff --git a/tests/toolkit/text_generation/test_tgis_utils.py b/tests/toolkit/text_generation/test_tgis_utils.py index 2e97a5e8..696a9acf 100644 --- a/tests/toolkit/text_generation/test_tgis_utils.py +++ b/tests/toolkit/text_generation/test_tgis_utils.py @@ -133,6 +133,9 @@ def test_TGISGenerationClient_rpc_errors(status_code, method): assert isinstance(rpc_err, grpc.RpcError) +# NOTE: This test is preserved in caikit-nlp despite being duplicated in +# caikit-tgis-backend so that we guarantee that the functionality is accessible +# in a version-compatible way here. @pytest.mark.parametrize( argnames=["context", "route_info"], argvalues=[ @@ -168,7 +171,7 @@ def test_TGISGenerationClient_rpc_errors(status_code, method): ) def test_get_route_info(context: RuntimeServerContextType, route_info: Optional[str]): if not isinstance(context, (fastapi.Request, grpc.ServicerContext, type(None))): - with pytest.raises(ValueError): + with pytest.raises(TypeError): tgis_utils.get_route_info(context) else: actual_route_info = tgis_utils.get_route_info(context)