Skip to content

Commit

Permalink
feat: hash device_code and code
Browse files Browse the repository at this point in the history
  • Loading branch information
aldbr committed Feb 19, 2024
1 parent 0eea030 commit b622745
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 11 deletions.
27 changes: 19 additions & 8 deletions diracx-db/src/diracx/db/sql/auth/db.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import hashlib
import secrets
from datetime import datetime
from uuid import uuid4
Expand Down Expand Up @@ -63,7 +64,7 @@ async def get_device_flow(self, device_code: str, max_validity: int):
),
).with_for_update()
stmt = stmt.where(
DeviceFlows.device_code == device_code,
DeviceFlows.device_code == hashlib.sha256(device_code.encode()).hexdigest(),
)
res = dict((await self.conn.execute(stmt)).one()._mapping)

Expand All @@ -74,7 +75,10 @@ async def get_device_flow(self, device_code: str, max_validity: int):
# Update the status to Done before returning
await self.conn.execute(
update(DeviceFlows)
.where(DeviceFlows.device_code == device_code)
.where(
DeviceFlows.device_code
== hashlib.sha256(device_code.encode()).hexdigest()
)
.values(status=FlowStatus.DONE)
)
return res
Expand Down Expand Up @@ -119,14 +123,17 @@ async def insert_device_flow(
secrets.choice(USER_CODE_ALPHABET)
for _ in range(DeviceFlows.user_code.type.length) # type: ignore
)
# user_code = "2QRKPY"
device_code = secrets.token_urlsafe()

# Hash the the device_code to avoid leaking information
hashed_device_code = hashlib.sha256(device_code.encode()).hexdigest()

stmt = insert(DeviceFlows).values(
client_id=client_id,
scope=scope,
audience=audience,
user_code=user_code,
device_code=device_code,
device_code=hashed_device_code,
)
try:
await self.conn.execute(stmt)
Expand Down Expand Up @@ -172,7 +179,10 @@ async def authorization_flow_insert_id_token(
:raises: AuthorizationError if no such uuid or status not pending
"""

# Hash the code to avoid leaking information
code = secrets.token_urlsafe()
hashed_code = hashlib.sha256(code.encode()).hexdigest()

stmt = update(AuthorizationFlows)

stmt = stmt.where(
Expand All @@ -181,7 +191,7 @@ async def authorization_flow_insert_id_token(
AuthorizationFlows.creation_time > substract_date(seconds=max_validity),
)

stmt = stmt.values(id_token=id_token, code=code, status=FlowStatus.READY)
stmt = stmt.values(id_token=id_token, code=hashed_code, status=FlowStatus.READY)
res = await self.conn.execute(stmt)

if res.rowcount != 1:
Expand All @@ -190,15 +200,16 @@ async def authorization_flow_insert_id_token(
stmt = select(AuthorizationFlows.code, AuthorizationFlows.redirect_uri)
stmt = stmt.where(AuthorizationFlows.uuid == uuid)
row = (await self.conn.execute(stmt)).one()
return row.code, row.redirect_uri
return code, row.redirect_uri

async def get_authorization_flow(self, code: str, max_validity: int):
hashed_code = hashlib.sha256(code.encode()).hexdigest()
# The with_for_update
# prevents that the token is retrieved
# multiple time concurrently
stmt = select(AuthorizationFlows).with_for_update()
stmt = stmt.where(
AuthorizationFlows.code == code,
AuthorizationFlows.code == hashed_code,
AuthorizationFlows.creation_time > substract_date(seconds=max_validity),
)

Expand All @@ -208,7 +219,7 @@ async def get_authorization_flow(self, code: str, max_validity: int):
# Update the status to Done before returning
await self.conn.execute(
update(AuthorizationFlows)
.where(AuthorizationFlows.code == code)
.where(AuthorizationFlows.code == hashed_code)
.values(status=FlowStatus.DONE)
)

Expand Down
4 changes: 2 additions & 2 deletions diracx-db/src/diracx/db/sql/auth/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class DeviceFlows(Base):
client_id = Column(String(255))
scope = Column(String(1024))
audience = Column(String(255))
device_code = Column(String(128), unique=True) # hash it ?
device_code = Column(String(128), unique=True) # Should be a hash
id_token = NullColumn(JSON())


Expand All @@ -61,7 +61,7 @@ class AuthorizationFlows(Base):
code_challenge = Column(String(255))
code_challenge_method = Column(String(8))
redirect_uri = Column(String(255))
code = NullColumn(String(255)) # hash it ?
code = NullColumn(String(255)) # Should be a hash
id_token = NullColumn(JSON())


Expand Down
2 changes: 1 addition & 1 deletion diracx-routers/src/diracx/routers/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ async def initiate_device_flow(
"user_code": user_code,
"device_code": device_code,
"verification_uri_complete": f"{verification_uri}?user_code={user_code}",
"verification_uri": str(request.url.replace(query={})),
"verification_uri": verification_uri,
"expires_in": settings.device_flow_expiration_seconds,
}

Expand Down

0 comments on commit b622745

Please sign in to comment.