Skip to content

Commit

Permalink
[python] Fix backport of rolling batch non-streaming non-200 error co…
Browse files Browse the repository at this point in the history
…de support (#2478)
  • Loading branch information
davidthomas426 authored Oct 22, 2024
1 parent 975c50d commit c807fe0
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 3 deletions.
84 changes: 83 additions & 1 deletion engines/python/setup/djl_python/output_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@
import json
import logging
import time
from typing import Union, Callable
from typing import Dict, Union, Callable

from typing_extensions import deprecated

from djl_python.request_io import Token, TextGenerationOutput, RequestOutput
from djl_python.utils import serving_backport_for_non_streaming_http_error_codes_enabled


def _json_output_formatter(request_output: RequestOutput):
Expand All @@ -26,6 +27,10 @@ def _json_output_formatter(request_output: RequestOutput):
:return: formatted output
"""
if serving_backport_for_non_streaming_http_error_codes_enabled():
return _json_output_formatter_backport_for_non_streaming_http_error_codes(
request_output)

best_sequence = request_output.sequences[
request_output.best_sequence_index]

Expand Down Expand Up @@ -65,6 +70,83 @@ def _json_output_formatter(request_output: RequestOutput):
return json_encoded_str


def _json_output_formatter_backport_for_non_streaming_http_error_codes(
request_output: TextGenerationOutput):
"""
json output formatter that allows non-streaming requests to return non-200 error codes on error.
Backported from djl-serving v0.29.0.
:return: formatted output
"""

def _get_last_token(seq):
if seq._last_token_index:
return seq.tokens[seq._last_token_index]
return None

def _get_generated_text(sequence, request_output):
parameters = request_output.input.parameters
generated_text = request_output.input.input_text if parameters.get(
"return_full_text") else ""
for token in sequence.tokens:
generated_text += token.text
return generated_text

def _get_details_dict(request_output: TextGenerationOutput,
include_tokens: bool = True) -> Dict:
parameters = request_output.input.parameters
best_sequence = request_output.sequences[
request_output.best_sequence_index]
if parameters.get("details", request_output.input.tgi_compat):
final_dict = {
"finish_reason": best_sequence.finish_reason,
"generated_tokens": len(best_sequence.tokens),
"inputs": request_output.input.input_text,
}

if include_tokens:
final_dict["tokens"] = request_output.get_tokens_as_dict()

if parameters.get("decoder_input_details"):
final_dict[
"prefill"] = request_output.get_prompt_tokes_as_dict()
return final_dict
elif best_sequence.finish_reason == "error":
return {"finish_reason": best_sequence.finish_reason}
else:
return {}

best_sequence = request_output.sequences[
request_output.best_sequence_index]
# TODO: Fix this so it is not required. Right now, this call is needed to
# advance the token iterator, which is needed for rolling batch to work properly
next_token, _, _ = best_sequence.get_next_token()
if not request_output.finished:
return ""
details = _get_details_dict(request_output, include_tokens=True)
if details.get("finish_reason") == "error":
final_token = _get_last_token(best_sequence)
# In non-streaming, request either succeeds or fails so do not provide the
# partial generation response that may exist
result = {
"generated_text": None,
"error": getattr(final_token, "error_msg", "error"),
"code": 400,
"details": details,
}
return json.dumps(result, ensure_ascii=False)
generated_text = _get_generated_text(best_sequence, request_output)
result = {
"generated_text": generated_text,
}
if details:
result["details"] = details
if request_output.input.tgi_compat:
result = [result]
return json.dumps(result, ensure_ascii=False)


def _jsonlines_output_formatter(request_output: RequestOutput):
"""
jsonlines output formatter
Expand Down
4 changes: 4 additions & 0 deletions engines/python/setup/djl_python/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from djl_python.output_formatter import get_output_formatter, _json_output_formatter, sse_response_formatter, \
adapt_legacy_output_formatter
from djl_python.request_io import Token, TextGenerationOutput, TextInput, RequestOutput
from djl_python.utils import serving_backport_for_non_streaming_http_error_codes_enabled


class Request(object):
Expand Down Expand Up @@ -114,6 +115,9 @@ def set_next_token(self,
self.request_output.set_finish_reason(finish_reason)
self.request_output.prompt_tokens_details = prompt_tokens_details
self.last_token = last_token
if (last_token and
serving_backport_for_non_streaming_http_error_codes_enabled()):
self.request_output.finished = True

def get_next_token(self) -> str:
"""
Expand Down
27 changes: 26 additions & 1 deletion engines/python/setup/djl_python/rolling_batch/rolling_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from djl_python.properties_manager.properties import Properties
from djl_python.request import Request
from djl_python.request_io import Token
from djl_python.utils import serving_backport_for_non_streaming_http_error_codes_enabled

FINISH_REASON_MAPPER = ["length", "eos_token", "stop_sequence"]

Expand Down Expand Up @@ -44,14 +45,31 @@ def stop_on_any_exception(func):
def try_catch_handling(self, *args, **kwargs):
try:
return func(self, *args, **kwargs)
except Exception:
except Exception as e:
logging.exception("Rolling batch inference error")
for request in self.active_requests:
token = Token(-1, "", -1, True)
request.set_next_token(token,
last_token=True,
finish_reason="error")
if serving_backport_for_non_streaming_http_error_codes_enabled(
):
request.error_message = str(e)
request.error_code = 424
response = self.postprocess_results()
if (serving_backport_for_non_streaming_http_error_codes_enabled()
and isinstance(response, list)):
# In case postprocess_results implementation doesn't set response "error"
# or "code", set it the same as we did on the request objects above.
# Note: We may want to forward-port this. Only downside is if we want
# `postprocess_results` to be able to "handle" the error, i.e., to
# intentionally not propagate the error_message and error_code above.
# But that's still doable by setting `error` to "" and `code` to 200.
for res in response:
if res.get("error", None) is None:
res["error"] = str(e)
if res.get("code", None) is None:
res["code"] = 424
self.reset()
return response

Expand Down Expand Up @@ -161,6 +179,13 @@ def postprocess_results(self) -> list[dict]:
"last": req.is_last_token(),
"content_type": req.get_content_type()
}
if serving_backport_for_non_streaming_http_error_codes_enabled():
error_message = getattr(req, "error_message", None)
error_code = getattr(req, "error_code", None)
if error_message is not None:
res["error"] = error_message
if error_code is not None:
res["code"] = error_code
req.reset_next_token()
results.append(res)

Expand Down
5 changes: 5 additions & 0 deletions engines/python/setup/djl_python/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.
import logging
import os
from typing import Union, Callable, Any, List

from djl_python.inputs import Input
Expand Down Expand Up @@ -192,3 +193,7 @@ def apply_profiling(self, *args, **kwargs):
return result

return apply_profiling


def serving_backport_for_non_streaming_http_error_codes_enabled():
return os.getenv("SERVING_BACKPORT_FOR_NON_STREAMING_HTTP_ERROR_CODES")
Original file line number Diff line number Diff line change
Expand Up @@ -374,10 +374,20 @@ void addResponse(byte[] json, Map<String, String> properties) {

if (code != null) {
Map<String, Object> map = new ConcurrentHashMap<>(2);
map.put("code", Integer.parseInt(code));
int httpStatusCode = Integer.parseInt(code);
map.put("code", httpStatusCode);
if (error != null) {
map.put("error", error);
}
if (isBackportForNonStreamingHttpErrorCodes) {
// Update http status code and any error message to the values here, so
// that non-streaming case can return non-200 on errors encountered during
// inference.
output.setCode(httpStatusCode);
if (error != null) {
output.setMessage(error);
}
}
byte[] buffer = JsonUtils.GSON.toJson(map).getBytes(StandardCharsets.UTF_8);
data.appendContent(buffer, true);
} else {
Expand Down

0 comments on commit c807fe0

Please sign in to comment.