Skip to content

Commit

Permalink
feat: add include args for label row listing (#24)
Browse files Browse the repository at this point in the history
* docs: add documentation around task runner

* Apply suggestions from code review

Co-authored-by: Eloy Pérez Torres <[email protected]>

* fix: use correct project for getting label rows in urnner + docstring

* feat: add include args for label row listing

Things to keep in mind:
1. Update doc string
2. Consistent naming
3. Make sure it works for all six cases [gcp, fastapi, task] x [with, without]

* docs: optional include argument for listing label rows on runner

* refactor: include args to be common for all modes

* docs: fix docstring for fastapi include args

* fix: require named arg for stage label listing

---------

Co-authored-by: Eloy Pérez Torres <[email protected]>
  • Loading branch information
frederik-encord and eloy-encord authored Dec 4, 2024
1 parent 3f6e167 commit 6ac9a26
Show file tree
Hide file tree
Showing 6 changed files with 156 additions and 20 deletions.
14 changes: 14 additions & 0 deletions docs/task_agents/runner.md
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,21 @@ def execute(self, refresh_every = None):
...
```
### Optional arguments
When you wrap a function with the `@runner.stage(...)` wrapper, you can add include a [`label_row_metadata_include_args: LabelRowMetadataIncludeArgs`](../reference/core.md#encord_agents.core.data_model.LabelRowMetadataIncludeArgs) argument which will be passed on to the Encord Project's [`list_label_row_v2` method](https://docs.encord.com/sdk-documentation/sdk-references/project#list-label-rows-v2){ target="\_blank", rel="noopener noreferrer" }. This is useful to, e.g., be able to _read_ the client metadata associated to a task.
Notice, if you need to update the metadata, you will have to use the `dep_storage_item` dependencies.
Here is an example:
```python
args = LabelRowMetadataIncludeArgs(
include_client_metadata=True,
)
@runner.stage("<my_stage_name>", label_row_metadata_include_args=args)
def my_agent(lr: LabelRowV2):
lr.client_metadata # will now be populated
```
## Dependencies
Expand Down
14 changes: 14 additions & 0 deletions encord_agents/core/data_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,20 @@
from encord_agents.core.vision import DATA_TYPES, b64_encode_image


class LabelRowMetadataIncludeArgs(BaseModel):
"""
Warning, including metadata via label rows is good for _reading_ metadata
**not** for writing to the metadata.
If you need to write to metadata, use the `dep_storage_item` dependencies instead.
"""

include_workflow_graph_node: bool = True
include_client_metadata: bool = False
include_images_data: bool = False
include_all_label_branches: bool = False


class FrameData(BaseModel):
"""
Holds the data sent from the Encord Label Editor at the time of triggering the agent.
Expand Down
9 changes: 6 additions & 3 deletions encord_agents/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from encord.objects.ontology_labels_impl import LabelRowV2
from encord.user_client import EncordUserClient

from encord_agents.core.data_model import FrameData
from encord_agents.core.data_model import FrameData, LabelRowMetadataIncludeArgs
from encord_agents.core.settings import Settings

from .video import get_frame
Expand All @@ -30,7 +30,9 @@ def get_user_client() -> EncordUserClient:
return EncordUserClient.create_with_ssh_private_key(ssh_private_key=settings.ssh_key, **kwargs)


def get_initialised_label_row(frame_data: FrameData) -> LabelRowV2:
def get_initialised_label_row(
frame_data: FrameData, include_args: LabelRowMetadataIncludeArgs | None = None
) -> LabelRowV2:
"""
Get an initialised label row from the frame_data information.
Expand All @@ -46,7 +48,8 @@ def get_initialised_label_row(frame_data: FrameData) -> LabelRowV2:
"""
user_client = get_user_client()
project = user_client.get_project(str(frame_data.project_hash))
matched_lrs = project.list_label_rows_v2(data_hashes=[frame_data.data_hash])
include_args = include_args or LabelRowMetadataIncludeArgs()
matched_lrs = project.list_label_rows_v2(data_hashes=[frame_data.data_hash], **include_args.model_dump())
num_matches = len(matched_lrs)
if num_matches > 1:
raise Exception(f"Non unique match: matched {num_matches} label rows!")
Expand Down
43 changes: 43 additions & 0 deletions encord_agents/fastapi/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def my_agent(
from encord.user_client import EncordUserClient
from numpy.typing import NDArray

from encord_agents.core.data_model import LabelRowMetadataIncludeArgs
from encord_agents.core.dependencies.shares import DataLookup
from encord_agents.core.vision import crop_to_object

Expand Down Expand Up @@ -75,6 +76,48 @@ def my_route(
return get_user_client()


def dep_label_row_with_include_args(
label_row_metadata_include_args: LabelRowMetadataIncludeArgs | None = None,
) -> Callable[[FrameData], LabelRowV2]:
"""
Dependency to provide an initialized label row.
**Example:**
```python
from encord_agents.core.data_model import LabelRowMetadataIncludeArgs
from encord_agents.fastapi.depencencies import dep_label_row_with_include_args
...
include_args = LabelRowMetadataIncludeArgs(
include_client_metadata=True,
include_workflow_graph_node=True,
)
@app.post("/my-route")
def my_route(
lr: Annotated[LabelRowV2, Depends(dep_label_row_with_include_args(include_args))]
):
assert lr.is_labelling_initialised # will work
assert lr.client_metadata # will be available if set already
```
Args:
frame_data: the frame data from the route. This parameter is automatically injected
if it's a part of your route (see example above)
Returns:
The initialized label row.
"""

def wrapper(frame_data: Annotated[FrameData, Form()]) -> LabelRowV2:
return get_initialised_label_row(frame_data, label_row_metadata_include_args)

return wrapper


def dep_label_row(frame_data: Annotated[FrameData, Form()]) -> LabelRowV2:
"""
Dependency to provide an initialized label row.
Expand Down
17 changes: 15 additions & 2 deletions encord_agents/gcp/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from flask import Request, Response, make_response

from encord_agents import FrameData
from encord_agents.core.data_model import LabelRowMetadataIncludeArgs
from encord_agents.core.dependencies.models import Context
from encord_agents.core.dependencies.utils import get_dependant, solve_dependencies
from encord_agents.core.utils import get_user_client
Expand All @@ -25,12 +26,21 @@ def generate_response() -> Response:
return response


def editor_agent() -> Callable[[AgentFunction], Callable[[Request], Response]]:
def editor_agent(
*, label_row_metadata_include_args: LabelRowMetadataIncludeArgs | None = None
) -> Callable[[AgentFunction], Callable[[Request], Response]]:
"""
Wrapper to make resources available for gcp editor agents.
The editor agents are intended to be used via dependency injections.
You can learn more via out [documentation](https://agents-docs.encord.com).
Args:
label_row_metadata_include_args: arguments to overwrite default arguments
on `project.list_label_rows_v2()`.
Returns:
A wrapped function suitable for gcp functions.
"""

def context_wrapper_inner(func: AgentFunction) -> Callable:
Expand All @@ -46,7 +56,10 @@ def wrapper(request: Request) -> Response:

label_row: LabelRowV2 | None = None
if dependant.needs_label_row:
label_row = project.list_label_rows_v2(data_hashes=[str(frame_data.data_hash)])[0]
include_args = label_row_metadata_include_args or LabelRowMetadataIncludeArgs()
label_row = project.list_label_rows_v2(
data_hashes=[str(frame_data.data_hash)], **include_args.model_dump()
)[0]
label_row.initialise_labels(include_signed_url=True)

context = Context(project=project, label_row=label_row, frame_data=frame_data)
Expand Down
79 changes: 64 additions & 15 deletions encord_agents/tasks/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from contextlib import ExitStack
from datetime import datetime, timedelta
from typing import Callable, Iterable, Optional, cast
from typing_extensions import Annotated
from uuid import UUID

import rich
Expand All @@ -18,7 +17,9 @@
from rich.panel import Panel
from tqdm.auto import tqdm
from typer import Abort, Option
from typing_extensions import Annotated

from encord_agents.core.data_model import LabelRowMetadataIncludeArgs
from encord_agents.core.dependencies.models import Context, DecoratedCallable, Dependant
from encord_agents.core.dependencies.utils import get_dependant, solve_dependencies
from encord_agents.core.utils import get_user_client
Expand All @@ -29,12 +30,17 @@

class RunnerAgent:
def __init__(
self, identity: str | UUID, callable: Callable[..., TaskAgentReturn], printable_name: str | None = None
self,
identity: str | UUID,
callable: Callable[..., TaskAgentReturn],
printable_name: str | None = None,
label_row_metadata_include_args: LabelRowMetadataIncludeArgs | None = None,
):
self.identity = identity
self.printable_name = printable_name or identity
self.callable = callable
self.dependant: Dependant = get_dependant(func=callable)
self.label_row_metadata_include_args = label_row_metadata_include_args

def __repr__(self) -> str:
return f'RunnerAgent("{self.printable_name}")'
Expand Down Expand Up @@ -85,12 +91,12 @@ def __init__(self, project_hash: str | None = None):
"""
Initialize the runner with an optional project hash.
The `project_hash` will allow stricter stage validation.
The `project_hash` will allow stricter stage validation.
If left unspecified, errors will first be raised during execution of the runner.
Args:
project_hash: The project hash that the runner applies to.
Can be left unspecified to be able to reuse same runner on multiple projects.
"""
self.project_hash = self.verify_project_hash(project_hash) if project_hash else None
Expand Down Expand Up @@ -118,10 +124,25 @@ def validate_project(project: Project | None):
len([s for s in project.workflow.stages if s.stage_type == WorkflowStageType.AGENT]) > 0
), f"Provided project does not have any agent stages in it's workflow. {PROJECT_MUSTS}"

def _add_stage_agent(self, identity: str | UUID, func: Callable[..., TaskAgentReturn], printable_name: str | None):
self.agents.append(RunnerAgent(identity=identity, callable=func, printable_name=printable_name))
def _add_stage_agent(
self,
identity: str | UUID,
func: Callable[..., TaskAgentReturn],
printable_name: str | None,
label_row_metadata_include_args: LabelRowMetadataIncludeArgs | None,
):
self.agents.append(
RunnerAgent(
identity=identity,
callable=func,
printable_name=printable_name,
label_row_metadata_include_args=label_row_metadata_include_args,
)
)

def stage(self, stage: str | UUID) -> Callable[[DecoratedCallable], DecoratedCallable]:
def stage(
self, stage: str | UUID, *, label_row_metadata_include_args: LabelRowMetadataIncludeArgs | None = None
) -> Callable[[DecoratedCallable], DecoratedCallable]:
r"""
Decorator to associate a function with an agent stage.
Expand Down Expand Up @@ -191,6 +212,8 @@ def my_func(
Args:
stage: The name or uuid of the stage that the function should be
associated with.
label_row_metadata_include_args: Arguments to be passed to
`project.list_label_rows_v2(...)`
Returns:
The decorated function.
Expand Down Expand Up @@ -221,7 +244,7 @@ def my_func(
)

def decorator(func: DecoratedCallable) -> DecoratedCallable:
self._add_stage_agent(stage, func, printable_name)
self._add_stage_agent(stage, func, printable_name, label_row_metadata_include_args)
return func

return decorator
Expand Down Expand Up @@ -273,10 +296,21 @@ def get_stage_names(valid_stages: list[AgentStage], join_str: str = ", "):
def __call__(
self,
# num_threads: int = 1,
refresh_every: Annotated[Optional[int], Option(help="Fetch task statuses from the Encord Project every `refresh_every` seconds. If `None`, the runner will exit once task queue is empty.")] = None,
num_retries: Annotated[int, Option(help="If an agent fails on a task, how many times should the runner retry it?")] = 3,
task_batch_size: Annotated[int, Option(help="Number of tasks for which labels are loaded into memory at once.")] = 300,
project_hash: Annotated[Optional[str], Option(help="The project hash if not defined at runner instantiation.")] = None,
refresh_every: Annotated[
Optional[int],
Option(
help="Fetch task statuses from the Encord Project every `refresh_every` seconds. If `None`, the runner will exit once task queue is empty."
),
] = None,
num_retries: Annotated[
int, Option(help="If an agent fails on a task, how many times should the runner retry it?")
] = 3,
task_batch_size: Annotated[
int, Option(help="Number of tasks for which labels are loaded into memory at once.")
] = 300,
project_hash: Annotated[
Optional[str], Option(help="The project hash if not defined at runner instantiation.")
] = None,
):
"""
Run your task agent `runner(...)`.
Expand Down Expand Up @@ -392,10 +426,16 @@ def {fn_name}(...):
if len(batch) == task_batch_size:
batch_lrs = [None] * len(batch)
if runner_agent.dependant.needs_label_row:
include_args = (
runner_agent.label_row_metadata_include_args or LabelRowMetadataIncludeArgs()
)
label_rows = {
UUID(lr.data_hash): lr
for lr in project.list_label_rows_v2(data_hashes=[t.data_hash for t in batch])
for lr in project.list_label_rows_v2(
data_hashes=[t.data_hash for t in batch], **include_args.model_dump()
)
}
print([lr.backing_item_uuid for lr in label_rows.values()])
batch_lrs = [label_rows.get(t.data_hash) for t in batch]
with project.create_bundle() as lr_bundle:
for lr in batch_lrs:
Expand All @@ -418,14 +458,23 @@ def {fn_name}(...):
if runner_agent.dependant.needs_label_row:
label_rows = {
UUID(lr.data_hash): lr
for lr in project.list_label_rows_v2(data_hashes=[t.data_hash for t in batch])
for lr in project.list_label_rows_v2(
data_hashes=[t.data_hash for t in batch],
**(
runner_agent.label_row_metadata_include_args.model_dump()
if runner_agent.label_row_metadata_include_args
else {}
),
)
}
print("I am here")
print([lr.backing_item_uuid for lr in label_rows.values()])
batch_lrs = [label_rows[t.data_hash] for t in batch]
with project.create_bundle() as lr_bundle:
for lr in batch_lrs:
if lr:
lr.initialise_labels(bundle=lr_bundle)
self._execute_tasks(zip(batch, batch_lrs), runner_agent, num_retries, pbar=pbar)
self._execute_tasks(project, zip(batch, batch_lrs), runner_agent, num_retries, pbar=pbar)
except (PrintableError, AssertionError) as err:
if self.was_called_from_cli:
panel = Panel(err.args[0], width=None)
Expand Down

0 comments on commit 6ac9a26

Please sign in to comment.