Skip to content

Commit

Permalink
chore: add pre-commit and mypy to the project (#28)
Browse files Browse the repository at this point in the history
* chore: add pre-commit and mypy to the project

* fix: refactor base64type to new file to break circular import

* chore: add unittests to pre-commit
  • Loading branch information
frederik-encord authored Dec 4, 2024
1 parent d2ec936 commit 556c28a
Show file tree
Hide file tree
Showing 22 changed files with 453 additions and 130 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
dist/
encord-agents-unit-test-report.xml
# Created by https://www.toptal.com/developers/gitignore/api/vim,python,macos
# Edit at https://www.toptal.com/developers/gitignore?templates=vim,python,macos

Expand Down
31 changes: 31 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
repos:
- repo: local
hooks:
- id: ruff-fmt
name: ruff-fmt
entry: poetry run ruff format --config=pyproject.toml .
types_or: [python, pyi]
language: system
pass_filenames: false

- id: ruff-check
name: ruff-check
entry: poetry run ruff check --config=pyproject.toml --fix .
types_or: [python, pyi]
language: system
pass_filenames: false

- id: mypy
name: mypy
entry: poetry run mypy . --config-file=pyproject.toml
types_or: [python, pyi]
language: system
pass_filenames: false

- id: unittest
name: unittest
entry: poetry run pytest --cov=encord_agents --junitxml encord-agents-unit-test-report.xml -s tests
types_or: [python, pyi]
language: system
pass_filenames: false
default_stages: [pre-push]
2 changes: 1 addition & 1 deletion encord_agents/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .core.data_model import FrameData

__version__ = "v0.1.2"
__ALL__ = ["FrameData"]
__all__ = ["FrameData"]
2 changes: 1 addition & 1 deletion encord_agents/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
@app.callback(invoke_without_command=True)
def version(
version_: bool = typer.Option(False, "--version", "-v", "-V", help="Print the current version of Encord Agents"),
):
) -> None:
if version_:
import rich

Expand Down
4 changes: 2 additions & 2 deletions encord_agents/cli/print.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@


@app.command(name="agent-nodes")
def print_agent_nodes(project_hash: str):
def print_agent_nodes(project_hash: str) -> None:
"""
Prints agent nodes from project.
Expand Down Expand Up @@ -50,7 +50,7 @@ def print_agent_nodes(project_hash: str):


@app.command(name="system-info")
def print_system_info():
def print_system_info() -> None:
"""
[bold]Prints[/bold] the information of the system for the purpose of bug reporting.
"""
Expand Down
6 changes: 3 additions & 3 deletions encord_agents/cli/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def local(
],
url: Annotated[str, Argument(help="Url copy/pasted from label editor")],
port: Annotated[int, Option(help="Local host port to hit")] = 8080,
):
) -> None:
"""Hit a localhost agents endpoint for testing an agent by copying the url from the Encord Label Editor over.
Given
Expand Down Expand Up @@ -114,8 +114,8 @@ def local(
table.add_row("label editor", editor_url)

headers = ["'{0}: {1}'".format(k, v) for k, v in prepped.headers.items()]
headers = " -H ".join(headers)
curl_command = f"curl -X {prepped.method} \\{os.linesep} -H {headers} \\{os.linesep} -d '{prepped.body}' \\{os.linesep} '{prepped.url}'"
str_headers = " -H ".join(headers)
curl_command = f"curl -X {prepped.method} \\{os.linesep} -H {str_headers} \\{os.linesep} -d '{prepped.body!r}' \\{os.linesep} '{prepped.url}'"
table.add_row("curl", curl_command)

rich.print(table)
8 changes: 5 additions & 3 deletions encord_agents/core/data_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

from encord_agents.core.vision import DATA_TYPES, b64_encode_image

Base64Formats = Literal[".jpeg", ".jpg", ".png"]


class LabelRowMetadataIncludeArgs(BaseModel):
"""
Expand Down Expand Up @@ -61,7 +63,7 @@ class Frame:
@overload
def b64_encoding(
self,
image_format: Literal[".jpeg", ".jpg", ".png"] = ".jpeg",
image_format: Base64Formats = ".jpeg",
output_format: Literal["raw", "url"] = "raw",
) -> str: ...

Expand All @@ -70,13 +72,13 @@ def b64_encoding(
self,
image_format: Literal[".jpeg", ".jpg", ".png"] = ".jpeg",
output_format: Literal["openai", "anthropic"] = "openai",
) -> dict: ...
) -> dict[str, str | dict[str, str]]: ...

def b64_encoding(
self,
image_format: Literal[".jpeg", ".jpg", ".png"] = ".jpeg",
output_format: Literal["url", "openai", "anthropic", "raw"] = "url",
) -> str | dict:
) -> str | dict[str, str | dict[str, str]]:
"""
Get a base64 representation of the image content.
Expand Down
2 changes: 1 addition & 1 deletion encord_agents/core/dependencies/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .models import Depends

__ALL__ = ["Depends"]
__all__ = ["Depends"]
12 changes: 8 additions & 4 deletions encord_agents/core/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class Settings(BaseSettings):

@field_validator("ssh_key_content")
@classmethod
def check_key_content(cls, content: str | None):
def check_key_content(cls, content: str | None) -> str | None:
if content is None:
return content

Expand All @@ -49,7 +49,7 @@ def check_key_content(cls, content: str | None):

@field_validator("ssh_key_file")
@classmethod
def check_path_expand_and_exists(cls, path: Path | None):
def check_path_expand_and_exists(cls, path: Path | None) -> Path | None:
if path is None:
return path

Expand All @@ -63,7 +63,7 @@ def check_path_expand_and_exists(cls, path: Path | None):
return path

@model_validator(mode="after")
def check_key(self):
def check_key(self: "Settings") -> "Settings":
if not any(map(bool, [self.ssh_key_content, self.ssh_key_file])):
raise PrintableError(
f"Must specify either `[blue]ENCORD_SSH_KEY_FILE[/blue]` or `[blue]ENCORD_SSH_KEY[/blue]` env variables. If you don't have an ssh key, please refere to our docs:{os.linesep}[magenta]https://docs.encord.com/platform-documentation/Annotate/annotate-api-keys#creating-keys-using-terminal-powershell[/magenta]"
Expand All @@ -80,4 +80,8 @@ def check_key(self):

@property
def ssh_key(self) -> str:
return self.ssh_key_content if self.ssh_key_content else self.ssh_key_file.read_text()
if self.ssh_key_content is None:
if self.ssh_key_file is None:
raise ValueError("Both ssh key content and ssh key file is None")
self.ssh_key_content = self.ssh_key_file.read_text()
return self.ssh_key_content
5 changes: 5 additions & 0 deletions encord_agents/core/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from typing import Literal

Base64Formats = Literal[".jpeg", ".jpg", ".png"]


5 changes: 3 additions & 2 deletions encord_agents/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def get_user_client() -> EncordUserClient:
An EncordUserClient authenticated with the credentials from the encord_agents.core.settings.Settings.
"""
settings = Settings() # type: ignore
settings = Settings()
kwargs: dict[str, Any] = {"domain": settings.domain} if settings.domain else {}
return EncordUserClient.create_with_ssh_private_key(ssh_private_key=settings.ssh_key, **kwargs)

Expand Down Expand Up @@ -166,4 +166,5 @@ def download_asset(lr: LabelRowV2, frame: int | None) -> Generator[Path, None, N
try:
yield file_path
finally:
[f.unlink(missing_ok=True) for f in files_to_unlink]
for to_unlink in files_to_unlink:
to_unlink.unlink(missing_ok=True)
59 changes: 13 additions & 46 deletions encord_agents/core/vision.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import base64
from typing import TypeAlias

import cv2
import numpy as np
from encord.objects.bitmask import BitmaskCoordinates
from encord.objects.coordinates import BoundingBoxCoordinates, PolygonCoordinates, RotatableBoundingBoxCoordinates
from numpy.typing import NDArray

CroppableCoordinates = (
from .types import Base64Formats

CroppableCoordinates: TypeAlias = (
BoundingBoxCoordinates | RotatableBoundingBoxCoordinates | BitmaskCoordinates | PolygonCoordinates
)

Expand All @@ -22,7 +25,7 @@ def rbb_to_poly(
rbb: RotatableBoundingBoxCoordinates,
img_width: int,
img_height: int,
) -> np.ndarray:
) -> NDArray[np.float32]:
x = rbb.top_left_x
y = rbb.top_left_y
w = rbb.width
Expand All @@ -33,7 +36,8 @@ def rbb_to_poly(
[(x + w) * img_width, y * img_height],
[(x + w) * img_width, (y + h) * img_height],
[x * img_width, (y + h) * img_height],
]
],
dtype=np.float32
)
angle = rbb.theta # [0; 360]
center = tuple(bbox_not_rotated.mean(0).tolist())
Expand All @@ -47,48 +51,11 @@ def rbb_to_poly(
mode="constant",
constant_values=1,
)
rotated_points: np.ndarray = points @ rotation_matrix.T
rotated_points = points @ rotation_matrix.T.astype(np.float32)
return rotated_points


def poly_to_rbb(
poly: np.ndarray,
img_width: int,
img_height: int,
) -> RotatableBoundingBoxCoordinates:
v1 = poly[1] - poly[0]
v1 = v1 / np.linalg.norm(v1, ord=2)
angle = np.degrees(np.arccos(v1[0]))

if not any(
[poly[0, 0] > poly[3, 0], poly[0, 0] == poly[3, 0] and poly[0, 1] < poly[3, 1]]
): # Initial points were rotated more than 180 degrees => Rotate backwards
angle = 360 - angle

center = poly.mean(axis=0)
rotation_matrix = cv2.getRotationMatrix2D(center, angle, scale=1.0)
points = np.pad(
poly,
[
(0, 0),
(0, 1),
],
mode="constant",
constant_values=1,
)
rotated_points = points @ rotation_matrix.T
x, y = rotated_points.min(0)
w, h = rotated_points.max(0) - rotated_points.min(0)
return RotatableBoundingBoxCoordinates(
top_left_x=float(x / img_width),
top_left_y=float(y / img_height),
width=float(w / img_width),
height=float(h / img_height),
theta=float(angle),
)


def crop_to_bbox(image: NDArray, bbox: BoundingBoxCoordinates) -> NDArray:
def crop_to_bbox(image: NDArray[np.uint8], bbox: BoundingBoxCoordinates) -> NDArray[np.uint8]:
img_height, img_width = image.shape[:2]
from_x = int(img_width * bbox.top_left_x + 0.5)
from_y = int(img_height * bbox.top_left_y + 0.5)
Expand All @@ -97,7 +64,7 @@ def crop_to_bbox(image: NDArray, bbox: BoundingBoxCoordinates) -> NDArray:
return image[from_y:to_y, from_x:to_x]


def poly_to_bbox(poly: PolygonCoordinates | NDArray) -> BoundingBoxCoordinates:
def poly_to_bbox(poly: PolygonCoordinates | NDArray[np.float32]) -> BoundingBoxCoordinates:
if isinstance(poly, PolygonCoordinates):
rel_coords = np.array([[v.x, v.y] for v in poly.values])
else:
Expand All @@ -111,7 +78,7 @@ def poly_to_bbox(poly: PolygonCoordinates | NDArray) -> BoundingBoxCoordinates:

def rbbox_to_surrounding_bbox(rbb: RotatableBoundingBoxCoordinates, img_w: int, img_h: int) -> BoundingBoxCoordinates:
abs_coords = rbb_to_poly(rbb, img_width=img_w, img_height=img_h)
rel_coords = abs_coords / np.array([[img_w, img_h]], dtype=float)
rel_coords = abs_coords / np.array([[img_w, img_h]], dtype=np.float32)
return poly_to_bbox(rel_coords)


Expand Down Expand Up @@ -143,6 +110,6 @@ def crop_to_object(image: NDArray[np.uint8], coordinates: CroppableCoordinates)
return crop_to_bbox(image, box)


def b64_encode_image(img: NDArray[np.uint8], format=".jpg"):
def b64_encode_image(img: NDArray[np.uint8], format: Base64Formats =".jpg") -> str:
_, encoded_image = cv2.imencode(format, img)
return base64.b64encode(encoded_image).decode("utf-8")
return base64.b64encode(encoded_image).decode("utf-8") # type: ignore
2 changes: 1 addition & 1 deletion encord_agents/fastapi/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .dependencies import dep_client, dep_label_row, dep_single_frame
from .utils import verify_auth

__ALL__ = [
__all__ = [
"dep_single_frame",
"dep_label_row",
"dep_client",
Expand Down
4 changes: 2 additions & 2 deletions encord_agents/fastapi/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def my_route(
return get_initialised_label_row(frame_data)


def dep_single_frame(lr: Annotated[LabelRowV2, Depends(dep_label_row)], frame_data: Annotated[FrameData, Form()]):
def dep_single_frame(lr: Annotated[LabelRowV2, Depends(dep_label_row)], frame_data: Annotated[FrameData, Form()]) -> NDArray[np.uint8]:
"""
Dependency to inject the underlying asset of the frame data.
Expand Down Expand Up @@ -214,7 +214,7 @@ def my_route(
yield iter_video(asset)


def dep_project(frame_data: Annotated[FrameData, Form()], client: Annotated[EncordUserClient, Depends(dep_client)]):
def dep_project(frame_data: Annotated[FrameData, Form()], client: Annotated[EncordUserClient, Depends(dep_client)]) -> Project:
r"""
Dependency to provide an instantiated
[Project](https://docs.encord.com/sdk-documentation/sdk-references/LabelRowV2){ target="\_blank", rel="noopener noreferrer" }.
Expand Down
2 changes: 1 addition & 1 deletion encord_agents/fastapi/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from encord_agents.exceptions import PrintableError


def verify_auth():
def verify_auth() -> None:
"""
FastAPI lifecycle start hook to fail early if ssh key is missing.
Expand Down
2 changes: 1 addition & 1 deletion encord_agents/gcp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@

from .wrappers import editor_agent

__ALL__ = ["editor_agent", "Depends"]
__all__ = ["editor_agent", "Depends"]
2 changes: 1 addition & 1 deletion encord_agents/gcp/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def editor_agent(
A wrapped function suitable for gcp functions.
"""

def context_wrapper_inner(func: AgentFunction) -> Callable:
def context_wrapper_inner(func: AgentFunction) -> Callable[[Request], Response]:
dependant = get_dependant(func=func)

@wraps(func)
Expand Down
2 changes: 1 addition & 1 deletion encord_agents/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@

from .runner import Runner

__ALL__ = ["Runner", "Depends"]
__all__ = ["Runner", "Depends"]
Loading

0 comments on commit 556c28a

Please sign in to comment.