Skip to content

Commit

Permalink
Merge pull request #109 from chrisburr/split-user-info
Browse files Browse the repository at this point in the history
Split UserInfo so it can be used in diracx.core
  • Loading branch information
chaen authored Sep 27, 2023
2 parents a974eb5 + 8026630 commit 34c630e
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 18 deletions.
7 changes: 7 additions & 0 deletions src/diracx/core/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,10 @@ class SetJobStatusReturn(BaseModel):
start_exec_time: datetime | None = Field(alias="StartExecTime")
end_exec_time: datetime | None = Field(alias="EndExecTime")
last_update_time: datetime | None = Field(alias="LastUpdateTime")


class UserInfo(BaseModel):
sub: str # dirac generated vo:sub
preferred_username: str
dirac_group: str
vo: str
22 changes: 9 additions & 13 deletions src/diracx/routers/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
ExpiredFlowError,
PendingAuthorizationError,
)
from diracx.core.models import TokenResponse
from diracx.core.models import TokenResponse, UserInfo
from diracx.core.properties import (
PROXY_MANAGEMENT,
SecurityProperty,
Expand Down Expand Up @@ -82,7 +82,7 @@ def has_properties(expression: UnevaluatedProperty | SecurityProperty):
)

async def require_property(
user: Annotated[UserInfo, Depends(verify_dirac_access_token)]
user: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)]
):
if not evaluator(user.properties):
raise HTTPException(status.HTTP_403_FORBIDDEN)
Expand Down Expand Up @@ -164,18 +164,14 @@ class AuthInfo(BaseModel):
properties: list[SecurityProperty]


class UserInfo(AuthInfo):
# dirac generated vo:sub
sub: str
preferred_username: str
dirac_group: str
vo: str
class AuthorizedUserInfo(AuthInfo, UserInfo):
pass


async def verify_dirac_access_token(
authorization: Annotated[str, Depends(oidc_scheme)],
settings: AuthSettings,
) -> UserInfo:
) -> AuthorizedUserInfo:
"""Verify dirac user token and return a UserInfo class
Used for each API endpoint
"""
Expand Down Expand Up @@ -204,7 +200,7 @@ async def verify_dirac_access_token(
detail="Invalid JWT",
) from None

return UserInfo(
return AuthorizedUserInfo(
bearer_token=raw_token,
token_id=token["jti"],
properties=token["dirac_properties"],
Expand Down Expand Up @@ -876,7 +872,7 @@ async def get_oidc_token_info_from_refresh_flow(
@router.get("/refresh-tokens")
async def get_refresh_tokens(
auth_db: AuthDB,
user_info: Annotated[UserInfo, Depends(verify_dirac_access_token)],
user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)],
) -> list:
subject: str | None = user_info.sub
if PROXY_MANAGEMENT in user_info.properties:
Expand All @@ -889,7 +885,7 @@ async def get_refresh_tokens(
@router.delete("/refresh-tokens/{jti}")
async def revoke_refresh_token(
auth_db: AuthDB,
user_info: Annotated[UserInfo, Depends(verify_dirac_access_token)],
user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)],
jti: str,
) -> str:
res = await auth_db.get_refresh_token(jti)
Expand Down Expand Up @@ -1006,7 +1002,7 @@ class UserInfoResponse(TypedDict):

@router.get("/userinfo")
async def userinfo(
user_info: Annotated[UserInfo, Depends(verify_dirac_access_token)]
user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)]
) -> UserInfoResponse:
return {
"sub": user_info.sub,
Expand Down
10 changes: 5 additions & 5 deletions src/diracx/routers/job_manager/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
set_job_status,
)

from ..auth import UserInfo, has_properties, verify_dirac_access_token
from ..auth import AuthorizedUserInfo, has_properties, verify_dirac_access_token
from ..dependencies import JobDB, JobLoggingDB
from ..fastapi_classes import DiracxRouter

Expand Down Expand Up @@ -105,7 +105,7 @@ async def submit_bulk_jobs(
job_definitions: Annotated[list[str], Body(example=EXAMPLE_JDLS["Simple JDL"])],
job_db: JobDB,
job_logging_db: JobLoggingDB,
user_info: Annotated[UserInfo, Depends(verify_dirac_access_token)],
user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)],
) -> list[InsertedJob]:
from DIRAC.Core.Utilities.ClassAd.ClassAdLight import ClassAd
from DIRAC.Core.Utilities.ReturnValues import returnValueOrRaise
Expand All @@ -116,7 +116,7 @@ async def submit_bulk_jobs(
)

class DiracxJobPolicy(JobPolicy):
def __init__(self, user_info: UserInfo, allInfo: bool = True):
def __init__(self, user_info: AuthorizedUserInfo, allInfo: bool = True):
self.userName = user_info.preferred_username
self.userGroup = user_info.dirac_group
self.userProperties = user_info.properties
Expand Down Expand Up @@ -353,7 +353,7 @@ async def get_job_status_history_bulk(
async def search(
config: Annotated[Config, Depends(ConfigSource.create)],
job_db: JobDB,
user_info: Annotated[UserInfo, Depends(verify_dirac_access_token)],
user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)],
page: int = 0,
per_page: int = 100,
body: Annotated[
Expand Down Expand Up @@ -385,7 +385,7 @@ async def search(
async def summary(
config: Annotated[Config, Depends(ConfigSource.create)],
job_db: JobDB,
user_info: Annotated[UserInfo, Depends(verify_dirac_access_token)],
user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)],
body: JobSummaryParams,
):
"""Show information suitable for plotting"""
Expand Down

0 comments on commit 34c630e

Please sign in to comment.