Skip to content

Commit

Permalink
Added missing response_headers variable in predict call.
Browse files Browse the repository at this point in the history
Signed-off-by: Andrews Arokiam <[email protected]>
  • Loading branch information
andyi2it committed Jan 29, 2024
1 parent 2c379a8 commit 5794dd2
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 15 deletions.
3 changes: 2 additions & 1 deletion python/kserve/test/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,7 @@ def test_infer_parameters_v2(self, http_server_client):

input_data = json.dumps(req.to_rest()).encode('utf-8')
with patch.object(DummyModel, 'predict', new_callable=mock.Mock) as mock_predict:
response_headers = {}
mock_predict.return_value = InferResponse(model_name="TestModel", response_id="123",
parameters={
"test-str": "dummy",
Expand All @@ -347,7 +348,7 @@ def test_infer_parameters_v2(self, http_server_client):
"test-str": "dummy",
"test-bool": True,
"test-int": 100
})])
})]), response_headers
resp = http_server_client.post('/v2/models/TestModel/infer', content=input_data)
mock_predict.assert_called_with(req, mock.ANY)

Expand Down
24 changes: 11 additions & 13 deletions python/lgbserver/lgbserver/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,39 +47,37 @@ def test_model():
'petal_width_(cm)': {0: 0.2}, 'sepal_length_(cm)': {0: 5.1}}

response, response_headers = model.predict({"inputs": [request, request]})
assert numpy.argmax(response["predictions"][0]) == 2
assert numpy.argmax(response["predictions"][0]) == 0

response, response_headers = model.predict(
{"instances": [request, request]})
assert numpy.argmax(response["predictions"][0]) == 2
response, response_headers = model.predict({"instances": [request, request]})
assert numpy.argmax(response["predictions"][0]) == 0

request = [
{'sepal_width_(cm)': 3.5}, {'petal_length_(cm)': 1.4},
{'petal_width_(cm)': 0.2}, {'sepal_length_(cm)': 5.1}
]
response = model.predict({"inputs": [request, request]})
response, response_headers = model.predict({"inputs": [request, request]})
assert numpy.argmax(response["predictions"][0]) == 0

response = model.predict({"instances": [request, request]})
response, response_headers = model.predict({"instances": [request, request]})
assert numpy.argmax(response["predictions"][0]) == 0

request = [
{'sepal_width_(cm)': 3.5}, {'petal_length_(cm)': 1.4},
{'petal_width_(cm)': 0.2}
]
response = model.predict({"inputs": [request, request]})
response, response_headers = model.predict({"inputs": [request, request]})
assert numpy.argmax(response["predictions"][0]) == 0

response = model.predict({"instances": [request, request]})
response, response_headers = model.predict({"instances": [request, request]})
assert numpy.argmax(response["predictions"][0]) == 0

# test v2 handler
infer_input = InferInput(name="input-0", shape=[2, 4], datatype="FP32",
data=[[6.8, 2.8, 4.8, 1.6], [6.0, 3.4, 4.5, 1.6]])
infer_request = InferRequest(
model_name="model", infer_inputs=[infer_input])
infer_request = InferRequest(model_name="model", infer_inputs=[infer_input])
infer_response, response_headers = model.predict(infer_request)
assert infer_response.to_rest()["outputs"] == \
[{'name': 'output-0', 'shape': [2, 3], 'datatype': 'FP64',
'data': [3.7899802486733807e-06, 0.9996982074114203, 0.00029800260833088297,
5.2172911836629736e-05, 0.99973341723876, 0.000214409849403366]}]
[{'name': 'output-0', 'shape': [2, 3], 'datatype': 'FP64',
'data': [3.7899802486733807e-06, 0.9996982074114203, 0.00029800260833088297,
5.2172911836629736e-05, 0.99973341723876, 0.000214409849403366]}]
3 changes: 2 additions & 1 deletion python/test_resources/graph/success_200_isvc/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ def load(self):
self.ready = True

def predict(self, payload: Union[Dict, InferRequest, ModelInferRequest], headers) -> Dict:
return {"message": "SUCCESS"}
response_headers = {}
return {"message": "SUCCESS"}, response_headers


parser = argparse.ArgumentParser(parents=[kserve.model_server.parser])
Expand Down

0 comments on commit 5794dd2

Please sign in to comment.