From 97ee77cd661181db1dd1c2904f749f507869e91d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Frederik=20Hvilsh=C3=B8j?= Date: Thu, 19 Dec 2024 12:11:11 +0100 Subject: [PATCH] fix: gcp editor agents to comply with application/json --- encord_agents/gcp/wrappers.py | 34 +++++++++++++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/encord_agents/gcp/wrappers.py b/encord_agents/gcp/wrappers.py index 249d022..f87bf37 100644 --- a/encord_agents/gcp/wrappers.py +++ b/encord_agents/gcp/wrappers.py @@ -1,4 +1,5 @@ import logging +import re from contextlib import ExitStack from functools import wraps from typing import Any, Callable @@ -26,6 +27,14 @@ def generate_response() -> Response: return response +ALLOWED_ORIGINS = [ + r"^https:\/\/app.encord.com$", + r"^https:\/\/dev.encord.com$", + r"^https:\/\/staging.encord.com$", + r"^https:\/\/cord-ai-development--[\w\d]+-[\w\d]+\.web.app$", +] + + def editor_agent( *, label_row_metadata_include_args: LabelRowMetadataIncludeArgs | None = None, @@ -52,7 +61,30 @@ def context_wrapper_inner(func: AgentFunction) -> Callable[[Request], Response]: @wraps(func) def wrapper(request: Request) -> Response: - frame_data = FrameData.model_validate_json(request.data) + # Set CORS headers for the preflight request + if request.method == "OPTIONS": + # Allows GET requests from any origin with the Content-Type + # header and caches preflight response for an 3600s + response = make_response("") + + if not any(re.fullmatch(o, request.origin) for o in ALLOWED_ORIGINS): + response.status_code = 403 + return response + + headers = { + "Access-Control-Allow-Origin": "*", + "Access-Control-Allow-Methods": "POST", + "Access-Control-Allow-Headers": "Content-Type", + "Access-Control-Max-Age": "3600", + } + response.headers.update(headers) + response.status_code = 204 + return response + + if request.is_json: + frame_data = FrameData.model_validate(request.get_json()) + else: + frame_data = FrameData.model_validate_json(request.get_data()) logging.info(f"Request: {frame_data}") client = get_user_client()