From b9061e26f4aa71ac6edd16095e040bab46ef5863 Mon Sep 17 00:00:00 2001 From: Gabriel Massadas Date: Tue, 10 Dec 2024 20:15:44 +0000 Subject: [PATCH] Add Cloudflare Access middleware --- .changeset/sharp-moles-knock.md | 5 + package.json | 1 + packages/cloudflare-access/CHANGELOG.md | 0 packages/cloudflare-access/README.md | 65 +++++ packages/cloudflare-access/package.json | 47 ++++ packages/cloudflare-access/src/index.test.ts | 251 +++++++++++++++++++ packages/cloudflare-access/src/index.ts | 164 ++++++++++++ packages/cloudflare-access/tsconfig.json | 10 + packages/cloudflare-access/vitest.config.ts | 8 + yarn.lock | 12 + 10 files changed, 563 insertions(+) create mode 100644 .changeset/sharp-moles-knock.md create mode 100644 packages/cloudflare-access/CHANGELOG.md create mode 100644 packages/cloudflare-access/README.md create mode 100644 packages/cloudflare-access/package.json create mode 100644 packages/cloudflare-access/src/index.test.ts create mode 100644 packages/cloudflare-access/src/index.ts create mode 100644 packages/cloudflare-access/tsconfig.json create mode 100644 packages/cloudflare-access/vitest.config.ts diff --git a/.changeset/sharp-moles-knock.md b/.changeset/sharp-moles-knock.md new file mode 100644 index 000000000..d7f65abd2 --- /dev/null +++ b/.changeset/sharp-moles-knock.md @@ -0,0 +1,5 @@ +--- +'@hono/cloudflare-access': minor +--- + +Initial release diff --git a/package.json b/package.json index ac12e8afa..a3e75b17b 100644 --- a/package.json +++ b/package.json @@ -40,6 +40,7 @@ "build:casbin": "yarn workspace @hono/casbin build", "build:ajv-validator": "yarn workspace @hono/ajv-validator build", "build:tsyringe": "yarn workspace @hono/tsyringe build", + "build:cloudflare-access": "yarn workspace @hono/cloudflare-access build", "build": "run-p 'build:*'", "lint": "eslint 'packages/**/*.{ts,tsx}'", "lint:fix": "eslint --fix 'packages/**/*.{ts,tsx}'", diff --git a/packages/cloudflare-access/CHANGELOG.md b/packages/cloudflare-access/CHANGELOG.md new file mode 100644 index 000000000..e69de29bb diff --git a/packages/cloudflare-access/README.md b/packages/cloudflare-access/README.md new file mode 100644 index 000000000..e1d4d3abf --- /dev/null +++ b/packages/cloudflare-access/README.md @@ -0,0 +1,65 @@ +# Cloudflare Access middleware for Hono + +This is a [Cloudflare Access](https://www.cloudflare.com/zero-trust/products/access/) third-party middleware +for [Hono](https://github.com/honojs/hono). + +This middleware can be used to validate that your application is being served behind Cloudflare Access by verifying the +JWT received, User details from the JWT are also available inside the request context. + +This middleware will also ensure the Access policy serving the application is from a +specific [Access Team](https://developers.cloudflare.com/cloudflare-one/faq/getting-started-faq/#whats-a-team-domainteam-name). + +## Usage + +```ts +import { cloudflareAccess } from '@hono/cloudflare-access' +import { Hono } from 'hono' + +const app = new Hono() + +app.use('*', cloudflareAccess('my-access-team-name')) +app.get('/', (c) => c.text('foo')) + +export default app +``` + +## Access JWT payload + +```ts +import { cloudflareAccess, CloudflareAccessVariables } from '@hono/cloudflare-access' +import { Hono } from 'hono' + +type myVariables = { + user: number +} + +const app = new Hono<{ Variables: myVariables & CloudflareAccessVariables }>() + +app.use('*', cloudflareAccess('my-access-team-name')) +app.get('/', (c) => { + const payload = c.get('accessPayload') + + return c.text(`You just authenticated with the email ${payload.email}`) +}) + +export default app +``` + + +## Errors throw by the middleware + +| Error | HTTP Code | +|--------------------------------------------------------------------------------------------------------|-----------| +| Authentication error: Missing bearer token | 401 | +| Authentication error: Unable to decode Bearer token | 401 | +| Authentication error: Token is expired | 401 | +| Authentication error: Expected team name {your-team-name}, but received ${different-team-signed-token} | 401 | +| Authentication error: Invalid Token | 401 | + +## Author + +Gabriel Massadas + +## License + +MIT diff --git a/packages/cloudflare-access/package.json b/packages/cloudflare-access/package.json new file mode 100644 index 000000000..702164770 --- /dev/null +++ b/packages/cloudflare-access/package.json @@ -0,0 +1,47 @@ +{ + "name": "@hono/cloudflare-access", + "version": "0.0.0", + "description": "A third-party Cloudflare Access auth middleware for Hono", + "type": "module", + "module": "dist/index.js", + "types": "dist/index.d.ts", + "files": [ + "dist" + ], + "scripts": { + "test": "vitest --run", + "build": "tsup ./src/index.ts --format esm,cjs --dts", + "publint": "publint", + "release": "yarn build && yarn test && yarn publint && yarn publish" + }, + "exports": { + ".": { + "import": { + "types": "./dist/index.d.ts", + "default": "./dist/index.js" + }, + "require": { + "types": "./dist/index.d.cts", + "default": "./dist/index.cjs" + } + } + }, + "license": "MIT", + "publishConfig": { + "registry": "https://registry.npmjs.org", + "access": "public" + }, + "repository": { + "type": "git", + "url": "https://github.com/honojs/middleware.git" + }, + "homepage": "https://github.com/honojs/middleware", + "peerDependencies": { + "hono": "*" + }, + "devDependencies": { + "hono": "^4.4.12", + "tsup": "^8.1.0", + "vitest": "^1.6.0" + } +} diff --git a/packages/cloudflare-access/src/index.test.ts b/packages/cloudflare-access/src/index.test.ts new file mode 100644 index 000000000..851e6adc7 --- /dev/null +++ b/packages/cloudflare-access/src/index.test.ts @@ -0,0 +1,251 @@ +import { Hono } from 'hono' +import { cloudflareAccess } from '../src' +import { describe, expect, it, vi } from 'vitest' + +import crypto from 'crypto'; +import { promisify } from 'util'; + +const generateKeyPair = promisify(crypto.generateKeyPair); + +interface KeyPairResult { + publicKey: string; + privateKey: string; +} + +interface JWK { + kid: string; + kty: string; + alg: string; + use: string; + e: string; + n: string; +} + +async function generateJWTKeyPair(): Promise { + try { + const { publicKey, privateKey } = await generateKeyPair('rsa', { + modulusLength: 2048, + publicKeyEncoding: { + type: 'spki', + format: 'pem' + }, + privateKeyEncoding: { + type: 'pkcs8', + format: 'pem' + } + }); + + return { + publicKey, + privateKey + }; + } catch (error) { + throw new Error(`Failed to generate key pair: ${(error as Error).message}`); + } +} + +function generateKeyThumbprint(modulusBase64: string): string { + const hash = crypto.createHash('sha256'); + hash.update(Buffer.from(modulusBase64, 'base64')); + return hash.digest('hex'); +} + +function publicKeyToJWK(publicKey: string): JWK { + // Convert PEM to key object + const keyObject = crypto.createPublicKey(publicKey); + + // Export the key in JWK format + const jwk = keyObject.export({ format: 'jwk' }); + + // Generate key ID using the modulus + const kid = generateKeyThumbprint(jwk.n as string); + + return { + kid, + kty: 'RSA', + alg: 'RS256', + use: 'sig', + e: jwk.e as string, + n: jwk.n as string, + }; +} + + +function base64URLEncode(str: string): string { + return Buffer.from(str) + .toString('base64') + .replace(/\+/g, '-') + .replace(/\//g, '_') + .replace(/=/g, ''); +} + +function generateJWT(privateKey: string, payload: Record, expiresIn: number = 3600): string { + // Create header + const header = { + alg: 'RS256', + typ: 'JWT' + }; + + // Add expiration to payload + const now = Math.floor(Date.now() / 1000); + const fullPayload = { + ...payload, + iat: now, + exp: now + expiresIn + }; + + // Encode header and payload + const encodedHeader = base64URLEncode(JSON.stringify(header)); + const encodedPayload = base64URLEncode(JSON.stringify(fullPayload)); + + // Create signature + const signatureInput = `${encodedHeader}.${encodedPayload}`; + const signer = crypto.createSign('RSA-SHA256'); + signer.update(signatureInput); + const signature = signer.sign(privateKey); + // @ts-ignore + const encodedSignature = base64URLEncode(signature); + + // Combine all parts + return `${encodedHeader}.${encodedPayload}.${encodedSignature}`; +} + + +describe('Cloudflare Access middleware', async () => { + const keyPair1 = await generateJWTKeyPair(); + const keyPair2 = await generateJWTKeyPair(); + const keyPair3 = await generateJWTKeyPair(); + + vi.stubGlobal('fetch', async () => { + return Response.json({ + keys: [ + publicKeyToJWK(keyPair1.publicKey), + publicKeyToJWK(keyPair2.publicKey), + ], + }) + }) + + const app = new Hono() + + app.use('/*', cloudflareAccess('my-cool-team-name')) + app.get('/hello-behind-access', (c) => c.text('foo')) + app.get('/access-payload', (c) => c.json(c.get('accessPayload'))) + + it('Should be throw Missing bearer token when nothing is sent', async () => { + const res = await app.request('http://localhost/hello-behind-access') + expect(res).not.toBeNull() + expect(res.status).toBe(401) + expect(await res.text()).toBe('Authentication error: Missing bearer token') + }) + + it('Should be throw Unable to decode Bearer token when sending garbage', async () => { + const res = await app.request('http://localhost/hello-behind-access', { + headers: { + 'cf-access-jwt-assertion': 'asdasdasda' + } + }) + expect(res).not.toBeNull() + expect(res.status).toBe(401) + expect(await res.text()).toBe('Authentication error: Unable to decode Bearer token') + }) + + it('Should be throw Token is expired when sending expired token', async () => { + const token = generateJWT(keyPair1.privateKey, { + sub: '1234567890', + }, -3600); + + const res = await app.request('http://localhost/hello-behind-access', { + headers: { + 'cf-access-jwt-assertion': token + } + }) + expect(res).not.toBeNull() + expect(res.status).toBe(401) + expect(await res.text()).toBe('Authentication error: Token is expired') + }) + + it('Should be throw Expected team name x, but received y when sending invalid iss', async () => { + const token = generateJWT(keyPair1.privateKey, { + sub: '1234567890', + iss: 'https://different-team.cloudflareaccess.com', + }); + + const res = await app.request('http://localhost/hello-behind-access', { + headers: { + 'cf-access-jwt-assertion': token + } + }) + expect(res).not.toBeNull() + expect(res.status).toBe(401) + expect(await res.text()).toBe('Authentication error: Expected team name https://my-cool-team-name.cloudflareaccess.com, but received https://different-team.cloudflareaccess.com') + }) + + it('Should be throw Invalid token when sending token signed with private key not in the allowed list', async () => { + const token = generateJWT(keyPair3.privateKey, { + sub: '1234567890', + iss: 'https://my-cool-team-name.cloudflareaccess.com', + }); + + const res = await app.request('http://localhost/hello-behind-access', { + headers: { + 'cf-access-jwt-assertion': token + } + }) + expect(res).not.toBeNull() + expect(res.status).toBe(401) + expect(await res.text()).toBe('Authentication error: Invalid Token') + }) + + it('Should work when sending everything correctly', async () => { + const token = generateJWT(keyPair1.privateKey, { + sub: '1234567890', + iss: 'https://my-cool-team-name.cloudflareaccess.com', + }); + + const res = await app.request('http://localhost/hello-behind-access', { + headers: { + 'cf-access-jwt-assertion': token + } + }) + expect(res).not.toBeNull() + expect(res.status).toBe(200) + expect(await res.text()).toBe('foo') + }) + + it('Should work with tokens signed by the 2ยบ key in the public keys list', async () => { + const token = generateJWT(keyPair2.privateKey, { + sub: '1234567890', + iss: 'https://my-cool-team-name.cloudflareaccess.com', + }); + + const res = await app.request('http://localhost/hello-behind-access', { + headers: { + 'cf-access-jwt-assertion': token + } + }) + expect(res).not.toBeNull() + expect(res.status).toBe(200) + expect(await res.text()).toBe('foo') + }) + + it('Should be able to retrieve the JWT payload from Hono context', async () => { + const token = generateJWT(keyPair1.privateKey, { + sub: '1234567890', + iss: 'https://my-cool-team-name.cloudflareaccess.com', + }); + + const res = await app.request('http://localhost/access-payload', { + headers: { + 'cf-access-jwt-assertion': token + } + }) + expect(res).not.toBeNull() + expect(res.status).toBe(200) + expect(await res.json()).toEqual({ + "sub":"1234567890", + "iss":"https://my-cool-team-name.cloudflareaccess.com", + "iat":expect.any(Number), + "exp":expect.any(Number) + }) + }) +}) diff --git a/packages/cloudflare-access/src/index.ts b/packages/cloudflare-access/src/index.ts new file mode 100644 index 000000000..cd69ff3ce --- /dev/null +++ b/packages/cloudflare-access/src/index.ts @@ -0,0 +1,164 @@ +import { createMiddleware } from 'hono/factory' +import { Context } from 'hono' + +export type CloudflareAccessPayload = { + aud: string[], + email: string, + exp: number, + iat: number, + nbf: number, + iss: string, + type: string, + identity_nonce: string, + sub: string, + country: string, +} + +export type CloudflareAccessVariables = { + accessPayload: CloudflareAccessPayload +} + +type DecodedToken = { + header: object + payload: CloudflareAccessPayload + signature: string + raw: { header?: string; payload?: string; signature?: string } +} + +declare module 'hono' { + interface ContextVariableMap { + accessPayload: CloudflareAccessPayload + } +} + +export const cloudflareAccess = (accessTeamName: string) => { + // This var will hold already imported jwt keys, this reduces the load of importing the key on every request + let cacheKeys: Record = {} + let cacheExpiration = 0 + + return createMiddleware(async (c, next) => { + const encodedToken = getJwt(c) + if (encodedToken === null) return c.text('Authentication error: Missing bearer token', 401) + + // Load jwt keys if they are not in memory or already expired + if (Object.keys(cacheKeys).length === 0 || Math.floor(Date.now() / 1000) < cacheExpiration) { + const publicKeys = await getPublicKeys(accessTeamName) + cacheKeys = publicKeys.keys + cacheExpiration = publicKeys.cacheExpiration + } + + // Decode Token + let token + try { + token = decodeJwt(encodedToken) + } catch (err) { + return c.text('Authentication error: Unable to decode Bearer token', 401) + } + + // Is the token expired? + const expiryDate = new Date(token.payload.exp * 1000) + const currentDate = new Date(Date.now()) + if (expiryDate <= currentDate) return c.text('Authentication error: Token is expired', 401) + + // Check is token is valid against at least one public key? + if (!(await isValidJwtSignature(token, cacheKeys))) + return c.text('Authentication error: Invalid Token', 401) + + // Is signed from the correct team? + const expectedIss = `https://${accessTeamName}.cloudflareaccess.com` + if (token.payload?.iss !== expectedIss) + return c.text( + `Authentication error: Expected team name ${expectedIss}, but received ${token.payload?.iss}`, + 401 + ) + + c.set('accessPayload', token.payload) + await next() + }) +} + +async function getPublicKeys(accessTeamName: string) { + const jwtUrl = `https://${accessTeamName}.cloudflareaccess.com/cdn-cgi/access/certs` + + const result = await fetch(jwtUrl, { + method: 'GET', + // @ts-ignore + cf: { + // Dont cache error responses + cacheTtlByStatus: { '200-299': 30, '300-599': 0 }, + }, + }) + + const data: any = await result.json() + + // Because we keep CryptoKey's in memory between requests, we need to make sure they are refreshed once in a while + let cacheExpiration = Math.floor(Date.now() / 1000) + 3600 // 1h + + const importedKeys: Record = {} + for (const key of data.keys) { + importedKeys[key.kid] = await crypto.subtle.importKey( + 'jwk', + key, + { + name: 'RSASSA-PKCS1-v1_5', + hash: 'SHA-256', + }, + false, + ['verify'] + ) + } + + return { + keys: importedKeys, + cacheExpiration: cacheExpiration, + } +} + +function getJwt(c: Context) { + const authHeader = c.req.header('cf-access-jwt-assertion') + if (!authHeader) { + return null + } + return authHeader.trim() +} + +function decodeJwt(token: string): DecodedToken { + const parts = token.split('.') + if (parts.length !== 3) { + throw new Error('Invalid token') + } + + const header = JSON.parse(atob(parts[0] as string)) + const payload = JSON.parse(atob(parts[1] as string)) + const signature = atob((parts[2] as string).replace(/_/g, '/').replace(/-/g, '+')) + + return { + header: header, + payload: payload, + signature: signature, + raw: { header: parts[0], payload: parts[1], signature: parts[2] }, + } +} + +async function isValidJwtSignature(token: DecodedToken, keys: Record) { + const encoder = new TextEncoder() + const data = encoder.encode([token.raw.header, token.raw.payload].join('.')) + + const signature = new Uint8Array(Array.from(token.signature).map((c) => c.charCodeAt(0))) + + for (const key of Object.values(keys)) { + const isValid = await validateSingleKey(key, signature, data) + + if (isValid) return true + } + + return false +} + +async function validateSingleKey( + key: CryptoKey, + signature: Uint8Array, + data: Uint8Array +): Promise { + return crypto.subtle.verify('RSASSA-PKCS1-v1_5', key, signature, data) +} diff --git a/packages/cloudflare-access/tsconfig.json b/packages/cloudflare-access/tsconfig.json new file mode 100644 index 000000000..acfcd8430 --- /dev/null +++ b/packages/cloudflare-access/tsconfig.json @@ -0,0 +1,10 @@ +{ + "extends": "../../tsconfig.json", + "compilerOptions": { + "rootDir": "./src", + "outDir": "./dist", + }, + "include": [ + "src/**/*.ts" + ], +} \ No newline at end of file diff --git a/packages/cloudflare-access/vitest.config.ts b/packages/cloudflare-access/vitest.config.ts new file mode 100644 index 000000000..17b54e485 --- /dev/null +++ b/packages/cloudflare-access/vitest.config.ts @@ -0,0 +1,8 @@ +/// +import { defineConfig } from 'vitest/config' + +export default defineConfig({ + test: { + globals: true, + }, +}) diff --git a/yarn.lock b/yarn.lock index f18bd75fd..a5d5d156e 100644 --- a/yarn.lock +++ b/yarn.lock @@ -2505,6 +2505,18 @@ __metadata: languageName: unknown linkType: soft +"@hono/cloudflare-access@workspace:packages/cloudflare-access": + version: 0.0.0-use.local + resolution: "@hono/cloudflare-access@workspace:packages/cloudflare-access" + dependencies: + hono: "npm:^4.4.12" + tsup: "npm:^8.1.0" + vitest: "npm:^1.6.0" + peerDependencies: + hono: "*" + languageName: unknown + linkType: soft + "@hono/conform-validator@workspace:packages/conform-validator": version: 0.0.0-use.local resolution: "@hono/conform-validator@workspace:packages/conform-validator"