Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Headers passing #16

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions python/alibiexplainer/alibiexplainer/explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,12 @@ def _predict_fn(self, arr: Union[np.ndarray, List]) -> np.ndarray:
else:
instances.append(req_data)
loop = asyncio.get_running_loop() # type: ignore
resp = loop.run_until_complete(self.predict({"instances": instances}))
resp, response_headers = loop.run_until_complete(
self.predict({"instances": instances}))
return np.array(resp["predictions"])

def explain(self, payload: Dict, headers: Dict[str, str] = None) -> Any:
response_headers = {'my-header': 'sample'}
if (
self.method is ExplainerMethod.anchor_tabular
or self.method is ExplainerMethod.anchor_images
Expand All @@ -84,6 +86,6 @@ def explain(self, payload: Dict, headers: Dict[str, str] = None) -> Any:
explanation = self.wrapper.explain(payload["instances"])
explanationAsJsonStr = explanation.to_json()
logging.info("Explanation: %s", explanationAsJsonStr)
return json.loads(explanationAsJsonStr)
return json.loads(explanationAsJsonStr), response_headers

raise NotImplementedError
2 changes: 1 addition & 1 deletion python/alibiexplainer/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,5 @@ def predict_fn(self, arr: Union[np.ndarray, List]) -> np.ndarray:
instances.append(req_data.tolist())
else:
instances.append(req_data)
resp = self.clf.predict({"instances": instances})
resp, response_headers = self.clf.predict({"instances": instances})
return np.array(resp["predictions"])
9 changes: 6 additions & 3 deletions python/artexplainer/artserver/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,12 @@ def _predict(self, x):
scoring_data = {'instances': input_image.tolist()}

loop = asyncio.get_running_loop()
resp = loop.run_until_complete(self.predict(scoring_data))
resp, response_headers = loop.run_until_complete(self.predict(scoring_data))
prediction = np.array(resp["predictions"])
return [1 if x == prediction else 0 for x in range(0, self.nb_classes)]

def explain(self, payload: Dict, headers: Dict[str, str] = None) -> Dict:
response_headers = {}
image = payload["instances"][0]
label = payload["instances"][1]
try:
Expand All @@ -74,7 +75,9 @@ def explain(self, payload: Dict, headers: Dict[str, str] = None) -> Dict:
adv_preds = np.argmax(classifier.predict(x_adv))
l2_error = np.linalg.norm(np.reshape(x_adv[0] - inputs, [-1]))

return {"explanations": {"adversarial_example": x_adv.tolist(), "L2 error": l2_error.tolist(),
"adversarial_prediction": adv_preds.tolist(), "prediction": preds.tolist()}}
return (
{"explanations": {"adversarial_example": x_adv.tolist(), "L2 error": l2_error.tolist(),
"adversarial_prediction": adv_preds.tolist(), "prediction": preds.tolist()}},
response_headers)
except Exception as e:
raise Exception("Failed to explain %s" % e)
5 changes: 3 additions & 2 deletions python/custom_model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def preprocess(self, payload: Union[Dict, InferRequest], headers: Dict[str, str]
return input_tensor.unsqueeze(0)

def predict(self, input_tensor: torch.Tensor, headers: Dict[str, str] = None) -> Union[Dict, InferResponse]:
response_headers = {}
output = self.model(input_tensor)
torch.nn.functional.softmax(output, dim=1)
values, top_5 = torch.topk(output, 5)
Expand All @@ -92,9 +93,9 @@ def predict(self, input_tensor: torch.Tensor, headers: Dict[str, str] = None) ->
infer_output = InferOutput(name="output-0", shape=list(values.shape), datatype="FP32", data=result)
infer_response = InferResponse(model_name=self.name, infer_outputs=[infer_output], response_id=response_id)
if "request-type" in headers and headers["request-type"] == "v1":
return {"predictions": result}
return {"predictions": result}, response_headers
else:
return infer_response
return infer_response, response_headers


parser = argparse.ArgumentParser(parents=[model_server.parser])
Expand Down
3 changes: 2 additions & 1 deletion python/custom_model/model_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def preprocess(self, payload: InferRequest, headers: Dict[str, str] = None) -> t
return torch.Tensor(np_array)

def predict(self, input_tensor: torch.Tensor, headers: Dict[str, str] = None) -> Dict:
response_headers = {}
output = self.model(input_tensor)
torch.nn.functional.softmax(output, dim=1)
values, top_5 = torch.topk(output, 5)
Expand All @@ -75,7 +76,7 @@ def predict(self, input_tensor: torch.Tensor, headers: Dict[str, str] = None) ->
"shape": list(values.shape)
}
]}
return response
return response, response_headers


if __name__ == "__main__":
Expand Down
3 changes: 2 additions & 1 deletion python/huggingfaceserver/huggingfaceserver/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ async def generate(self, generate_request: GenerateRequest, headers: Dict[str, s

async def predict(self, input_batch: Union[BatchEncoding, InferRequest], context: Dict[str, Any] = None) \
-> Union[Tensor, InferResponse]:
response_headers = {}
if self.predictor_host:
# when predictor_host is provided, serialize the tensor and send to optimized model serving runtime
# like NVIDIA triton inference server
Expand All @@ -208,7 +209,7 @@ async def predict(self, input_batch: Union[BatchEncoding, InferRequest], context
outputs = self.model.generate(**input_batch)
else:
outputs = self.model(**input_batch).logits
return outputs
return outputs, response_headers
except Exception as e:
raise InferenceError(str(e))

Expand Down
12 changes: 6 additions & 6 deletions python/huggingfaceserver/huggingfaceserver/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,16 @@ def test_t5():
model.load()

request = "translate this to germany"
response = asyncio.run(model({"instances": [request, request]}, headers={}))
response, response_headers = asyncio.run(model({"instances": [request, request]}, headers={}))
assert response == {"predictions": ['Das ist für Deutschland', 'Das ist für Deutschland']}


def test_bert():
model = HuggingfaceModel("bert-base-uncased", {"model_id": "bert-base-uncased", "do_lower_case": True})
model.load()

response = asyncio.run(model({"instances": ["The capital of France is [MASK].",
"The capital of [MASK] is paris."]}, headers={}))
response, response_headers = asyncio.run(model({"instances": ["The capital of France is [MASK].",
"The capital of [MASK] is paris."]}, headers={}))
assert response == {"predictions": ["paris", "france"]}


Expand All @@ -51,7 +51,7 @@ def test_bert_predictor_host(httpx_mock: HTTPXMock):
predictor_host="localhost:8081", predictor_protocol="v2"))
model.load()

response = asyncio.run(model({"instances": ["The capital of France is [MASK]."]}, headers={}))
response, response_headers = asyncio.run(model({"instances": ["The capital of France is [MASK]."]}, headers={}))
assert response == {"predictions": ["[PAD]"]}


Expand All @@ -62,7 +62,7 @@ def test_bert_sequence_classification():
model.load()

request = "Hello, my dog is cute."
response = asyncio.run(model({"instances": [request, request]}, headers={}))
response, response_headers = asyncio.run(model({"instances": [request, request]}, headers={}))
assert response == {"predictions": [1, 1]}


Expand All @@ -73,7 +73,7 @@ def test_bert_token_classification():
model.load()

request = "HuggingFace is a company based in Paris and New York"
response = asyncio.run(model({"instances": [request, request]}, headers={}))
response, response_headers = asyncio.run(model({"instances": [request, request]}, headers={}))
assert response == {"predictions": [[[0, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
[[0, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]]}

Expand Down
60 changes: 39 additions & 21 deletions python/kserve/kserve/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import inspect
import time
from enum import Enum
from typing import Dict, List, Union, Optional, AsyncIterator, Any
from typing import Dict, List, Tuple, Union, Optional, AsyncIterator, Any

import grpc
import httpx
Expand Down Expand Up @@ -135,14 +135,18 @@ async def __call__(self, body: Union[Dict, CloudEvent, InferRequest],
if verb == InferenceVerb.EXPLAIN:
with EXPLAIN_HIST_TIME.labels(**prom_labels).time():
start = time.time()
response = (await self.explain(payload, headers)) if inspect.iscoroutinefunction(self.explain) \
else self.explain(payload, headers)
if inspect.iscoroutinefunction(self.explain):
response, response_headers = (await self.explain(payload, headers))
else:
response, response_headers = self.explain(payload, headers)
explain_ms = get_latency_ms(start, time.time())
elif verb == InferenceVerb.PREDICT:
with PREDICT_HIST_TIME.labels(**prom_labels).time():
start = time.time()
response = (await self.predict(payload, headers)) if inspect.iscoroutinefunction(self.predict) \
else self.predict(payload, headers)
if inspect.iscoroutinefunction(self.predict):
response, response_headers = (await self.predict(payload, headers))
else:
response, response_headers = self.predict(payload, headers)
predict_ms = get_latency_ms(start, time.time())
else:
raise NotImplementedError
Expand All @@ -158,7 +162,7 @@ async def __call__(self, body: Union[Dict, CloudEvent, InferRequest],
f"explain_ms: {explain_ms}, predict_ms: {predict_ms}, "
f"postprocess_ms: {postprocess_ms}")

return response
return response, response_headers

@property
def _http_client(self):
Expand All @@ -174,10 +178,12 @@ def _grpc_client(self):
port = 443 if self.use_ssl else 80
self.predictor_host = f"{self.predictor_host}:{port}"
if self.use_ssl:
_channel = grpc.aio.secure_channel(self.predictor_host, grpc.ssl_channel_credentials())
_channel = grpc.aio.secure_channel(
self.predictor_host, grpc.ssl_channel_credentials())
else:
_channel = grpc.aio.insecure_channel(self.predictor_host)
self._grpc_client_stub = grpc_predict_v2_pb2_grpc.GRPCInferenceServiceStub(_channel)
self._grpc_client_stub = grpc_predict_v2_pb2_grpc.GRPCInferenceServiceStub(
_channel)
return self._grpc_client_stub

def validate(self, payload):
Expand Down Expand Up @@ -252,9 +258,11 @@ async def postprocess(self, result: Union[Dict, InferResponse], headers: Dict[st

async def _http_predict(self, payload: Union[Dict, InferRequest], headers: Dict[str, str] = None) -> Dict:
protocol = "https" if self.use_ssl else "http"
predict_url = PREDICTOR_URL_FORMAT.format(protocol, self.predictor_host, self.name)
predict_url = PREDICTOR_URL_FORMAT.format(
protocol, self.predictor_host, self.name)
if self.protocol == PredictorProtocol.REST_V2.value:
predict_url = PREDICTOR_V2_URL_FORMAT.format(protocol, self.predictor_host, self.name)
predict_url = PREDICTOR_V2_URL_FORMAT.format(
protocol, self.predictor_host, self.name)

# Adjusting headers. Inject content type if not exist.
# Also, removing host, as the header is the one passed to transformer and contains transformer's host
Expand Down Expand Up @@ -283,11 +291,13 @@ async def _http_predict(self, payload: Union[Dict, InferRequest], headers: Dict[
if "error" in error_message:
error_message = error_message["error"]
message = message.format(response, error_message=error_message)
raise HTTPStatusError(message, request=response.request, response=response)
return orjson.loads(response.content)
raise HTTPStatusError(
message, request=response.request, response=response)
return orjson.loads(response.content), response.headers

async def _grpc_predict(self, payload: Union[ModelInferRequest, InferRequest], headers: Dict[str, str] = None) \
-> ModelInferResponse:
response_headers = {}
if isinstance(payload, InferRequest):
payload = payload.to_grpc()
async_result = await self._grpc_client.ModelInfer(
Expand All @@ -297,12 +307,11 @@ async def _grpc_predict(self, payload: Union[ModelInferRequest, InferRequest], h
('response_type', 'grpc_v2'),
('x-request-id', headers.get('x-request-id', '')))
)
return async_result
return async_result, response_headers

async def predict(self, payload: Union[Dict, InferRequest, ModelInferRequest],
headers: Dict[str, str] = None) -> Union[Dict, InferResponse]:
""" The `predict` handler can be overridden for performing the inference.
By default, the predict handler makes call to predictor for the inference step.
headers: Dict[str, str] = None) -> Tuple[Union[Dict, InferResponse], Dict]:
"""

Args:
payload: Model inputs passed from `preprocess` handler.
Expand All @@ -317,12 +326,21 @@ async def predict(self, payload: Union[Dict, InferRequest, ModelInferRequest],
if not self.predictor_host:
raise NotImplementedError("Could not find predictor_host.")
if self.protocol == PredictorProtocol.GRPC_V2.value:
res = await self._grpc_predict(payload, headers)
return InferResponse.from_grpc(res)
response_content, response_headers = await self._grpc_predict(payload, headers)
return InferResponse.from_grpc(response_content), response_headers
else:
res = await self._http_predict(payload, headers)
response_content, response_headers = await self._http_predict(payload, headers)
response_headers = {}
# Check if 'Content-Length' header exists in the response.headers dictionary
if 'Content-Length' in response_headers:
# Remove the 'Content-Length' from response header
del response_headers['Content-Length']

# return an InferResponse if this is REST V2, otherwise just return the dictionary
return InferResponse.from_rest(self.name, res) if is_v2(PredictorProtocol(self.protocol)) else res
if is_v2(PredictorProtocol(self.protocol)):
return InferResponse.from_rest(self.name, response_content), response_headers
else:
return response_content, response_headers

async def generate(self, payload: GenerateRequest,
headers: Dict[str, str] = None) -> Union[GenerateResponse, AsyncIterator[Any]]:
Expand Down Expand Up @@ -358,4 +376,4 @@ async def explain(self, payload: Dict, headers: Dict[str, str] = None) -> Dict:
)

response.raise_for_status()
return orjson.loads(response.content)
return orjson.loads(response.content), response.headers
21 changes: 12 additions & 9 deletions python/kserve/kserve/protocol/dataplane.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,10 @@ def decode_cloudevent(self, body) -> Tuple[Union[Dict, InferRequest], Dict]:
return decoded_body, attributes

def encode(self, model_name, response, headers, req_attributes: Dict) -> Tuple[Dict, Dict[str, str]]:
response_headers = {}
response_headers = headers
if not headers:
response_headers = {}

# if we received a cloudevent, then also return a cloudevent
is_cloudevent = False
is_binary_cloudevent = False
Expand Down Expand Up @@ -317,12 +320,12 @@ async def infer(
# call model locally or remote model workers
model = self.get_model(model_name)
if isinstance(model, RayServeSyncHandle):
response = ray.get(model.remote(request, headers=headers))
response, response_headers = ray.get(model.remote(request, headers=headers))
elif isinstance(model, (RayServeHandle, DeploymentHandle)):
response = await model.remote(request, headers=headers)
response, response_headers = await model.remote(request, headers=headers)
else:
response = await model(request, headers=headers)
return response, headers
response, response_headers = await model(request, headers=headers)
return response, response_headers

async def generate(
self,
Expand Down Expand Up @@ -368,9 +371,9 @@ async def explain(self, model_name: str,
# call model locally or remote model workers
model = self.get_model(model_name)
if isinstance(model, RayServeSyncHandle):
response = ray.get(model.remote(request, verb=InferenceVerb.EXPLAIN))
response, response_headers = ray.get(model.remote(request, verb=InferenceVerb.EXPLAIN))
elif isinstance(model, (RayServeHandle, DeploymentHandle)):
response = await model.remote(request, verb=InferenceVerb.EXPLAIN)
response, response_headers = await model.remote(request, verb=InferenceVerb.EXPLAIN)
else:
response = await model(request, verb=InferenceVerb.EXPLAIN)
return response, headers
response, response_headers = await model(request, verb=InferenceVerb.EXPLAIN)
return response, response_headers
19 changes: 11 additions & 8 deletions python/kserve/kserve/protocol/rest/v1_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing import Optional, Union, Dict, List

from fastapi import Request, Response
from fastapi.responses import JSONResponse

from kserve.errors import ModelNotReady
from ..dataplane import DataPlane
Expand Down Expand Up @@ -78,11 +79,12 @@ async def predict(self, model_name: str, request: Request) -> Union[Response, Di
headers=headers)
response, response_headers = self.dataplane.encode(model_name=model_name,
response=response,
headers=headers, req_attributes=req_attributes)
headers=response_headers, req_attributes=req_attributes)

if not isinstance(response, dict):
return Response(content=response, headers=response_headers)
return response
if isinstance(response, dict):
return JSONResponse(content=response, headers=response_headers)

return Response(content=response, headers=response_headers)

async def explain(self, model_name: str, request: Request) -> Union[Response, Dict]:
"""Explain handler.
Expand All @@ -108,8 +110,9 @@ async def explain(self, model_name: str, request: Request) -> Union[Response, Di
headers=headers)
response, response_headers = self.dataplane.encode(model_name=model_name,
response=response,
headers=headers, req_attributes=req_attributes)
headers=response_headers, req_attributes=req_attributes)

if isinstance(response, dict):
return JSONResponse(content=response, headers=response_headers)

if not isinstance(response, dict):
return Response(content=response, headers=response_headers)
return response
return Response(content=response, headers=response_headers)
Loading
Loading