diff --git a/docs/code_examples/fastapi/frame_classification.py b/docs/code_examples/fastapi/frame_classification.py index 25d43f9..eba5831 100644 --- a/docs/code_examples/fastapi/frame_classification.py +++ b/docs/code_examples/fastapi/frame_classification.py @@ -4,13 +4,13 @@ from anthropic import Anthropic from encord.objects.ontology_labels_impl import LabelRowV2 from fastapi import Depends, FastAPI, Form -from fastapi.middleware.cors import CORSMiddleware from numpy.typing import NDArray from typing_extensions import Annotated from encord_agents.core.data_model import Frame from encord_agents.core.ontology import OntologyDataModel from encord_agents.core.utils import get_user_client +from encord_agents.fastapi.cors import EncordCORSMiddleware from encord_agents.fastapi.dependencies import ( FrameData, dep_label_row, @@ -19,10 +19,7 @@ # Initialize FastAPI app app = FastAPI() -app.add_middleware( - CORSMiddleware, - allow_origins=["*", "https://app.encord.com"], -) +app.add_middleware(EncordCORSMiddleware) # Setup project and data model client = get_user_client() @@ -47,7 +44,7 @@ @app.post("/frame_classification") async def classify_frame( - frame_data: Annotated[FrameData, Form()], + frame_data: FrameData, lr: Annotated[LabelRowV2, Depends(dep_label_row)], content: Annotated[NDArray[np.uint8], Depends(dep_single_frame)], ): diff --git a/docs/code_examples/fastapi/object_classification.py b/docs/code_examples/fastapi/object_classification.py index ebf30a4..6b373c1 100644 --- a/docs/code_examples/fastapi/object_classification.py +++ b/docs/code_examples/fastapi/object_classification.py @@ -2,13 +2,13 @@ from anthropic import Anthropic from encord.objects.ontology_labels_impl import LabelRowV2 -from fastapi import Depends, FastAPI, Form -from fastapi.middleware.cors import CORSMiddleware +from fastapi import Depends, FastAPI from typing_extensions import Annotated from encord_agents.core.data_model import InstanceCrop from encord_agents.core.ontology import OntologyDataModel from encord_agents.core.utils import get_user_client +from encord_agents.fastapi.cors import EncordCORSMiddleware from encord_agents.fastapi.dependencies import ( FrameData, dep_label_row, @@ -17,10 +17,7 @@ # Initialize FastAPI app app = FastAPI() -app.add_middleware( - CORSMiddleware, - allow_origins=["*", "https://app.encord.com"], -) +app.add_middleware(EncordCORSMiddleware) # User client and ontology setup client = get_user_client() @@ -49,7 +46,7 @@ @app.post("/object_classification") async def classify_objects( - frame_data: Annotated[FrameData, Form()], + frame_data: FrameData, lr: Annotated[LabelRowV2, Depends(dep_label_row)], crops: Annotated[ list[InstanceCrop], diff --git a/docs/editor_agents/examples/index.md b/docs/editor_agents/examples/index.md index 229bcef..7dead05 100644 --- a/docs/editor_agents/examples/index.md +++ b/docs/editor_agents/examples/index.md @@ -1047,7 +1047,7 @@ Let us go through the code section by section. First, we import dependencies and setup the FastAPI app with CORS middleware: -[main.py](../../code_examples/fastapi/frame_classification.py) lines:1-25 +[main.py](../../code_examples/fastapi/frame_classification.py) lines:1-22 The CORS middleware is crucial as it allows the Encord platform to make requests to your API. @@ -1055,19 +1055,19 @@ The CORS middleware is crucial as it allows the Encord platform to make requests Next, we set up the Project and create a data model based on the Ontology: -[main.py](../../code_examples/fastapi/frame_classification.py) lines:28-30 +[main.py](../../code_examples/fastapi/frame_classification.py) lines:25-27 We create the system prompt that tells Claude how to structure its response: -[main.py](../../code_examples/fastapi/frame_classification.py) lines:33-45 +[main.py](../../code_examples/fastapi/frame_classification.py) lines:30-42 Finally, we define the endpoint to handle the classification: -[main.py](../../code_examples/fastapi/frame_classification.py) lines:48-78 +[main.py](../../code_examples/fastapi/frame_classification.py) lines:45-75 The endpoint: @@ -1155,25 +1155,25 @@ Let's walk through the key components. First, we setup the FastAPI app and CORS middleware: -[main.py](../../code_examples/fastapi/object_classification.py) lines:1-23 +[main.py](../../code_examples/fastapi/object_classification.py) lines:1-20 Then we setup the client, Project, and extract the generic Ontology object: -[main.py](../../code_examples/fastapi/object_classification.py) lines:26-32 +[main.py](../../code_examples/fastapi/object_classification.py) lines:23-29 We create the data model and system prompt for Claude: -[main.py](../../code_examples/fastapi/object_classification.py) lines:34-47 +[main.py](../../code_examples/fastapi/object_classification.py) lines:32-44 Finally, we define our object classification endpoint: -[main.py](../../code_examples/fastapi/object_classification.py) lines:50-97 +[main.py](../../code_examples/fastapi/object_classification.py) lines:47-94 The endpoint: diff --git a/docs/editor_agents/fastapi.md b/docs/editor_agents/fastapi.md index e1edb2d..09a5a6c 100644 --- a/docs/editor_agents/fastapi.md +++ b/docs/editor_agents/fastapi.md @@ -37,19 +37,16 @@ from typing_extensions import Annotated from encord.objects.ontology_labels_impl import LabelRowV2 from encord_agents import FrameData from encord_agents.fastapi import dep_label_row +from encord_agents.fastapi.cors import EncordCORSMiddleware from fastapi import FastAPI, Depends, Form -from fastapi.middleware.cors import CORSMiddleware app = FastAPI() -app.add_middleware( - CORSMiddleware, - allow_origins=["*", "https://app.encord.com"], -) +app.add_middleware(EncordCORSMiddleware) @app.post("/my_agent") def my_agent( - frame_data: Annotated[FrameData, Form()], + frame_data: FrameData, label_row: Annotated[LabelRowV2, Depends(dep_label_row)], ): # ... Do your edits to the labels diff --git a/docs/reference/editor_agents.md b/docs/reference/editor_agents.md index e15634b..8cb61a5 100644 --- a/docs/reference/editor_agents.md +++ b/docs/reference/editor_agents.md @@ -9,7 +9,7 @@ ## FastAPI -::: encord_agents.fastapi.dependencies +::: encord_agents.fastapi options: show_if_no_docstring: false show_subodules: false diff --git a/encord_agents/cli/test.py b/encord_agents/cli/test.py index 3f1903b..1d0d576 100644 --- a/encord_agents/cli/test.py +++ b/encord_agents/cli/test.py @@ -77,8 +77,8 @@ def local( request = requests.Request( "POST", f"http://localhost:{port}{target}", - data=payload, - headers={"Content-type": "application/x-www-form-urlencoded"}, + json=payload, + headers={"Content-type": "application/json"}, ) prepped = request.prepare() @@ -96,7 +96,8 @@ def local( table.add_section() table.add_row("[green]Request[/green]") table.add_row("url", prepped.url) - table.add_row("data", prepped.body) # type: ignore + body_json_str = prepped.body.decode("utf-8") # type: ignore + table.add_row("data", body_json_str) table_headers = ", ".join([f"'{k}': '{v}'" for k, v in prepped.headers.items()]) table.add_row("headers", f"{{{table_headers}}}") @@ -115,7 +116,7 @@ def local( headers = ["'{0}: {1}'".format(k, v) for k, v in prepped.headers.items()] str_headers = " -H ".join(headers) - curl_command = f"curl -X {prepped.method} \\{os.linesep} -H {str_headers} \\{os.linesep} -d '{prepped.body!r}' \\{os.linesep} '{prepped.url}'" + curl_command = f"curl -X {prepped.method} \\{os.linesep} -H {str_headers} \\{os.linesep} -d '{body_json_str}' \\{os.linesep} '{prepped.url}'" table.add_row("curl", curl_command) rich.print(table) diff --git a/encord_agents/core/constants.py b/encord_agents/core/constants.py new file mode 100644 index 0000000..53f5f89 --- /dev/null +++ b/encord_agents/core/constants.py @@ -0,0 +1,3 @@ +ENCORD_DOMAIN_REGEX = ( + r"^https:\/\/(?:(?:cord-ai-development--[\w\d]+-[\w\d]+\.web.app)|(?:(?:dev|staging|app)\.(us\.)?encord\.com))$" +) diff --git a/encord_agents/fastapi/cors.py b/encord_agents/fastapi/cors.py new file mode 100644 index 0000000..b474bd8 --- /dev/null +++ b/encord_agents/fastapi/cors.py @@ -0,0 +1,60 @@ +""" +Convenience method to easily extend FastAPI servers +with the appropriate CORS Middleware to allow +interactions from the Encord platform. +""" + +import typing + +try: + from fastapi.middleware.cors import CORSMiddleware + from starlette.types import ASGIApp +except ModuleNotFoundError: + print( + 'To use the `fastapi` dependencies, you must also install fastapi. `python -m pip install "fastapi[standard]"' + ) + exit() + +from encord_agents.core.constants import ENCORD_DOMAIN_REGEX + + +# Type checking does not work here because we do not enforce people to +# install fastapi as they can use package for, e.g., task runner wo fastapi. +class EncordCORSMiddleware(CORSMiddleware): # type: ignore [misc] + """ + Like a regular `fastapi.midleware.cors.CORSMiddleware` but matches against + the Encord origin by default. + + **Example:** + ```python + from fastapi import FastAPI + from encord_agents.fastapi.cors import EncordCORSMiddleware + + app = FastAPI() + app.add_middleware(EncordCORSMiddleware) + ``` + + The CORS middleware will allow POST requests from the Encord domain. + """ + + def __init__( + self, + app: ASGIApp, + allow_origins: typing.Sequence[str] = (), + allow_methods: typing.Sequence[str] = ("POST",), + allow_headers: typing.Sequence[str] = (), + allow_credentials: bool = False, + allow_origin_regex: str = ENCORD_DOMAIN_REGEX, + expose_headers: typing.Sequence[str] = (), + max_age: int = 3600, + ) -> None: + super().__init__( + app, + allow_origins, + allow_methods, + allow_headers, + allow_credentials, + allow_origin_regex, + expose_headers, + max_age, + ) diff --git a/encord_agents/fastapi/dependencies.py b/encord_agents/fastapi/dependencies.py index b3e4af0..4721242 100644 --- a/encord_agents/fastapi/dependencies.py +++ b/encord_agents/fastapi/dependencies.py @@ -13,7 +13,7 @@ ... @app.post("/my-agent-route") def my_agent( - frame_data: Annotated[FrameData, Form()], + frame_data: FrameData, ): ... ``` @@ -117,7 +117,7 @@ def my_route( """ - def wrapper(frame_data: Annotated[FrameData, Form()]) -> LabelRowV2: + def wrapper(frame_data: FrameData) -> LabelRowV2: return get_initialised_label_row( frame_data, include_args=label_row_metadata_include_args, init_args=label_row_initialise_labels_args ) @@ -125,7 +125,7 @@ def wrapper(frame_data: Annotated[FrameData, Form()]) -> LabelRowV2: return wrapper -def dep_label_row(frame_data: Annotated[FrameData, Form()]) -> LabelRowV2: +def dep_label_row(frame_data: FrameData) -> LabelRowV2: """ Dependency to provide an initialized label row. @@ -154,9 +154,7 @@ def my_route( return get_initialised_label_row(frame_data) -def dep_single_frame( - lr: Annotated[LabelRowV2, Depends(dep_label_row)], frame_data: Annotated[FrameData, Form()] -) -> NDArray[np.uint8]: +def dep_single_frame(lr: Annotated[LabelRowV2, Depends(dep_label_row)], frame_data: FrameData) -> NDArray[np.uint8]: """ Dependency to inject the underlying asset of the frame data. @@ -266,9 +264,7 @@ def my_route( yield iter_video(asset) -def dep_project( - frame_data: Annotated[FrameData, Form()], client: Annotated[EncordUserClient, Depends(dep_client)] -) -> Project: +def dep_project(frame_data: FrameData, client: Annotated[EncordUserClient, Depends(dep_client)]) -> Project: r""" Dependency to provide an instantiated [Project](https://docs.encord.com/sdk-documentation/sdk-references/LabelRowV2){ target="\_blank", rel="noopener noreferrer" }. @@ -327,7 +323,7 @@ def dep_data_lookup(lookup: Annotated[DataLookup, Depends(_lookup_adapter)]) -> ... @app.post("/my-agent") def my_agent( - frame_data: Annotated[FrameData, Form()], + frame_data: FrameData, lookup: Annotated[DataLookup, Depends(dep_data_lookup)] ): # Client will authenticated and ready to use. diff --git a/encord_agents/fastapi/utils.py b/encord_agents/fastapi/utils.py index b63c34b..9c1cbaf 100644 --- a/encord_agents/fastapi/utils.py +++ b/encord_agents/fastapi/utils.py @@ -1,7 +1,5 @@ import os -from pydantic import ValidationError - from encord_agents.core.settings import Settings from encord_agents.core.utils import get_user_client from encord_agents.exceptions import PrintableError @@ -20,7 +18,7 @@ def verify_auth() -> None: on_startup=[verify_auth] ``` - This will make the server fail early if auth is not setup. + This will make the server fail early if auth is not set up. """ from datetime import datetime, timedelta diff --git a/encord_agents/gcp/wrappers.py b/encord_agents/gcp/wrappers.py index a5d428d..4f9c4f8 100644 --- a/encord_agents/gcp/wrappers.py +++ b/encord_agents/gcp/wrappers.py @@ -1,4 +1,5 @@ import logging +import re from contextlib import ExitStack from functools import wraps from typing import Any, Callable @@ -8,6 +9,7 @@ from flask import Request, Response, make_response from encord_agents import FrameData +from encord_agents.core.constants import ENCORD_DOMAIN_REGEX from encord_agents.core.data_model import LabelRowInitialiseLabelsArgs, LabelRowMetadataIncludeArgs from encord_agents.core.dependencies.models import Context from encord_agents.core.dependencies.utils import get_dependant, solve_dependencies @@ -49,10 +51,34 @@ def editor_agent( def context_wrapper_inner(func: AgentFunction) -> Callable[[Request], Response]: dependant = get_dependant(func=func) + cors_regex = re.compile(ENCORD_DOMAIN_REGEX) @wraps(func) def wrapper(request: Request) -> Response: - frame_data = FrameData.model_validate_json(orjson.dumps(request.form.to_dict())) + if request.method == "OPTIONS": + response = make_response("") + response.headers["Vary"] = "Origin" + + if not cors_regex.fullmatch(request.origin): + response.status_code = 403 + return response + + headers = { + "Access-Control-Allow-Origin": request.origin, + "Access-Control-Allow-Methods": "POST", + "Access-Control-Allow-Headers": "Content-Type", + "Access-Control-Max-Age": "3600", + } + response.headers.update(headers) + response.status_code = 204 + return response + + # TODO: We'll remove FF from FE on Jan. 31 2025. + # At that point, only the if statement applies and the else should be removed. + if request.is_json: + frame_data = FrameData.model_validate(request.get_json()) + else: + frame_data = FrameData.model_validate_json(request.get_data()) logging.info(f"Request: {frame_data}") client = get_user_client() diff --git a/tests/test_cors.py b/tests/test_cors.py new file mode 100644 index 0000000..969687d --- /dev/null +++ b/tests/test_cors.py @@ -0,0 +1,50 @@ +import re + +import pytest + +from encord_agents.core.constants import ENCORD_DOMAIN_REGEX + + +@pytest.fixture +def legal_origins() -> list[str]: + return [ + # Example development previews + "https://cord-ai-development--eb393d03-pccc0hqn.web.app", + "https://cord-ai-development--40816cb1-dij7k5yt.web.app", + "https://cord-ai-development--a3353fa9-0wf42o8h.web.app", + # Main deployment, + "https://app.encord.com", + "https://dev.encord.com", + "https://staging.encord.com", + # US Deployments, + "https://staging.us.encord.com", + "https://dev.us.encord.com", + "https://app.us.encord.com", + ] + + +@pytest.fixture +def illegal_origins() -> list[str]: + return [ + "https://google.com", + "https://test.encord.com", + "https://us.app.encord.com", + "https://app.encord.com.something-else.com", + "https://dev.encord.com.something-else.com", + "https://staging.encord.com.something-else.com", + ] + + +@pytest.fixture +def compiled_regex() -> re.Pattern[str]: + return re.compile(ENCORD_DOMAIN_REGEX) + + +def test_legal_domains_against_CORS_regex(legal_origins: list[str], compiled_regex: re.Pattern[str]) -> None: + for origin in legal_origins: + assert compiled_regex.fullmatch(origin), f"Origin should have been allowed: `{origin}`" + + +def test_illegal_domains_against_CORS_regex(illegal_origins: list[str], compiled_regex: re.Pattern[str]) -> None: + for origin in illegal_origins: + assert not compiled_regex.fullmatch(origin), f"Origin should _not_ have been allowed: `{origin}`"