Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: optimized board queries #6931

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion invokeai/app/api/routers/boards.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pydantic import BaseModel, Field

from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.services.board_records.board_records_common import BoardChanges
from invokeai.app.services.board_records.board_records_common import BoardChanges, UncategorizedImageCounts
from invokeai.app.services.boards.boards_common import BoardDTO
from invokeai.app.services.shared.pagination import OffsetPaginatedResults

Expand Down Expand Up @@ -146,3 +146,14 @@ async def list_all_board_image_names(
board_id,
)
return image_names


@boards_router.get(
"/uncategorized/counts",
operation_id="get_uncategorized_image_counts",
response_model=UncategorizedImageCounts,
)
async def get_uncategorized_image_counts() -> UncategorizedImageCounts:
"""Gets count of images and assets for uncategorized images (images with no board assocation)"""

return ApiDependencies.invoker.services.board_records.get_uncategorized_image_counts()
7 changes: 6 additions & 1 deletion invokeai/app/services/board_records/board_records_base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from abc import ABC, abstractmethod

from invokeai.app.services.board_records.board_records_common import BoardChanges, BoardRecord
from invokeai.app.services.board_records.board_records_common import BoardChanges, BoardRecord, UncategorizedImageCounts
from invokeai.app.services.shared.pagination import OffsetPaginatedResults


Expand Down Expand Up @@ -48,3 +48,8 @@ def get_many(
def get_all(self, include_archived: bool = False) -> list[BoardRecord]:
"""Gets all board records."""
pass

@abstractmethod
def get_uncategorized_image_counts(self) -> UncategorizedImageCounts:
"""Gets count of images and assets for uncategorized images (images with no board assocation)."""
pass
23 changes: 17 additions & 6 deletions invokeai/app/services/board_records/board_records_common.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from datetime import datetime
from typing import Optional, Union
from typing import Any, Optional, Union

from pydantic import BaseModel, Field

Expand All @@ -26,21 +26,25 @@ class BoardRecord(BaseModelExcludeNull):
"""Whether or not the board is archived."""
is_private: Optional[bool] = Field(default=None, description="Whether the board is private.")
"""Whether the board is private."""
image_count: int = Field(description="The number of images in the board.")
asset_count: int = Field(description="The number of assets in the board.")


def deserialize_board_record(board_dict: dict) -> BoardRecord:
def deserialize_board_record(board_dict: dict[str, Any]) -> BoardRecord:
"""Deserializes a board record."""

# Retrieve all the values, setting "reasonable" defaults if they are not present.

board_id = board_dict.get("board_id", "unknown")
board_name = board_dict.get("board_name", "unknown")
cover_image_name = board_dict.get("cover_image_name", "unknown")
cover_image_name = board_dict.get("cover_image_name", None)
created_at = board_dict.get("created_at", get_iso_timestamp())
updated_at = board_dict.get("updated_at", get_iso_timestamp())
deleted_at = board_dict.get("deleted_at", get_iso_timestamp())
archived = board_dict.get("archived", False)
is_private = board_dict.get("is_private", False)
image_count = board_dict.get("image_count", 0)
asset_count = board_dict.get("asset_count", 0)

return BoardRecord(
board_id=board_id,
Expand All @@ -51,6 +55,8 @@ def deserialize_board_record(board_dict: dict) -> BoardRecord:
deleted_at=deleted_at,
archived=archived,
is_private=is_private,
image_count=image_count,
asset_count=asset_count,
)


Expand All @@ -63,19 +69,24 @@ class BoardChanges(BaseModel, extra="forbid"):
class BoardRecordNotFoundException(Exception):
"""Raised when an board record is not found."""

def __init__(self, message="Board record not found"):
def __init__(self, message: str = "Board record not found"):
super().__init__(message)


class BoardRecordSaveException(Exception):
"""Raised when an board record cannot be saved."""

def __init__(self, message="Board record not saved"):
def __init__(self, message: str = "Board record not saved"):
super().__init__(message)


class BoardRecordDeleteException(Exception):
"""Raised when an board record cannot be deleted."""

def __init__(self, message="Board record not deleted"):
def __init__(self, message: str = "Board record not deleted"):
super().__init__(message)


class UncategorizedImageCounts(BaseModel):
image_count: int = Field(description="The number of uncategorized images.")
asset_count: int = Field(description="The number of uncategorized assets.")
200 changes: 142 additions & 58 deletions invokeai/app/services/board_records/board_records_sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,121 @@
BoardRecordDeleteException,
BoardRecordNotFoundException,
BoardRecordSaveException,
UncategorizedImageCounts,
deserialize_board_record,
)
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.util.misc import uuid_string

_BASE_BOARD_RECORD_QUERY = """
-- This query retrieves board records, joining with the board_images and images tables to get image counts and cover image names.
-- It is not a complete query, as it is missing a GROUP BY or WHERE clause (and is unterminated).
SELECT b.board_id,
b.board_name,
b.created_at,
b.updated_at,
b.archived,
-- Count the number of images in the board, alias image_count
COUNT(
CASE
WHEN i.image_category in ('general') -- "Images" are images in the 'general' category
AND i.is_intermediate = 0 THEN 1 -- Intermediates are not counted
END
) AS image_count,
-- Count the number of assets in the board, alias asset_count
COUNT(
CASE
WHEN i.image_category in ('control', 'mask', 'user', 'other') -- "Assets" are images in any of the other categories ('control', 'mask', 'user', 'other')
AND i.is_intermediate = 0 THEN 1 -- Intermediates are not counted
END
) AS asset_count,
-- Get the name of the the most recent image in the board, alias cover_image_name
(
SELECT bi.image_name
FROM board_images bi
JOIN images i ON bi.image_name = i.image_name
WHERE bi.board_id = b.board_id
AND i.is_intermediate = 0 -- Intermediates cannot be cover images
ORDER BY i.created_at DESC -- Sort by created_at to get the most recent image
LIMIT 1
) AS cover_image_name
FROM boards b
LEFT JOIN board_images bi ON b.board_id = bi.board_id
LEFT JOIN images i ON bi.image_name = i.image_name
"""


def get_paginated_list_board_records_queries(include_archived: bool) -> str:
"""Gets a query to retrieve a paginated list of board records. The query has placeholders for limit and offset.

Args:
include_archived: Whether to include archived board records in the results.

Returns:
A query to retrieve a paginated list of board records.
"""

archived_condition = "WHERE b.archived = 0" if not include_archived else ""

# The GROUP BY must be added _after_ the WHERE clause!
query = f"""
{_BASE_BOARD_RECORD_QUERY}
{archived_condition}
GROUP BY b.board_id,
b.board_name,
b.created_at,
b.updated_at
ORDER BY b.created_at DESC
LIMIT ? OFFSET ?;
"""

return query


def get_total_boards_count_query(include_archived: bool) -> str:
"""Gets a query to retrieve the total count of board records.

Args:
include_archived: Whether to include archived board records in the count.

Returns:
A query to retrieve the total count of board records.
"""

archived_condition = "WHERE b.archived = 0" if not include_archived else ""

return f"SELECT COUNT(*) FROM boards {archived_condition};"


def get_list_all_board_records_query(include_archived: bool) -> str:
"""Gets a query to retrieve all board records.

Args:
include_archived: Whether to include archived board records in the results.

Returns:
A query to retrieve all board records.
"""

archived_condition = "WHERE b.archived = 0" if not include_archived else ""

return f"""
{_BASE_BOARD_RECORD_QUERY}
{archived_condition}
GROUP BY b.board_id,
b.board_name,
b.created_at,
b.updated_at
ORDER BY b.created_at DESC;
"""


def get_board_record_query() -> str:
"""Gets a query to retrieve a board record. The query has a placeholder for the board_id."""

return f"{_BASE_BOARD_RECORD_QUERY} WHERE b.board_id = ?;"


class SqliteBoardRecordStorage(BoardRecordStorageBase):
_conn: sqlite3.Connection
Expand Down Expand Up @@ -76,11 +185,7 @@ def get(
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT *
FROM boards
WHERE board_id = ?;
""",
get_board_record_query(),
(board_id,),
)

Expand All @@ -92,7 +197,7 @@ def get(
self._lock.release()
if result is None:
raise BoardRecordNotFoundException
return BoardRecord(**dict(result))
return deserialize_board_record(dict(result))

def update(
self,
Expand Down Expand Up @@ -149,45 +254,15 @@ def get_many(
try:
self._lock.acquire()

# Build base query
base_query = """
SELECT *
FROM boards
{archived_filter}
ORDER BY created_at DESC
LIMIT ? OFFSET ?;
"""

# Determine archived filter condition
if include_archived:
archived_filter = ""
else:
archived_filter = "WHERE archived = 0"

final_query = base_query.format(archived_filter=archived_filter)
main_query = get_paginated_list_board_records_queries(include_archived=include_archived)

# Execute query to fetch boards
self._cursor.execute(final_query, (limit, offset))
self._cursor.execute(main_query, (limit, offset))

result = cast(list[sqlite3.Row], self._cursor.fetchall())
boards = [deserialize_board_record(dict(r)) for r in result]

# Determine count query
if include_archived:
count_query = """
SELECT COUNT(*)
FROM boards;
"""
else:
count_query = """
SELECT COUNT(*)
FROM boards
WHERE archived = 0;
"""

# Execute count query
self._cursor.execute(count_query)

total_query = get_total_boards_count_query(include_archived=include_archived)
self._cursor.execute(total_query)
count = cast(int, self._cursor.fetchone()[0])

return OffsetPaginatedResults[BoardRecord](items=boards, offset=offset, limit=limit, total=count)
Expand All @@ -201,30 +276,39 @@ def get_many(
def get_all(self, include_archived: bool = False) -> list[BoardRecord]:
try:
self._lock.acquire()

base_query = """
SELECT *
FROM boards
{archived_filter}
ORDER BY created_at DESC
"""

if include_archived:
archived_filter = ""
else:
archived_filter = "WHERE archived = 0"

final_query = base_query.format(archived_filter=archived_filter)

self._cursor.execute(final_query)

query = get_list_all_board_records_query(include_archived=include_archived)
self._cursor.execute(query)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
boards = [deserialize_board_record(dict(r)) for r in result]

return boards

except sqlite3.Error as e:
self._conn.rollback()
raise e
finally:
self._lock.release()

def get_uncategorized_image_counts(self) -> UncategorizedImageCounts:
try:
self._lock.acquire()
query = """
-- Get the count of uncategorized images and assets.
SELECT
CASE
WHEN i.image_category = 'general' THEN 'image_count' -- "Images" are images in the 'general' category
ELSE 'asset_count' -- "Assets" are images in any of the other categories ('control', 'mask', 'user', 'other')
END AS category_type,
COUNT(*) AS unassigned_count
FROM images i
LEFT JOIN board_images bi ON i.image_name = bi.image_name
WHERE bi.board_id IS NULL -- Uncategorized images have no board association
AND i.is_intermediate = 0 -- Omit intermediates from the counts
GROUP BY category_type; -- Group by category_type alias, as derived from the image_category column earlier
"""
self._cursor.execute(query)
results = self._cursor.fetchall()
image_count = dict(results)["image_count"]
asset_count = dict(results)["asset_count"]
return UncategorizedImageCounts(image_count=image_count, asset_count=asset_count)
finally:
self._lock.release()
21 changes: 3 additions & 18 deletions invokeai/app/services/boards/boards_common.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,8 @@
from typing import Optional

from pydantic import Field

from invokeai.app.services.board_records.board_records_common import BoardRecord


# TODO(psyche): BoardDTO is now identical to BoardRecord. We should consider removing it.
class BoardDTO(BoardRecord):
"""Deserialized board record with cover image URL and image count."""

cover_image_name: Optional[str] = Field(description="The name of the board's cover image.")
"""The URL of the thumbnail of the most recent image in the board."""
image_count: int = Field(description="The number of images in the board.")
"""The number of images in the board."""

"""Deserialized board record."""

def board_record_to_dto(board_record: BoardRecord, cover_image_name: Optional[str], image_count: int) -> BoardDTO:
"""Converts a board record to a board DTO."""
return BoardDTO(
**board_record.model_dump(exclude={"cover_image_name"}),
cover_image_name=cover_image_name,
image_count=image_count,
)
pass
Loading