Skip to content

Commit

Permalink
feat: Authorization
Browse files Browse the repository at this point in the history
- Added dependencies in requirements and pyproject
- Moved to enumerated resource types and actions for authz_context
- Added reset authz_context
  • Loading branch information
tazarov committed Oct 19, 2023
1 parent a4ee5e5 commit 19daf90
Show file tree
Hide file tree
Showing 7 changed files with 68 additions and 58 deletions.
25 changes: 23 additions & 2 deletions chromadb/auth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
Generic,
Union,
)
from attr import dataclass
from dataclasses import dataclass

from overrides import EnforceOverrides, override
from pydantic import SecretStr
Expand Down Expand Up @@ -260,7 +260,28 @@ def get_user_identity(self, credentials: AbstractCredentials[T]) \

# --- AuthZ ---#

# TODO move this to basic impl
class AuthzResourceTypes(str, Enum):
DB = "db"
COLLECTION = "collection"


class AuthzResourceActions(str, Enum):
LIST_COLLECTIONS = "list_collections"
GET_COLLECTION = "get_collection"
CREATE_COLLECTION = "create_collection"
GET_OR_CREATE_COLLECTION = "get_or_create_collection"
DELETE_COLLECTION = "delete_collection"
UPDATE_COLLECTION = "update_collection"
ADD = "add"
DELETE = "delete"
GET = "get"
QUERY = "query"
PEEK = "peek"
COUNT = "count"
UPDATE = "update"
UPSERT = "upsert"
RESET = "reset"


@dataclass
class AuthzUser:
Expand Down
1 change: 0 additions & 1 deletion chromadb/auth/authz/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ def authorize(self, context: AuthorizationContext) \
context.resource.type,
context.action.id)

print(_authz_tuple)
policy_decision = False
if _authz_tuple in self._authz_tuples:
policy_decision = True
Expand Down
68 changes: 32 additions & 36 deletions chromadb/server/fastapi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import chromadb
from chromadb.api.models.Collection import Collection
from chromadb.api.types import GetResult, QueryResult
from chromadb.auth import (AuthzDynamicParams,
from chromadb.auth import (AuthzDynamicParams, AuthzResourceActions, AuthzResourceTypes,
DynamicAuthzResource)
from chromadb.auth.fastapi import (
FastAPIChromaAuthMiddleware,
Expand Down Expand Up @@ -243,19 +243,19 @@ def heartbeat(self) -> Dict[str, int]:
def version(self) -> str:
return self._api.get_version()


@authz_context(action="list_collections",
@authz_context(action=AuthzResourceActions.LIST_COLLECTIONS,
resource=DynamicAuthzResource(
id="*",
type="db"))
type=AuthzResourceTypes.DB,
))
@trace_method("FastAPI.list_collections", OpenTelemetryGranularity.OPERATION)
def list_collections(self) -> Sequence[Collection]:
return self._api.list_collections()

@authz_context(action="create_collection",
@authz_context(action=AuthzResourceActions.CREATE_COLLECTION,
resource=DynamicAuthzResource(
id="*",
type="db"
type=AuthzResourceTypes.DB,
))
@trace_method("FastAPI.create_collection", OpenTelemetryGranularity.OPERATION)
def create_collection(self, collection: CreateCollection) -> Collection:
Expand All @@ -265,21 +265,21 @@ def create_collection(self, collection: CreateCollection) -> Collection:
get_or_create=collection.get_or_create,
)


@authz_context(action="get_collection",
@authz_context(action=AuthzResourceActions.GET_COLLECTION,
resource=DynamicAuthzResource(
id=AuthzDynamicParams.from_function_kwargs(
arg_name="collection_name"),
type="db",
type=AuthzResourceTypes.DB,
))
@trace_method("FastAPI.get_collection", OpenTelemetryGranularity.OPERATION)
def get_collection(self, collection_name: str) -> Collection:
return self._api.get_collection(collection_name)
@authz_context(action="update_collection",

@authz_context(action=AuthzResourceActions.UPDATE_COLLECTION,
resource=DynamicAuthzResource(
id=AuthzDynamicParams.from_function_kwargs(
arg_name="collection_name"),
type="db",
type=AuthzResourceTypes.DB,
))
@trace_method("FastAPI.update_collection", OpenTelemetryGranularity.OPERATION)
def update_collection(
Expand All @@ -291,24 +291,21 @@ def update_collection(
new_metadata=collection.new_metadata,
)


@authz_context(action="delete_collection",
@authz_context(action=AuthzResourceActions.DELETE_COLLECTION,
resource=DynamicAuthzResource(
id=AuthzDynamicParams.from_function_kwargs(
arg_name="collection_name"),
type="db",
type=AuthzResourceTypes.DB,
))


@trace_method("FastAPI.delete_collection", OpenTelemetryGranularity.OPERATION)
def delete_collection(self, collection_name: str) -> None:
return self._api.delete_collection(collection_name)

@authz_context(action="add",
@authz_context(action=AuthzResourceActions.ADD,
resource=DynamicAuthzResource(
id=AuthzDynamicParams.from_function_kwargs(
arg_name="collection_id"),
type="collection",
type=AuthzResourceTypes.COLLECTION,
))
@trace_method("FastAPI.add", OpenTelemetryGranularity.OPERATION)
def add(self, collection_id: str, add: AddEmbedding) -> None:
Expand All @@ -324,12 +321,11 @@ def add(self, collection_id: str, add: AddEmbedding) -> None:
raise HTTPException(status_code=500, detail=str(e))
return result


@authz_context(action="update",
@authz_context(action=AuthzResourceActions.UPDATE,
resource=DynamicAuthzResource(
id=AuthzDynamicParams.from_function_kwargs(
arg_name="collection_id"),
type="collection",
type=AuthzResourceTypes.COLLECTION,
))
@trace_method("FastAPI.update", OpenTelemetryGranularity.OPERATION)
def update(self, collection_id: str, add: UpdateEmbedding) -> None:
Expand All @@ -341,12 +337,11 @@ def update(self, collection_id: str, add: UpdateEmbedding) -> None:
metadatas=add.metadatas,
)


@authz_context(action="upsert",
@authz_context(action=AuthzResourceActions.UPSERT,
resource=DynamicAuthzResource(
id=AuthzDynamicParams.from_function_kwargs(
arg_name="collection_id"),
type="collection",
type=AuthzResourceTypes.COLLECTION,
))
@trace_method("FastAPI.upsert", OpenTelemetryGranularity.OPERATION)
def upsert(self, collection_id: str, upsert: AddEmbedding) -> None:
Expand All @@ -358,12 +353,11 @@ def upsert(self, collection_id: str, upsert: AddEmbedding) -> None:
metadatas=upsert.metadatas,
)


@authz_context(action="get",
@authz_context(action=AuthzResourceActions.GET,
resource=DynamicAuthzResource(
id=AuthzDynamicParams.from_function_kwargs(
arg_name="collection_id"),
type="collection",
type=AuthzResourceTypes.COLLECTION,
))
@trace_method("FastAPI.get", OpenTelemetryGranularity.OPERATION)
def get(self, collection_id: str, get: GetEmbedding) -> GetResult:
Expand All @@ -378,12 +372,11 @@ def get(self, collection_id: str, get: GetEmbedding) -> GetResult:
include=get.include,
)


@authz_context(action="delete",
@authz_context(action=AuthzResourceActions.DELETE,
resource=DynamicAuthzResource(
id=AuthzDynamicParams.from_function_kwargs(
arg_name="collection_id"),
type="collection",
type=AuthzResourceTypes.COLLECTION,
))
@trace_method("FastAPI.delete", OpenTelemetryGranularity.OPERATION)
def delete(self, collection_id: str, delete: DeleteEmbedding) -> List[UUID]:
Expand All @@ -394,26 +387,29 @@ def delete(self, collection_id: str, delete: DeleteEmbedding) -> List[UUID]:
where_document=delete.where_document,
)


@authz_context(action="count",
@authz_context(action=AuthzResourceActions.COUNT,
resource=DynamicAuthzResource(
id=AuthzDynamicParams.from_function_kwargs(
arg_name="collection_id"),
type="collection",
type=AuthzResourceTypes.COLLECTION,
))
@trace_method("FastAPI.count", OpenTelemetryGranularity.OPERATION)
def count(self, collection_id: str) -> int:
return self._api._count(_uuid(collection_id))

@authz_context(action=AuthzResourceActions.RESET,
resource=DynamicAuthzResource(
id="*",
type=AuthzResourceTypes.DB,
))
def reset(self) -> bool:
return self._api.reset()


@authz_context(action="query",
@authz_context(action=AuthzResourceActions.QUERY,
resource=DynamicAuthzResource(
id=AuthzDynamicParams.from_function_kwargs(
arg_name="collection_id"),
type="collection",
type=AuthzResourceTypes.COLLECTION,
))
@trace_method("FastAPI.get_nearest_neighbors", OpenTelemetryGranularity.OPERATION)
def get_nearest_neighbors(
Expand Down
27 changes: 9 additions & 18 deletions examples/basic_functionality/authz/authz.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 23,
"metadata": {},
"outputs": [
{
Expand All @@ -13,23 +13,14 @@
]
},
{
"ename": "Exception",
"evalue": "{\"error\":\"AuthorizationError\",\"message\":\"Unauthorized\"}",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mHTTPError\u001b[0m Traceback (most recent call last)",
"File \u001b[0;32m~/experiments/chroma-experiments/chroma-authz/chromadb/api/fastapi.py:468\u001b[0m, in \u001b[0;36mraise_chroma_error\u001b[0;34m(resp)\u001b[0m\n\u001b[1;32m 467\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[0;32m--> 468\u001b[0m resp\u001b[39m.\u001b[39;49mraise_for_status()\n\u001b[1;32m 469\u001b[0m \u001b[39mexcept\u001b[39;00m requests\u001b[39m.\u001b[39mHTTPError:\n",
"File \u001b[0;32m~/experiments/chroma-experiments/chroma-authz/venv/lib/python3.11/site-packages/requests/models.py:1021\u001b[0m, in \u001b[0;36mResponse.raise_for_status\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1020\u001b[0m \u001b[39mif\u001b[39;00m http_error_msg:\n\u001b[0;32m-> 1021\u001b[0m \u001b[39mraise\u001b[39;00m HTTPError(http_error_msg, response\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m)\n",
"\u001b[0;31mHTTPError\u001b[0m: 403 Client Error: Forbidden for url: http://localhost:8000/api/v1/collections/511c1965-2b74-45f9-8aed-cc8b567f8127/add",
"\nDuring handling of the above exception, another exception occurred:\n",
"\u001b[0;31mException\u001b[0m Traceback (most recent call last)",
"\u001b[1;32m/Users/tazarov/experiments/chroma-experiments/chroma-authz/examples/basic_functionality/authz/authz.ipynb Cell 1\u001b[0m line \u001b[0;36m1\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/tazarov/experiments/chroma-experiments/chroma-authz/examples/basic_functionality/authz/authz.ipynb#W0sZmlsZQ%3D%3D?line=8'>9</a>\u001b[0m client\u001b[39m.\u001b[39mlist_collections()\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/tazarov/experiments/chroma-experiments/chroma-authz/examples/basic_functionality/authz/authz.ipynb#W0sZmlsZQ%3D%3D?line=9'>10</a>\u001b[0m collection \u001b[39m=\u001b[39m client\u001b[39m.\u001b[39mget_or_create_collection(\u001b[39m\"\u001b[39m\u001b[39mtest_collection\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[0;32m---> <a href='vscode-notebook-cell:/Users/tazarov/experiments/chroma-experiments/chroma-authz/examples/basic_functionality/authz/authz.ipynb#W0sZmlsZQ%3D%3D?line=11'>12</a>\u001b[0m collection\u001b[39m.\u001b[39;49madd(documents\u001b[39m=\u001b[39;49m[\u001b[39m\"\u001b[39;49m\u001b[39mtest\u001b[39;49m\u001b[39m\"\u001b[39;49m],ids\u001b[39m=\u001b[39;49m[\u001b[39m\"\u001b[39;49m\u001b[39m1\u001b[39;49m\u001b[39m\"\u001b[39;49m])\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/tazarov/experiments/chroma-experiments/chroma-authz/examples/basic_functionality/authz/authz.ipynb#W0sZmlsZQ%3D%3D?line=12'>13</a>\u001b[0m collection\u001b[39m.\u001b[39mget()\n",
"File \u001b[0;32m~/experiments/chroma-experiments/chroma-authz/chromadb/api/models/Collection.py:100\u001b[0m, in \u001b[0;36mCollection.add\u001b[0;34m(self, ids, embeddings, metadatas, documents)\u001b[0m\n\u001b[1;32m 77\u001b[0m \u001b[39m\u001b[39m\u001b[39m\"\"\"Add embeddings to the data store.\u001b[39;00m\n\u001b[1;32m 78\u001b[0m \u001b[39mArgs:\u001b[39;00m\n\u001b[1;32m 79\u001b[0m \u001b[39m ids: The ids of the embeddings you wish to add\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 93\u001b[0m \n\u001b[1;32m 94\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 96\u001b[0m ids, embeddings, metadatas, documents \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_validate_embedding_set(\n\u001b[1;32m 97\u001b[0m ids, embeddings, metadatas, documents\n\u001b[1;32m 98\u001b[0m )\n\u001b[0;32m--> 100\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_client\u001b[39m.\u001b[39;49m_add(ids, \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mid, embeddings, metadatas, documents)\n",
"File \u001b[0;32m~/experiments/chroma-experiments/chroma-authz/chromadb/api/fastapi.py:340\u001b[0m, in \u001b[0;36mFastAPI._add\u001b[0;34m(self, ids, collection_id, embeddings, metadatas, documents)\u001b[0m\n\u001b[1;32m 338\u001b[0m validate_batch(batch, {\u001b[39m\"\u001b[39m\u001b[39mmax_batch_size\u001b[39m\u001b[39m\"\u001b[39m: \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mmax_batch_size})\n\u001b[1;32m 339\u001b[0m resp \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_submit_batch(batch, \u001b[39m\"\u001b[39m\u001b[39m/collections/\u001b[39m\u001b[39m\"\u001b[39m \u001b[39m+\u001b[39m \u001b[39mstr\u001b[39m(collection_id) \u001b[39m+\u001b[39m \u001b[39m\"\u001b[39m\u001b[39m/add\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[0;32m--> 340\u001b[0m raise_chroma_error(resp)\n\u001b[1;32m 341\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mTrue\u001b[39;00m\n",
"File \u001b[0;32m~/experiments/chroma-experiments/chroma-authz/chromadb/api/fastapi.py:470\u001b[0m, in \u001b[0;36mraise_chroma_error\u001b[0;34m(resp)\u001b[0m\n\u001b[1;32m 468\u001b[0m resp\u001b[39m.\u001b[39mraise_for_status()\n\u001b[1;32m 469\u001b[0m \u001b[39mexcept\u001b[39;00m requests\u001b[39m.\u001b[39mHTTPError:\n\u001b[0;32m--> 470\u001b[0m \u001b[39mraise\u001b[39;00m (\u001b[39mException\u001b[39;00m(resp\u001b[39m.\u001b[39mtext))\n",
"\u001b[0;31mException\u001b[0m: {\"error\":\"AuthorizationError\",\"message\":\"Unauthorized\"}"
]
"data": {
"text/plain": [
"{'ids': ['1'], 'embeddings': None, 'metadatas': [None], 'documents': ['test']}"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
Expand Down
3 changes: 2 additions & 1 deletion examples/basic_functionality/authz/authz.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ roles_mapping:
admin:
actions:
[
db:reset,
db:list_collections,
db:get_collection,
db:create_collection,
Expand Down Expand Up @@ -83,7 +84,7 @@ roles_mapping:
resources: ["<UUID>"] #not yet supported
users:
- id: [email protected]
role: db_read
role: admin
tokens:
- token: test-token-admin
secret: my_api_secret # not yet supported
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ dependencies = [
'typer >= 0.9.0',
'kubernetes>=28.1.0',
'tenacity>=8.2.3',
'PyYAML>=6.0.0',
]

[tool.black]
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@ tqdm>=4.65.0
typer>=0.9.0
typing_extensions>=4.5.0
uvicorn[standard]==0.18.3
PyYAML>=6.0.0

0 comments on commit 19daf90

Please sign in to comment.