Skip to content

Commit

Permalink
feat: Introduce idp identifier and email columns
Browse files Browse the repository at this point in the history
This commit primarily adds two columns to the user table,
the idp identifier and the email. The idp identifier is used to map
the idp user to the local user. Therefore, it is critical that the idp
claim used for this is unique per user. The following breaking changes
are introduced to the values.yaml file: The `jwt.usernameClaim' is
now moved to `claimMapping.username' and specifies the identity token
claim used for the username column. The `claimMapping.idpIdentifier'
is added, which specifies the identity token claim used for the
new idp identifier column and must be unique within the idp.
The `claimMapping.email` is added, which specifies the identity token
claim used for the new email column. The breaking changes are
detailed in the PR description and in the release notes.
  • Loading branch information
dominik003 committed Jul 29, 2024
1 parent 8175f3f commit b1dcc31
Show file tree
Hide file tree
Showing 20 changed files with 275 additions and 120 deletions.
2 changes: 2 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ helm-deploy:
--set docker.registry.sessions=$(CAPELLACOLLAB_SESSIONS_REGISTRY) \
--set docker.tag=$(DOCKER_TAG) \
--set mocks.oauth=True \
--set authentication.claimMapping.username=sub \
--set authentication.endpoints.authorization=https://localhost/default/authorize \
--set development=$(DEVELOPMENT_MODE) \
--set cluster.ingressClassName=traefik \
--set cluster.ingressNamespace=kube-system \
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# SPDX-FileCopyrightText: Copyright DB InfraGO AG and contributors
# SPDX-License-Identifier: Apache-2.0

"""Add IdP identifier and email columns
Revision ID: 028c72ddfd20
Revises: a1e59021e0d0
Create Date: 2024-07-22 14:49:47.575605
"""
import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision = "028c72ddfd20"
down_revision = "a1e59021e0d0"
branch_labels = None
depends_on = None


def upgrade():
op.add_column(
"users", sa.Column("idp_identifier", sa.String(), nullable=True)
)

t_users = sa.Table("users", sa.MetaData(), autoload_with=op.get_bind())

users = op.get_bind().execute(sa.select(t_users))
for user in users:
op.get_bind().execute(
sa.update(t_users)
.where(t_users.c.id == user.id)
.values(idp_identifier=user.name)
)

op.alter_column("users", "idp_identifier", nullable=False)

op.add_column("users", sa.Column("email", sa.String(), nullable=True))
op.create_index(op.f("ix_users_email"), "users", ["email"], unique=True)
op.create_index(
op.f("ix_users_idp_identifier"),
"users",
["idp_identifier"],
unique=True,
)
22 changes: 8 additions & 14 deletions backend/capellacollab/config/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ class AuthOauthClientConfig(BaseConfig):
default="default", description="The authentication provider client ID."
)
secret: str = pydantic.Field(
default=None, description="The authentication provider client secret."
default="", description="The authentication provider client secret."
)


Expand All @@ -247,24 +247,18 @@ class AuthOauthEndpointsConfig(BaseConfig):
)


class JWTConfig(BaseConfig):
username_claim: str = pydantic.Field(
default="sub",
description="Specifies the key in the JWT payload where the username is stored.",
examples=["sub", "aud", "preferred_username"],
)


class GeneralAuthenticationConfig(BaseConfig):
jwt: JWTConfig = JWTConfig()
class ClaimMappingConfig(BaseConfig):
identifier: str = pydantic.Field(default="sub")
username: str = pydantic.Field(default="sub")
email: str | None = pydantic.Field(default="email")


class AuthenticationConfig(GeneralAuthenticationConfig):
class AuthenticationConfig(BaseConfig):
endpoints: AuthOauthEndpointsConfig = AuthOauthEndpointsConfig()
audience: str = pydantic.Field(default="default")
issuer: str = pydantic.Field(default="http://localhost:8083/default")
mapping: ClaimMappingConfig = ClaimMappingConfig()
scopes: list[str] = pydantic.Field(
default=["openid", "offline_access"],
default=["openid", "profile", "offline_access"],
description="List of scopes that application neeeds to access the required attributes.",
)
client: AuthOauthClientConfig = AuthOauthClientConfig()
Expand Down
19 changes: 16 additions & 3 deletions backend/capellacollab/core/authentication/api_key_cookie.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

log = logging.getLogger(__name__)

auth_config = config.authentication


class JWTConfigBorg:
_shared_state: dict[str, str] = {}
Expand Down Expand Up @@ -48,9 +50,6 @@ async def __call__(self, request: fastapi.Request) -> str:
token_decoded = self.validate_token(token)
return self.get_username(token_decoded)

def get_username(self, token_decoded: dict[str, str]) -> str:
return token_decoded[config.authentication.jwt.username_claim].strip()

def validate_token(self, token: str) -> dict[str, t.Any]:
try:
signing_key = self.jwt_config.jwks_client.get_signing_key_from_jwt(
Expand All @@ -74,3 +73,17 @@ def validate_token(self, token: str) -> dict[str, t.Any]:
except jwt_exceptions.PyJWTError:
log.exception("JWT validation failed", exc_info=True)
raise exceptions.JWTValidationFailed()

@classmethod
def get_username(cls, token_decoded: dict[str, str]) -> str:
return token_decoded[auth_config.mapping.username].strip()

@classmethod
def get_idp_identifier(cls, token_decoded: dict[str, str]) -> str:
return token_decoded[auth_config.mapping.identifier].strip()

@classmethod
def get_email(cls, token_decoded: dict[str, str]) -> str | None:
if auth_config.mapping.email:
return token_decoded.get(auth_config.mapping.email, None)
return None
28 changes: 20 additions & 8 deletions backend/capellacollab/core/authentication/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,10 @@ async def api_get_token(
token_request.code, token_request.code_verifier
)

user = validate_id_token(
db, tokens["id_token"], provider_config, token_request.nonce
validated_id_token = validate_id_token(
tokens["id_token"], provider_config, None
)
user = create_or_update_user(db, validated_id_token)

update_token_cookies(
response, tokens["id_token"], tokens.get("refresh_token", None), user
Expand All @@ -80,7 +81,11 @@ async def api_refresh_token(

tokens = provider.refresh_token(refresh_token)

user = validate_id_token(db, tokens["id_token"], provider_config, None)
validated_id_token = validate_id_token(
tokens["id_token"], provider_config, None
)
user = create_or_update_user(db, validated_id_token)

update_token_cookies(
response, tokens["id_token"], tokens.get("refresh_token", None), user
)
Expand Down Expand Up @@ -110,11 +115,10 @@ async def validate_token(


def validate_id_token(
db: orm.Session,
id_token: str,
provider_config: oidc_provider.AbstractOIDCProviderConfig,
nonce: str | None,
) -> users_models.DatabaseUser:
) -> dict[str, str]:
validated_id_token = api_key_cookie.JWTAPIKeyCookie(
provider_config
).validate_token(id_token)
Expand All @@ -125,13 +129,21 @@ def validate_id_token(
if provider_config.get_client_id() not in validated_id_token["aud"]:
raise exceptions.UnauthenticatedError()

username = api_key_cookie.JWTAPIKeyCookie(provider_config).get_username(
return validated_id_token


def create_or_update_user(
db: orm.Session, validated_id_token: dict[str, str]
) -> users_models.DatabaseUser:
username = api_key_cookie.JWTAPIKeyCookie.get_username(validated_id_token)
idp_identifier = api_key_cookie.JWTAPIKeyCookie.get_idp_identifier(
validated_id_token
)
email = api_key_cookie.JWTAPIKeyCookie.get_email(validated_id_token)

user = users_crud.get_user_by_name(db, username)
user = users_crud.get_user_by_idp_identifier(db, idp_identifier)
if not user:
user = users_crud.create_user(db, username)
user = users_crud.create_user(db, username, idp_identifier, email)
events_crud.create_user_creation_event(db, user)

users_crud.update_last_login(db, user)
Expand Down
1 change: 1 addition & 0 deletions backend/capellacollab/core/database/migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def initialize_admin_user(db: orm.Session):
admin_user = users_crud.create_user(
db=db,
username=config.initial.admin,
idp_identifier=config.initial.admin,
role=users_models.Role.ADMIN,
)
events_crud.create_user_creation_event(db, admin_user)
Expand Down
16 changes: 7 additions & 9 deletions backend/capellacollab/core/logging/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

from capellacollab import config
from capellacollab.core.authentication import injectables as auth_injectables
from capellacollab.core.authentication import oidc_provider

LOGGING_LEVEL = config.config.logging.level

Expand Down Expand Up @@ -75,18 +74,17 @@ async def dispatch(


class AttachUserNameMiddleware(base.BaseHTTPMiddleware):
def __init__(self, app):
super().__init__(app)
self.wellknown_oidc_provider_config = (
oidc_provider.WellKnownOIDCProviderConfig()
)

async def dispatch(
self, request: fastapi.Request, call_next: base.RequestResponseEndpoint
self,
request: fastapi.Request,
call_next: base.RequestResponseEndpoint,
):
try:
oidc_provider_config = (
await auth_injectables.get_oidc_provider_config()
)
username = await auth_injectables.get_username(
request, self.wellknown_oidc_provider_config
request, oidc_provider_config
)
except fastapi.HTTPException:
username = "anonymous"
Expand Down
27 changes: 26 additions & 1 deletion backend/capellacollab/users/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import sqlalchemy as sa
from sqlalchemy import orm

from capellacollab.core import database
from capellacollab.users import models


Expand All @@ -28,6 +29,16 @@ def get_user_by_id(
).scalar_one_or_none()


def get_user_by_idp_identifier(
db: orm.Session, idp_identifier: str
) -> models.DatabaseUser | None:
return db.execute(
sa.select(models.DatabaseUser).where(
models.DatabaseUser.idp_identifier == idp_identifier
)
).scalar_one_or_none()


def get_users(db: orm.Session) -> abc.Sequence[models.DatabaseUser]:
return db.execute(sa.select(models.DatabaseUser)).scalars().all()

Expand All @@ -45,10 +56,16 @@ def get_admin_users(db: orm.Session) -> abc.Sequence[models.DatabaseUser]:


def create_user(
db: orm.Session, username: str, role: models.Role = models.Role.USER
db: orm.Session,
username: str,
idp_identifier: str,
email: str | None = None,
role: models.Role = models.Role.USER,
) -> models.DatabaseUser:
user = models.DatabaseUser(
name=username,
idp_identifier=idp_identifier,
email=email,
role=role,
created=datetime.datetime.now(datetime.UTC),
projects=[],
Expand All @@ -60,6 +77,14 @@ def create_user(
return user


def update_user(
db: orm.Session, user: models.DatabaseUser, patch_user: models.PatchUser
) -> models.DatabaseUser:
database.patch_database_with_pydantic_object(user, patch_user)
db.commit()
return user


def update_role_of_user(
db: orm.Session, user: models.DatabaseUser, role: models.Role
) -> models.DatabaseUser:
Expand Down
10 changes: 10 additions & 0 deletions backend/capellacollab/users/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,13 @@ def __init__(self, user_id: int):
),
err_code="NO_PROJECTS_IN_COMMON",
)


class RoleUpdateRequiresReasonError(core_exceptions.BaseError):
def __init__(self):
super().__init__(
status_code=status.HTTP_403_FORBIDDEN,
title="No reason provided",
reason=("You must provide a reason for updating the users roles."),
err_code="ROLE_UPDATE_REQUIRES_REASON",
)
22 changes: 18 additions & 4 deletions backend/capellacollab/users/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ class Role(enum.Enum):
class BaseUser(core_pydantic.BaseModel):
id: int
name: str
idp_identifier: str
role: Role


class User(BaseUser):
id: int
created: datetime.datetime | None = None
last_login: datetime.datetime | None = None

Expand All @@ -44,13 +44,18 @@ class User(BaseUser):
)


class PatchUserRoleRequest(core_pydantic.BaseModel):
role: Role
reason: str
class PatchUser(core_pydantic.BaseModel):
name: str | None = None
idp_identifier: str | None = None
email: str | None = None
role: Role | None = None
reason: str | None = None


class PostUser(core_pydantic.BaseModel):
name: str
idp_identifier: str
email: str | None = None
role: Role
reason: str

Expand All @@ -62,8 +67,17 @@ class DatabaseUser(database.Base):
init=False, primary_key=True, index=True
)

idp_identifier: orm.Mapped[str] = orm.mapped_column(
unique=True, index=True
)
name: orm.Mapped[str] = orm.mapped_column(unique=True, index=True)

role: orm.Mapped[Role]

email: orm.Mapped[str | None] = orm.mapped_column(
default=None, unique=True, index=True
)

created: orm.Mapped[datetime.datetime | None] = orm.mapped_column(
default=datetime.datetime.now(datetime.UTC)
)
Expand Down
Loading

0 comments on commit b1dcc31

Please sign in to comment.