Skip to content

Commit

Permalink
Fixed lint issue, e2e and kserve test fail.
Browse files Browse the repository at this point in the history
Signed-off-by: Andrews Arokiam <[email protected]>

Fixed alibi test and added response headers in grpc predict.

Signed-off-by: Andrews Arokiam <[email protected]>
  • Loading branch information
andyi2it committed Jan 29, 2024
1 parent ff64367 commit 2c379a8
Show file tree
Hide file tree
Showing 11 changed files with 114 additions and 77 deletions.
7 changes: 4 additions & 3 deletions python/alibiexplainer/alibiexplainer/explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,12 @@ def _predict_fn(self, arr: Union[np.ndarray, List]) -> np.ndarray:
else:
instances.append(req_data)
loop = asyncio.get_running_loop() # type: ignore
resp, response_headers = loop.run_until_complete(self.predict({"instances": instances}))
resp, response_headers = loop.run_until_complete(
self.predict({"instances": instances}))
return np.array(resp["predictions"])

def explain(self, payload: Dict, headers: Dict[str, str] = None) -> Any:
response_headers = {'my-header':'sample'}
response_headers = {'my-header': 'sample'}
if (
self.method is ExplainerMethod.anchor_tabular
or self.method is ExplainerMethod.anchor_images
Expand All @@ -85,6 +86,6 @@ def explain(self, payload: Dict, headers: Dict[str, str] = None) -> Any:
explanation = self.wrapper.explain(payload["instances"])
explanationAsJsonStr = explanation.to_json()
logging.info("Explanation: %s", explanationAsJsonStr)
return json.loads(explanationAsJsonStr),response_headers
return json.loads(explanationAsJsonStr), response_headers

raise NotImplementedError
2 changes: 1 addition & 1 deletion python/alibiexplainer/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,5 @@ def predict_fn(self, arr: Union[np.ndarray, List]) -> np.ndarray:
instances.append(req_data.tolist())
else:
instances.append(req_data)
resp = self.clf.predict({"instances": instances})
resp, response_headers = self.clf.predict({"instances": instances})
return np.array(resp["predictions"])
6 changes: 4 additions & 2 deletions python/artexplainer/artserver/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,9 @@ def explain(self, payload: Dict, headers: Dict[str, str] = None) -> Dict:
adv_preds = np.argmax(classifier.predict(x_adv))
l2_error = np.linalg.norm(np.reshape(x_adv[0] - inputs, [-1]))

return {"explanations": {"adversarial_example": x_adv.tolist(), "L2 error": l2_error.tolist(),
"adversarial_prediction": adv_preds.tolist(), "prediction": preds.tolist()}}, response_headers
return (
{"explanations": {"adversarial_example": x_adv.tolist(), "L2 error": l2_error.tolist(),
"adversarial_prediction": adv_preds.tolist(), "prediction": preds.tolist()}},
response_headers)
except Exception as e:
raise Exception("Failed to explain %s" % e)
26 changes: 15 additions & 11 deletions python/kserve/kserve/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import inspect
import time
from enum import Enum
from typing import Dict, List, Union, Optional, AsyncIterator, Any
from typing import Dict, List, Tuple, Union, Optional, AsyncIterator, Any

import grpc
import httpx
Expand Down Expand Up @@ -135,14 +135,18 @@ async def __call__(self, body: Union[Dict, CloudEvent, InferRequest],
if verb == InferenceVerb.EXPLAIN:
with EXPLAIN_HIST_TIME.labels(**prom_labels).time():
start = time.time()
response, response_headers = (await self.explain(payload, headers)) if inspect.iscoroutinefunction(self.explain) \
else self.explain(payload, headers)
if inspect.iscoroutinefunction(self.explain):
response, response_headers = (await self.explain(payload, headers))
else:
response, response_headers = self.explain(payload, headers)
explain_ms = get_latency_ms(start, time.time())
elif verb == InferenceVerb.PREDICT:
with PREDICT_HIST_TIME.labels(**prom_labels).time():
start = time.time()
response, response_headers = (await self.predict(payload, headers)) if inspect.iscoroutinefunction(self.predict) \
else self.predict(payload, headers)
if inspect.iscoroutinefunction(self.predict):
response, response_headers = (await self.predict(payload, headers))
else:
response, response_headers = self.predict(payload, headers)
predict_ms = get_latency_ms(start, time.time())
else:
raise NotImplementedError
Expand Down Expand Up @@ -293,6 +297,7 @@ async def _http_predict(self, payload: Union[Dict, InferRequest], headers: Dict[

async def _grpc_predict(self, payload: Union[ModelInferRequest, InferRequest], headers: Dict[str, str] = None) \
-> ModelInferResponse:
response_headers = {}
if isinstance(payload, InferRequest):
payload = payload.to_grpc()
async_result = await self._grpc_client.ModelInfer(
Expand All @@ -302,12 +307,11 @@ async def _grpc_predict(self, payload: Union[ModelInferRequest, InferRequest], h
('response_type', 'grpc_v2'),
('x-request-id', headers.get('x-request-id', '')))
)
return async_result
return async_result, response_headers

async def predict(self, payload: Union[Dict, InferRequest, ModelInferRequest],
headers: Dict[str, str] = None) -> Union[Dict, InferResponse]:
""" The `predict` handler can be overridden for performing the inference.
By default, the predict handler makes call to predictor for the inference step.
headers: Dict[str, str] = None) -> Tuple[Union[Dict, InferResponse], Dict]:
"""
Args:
payload: Model inputs passed from `preprocess` handler.
Expand All @@ -322,8 +326,8 @@ async def predict(self, payload: Union[Dict, InferRequest, ModelInferRequest],
if not self.predictor_host:
raise NotImplementedError("Could not find predictor_host.")
if self.protocol == PredictorProtocol.GRPC_V2.value:
response_content = await self._grpc_predict(payload, headers)
return InferResponse.from_grpc(response_content)
response_content, response_headers = await self._grpc_predict(payload, headers)
return InferResponse.from_grpc(response_content), response_headers
else:
response_content, response_headers = await self._http_predict(payload, headers)
response_headers = {}
Expand Down
16 changes: 8 additions & 8 deletions python/kserve/kserve/protocol/dataplane.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,12 +320,12 @@ async def infer(
# call model locally or remote model workers
model = self.get_model(model_name)
if isinstance(model, RayServeSyncHandle):
response = ray.get(model.remote(request, headers=headers))
response, response_headers = ray.get(model.remote(request, headers=headers))
elif isinstance(model, (RayServeHandle, DeploymentHandle)):
response = await model.remote(request, headers=headers)
response, response_headers = await model.remote(request, headers=headers)
else:
response = await model(request, headers=headers)
return response, headers
response, response_headers = await model(request, headers=headers)
return response, response_headers

async def generate(
self,
Expand Down Expand Up @@ -371,9 +371,9 @@ async def explain(self, model_name: str,
# call model locally or remote model workers
model = self.get_model(model_name)
if isinstance(model, RayServeSyncHandle):
response = ray.get(model.remote(request, verb=InferenceVerb.EXPLAIN))
response, response_headers = ray.get(model.remote(request, verb=InferenceVerb.EXPLAIN))
elif isinstance(model, (RayServeHandle, DeploymentHandle)):
response = await model.remote(request, verb=InferenceVerb.EXPLAIN)
response, response_headers = await model.remote(request, verb=InferenceVerb.EXPLAIN)
else:
response = await model(request, verb=InferenceVerb.EXPLAIN)
return response, headers
response, response_headers = await model(request, verb=InferenceVerb.EXPLAIN)
return response, response_headers
90 changes: 60 additions & 30 deletions python/kserve/test/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,15 +79,17 @@ def load(self):
self.ready = True

async def predict(self, request, headers=None):
response_headers = {}
if isinstance(request, InferRequest):
inputs = get_predict_input(request)
infer_response = get_predict_response(request, inputs, self.name)
return infer_response
return infer_response, response_headers
else:
return {"predictions": request["instances"]}
return {"predictions": request["instances"]}, response_headers

async def explain(self, request, headers=None):
return {"predictions": request["instances"]}
response_headers = {}
return {"predictions": request["instances"]}, response_headers


@serve.deployment
Expand All @@ -101,15 +103,17 @@ def load(self):
self.ready = True

async def predict(self, request, headers=None):
response_headers = {}
if isinstance(request, InferRequest):
inputs = get_predict_input(request)
infer_response = get_predict_response(request, inputs, self.name)
return infer_response
return infer_response, response_headers
else:
return {"predictions": request["instances"]}
return {"predictions": request["instances"]}, response_headers

async def explain(self, request, headers=None):
return {"predictions": request["instances"]}
response_headers = {}
return {"predictions": request["instances"]}, response_headers


class DummyCEModel(Model):
Expand All @@ -122,10 +126,12 @@ def load(self):
self.ready = True

async def predict(self, request, headers=None):
return {"predictions": request["instances"]}
response_headers = headers
return {"predictions": request["instances"]}, response_headers

async def explain(self, request, headers=None):
return {"predictions": request["instances"]}
response_headers = headers
return {"predictions": request["instances"]}, response_headers


class DummyAvroCEModel(Model):
Expand Down Expand Up @@ -154,12 +160,18 @@ def preprocess(self, request, headers: Dict[str, str] = None):
return self._parserequest(request)

async def predict(self, request, headers=None):
return {"predictions": [[request['name'], request['favorite_number'],
request['favorite_color']]]}
response_headers = headers
return (
{"predictions": [
[request['name'], request['favorite_number'], request['favorite_color']]]},
response_headers)

async def explain(self, request, headers=None):
return {"predictions": [[request['name'], request['favorite_number'],
request['favorite_color']]]}
response_headers = headers
return (
{"predictions": [
[request['name'], request['favorite_number'], request['favorite_color']]]},
response_headers)


class DummyModelRepository(ModelRepository):
Expand Down Expand Up @@ -214,7 +226,8 @@ def app(self):
model.load()
server = ModelServer()
server.register_model(model)
rest_server = RESTServer(server.dataplane, server.model_repository_extension)
rest_server = RESTServer(
server.dataplane, server.model_repository_extension)
return rest_server.create_application()

@pytest.fixture(scope="class")
Expand All @@ -233,7 +246,8 @@ def test_model_v1(self, http_server_client):
def test_unknown_model_v1(self, http_server_client):
resp = http_server_client.get('/v1/models/InvalidModel')
assert resp.status_code == 404
assert resp.json() == {"error": "Model with name InvalidModel does not exist."}
assert resp.json() == {
"error": "Model with name InvalidModel does not exist."}

def test_list_models_v1(self, http_server_client):
resp = http_server_client.get('/v1/models')
Expand Down Expand Up @@ -356,7 +370,8 @@ def app(self): # pylint: disable=no-self-use

server = ModelServer()
server.register_model_handle("TestModel", handle)
rest_server = RESTServer(server.dataplane, server.model_repository_extension)
rest_server = RESTServer(
server.dataplane, server.model_repository_extension)
return rest_server.create_application()

@pytest.fixture(scope='class')
Expand Down Expand Up @@ -410,7 +425,8 @@ def app(self): # pylint: disable=no-self-use
model = DummyModel("TestModel")
server = ModelServer()
server.register_model(model)
rest_server = RESTServer(server.dataplane, server.model_repository_extension)
rest_server = RESTServer(
server.dataplane, server.model_repository_extension)
return rest_server.create_application()

@pytest.fixture(scope='class')
Expand All @@ -429,7 +445,8 @@ def app(self): # pylint: disable=no-self-use
model.load()
server = ModelServer()
server.register_model(model)
rest_server = RESTServer(server.dataplane, server.model_repository_extension)
rest_server = RESTServer(
server.dataplane, server.model_repository_extension)
return rest_server.create_application()

@pytest.fixture(scope='class')
Expand Down Expand Up @@ -472,7 +489,8 @@ def test_predict_custom_ce_attributes(self, http_server_client):

def test_predict_merge_structured_ce_attributes(self, http_server_client):
with mock.patch.dict(os.environ, {"CE_MERGE": "true"}):
event = dummy_cloud_event({"instances": [[1, 2]]}, add_extension=True)
event = dummy_cloud_event(
{"instances": [[1, 2]]}, add_extension=True)
headers, body = to_structured(event)

resp = http_server_client.post('/v1/models/TestModel:predict', headers=headers, content=body)
Expand All @@ -485,12 +503,14 @@ def test_predict_merge_structured_ce_attributes(self, http_server_client):
assert body["data"] == {"predictions": [[1, 2]]}
assert body['source'] == "io.kserve.inference.TestModel"
assert body['type'] == "io.kserve.inference.response"
assert body["custom-extension"] == "custom-value" # Added by add_extension=True in dummy_cloud_event
# Added by add_extension=True in dummy_cloud_event
assert body["custom-extension"] == "custom-value"
assert body['time'] > "2021-01-28T21:04:43.144141+00:00"

def test_predict_merge_binary_ce_attributes(self, http_server_client):
with mock.patch.dict(os.environ, {"CE_MERGE": "true"}):
event = dummy_cloud_event({"instances": [[1, 2]]}, set_contenttype=True, add_extension=True)
event = dummy_cloud_event(
{"instances": [[1, 2]]}, set_contenttype=True, add_extension=True)
headers, body = to_binary(event)

resp = http_server_client.post('/v1/models/TestModel:predict', headers=headers, content=body)
Expand All @@ -507,7 +527,8 @@ def test_predict_merge_binary_ce_attributes(self, http_server_client):
assert resp.content == b'{"predictions": [[1, 2]]}'

def test_predict_ce_binary_dict(self, http_server_client):
event = dummy_cloud_event({"instances": [[1, 2]]}, set_contenttype=True)
event = dummy_cloud_event(
{"instances": [[1, 2]]}, set_contenttype=True)
headers, body = to_binary(event)

resp = http_server_client.post('/v1/models/TestModel:predict', headers=headers, content=body)
Expand All @@ -522,7 +543,8 @@ def test_predict_ce_binary_dict(self, http_server_client):
assert resp.content == b'{"predictions": [[1, 2]]}'

def test_predict_ce_binary_bytes(self, http_server_client):
event = dummy_cloud_event(b'{"instances":[[1,2]]}', set_contenttype=True)
event = dummy_cloud_event(
b'{"instances":[[1,2]]}', set_contenttype=True)
headers, body = to_binary(event)
resp = http_server_client.post('/v1/models/TestModel:predict', headers=headers, content=body)

Expand All @@ -548,7 +570,8 @@ def test_predict_ce_bytes_bad_format_exception(self, http_server_client):
assert error_regex.match(response["error"]) is not None

def test_predict_ce_bytes_bad_hex_format_exception(self, http_server_client):
event = dummy_cloud_event(b'0\x80\x80\x06World!\x00\x00', set_contenttype=True)
event = dummy_cloud_event(
b'0\x80\x80\x06World!\x00\x00', set_contenttype=True)
headers, body = to_binary(event)

resp = http_server_client.post('/v1/models/TestModel:predict', headers=headers, content=body)
Expand All @@ -568,7 +591,8 @@ def app(self): # pylint: disable=no-self-use
model.load()
server = ModelServer()
server.register_model(model)
rest_server = RESTServer(server.dataplane, server.model_repository_extension)
rest_server = RESTServer(
server.dataplane, server.model_repository_extension)
return rest_server.create_application()

@pytest.fixture(scope='class')
Expand All @@ -585,7 +609,8 @@ def test_predict_ce_avro_binary(self, http_server_client):
writer.write(msg, encoder)
data = bytes_writer.getvalue()

event = dummy_cloud_event(data, set_contenttype=True, contenttype="application/avro")
event = dummy_cloud_event(
data, set_contenttype=True, contenttype="application/avro")
# Creates the HTTP request representation of the CloudEvent in binary content mode
headers, body = to_binary(event)
resp = http_server_client.post('/v1/models/TestModel:predict', headers=headers, content=body)
Expand All @@ -604,8 +629,10 @@ class TestTFHttpServerLoadAndUnLoad:

@pytest.fixture(scope="class")
def app(self): # pylint: disable=no-self-use
server = ModelServer(registered_models=DummyModelRepository(test_load_success=True))
rest_server = RESTServer(server.dataplane, server.model_repository_extension)
server = ModelServer(
registered_models=DummyModelRepository(test_load_success=True))
rest_server = RESTServer(
server.dataplane, server.model_repository_extension)
return rest_server.create_application()

@pytest.fixture(scope='class')
Expand All @@ -626,8 +653,10 @@ def test_unload(self, http_server_client):
class TestTFHttpServerLoadAndUnLoadFailure:
@pytest.fixture(scope="class")
def app(self): # pylint: disable=no-self-use
server = ModelServer(registered_models=DummyModelRepository(test_load_success=False))
rest_server = RESTServer(server.dataplane, server.model_repository_extension)
server = ModelServer(
registered_models=DummyModelRepository(test_load_success=False))
rest_server = RESTServer(
server.dataplane, server.model_repository_extension)
return rest_server.create_application()

@pytest.fixture(scope='class')
Expand All @@ -649,7 +678,8 @@ def app(self): # pylint: disable=no-self-use
model = DummyModel("TestModel")
server = ModelServer()
server.register_model(model)
rest_server = RESTServer(server.dataplane, server.model_repository_extension)
rest_server = RESTServer(
server.dataplane, server.model_repository_extension)
return rest_server.create_application()

@pytest.fixture(scope='class')
Expand Down
Loading

0 comments on commit 2c379a8

Please sign in to comment.