diff --git a/engines/python/setup/djl_python/output_formatter.py b/engines/python/setup/djl_python/output_formatter.py index 9ae941971..f7c67de6e 100644 --- a/engines/python/setup/djl_python/output_formatter.py +++ b/engines/python/setup/djl_python/output_formatter.py @@ -13,11 +13,12 @@ import json import logging import time -from typing import Union, Callable +from typing import Dict, Union, Callable from typing_extensions import deprecated from djl_python.request_io import Token, TextGenerationOutput, RequestOutput +from djl_python.utils import serving_backport_for_non_streaming_http_error_codes_enabled def _json_output_formatter(request_output: RequestOutput): @@ -26,6 +27,10 @@ def _json_output_formatter(request_output: RequestOutput): :return: formatted output """ + if serving_backport_for_non_streaming_http_error_codes_enabled(): + return _json_output_formatter_backport_for_non_streaming_http_error_codes( + request_output) + best_sequence = request_output.sequences[ request_output.best_sequence_index] @@ -65,6 +70,83 @@ def _json_output_formatter(request_output: RequestOutput): return json_encoded_str +def _json_output_formatter_backport_for_non_streaming_http_error_codes( + request_output: TextGenerationOutput): + """ + json output formatter that allows non-streaming requests to return non-200 error codes on error. + + Backported from djl-serving v0.29.0. + + :return: formatted output + """ + + def _get_last_token(seq): + if seq._last_token_index: + return seq.tokens[seq._last_token_index] + return None + + def _get_generated_text(sequence, request_output): + parameters = request_output.input.parameters + generated_text = request_output.input.input_text if parameters.get( + "return_full_text") else "" + for token in sequence.tokens: + generated_text += token.text + return generated_text + + def _get_details_dict(request_output: TextGenerationOutput, + include_tokens: bool = True) -> Dict: + parameters = request_output.input.parameters + best_sequence = request_output.sequences[ + request_output.best_sequence_index] + if parameters.get("details", request_output.input.tgi_compat): + final_dict = { + "finish_reason": best_sequence.finish_reason, + "generated_tokens": len(best_sequence.tokens), + "inputs": request_output.input.input_text, + } + + if include_tokens: + final_dict["tokens"] = request_output.get_tokens_as_dict() + + if parameters.get("decoder_input_details"): + final_dict[ + "prefill"] = request_output.get_prompt_tokes_as_dict() + return final_dict + elif best_sequence.finish_reason == "error": + return {"finish_reason": best_sequence.finish_reason} + else: + return {} + + best_sequence = request_output.sequences[ + request_output.best_sequence_index] + # TODO: Fix this so it is not required. Right now, this call is needed to + # advance the token iterator, which is needed for rolling batch to work properly + next_token, _, _ = best_sequence.get_next_token() + if not request_output.finished: + return "" + details = _get_details_dict(request_output, include_tokens=True) + if details.get("finish_reason") == "error": + final_token = _get_last_token(best_sequence) + # In non-streaming, request either succeeds or fails so do not provide the + # partial generation response that may exist + result = { + "generated_text": None, + "error": getattr(final_token, "error_msg", "error"), + "code": 400, + "details": details, + } + return json.dumps(result, ensure_ascii=False) + generated_text = _get_generated_text(best_sequence, request_output) + result = { + "generated_text": generated_text, + } + if details: + result["details"] = details + if request_output.input.tgi_compat: + result = [result] + return json.dumps(result, ensure_ascii=False) + + def _jsonlines_output_formatter(request_output: RequestOutput): """ jsonlines output formatter diff --git a/engines/python/setup/djl_python/request.py b/engines/python/setup/djl_python/request.py index aa9a37ecc..f1d44fa4b 100644 --- a/engines/python/setup/djl_python/request.py +++ b/engines/python/setup/djl_python/request.py @@ -16,6 +16,7 @@ from djl_python.output_formatter import get_output_formatter, _json_output_formatter, sse_response_formatter, \ adapt_legacy_output_formatter from djl_python.request_io import Token, TextGenerationOutput, TextInput, RequestOutput +from djl_python.utils import serving_backport_for_non_streaming_http_error_codes_enabled class Request(object): @@ -114,6 +115,9 @@ def set_next_token(self, self.request_output.set_finish_reason(finish_reason) self.request_output.prompt_tokens_details = prompt_tokens_details self.last_token = last_token + if (last_token and + serving_backport_for_non_streaming_http_error_codes_enabled()): + self.request_output.finished = True def get_next_token(self) -> str: """ diff --git a/engines/python/setup/djl_python/rolling_batch/rolling_batch.py b/engines/python/setup/djl_python/rolling_batch/rolling_batch.py index 680fbebcd..2dbc62dce 100644 --- a/engines/python/setup/djl_python/rolling_batch/rolling_batch.py +++ b/engines/python/setup/djl_python/rolling_batch/rolling_batch.py @@ -17,6 +17,7 @@ from djl_python.properties_manager.properties import Properties from djl_python.request import Request from djl_python.request_io import Token +from djl_python.utils import serving_backport_for_non_streaming_http_error_codes_enabled FINISH_REASON_MAPPER = ["length", "eos_token", "stop_sequence"] @@ -44,14 +45,31 @@ def stop_on_any_exception(func): def try_catch_handling(self, *args, **kwargs): try: return func(self, *args, **kwargs) - except Exception: + except Exception as e: logging.exception("Rolling batch inference error") for request in self.active_requests: token = Token(-1, "", -1, True) request.set_next_token(token, last_token=True, finish_reason="error") + if serving_backport_for_non_streaming_http_error_codes_enabled( + ): + request.error_message = str(e) + request.error_code = 424 response = self.postprocess_results() + if (serving_backport_for_non_streaming_http_error_codes_enabled() + and isinstance(response, list)): + # In case postprocess_results implementation doesn't set response "error" + # or "code", set it the same as we did on the request objects above. + # Note: We may want to forward-port this. Only downside is if we want + # `postprocess_results` to be able to "handle" the error, i.e., to + # intentionally not propagate the error_message and error_code above. + # But that's still doable by setting `error` to "" and `code` to 200. + for res in response: + if res.get("error", None) is None: + res["error"] = str(e) + if res.get("code", None) is None: + res["code"] = 424 self.reset() return response @@ -161,6 +179,13 @@ def postprocess_results(self) -> list[dict]: "last": req.is_last_token(), "content_type": req.get_content_type() } + if serving_backport_for_non_streaming_http_error_codes_enabled(): + error_message = getattr(req, "error_message", None) + error_code = getattr(req, "error_code", None) + if error_message is not None: + res["error"] = error_message + if error_code is not None: + res["code"] = error_code req.reset_next_token() results.append(res) diff --git a/engines/python/setup/djl_python/utils.py b/engines/python/setup/djl_python/utils.py index 286d2b984..a892ded90 100644 --- a/engines/python/setup/djl_python/utils.py +++ b/engines/python/setup/djl_python/utils.py @@ -11,6 +11,7 @@ # BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for # the specific language governing permissions and limitations under the License. import logging +import os from typing import Union, Callable, Any, List from djl_python.inputs import Input @@ -192,3 +193,7 @@ def apply_profiling(self, *args, **kwargs): return result return apply_profiling + + +def serving_backport_for_non_streaming_http_error_codes_enabled(): + return os.getenv("SERVING_BACKPORT_FOR_NON_STREAMING_HTTP_ERROR_CODES") diff --git a/engines/python/src/main/java/ai/djl/python/engine/RollingBatch.java b/engines/python/src/main/java/ai/djl/python/engine/RollingBatch.java index 75ad7c4a2..18266ff70 100644 --- a/engines/python/src/main/java/ai/djl/python/engine/RollingBatch.java +++ b/engines/python/src/main/java/ai/djl/python/engine/RollingBatch.java @@ -374,10 +374,20 @@ void addResponse(byte[] json, Map properties) { if (code != null) { Map map = new ConcurrentHashMap<>(2); - map.put("code", Integer.parseInt(code)); + int httpStatusCode = Integer.parseInt(code); + map.put("code", httpStatusCode); if (error != null) { map.put("error", error); } + if (isBackportForNonStreamingHttpErrorCodes) { + // Update http status code and any error message to the values here, so + // that non-streaming case can return non-200 on errors encountered during + // inference. + output.setCode(httpStatusCode); + if (error != null) { + output.setMessage(error); + } + } byte[] buffer = JsonUtils.GSON.toJson(map).getBytes(StandardCharsets.UTF_8); data.appendContent(buffer, true); } else {