Skip to content

Commit

Permalink
Optimized async generation response when worker queue is full
Browse files Browse the repository at this point in the history
  • Loading branch information
konieshadow committed Jan 25, 2024
1 parent b5d57f8 commit 69ab0f9
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 81 deletions.
102 changes: 67 additions & 35 deletions fooocusapi/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@

from fooocusapi.args import args
from fooocusapi.models import *
from fooocusapi.api_utils import generation_output, req_to_params
from fooocusapi.api_utils import req_to_params, generate_async_output, generate_streaming_output, generate_image_result_output
import fooocusapi.file_utils as file_utils
from fooocusapi.parameters import GenerationFinishReason, ImageGenerationResult
from fooocusapi.task_queue import TaskType
from fooocusapi.task_queue import QueueTask, TaskType
from fooocusapi.worker import process_generate, task_queue, process_top
from fooocusapi.models_v2 import *
from fooocusapi.img_utils import base64_to_stream, read_input_image
Expand Down Expand Up @@ -58,31 +58,63 @@
}


def call_worker(req: Text2ImgRequest, accept: str):
task_type = TaskType.text_2_img
def get_task_type(req: Text2ImgRequest) -> TaskType:
if isinstance(req, ImgUpscaleOrVaryRequest) or isinstance(req, ImgUpscaleOrVaryRequestJson):
task_type = TaskType.img_uov
return TaskType.img_uov
elif isinstance(req, ImgPromptRequest) or isinstance(req, ImgPromptRequestJson):
task_type = TaskType.img_prompt
return TaskType.img_prompt
elif isinstance(req, ImgInpaintOrOutpaintRequest) or isinstance(req, ImgInpaintOrOutpaintRequestJson):
task_type = TaskType.img_inpaint_outpaint
return TaskType.img_inpaint_outpaint
else:
return TaskType.text_2_img


def call_worker(req: Text2ImgRequest, accept: str) -> Tuple[QueueTask | None, List[ImageGenerationResult] | None]:
task_type = get_task_type(req)
params = req_to_params(req)
queue_task = task_queue.add_task(
task_type, {'params': params.__dict__, 'accept': accept, 'require_base64': req.require_base64},
webhook_url=req.webhook_url)

if queue_task is None:
print("[Task Queue] The task queue has reached limit")
results = [ImageGenerationResult(im=None, seed=0,
return None, [ImageGenerationResult(im=None, seed='',
finish_reason=GenerationFinishReason.queue_is_full)]
elif req.async_process:
work_executor.submit(process_generate, queue_task, params)
results = queue_task
return queue_task, None
else:
results = process_generate(queue_task, params)

return results
return queue_task, results


def build_generation_response(req: Text2ImgRequest,
streaming_output: bool,
task: QueueTask | None,
results: List[ImageGenerationResult] | None) -> Response | AsyncJobResponse | List[GeneratedImageResult]:
if streaming_output:
return generate_streaming_output([] if results is None else results)

job_result: List[GeneratedImageResult] = []
if results is not None:
job_result = generate_image_result_output(results, req.require_base64)

if task is None:
# add to worker queue failed
if req.async_process:
return AsyncJobResponse(job_id='',
job_type=get_task_type(req),
job_stage=AsyncJobStage.error,
job_progress=0,
job_status=None,
job_step_preview=None,
job_result=job_result)
return job_result

if req.async_process:
return generate_async_output(task)
else:
return job_result


def stop_worker():
Expand Down Expand Up @@ -112,8 +144,8 @@ def text2img_generation(req: Text2ImgRequest, accept: str = Header(None),
else:
streaming_output = False

results = call_worker(req, accept)
return generation_output(results, streaming_output, req.require_base64)
task, results = call_worker(req, accept)
return build_generation_response(req, streaming_output, task, results)


@app.post("/v2/generation/text-to-image-with-ip", response_model=List[GeneratedImageResult] | AsyncJobResponse, responses=img_generate_responses)
Expand Down Expand Up @@ -145,8 +177,8 @@ def text_to_img_with_ip(req: Text2ImgRequestWithPrompt,

req.image_prompts = image_prompts_files

results = call_worker(req, accept)
return generation_output(results, streaming_output, req.require_base64)
task, results = call_worker(req, accept)
return build_generation_response(req, streaming_output, task, results)


@app.post("/v1/generation/image-upscale-vary", response_model=List[GeneratedImageResult] | AsyncJobResponse, responses=img_generate_responses)
Expand All @@ -163,8 +195,8 @@ def img_upscale_or_vary(input_image: UploadFile, req: ImgUpscaleOrVaryRequest =
else:
streaming_output = False

results = call_worker(req, accept)
return generation_output(results, streaming_output, req.require_base64)
task, results = call_worker(req, accept)
return build_generation_response(req, streaming_output, task, results)


@app.post("/v2/generation/image-upscale-vary", response_model=List[GeneratedImageResult] | AsyncJobResponse, responses=img_generate_responses)
Expand Down Expand Up @@ -195,8 +227,8 @@ def img_upscale_or_vary_v2(req: ImgUpscaleOrVaryRequestJson,
image_prompts_files.append(default_image_promt)
req.image_prompts = image_prompts_files

results = call_worker(req, accept)
return generation_output(results, streaming_output, req.require_base64)
task, results = call_worker(req, accept)
return build_generation_response(req, streaming_output, task, results)


@app.post("/v1/generation/image-inpait-outpaint", response_model=List[GeneratedImageResult] | AsyncJobResponse, responses=img_generate_responses)
Expand All @@ -213,8 +245,8 @@ def img_inpaint_or_outpaint(input_image: UploadFile, req: ImgInpaintOrOutpaintRe
else:
streaming_output = False

results = call_worker(req, accept)
return generation_output(results, streaming_output, req.require_base64)
task, results = call_worker(req, accept)
return build_generation_response(req, streaming_output, task, results)


@app.post("/v2/generation/image-inpait-outpaint", response_model=List[GeneratedImageResult] | AsyncJobResponse, responses=img_generate_responses)
Expand Down Expand Up @@ -247,8 +279,8 @@ def img_inpaint_or_outpaint_v2(req: ImgInpaintOrOutpaintRequestJson,
image_prompts_files.append(default_image_promt)
req.image_prompts = image_prompts_files

results = call_worker(req, accept)
return generation_output(results, streaming_output, req.require_base64)
task, results = call_worker(req, accept)
return build_generation_response(req, streaming_output, task, results)


@app.post("/v1/generation/image-prompt", response_model=List[GeneratedImageResult] | AsyncJobResponse, responses=img_generate_responses)
Expand All @@ -266,8 +298,8 @@ def img_prompt(cn_img1: Optional[UploadFile] = File(None),
else:
streaming_output = False

results = call_worker(req, accept)
return generation_output(results, streaming_output, req.require_base64)
task, results = call_worker(req, accept)
return build_generation_response(req, streaming_output, task, results)


@app.post("/v2/generation/image-prompt", response_model=List[GeneratedImageResult] | AsyncJobResponse, responses=img_generate_responses)
Expand Down Expand Up @@ -304,22 +336,22 @@ def img_prompt(req: ImgPromptRequestJson,

req.image_prompts = image_prompts_files

results = call_worker(req, accept)
return generation_output(results, streaming_output, req.require_base64)
task, results = call_worker(req, accept)
return build_generation_response(req, streaming_output, task, results)


@app.get("/v1/generation/query-job", response_model=AsyncJobResponse, description="Query async generation job")
def query_job(req: QueryJobRequest = Depends()):
queue_task = task_queue.get_task(req.job_id, True)
if queue_task is None:
return JSONResponse(content=AsyncJobResponse(job_id="",
job_type="Not Found",
job_stage="ERROR",
job_progress=0,
job_status="Job not found"), status_code=404)

return generation_output(queue_task, streaming_output=False, require_base64=False,
require_step_preivew=req.require_step_preivew)
result = AsyncJobResponse(job_id="",
job_type=TaskType.not_found,
job_stage=AsyncJobStage.error,
job_progress=0,
job_status="Job not found")
content = result.model_dump_json()
return Response(content=content, media_type='application/json', status_code=404)
return generate_async_output(queue_task)


@app.get("/v1/generation/job-queue", response_model=JobQueueInfo, description="Query job queue info")
Expand Down
90 changes: 45 additions & 45 deletions fooocusapi/api_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,54 +159,54 @@ def req_to_params(req: Text2ImgRequest) -> ImageGenerationParams:
)


def generation_output(results: QueueTask | List[ImageGenerationResult], streaming_output: bool, require_base64: bool, require_step_preivew: bool=False) -> Response | List[GeneratedImageResult] | AsyncJobResponse:
if isinstance(results, QueueTask):
task = results
job_stage = AsyncJobStage.running
job_result = None
if task.start_millis == 0:
job_stage = AsyncJobStage.waiting
if task.is_finished:
if task.finish_with_error:
job_stage = AsyncJobStage.error
else:
if task.task_result != None:
job_stage = AsyncJobStage.success
task_result_require_base64 = False
if 'require_base64' in task.req_param and task.req_param['require_base64']:
task_result_require_base64 = True

job_result = generation_output(task.task_result, False, task_result_require_base64)
job_step_preview = None if not require_step_preivew else task.task_step_preview
return AsyncJobResponse(job_id=task.job_id,
job_type=task.type,
job_stage=job_stage,
job_progress=task.finish_progress,
job_status=task.task_status,
job_step_preview=job_step_preview,
job_result=job_result)

if streaming_output:
if len(results) == 0:
return Response(status_code=500)
result = results[0]
if result.finish_reason == GenerationFinishReason.queue_is_full:
return Response(status_code=409, content=result.finish_reason.value)
elif result.finish_reason == GenerationFinishReason.user_cancel:
return Response(status_code=400, content=result.finish_reason.value)
elif result.finish_reason == GenerationFinishReason.error:
return Response(status_code=500, content=result.finish_reason.value)
bytes = output_file_to_bytesimg(results[0].im)
return Response(bytes, media_type='image/png')
else:
results = [GeneratedImageResult(
base64=output_file_to_base64img(
item.im) if require_base64 else None,
def generate_async_output(task: QueueTask) -> AsyncJobResponse:
job_stage = AsyncJobStage.running
job_result = None

if task.start_millis == 0:
job_stage = AsyncJobStage.waiting

if task.is_finished:
if task.finish_with_error:
job_stage = AsyncJobStage.error
elif task.task_result != None:
job_stage = AsyncJobStage.success
task_result_require_base64 = False
if 'require_base64' in task.req_param and task.req_param['require_base64']:
task_result_require_base64 = True

job_result = generate_image_result_output(task.task_result, task_result_require_base64)
return AsyncJobResponse(job_id=task.job_id,
job_type=task.type,
job_stage=job_stage,
job_progress=task.finish_progress,
job_status=task.task_status,
job_step_preview=task.task_step_preview,
job_result=job_result)


def generate_streaming_output(results: List[ImageGenerationResult]) -> Response:
if len(results) == 0:
return Response(status_code=500)
result = results[0]
if result.finish_reason == GenerationFinishReason.queue_is_full:
return Response(status_code=409, content=result.finish_reason.value)
elif result.finish_reason == GenerationFinishReason.user_cancel:
return Response(status_code=400, content=result.finish_reason.value)
elif result.finish_reason == GenerationFinishReason.error:
return Response(status_code=500, content=result.finish_reason.value)

bytes = output_file_to_bytesimg(results[0].im)
return Response(bytes, media_type='image/png')


def generate_image_result_output(results: List[ImageGenerationResult], require_base64: bool) -> List[GeneratedImageResult]:
results = [GeneratedImageResult(
base64=output_file_to_base64img(item.im) if require_base64 else None,
url=get_file_serve_url(item.im),
seed=item.seed,
finish_reason=item.finish_reason) for item in results]
return results
return results


class QueueReachLimitException(Exception):
Expand Down
2 changes: 1 addition & 1 deletion fooocusapi/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ class AsyncJobResponse(BaseModel):
class JobQueueInfo(BaseModel):
running_size: int = Field(description="The current running and waiting job count")
finished_size: int = Field(description="Finished job cound (after auto clean)")
last_job_id: str = Field(description="Last submit generation job id")
last_job_id: str | None = Field(description="Last submit generation job id")


# TODO May need more detail fields, will add later when someone need
Expand Down

0 comments on commit 69ab0f9

Please sign in to comment.