Skip to content

Commit

Permalink
[BUG][SEC]: Fixes AuthN server-side providers (#2099)
Browse files Browse the repository at this point in the history
## Description of changes

*Summarize the changes made by this PR.*
 - Improvements & Bug fixes
- Added random jitter to mitigate timing attacks on early exist
conditions
	- Added generic AuthError for early exist conditions
	- Added checks for missing headers(as this was throwing key error)
- Made it possible to use the TokenTransportHeader Enum when defining
client-side auth(docs also updated)
- Improved validation of wrong Token transport headers with a friendlier
user message
- Broader exception handling now logs only exception type and line
number to prevent unintentional information disclosure(A02:
2021-Cryptographic Failures - owasp top10)
	- Aligned both token and basic auth logic flow with exist conditions
- Fixed an issue with replacement of auth headers Basic and Bearer with
regex which aligns well with case-insensitivity of http headers as well
as properly checking that the respective string is at the beginning of
the header value
- Basic auth username check is not early exit case - this simplifies the
logic for pwd check
- Basic auth splitting on ':' is now works for only the first split this
will prevent unexpected exceptions in case the user adds more than a
single ':' in the base64-encoded header value
- Basic auth - decoded username and password are converted to string to
prevent http header injections

## Test plan
*How are these changes tested?*

- [x] Tests pass locally with `pytest` for python, `yarn test` for js,
`cargo test` for rust

## Documentation Changes

Docs PR TBD
  • Loading branch information
tazarov authored May 2, 2024
1 parent d278a35 commit 31c7f9d
Show file tree
Hide file tree
Showing 7 changed files with 120 additions and 38 deletions.
4 changes: 2 additions & 2 deletions bin/ts-integration-test
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@ EOF
;;
token)
cat <<EOF > .chroma_env
CHROMA_AUTH_TOKEN_TRANSPORT_HEADER="AUTHORIZATION"
CHROMA_AUTH_TOKEN_TRANSPORT_HEADER="Authorization"
CHROMA_SERVER_AUTHN_CREDENTIALS="test-token"
CHROMA_SERVER_AUTHN_PROVIDER="chromadb.auth.token_authn.TokenAuthenticationServerProvider"
EOF
;;
xtoken)
cat <<EOF > .chroma_env
CHROMA_AUTH_TOKEN_TRANSPORT_HEADER="X_CHROMA_TOKEN"
CHROMA_AUTH_TOKEN_TRANSPORT_HEADER="X-Chroma-Token"
CHROMA_SERVER_AUTHN_CREDENTIALS="test-token"
CHROMA_SERVER_AUTHN_PROVIDER="chromadb.auth.token_authn.TokenAuthenticationServerProvider"
EOF
Expand Down
5 changes: 2 additions & 3 deletions chromadb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
"UpdateCollectionMetadata",
"QueryResult",
"GetResult",
"TokenTransportHeader",
]

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -250,9 +251,7 @@ def CloudClient(
"chromadb.auth.token_authn.TokenAuthClientProvider"
)
settings.chroma_client_auth_credentials = api_key
settings.chroma_auth_token_transport_header = (
TokenTransportHeader.X_CHROMA_TOKEN.name
)
settings.chroma_auth_token_transport_header = TokenTransportHeader.X_CHROMA_TOKEN

return ClientCreator(tenant=tenant, database=database, settings=settings)

Expand Down
4 changes: 4 additions & 0 deletions chromadb/auth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@
S = TypeVar("S")


class AuthError(Exception):
pass


ClientAuthHeaders = Dict[str, SecretStr]


Expand Down
48 changes: 37 additions & 11 deletions chromadb/auth/basic_authn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
import base64
import random
import re
import time
import traceback

import bcrypt
import logging

Expand All @@ -11,6 +16,7 @@
ServerAuthenticationProvider,
ClientAuthProvider,
ClientAuthHeaders,
AuthError,
)
from chromadb.config import System
from chromadb.telemetry.opentelemetry import (
Expand All @@ -24,6 +30,8 @@

__all__ = ["BasicAuthenticationServerProvider", "BasicAuthClientProvider"]

AUTHORIZATION_HEADER = "Authorization"


class BasicAuthClientProvider(ClientAuthProvider):
"""
Expand All @@ -43,7 +51,7 @@ def authenticate(self) -> ClientAuthHeaders:
f"{self._creds.get_secret_value()}".encode("utf-8")
).decode("utf-8")
return {
"Authorization": SecretStr(f"Basic {encoded}"),
AUTHORIZATION_HEADER: SecretStr(f"Basic {encoded}"),
}


Expand Down Expand Up @@ -94,25 +102,43 @@ def __init__(self, system: System) -> None:
@override
def authenticate_or_raise(self, headers: Headers) -> UserIdentity:
try:
_auth_header = headers["Authorization"]
_auth_header = _auth_header.replace("Basic ", "")
if AUTHORIZATION_HEADER not in headers:
raise AuthError(AUTHORIZATION_HEADER + " header not found")
_auth_header = headers[AUTHORIZATION_HEADER]
_auth_header = re.sub(r"^Basic ", "", _auth_header)
_auth_header = _auth_header.strip()

base64_decoded = base64.b64decode(_auth_header).decode("utf-8")
username, password = base64_decoded.split(":")
if not username or not password:
raise HTTPException(status_code=401, detail="Unauthorized")
if ":" not in base64_decoded:
raise AuthError("Invalid Authorization header format")
username, password = base64_decoded.split(":", 1)
username = str(username) # convert to string to prevent header injection
password = str(password) # convert to string to prevent header injection
if username not in self._creds:
raise AuthError("Invalid username or password")

_usr_check = username in self._creds
_pwd_check = bcrypt.checkpw(
password.encode("utf-8"),
self._creds[username].get_secret_value().encode("utf-8"),
)
if _usr_check and _pwd_check:
return UserIdentity(user_id=username)

if not _pwd_check:
raise AuthError("Invalid username or password")
return UserIdentity(user_id=username)
except AuthError as e:
logger.error(
f"BasicAuthenticationServerProvider.authenticate failed: {repr(e)}"
)
except Exception as e:
tb = traceback.extract_tb(e.__traceback__)
# Get the last call stack
last_call_stack = tb[-1]
line_number = last_call_stack.lineno
filename = last_call_stack.filename
logger.error(
"BasicAuthenticationServerProvider.authenticate " f"failed: {repr(e)}"
"BasicAuthenticationServerProvider.authenticate failed: "
f"Failed to authenticate {type(e).__name__} at {filename}:{line_number}"
)
time.sleep(
random.uniform(0.001, 0.005)
) # add some jitter to avoid timing attacks
raise HTTPException(status_code=403, detail="Forbidden")
84 changes: 64 additions & 20 deletions chromadb/auth/token_authn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import logging
import random
import re
import string

import time
import traceback
from enum import Enum
from starlette.datastructures import Headers
from typing import cast, Dict, List, Optional, TypedDict, TypeVar
Expand All @@ -15,6 +18,7 @@
ClientAuthProvider,
ClientAuthHeaders,
UserIdentity,
AuthError,
)
from chromadb.config import System
from chromadb.telemetry.opentelemetry import (
Expand All @@ -26,10 +30,14 @@

logger = logging.getLogger(__name__)

__all__ = ["TokenAuthenticationServerProvider", "TokenAuthClientProvider"]
__all__ = [
"TokenAuthenticationServerProvider",
"TokenAuthClientProvider",
"TokenTransportHeader",
]


class TokenTransportHeader(Enum):
class TokenTransportHeader(str, Enum):
"""
Accceptable token transport headers.
"""
Expand All @@ -42,15 +50,28 @@ class TokenTransportHeader(Enum):
X_CHROMA_TOKEN = "X-Chroma-Token"


valid_token_chars = set(string.digits + string.ascii_letters + string.punctuation)


def _check_token(token: str) -> None:
token_str = str(token)
if not all(
c in string.digits + string.ascii_letters + string.punctuation
for c in token_str
):
if not all(c in valid_token_chars for c in token_str):
raise ValueError(
"Invalid token. Must contain \
only ASCII letters and digits."
"Invalid token. Must contain only ASCII letters, digits, and punctuation."
)


allowed_token_headers = [
TokenTransportHeader.AUTHORIZATION.value,
TokenTransportHeader.X_CHROMA_TOKEN.value,
]


def _check_allowed_token_headers(token_header: str) -> None:
if token_header not in allowed_token_headers:
raise ValueError(
f"Invalid token transport header: {token_header}. "
f"Must be one of {allowed_token_headers}"
)


Expand All @@ -71,9 +92,12 @@ def __init__(self, system: System) -> None:
_check_token(self._token.get_secret_value())

if system.settings.chroma_auth_token_transport_header:
self._token_transport_header = TokenTransportHeader[
str(system.settings.chroma_auth_token_transport_header)
]
_check_allowed_token_headers(
system.settings.chroma_auth_token_transport_header
)
self._token_transport_header = TokenTransportHeader(
system.settings.chroma_auth_token_transport_header
)
else:
self._token_transport_header = TokenTransportHeader.AUTHORIZATION

Expand Down Expand Up @@ -119,9 +143,12 @@ def __init__(self, system: System) -> None:
super().__init__(system)
self._settings = system.settings
if system.settings.chroma_auth_token_transport_header:
self._token_transport_header = TokenTransportHeader[
str(system.settings.chroma_auth_token_transport_header)
]
_check_allowed_token_headers(
system.settings.chroma_auth_token_transport_header
)
self._token_transport_header = TokenTransportHeader(
system.settings.chroma_auth_token_transport_header
)
else:
self._token_transport_header = TokenTransportHeader.AUTHORIZATION

Expand Down Expand Up @@ -166,26 +193,43 @@ def __init__(self, system: System) -> None:
@override
def authenticate_or_raise(self, headers: Headers) -> UserIdentity:
try:
if self._token_transport_header.value not in headers:
raise AuthError(
f"Authorization header '{self._token_transport_header.value}' not found"
)
token = headers[self._token_transport_header.value]
if self._token_transport_header == TokenTransportHeader.AUTHORIZATION:
if not token.startswith("Bearer "):
raise HTTPException(status_code=401, detail="Unauthorized")
token = token.replace("Bearer ", "")
raise AuthError("Bearer not found in Authorization header")
token = re.sub(r"^Bearer ", "", token)

token = token.strip()
_check_token(token)

if token not in self._token_user_mapping:
raise HTTPException(status_code=401, detail="Unauthorized")
raise AuthError("Invalid credentials: Token not found}")

user_identity = UserIdentity(
user_id=self._token_user_mapping[token]["id"],
tenant=self._token_user_mapping[token]["tenant"],
databases=self._token_user_mapping[token]["databases"],
)
return user_identity
except AuthError as e:
logger.debug(
f"TokenAuthenticationServerProvider.authenticate failed: {repr(e)}"
)
except Exception as e:
tb = traceback.extract_tb(e.__traceback__)
# Get the last call stack
last_call_stack = tb[-1]
line_number = last_call_stack.lineno
filename = last_call_stack.filename
logger.debug(
"TokenAuthenticationServerProvider.authenticate " f"failed: {repr(e)}"
"TokenAuthenticationServerProvider.authenticate failed: "
f"Failed to authenticate {type(e).__name__} at {filename}:{line_number}"
)
raise HTTPException(status_code=403, detail="Forbidden")
time.sleep(
random.uniform(0.001, 0.005)
) # add some jitter to avoid timing attacks
raise HTTPException(status_code=403, detail="Forbidden")
11 changes: 10 additions & 1 deletion chromadb/test/auth/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import string

from chromadb import TokenTransportHeader
from chromadb.test.property.strategies import collection_name


Expand Down Expand Up @@ -41,7 +42,15 @@ def random_token(draw: st.DrawFn) -> str:

@st.composite
def random_token_transport_header(draw: st.DrawFn) -> Optional[str]:
return draw(st.sampled_from(["AUTHORIZATION", "X_CHROMA_TOKEN", None]))
return draw(
st.sampled_from(
[
TokenTransportHeader.AUTHORIZATION,
TokenTransportHeader.X_CHROMA_TOKEN,
None,
]
)
)


@st.composite
Expand Down
2 changes: 1 addition & 1 deletion chromadb/test/client/test_cloud_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from chromadb.test.conftest import _await_server, _run_server, find_free_port

TOKEN_TRANSPORT_HEADER = TokenTransportHeader.X_CHROMA_TOKEN.name
TOKEN_TRANSPORT_HEADER = TokenTransportHeader.X_CHROMA_TOKEN
TEST_CLOUD_HOST = "localhost"


Expand Down

0 comments on commit 31c7f9d

Please sign in to comment.