Skip to content

Commit

Permalink
Merge pull request caikit#363 from mynhardtburger/lazy-model-connection
Browse files Browse the repository at this point in the history
Lazy model connection & x-route-info
  • Loading branch information
gabe-l-hart authored Jun 7, 2024
2 parents 4bf53fd + 527b455 commit 3b2f8fd
Show file tree
Hide file tree
Showing 7 changed files with 224 additions and 33 deletions.
59 changes: 51 additions & 8 deletions caikit_nlp/modules/text_generation/peft_tgis_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
"""This file contains a distributed backend implementation for leveraging the PEFT-trained
prompt vectors in TGIS generation requests.
"""

# Standard
from functools import cached_property
from typing import Iterable, List, Optional, Tuple, Union
import os

Expand All @@ -32,6 +34,7 @@
TokenizationResults,
)
from caikit.interfaces.nlp.tasks import TextGenerationTask, TokenizationTask
from caikit.interfaces.runtime.data_model import RuntimeServerContextType
from caikit_tgis_backend import TGISBackend
import alog

Expand All @@ -40,6 +43,7 @@
from ...toolkit.text_generation.tgis_utils import (
GENERATE_FUNCTION_TGIS_ARGS,
TGISGenerationClient,
get_route_info,
)
from ...toolkit.verbalizer_utils import render_verbalizer
from . import PeftPromptTuning
Expand Down Expand Up @@ -68,15 +72,14 @@ def __init__(
prompt_artifacts: Optional[List[str]] = None,
) -> None:
super().__init__()
# Configure the internal client
# NOTE: This is made optional for the cases where we do not need to execute `.run` function
# for example, bootstrapping a model to caikit format and saving.
self._client = None

self._tgis_backend = tgis_backend
if enable_backend:
error.type_check(
"<NLP33971947E>", TGISBackend, tgis_backend=self._tgis_backend
)
# get_client will also launch a local TGIS process and get the model
# loaded when using the local TGIS backend
self._client = tgis_backend.get_client(base_model_name)

# Tell the backend to load all of the available prompt files
if prompt_artifacts:
Expand Down Expand Up @@ -107,6 +110,14 @@ def __del__(self):
if tgis_backend and prompt_cache_id and model_id:
tgis_backend.unload_prompt_artifacts(model_id, prompt_cache_id)

@cached_property
def _client(self):
# Configure the internal client
# NOTE: This is made optional for the cases where we do not need to execute `.run` function
# for example, bootstrapping a model to caikit format and saving.
if self._tgis_backend:
return self._tgis_backend.get_client(self.base_model_name)

@classmethod
def load(cls, model_path: str, load_backend: BackendBase) -> "PeftPromptTuningTGIS":
"""Load a TGIS Peft Prompt Tuning distributed module. Note that we do not
Expand Down Expand Up @@ -182,7 +193,7 @@ def save(self, model_path: str):
)

# pylint: disable=duplicate-code
@TextGenerationTask.taskmethod()
@TextGenerationTask.taskmethod(context_arg="context")
def run(
self,
text: str,
Expand All @@ -206,6 +217,7 @@ def run(
generated_tokens: bool = True,
token_logprobs: bool = True,
token_ranks: bool = True,
context: Optional[RuntimeServerContextType] = None,
) -> GeneratedTextResult:
f"""Run inference against the model running in TGIS.
Expand All @@ -221,6 +233,8 @@ def run(
self.enable_backend,
"Backend must be configured and loaded with this module before executing `run` call.",
)
self._register_model_connection_with_context(context)

verbalized_text = render_verbalizer(self.verbalizer, {"input": text})
return self.tgis_generation_client.unary_generate(
text=verbalized_text,
Expand All @@ -244,7 +258,7 @@ def run(
stop_sequences=stop_sequences,
)

@TextGenerationTask.taskmethod(output_streaming=True)
@TextGenerationTask.taskmethod(output_streaming=True, context_arg="context")
def run_stream_out(
self,
text: str,
Expand All @@ -268,6 +282,7 @@ def run_stream_out(
generated_tokens: bool = True,
token_logprobs: bool = True,
token_ranks: bool = True,
context: Optional[RuntimeServerContextType] = None,
) -> Iterable[GeneratedTextStreamResult]:
f"""Run output stream inferencing against the model running in TGIS
Expand All @@ -283,6 +298,9 @@ def run_stream_out(
"Backend must be configured and loaded with this module \
before executing `run_stream_out` call.",
)

self._register_model_connection_with_context(context)

verbalized_text = render_verbalizer(self.verbalizer, {"input": text})
return self.tgis_generation_client.stream_generate(
text=verbalized_text,
Expand All @@ -306,10 +324,11 @@ def run_stream_out(
stop_sequences=stop_sequences,
)

@TokenizationTask.taskmethod()
@TokenizationTask.taskmethod(context_arg="context")
def run_tokenizer(
self,
text: str,
context: Optional[RuntimeServerContextType] = None,
) -> TokenizationResults:
"""Run tokenization task against the model running in TGIS.
Expand All @@ -320,6 +339,30 @@ def run_tokenizer(
TokenizationResults
The token count
"""

self._register_model_connection_with_context(context)

return self.tgis_generation_client.unary_tokenize(
text=text,
)

def _register_model_connection_with_context(
self, context: Optional[RuntimeServerContextType]
):
"""
Register a remote model connection with the configured TGISBackend if there is
a context override provided.
"""
if self._tgis_backend:
if route_info := get_route_info(context):
log.debug(
"<NLP10705560D> Registering remote model connection with context "
"override: 'hostname: %s'",
route_info,
)
self._tgis_backend.register_model_connection(
self.base_model_name,
{"hostname": route_info},
fill_with_defaults=True,
)
self._model_loaded = True
65 changes: 50 additions & 15 deletions caikit_nlp/modules/text_generation/text_generation_tgis.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@


# Standard
from functools import cached_property
from typing import Iterable, List, Optional, Tuple, Union
import os

Expand All @@ -30,6 +31,7 @@
TokenizationResults,
)
from caikit.interfaces.nlp.tasks import TextGenerationTask, TokenizationTask
from caikit.interfaces.runtime.data_model import RuntimeServerContextType
from caikit_tgis_backend import TGISBackend
import alog

Expand All @@ -43,6 +45,7 @@
from ...toolkit.text_generation.tgis_utils import (
GENERATE_FUNCTION_TGIS_ARGS,
TGISGenerationClient,
get_route_info,
)
from .text_generation_local import TextGeneration

Expand Down Expand Up @@ -86,28 +89,33 @@ def __init__(
# Set _model_loaded as False by default. This will only get set to True if
# we enable the tgis_backend and we are able to fetch the client successfully.
self._model_loaded = False
# Configure the internal client
# NOTE: This is made optional for the cases where we do not need to execute `.run` function
# for example, bootstrapping a model to caikit format and saving.
self._client = None
if tgis_backend:
self._client = tgis_backend.get_client(model_name)
# mark that the model is loaded so that we can unload it later
self._model_loaded = True
self.tgis_backend = tgis_backend

self._tgis_backend = tgis_backend
self._bos_token = bos_token
self._sep_token = sep_token
self._eos_token = eos_token
self._pad_token = pad_token
self.tgis_generation_client = TGISGenerationClient(
self.model_name, self._eos_token, self._client, self.PRODUCER_ID
)

def __del__(self):
# nothing to unload if we didn't finish loading
if self._model_loaded and self.tgis_backend:
self.tgis_backend.unload_model(self.model_name)
if self._model_loaded and self._tgis_backend:
self._tgis_backend.unload_model(self.model_name)

@cached_property
def _client(self):
# Lazily configure/create the internal tgis backend client
if self._tgis_backend:
return self._tgis_backend.get_client(self.model_name)

@cached_property
def tgis_generation_client(self):
# Lazily create the generation client
# This in turn calls self._client which also lazily gets the tgis backend client
return TGISGenerationClient(
self.model_name, self._eos_token, self._client, self.PRODUCER_ID
)

@classmethod
def bootstrap(cls, model_path: str, load_backend: Union[BackendBase, None] = None):
Expand Down Expand Up @@ -207,7 +215,7 @@ def save(self, model_path: str):
)

# pylint: disable=duplicate-code
@TextGenerationTask.taskmethod()
@TextGenerationTask.taskmethod(context_arg="context")
def run(
self,
text: str,
Expand All @@ -231,6 +239,7 @@ def run(
generated_tokens: bool = True,
token_logprobs: bool = True,
token_ranks: bool = True,
context: Optional[RuntimeServerContextType] = None,
) -> GeneratedTextResult:
f"""Run inference against the model running in TGIS.
Expand All @@ -240,6 +249,8 @@ def run(
GeneratedTextResult
Generated text result produced by TGIS.
"""
self._register_model_connection_with_context(context)

if self._model_loaded:
return self.tgis_generation_client.unary_generate(
text=text,
Expand All @@ -263,7 +274,7 @@ def run(
stop_sequences=stop_sequences,
)

@TextGenerationTask.taskmethod(output_streaming=True)
@TextGenerationTask.taskmethod(output_streaming=True, context_arg="context")
def run_stream_out(
self,
text: str,
Expand All @@ -287,6 +298,7 @@ def run_stream_out(
generated_tokens: bool = True,
token_logprobs: bool = True,
token_ranks: bool = True,
context: Optional[RuntimeServerContextType] = None,
) -> Iterable[GeneratedTextStreamResult]:
f"""Run output stream inferencing for text generation module.
Expand All @@ -295,6 +307,7 @@ def run_stream_out(
Returns:
Iterable[GeneratedTextStreamResult]
"""
self._register_model_connection_with_context(context)

if self._model_loaded:
return self.tgis_generation_client.stream_generate(
Expand All @@ -319,10 +332,11 @@ def run_stream_out(
stop_sequences=stop_sequences,
)

@TokenizationTask.taskmethod()
@TokenizationTask.taskmethod(context_arg="context")
def run_tokenizer(
self,
text: str,
context: Optional[RuntimeServerContextType] = None,
) -> TokenizationResults:
"""Run tokenization task against the model running in TGIS.
Expand All @@ -333,7 +347,28 @@ def run_tokenizer(
TokenizationResults
The token count
"""
self._register_model_connection_with_context(context)

if self._model_loaded:
return self.tgis_generation_client.unary_tokenize(
text=text,
)

def _register_model_connection_with_context(
self, context: Optional[RuntimeServerContextType]
):
"""
Register a remote model connection with the configured TGISBackend if there is
a context override provided.
"""
if self._tgis_backend:
if route_info := get_route_info(context):
log.debug(
"<NLP15770311D> Registering remote model connection with context "
"override: 'hostname: %s'",
route_info,
)
self._tgis_backend.register_model_connection(
self.model_name, {"hostname": route_info}, fill_with_defaults=True
)
self._model_loaded = True
39 changes: 36 additions & 3 deletions caikit_nlp/toolkit/text_generation/tgis_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This file is for helper functions related to TGIS.
"""
"""This file is for helper functions related to TGIS."""

# Standard
from typing import Iterable
from typing import Iterable, Optional

# Third Party
import fastapi
import grpc

# First Party
Expand All @@ -33,6 +34,7 @@
TokenizationResults,
TokenStreamDetails,
)
from caikit.interfaces.runtime.data_model import RuntimeServerContextType
from caikit_tgis_backend.protobufs import generation_pb2
import alog

Expand Down Expand Up @@ -84,6 +86,9 @@
grpc.StatusCode.UNAUTHENTICATED: CaikitCoreStatusCode.UNAUTHORIZED,
}

# HTTP Header / gRPC Metadata key used to identify a route override
ROUTE_INFO_HEADER_KEY = "x-route-info"


def raise_caikit_core_exception(rpc_error: grpc.RpcError):
"""Helper to wrap logic of converting from grpc.RpcError ->
Expand Down Expand Up @@ -683,3 +688,31 @@ def unary_tokenize(
return TokenizationResults(
token_count=response.token_count,
)


def get_route_info(
context: Optional[RuntimeServerContextType],
) -> Optional[str]:
"""
Returns a tuple `(True, x-route-info)` from context if "x-route-info" was found in
the headers/metadata.
Otherwise returns a tuple `(False, None)` if "x-route-info" was not found in the
context or if context is None.
"""
if context is None:
return None

if isinstance(context, grpc.ServicerContext):
route_info = dict(context.invocation_metadata()).get(ROUTE_INFO_HEADER_KEY)
if route_info:
return route_info
elif isinstance(context, fastapi.Request):
route_info = context.headers.get(ROUTE_INFO_HEADER_KEY)
if route_info:
return route_info
else:
error.log_raise(
"<NLP92615097E>",
ValueError(f"context is of an unsupported type: {type(context)}"),
)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ classifiers=[
]
dependencies = [
"caikit[runtime-grpc,runtime-http]>=0.26.27,<0.27.0",
"caikit-tgis-backend>=0.1.27,<0.2.0",
"caikit-tgis-backend>=0.1.33,<0.2.0",
# TODO: loosen dependencies
"grpcio>=1.62.2", # explicitly pin grpc dependencies to a recent version to avoid pip backtracking
"grpcio-reflection>=1.62.2",
Expand Down
Loading

0 comments on commit 3b2f8fd

Please sign in to comment.