Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add authentication middleware for routes and ALLOWED_HANDLES restrictions #53

Merged
merged 6 commits into from
Mar 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ PRIVATE_KEY=-----BEGIN RSA PRIVATE KEY-----\nYOUR KEY HERE WITH \n INCLUDED=\n--
# Auth configs
NEXTAUTH_SECRET=bad-secret
NEXTAUTH_URL=http://localhost:3000
ALLOWED_HANDLES=

# Go to https://smee.io/new set this to the URL that you are redirected to.
WEBHOOK_PROXY_URL=
Expand All @@ -19,3 +20,7 @@ LOG_LEVEL=debug

# Used for settings various configuration in the app
NODE_ENV=development

# Used for GHEC configs
PUBLIC_ORG=
PRIVATE_ORG=
11 changes: 11 additions & 0 deletions env.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,16 @@ export const env = createEnv({
NODE_ENV: z.string().optional().default('development'),
PUBLIC_ORG: z.string().optional(),
PRIVATE_ORG: z.string().optional(),
// Custom validation for a comma separated list of strings
// ex: ajhenry,github,ahpook
ALLOWED_HANDLES: z
.string()
.optional()
.default('')
.refine((val) => {
if (val === '') return true
return val.split(',').every((handle) => handle.trim().length > 0)
}, 'Invalid comma separated list of GitHub handles'),
},
/*
* Environment variables available on the client (and server).
Expand All @@ -48,6 +58,7 @@ export const env = createEnv({
NODE_ENV: process.env.NODE_ENV,
PUBLIC_ORG: process.env.PUBLIC_ORG,
PRIVATE_ORG: process.env.PRIVATE_ORG,
ALLOWED_HANDLES: process.env.ALLOWED_HANDLES,
},
skipValidation: process.env.SKIP_ENV_VALIDATIONS === 'true',
})
33 changes: 32 additions & 1 deletion src/pages/api/auth/[...nextauth].ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { personalOctokit } from 'bot/octokit'
import NextAuth, { AuthOptions } from 'next-auth'
import NextAuth, { AuthOptions, Profile } from 'next-auth'
import GitHub from 'next-auth/providers/github'
import { logger } from '../../../utils/logger'

Expand Down Expand Up @@ -52,6 +52,37 @@ export const nextAuthOptions: AuthOptions = {
},
},
callbacks: {
signIn: async (params) => {
authLogger.debug('Sign in callback')

const profile = params.profile as Profile & { login?: string }
const allowedHandles = (
process.env.ALLOWED_HANDLES?.split(',') ?? []
).filter((handle) => handle !== '')

if (allowedHandles.length === 0) {
authLogger.info(
'No allowed handles specified via ALLOWED_HANDLES, allowing all users.',
)
return true
}

if (!profile?.login) {
return false
}

authLogger.debug('Trying to sign in with handle:', profile.login)

if (allowedHandles.includes(profile.login)) {
return true
}

authLogger.warn(
`User "${profile.login}" is not in the allowed handles list`,
)

return false
},
session: async ({ session, token }) => {
authLogger.debug('Session callback')

Expand Down
3 changes: 2 additions & 1 deletion src/pages/api/trpc/[trpc].ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import * as trpcNext from '@trpc/server/adapters/next'
import { createContext } from 'server/trpc'
import { appRouter } from '../../../server/routers/_app'

// export API handler
// @see https://trpc.io/docs/server/adapters
export default trpcNext.createNextApiHandler({
router: appRouter,
createContext: () => ({}),
createContext,
})
25 changes: 25 additions & 0 deletions src/server/lib/auth.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import { TRPCError } from '@trpc/server'
import { personalOctokit } from 'bot/octokit'
import { logger } from 'utils/logger'

const middlewareLogger = logger.getSubLogger({ name: 'middleware' })

export const checkGitHubAuth = async (accessToken: string | undefined) => {
if (!accessToken) {
middlewareLogger.error('No access token provided')
throw new TRPCError({ code: 'UNAUTHORIZED' })
}

// Check validity of token
const octokit = personalOctokit(accessToken)
try {
const user = await octokit.rest.users.getAuthenticated()
if (!user) {
middlewareLogger.error('No user found')
throw new TRPCError({ code: 'UNAUTHORIZED' })
}
} catch (error) {
middlewareLogger.error('Error checking github auth', error)
throw new TRPCError({ code: 'UNAUTHORIZED' })
}
}
13 changes: 13 additions & 0 deletions src/server/middleware/auth.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import { checkGitHubAuth } from 'server/lib/auth'
import { Middleware } from 'server/trpc'

export const verifyAuth: Middleware = async (opts) => {
const { ctx } = opts

// Verify valid github session
checkGitHubAuth(ctx.session?.user?.accessToken)

return opts.next({
ctx,
})
}
2 changes: 2 additions & 0 deletions src/server/routers/_app.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import { router } from '../trpc'
import { gitRouter } from './git'
import { healthRouter } from './health'
import { octokitRouter } from './octokit'
import { reposRouter } from './repos'

export const appRouter = router({
git: gitRouter,
octokit: octokitRouter,
repos: reposRouter,
health: healthRouter,
})

// export type definition of API
Expand Down
9 changes: 9 additions & 0 deletions src/server/routers/health.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// Simple health check endpoint
import { procedure, router } from '../trpc'

export const healthRouter = router({
// Queries
ping: procedure.query(async () => {
return 'pong'
}),
})
19 changes: 17 additions & 2 deletions src/server/trpc.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,28 @@ import { createTRPCNext } from '@trpc/next'
import type { AppRouter } from './routers/_app'

import { initTRPC } from '@trpc/server'
import { CreateNextContextOptions } from '@trpc/server/adapters/next'
import { getServerSession } from 'next-auth'
import { nextAuthOptions } from 'pages/api/auth/[...nextauth]'
import { verifyAuth } from './middleware/auth'

export const createContext = async (opts: CreateNextContextOptions) => {
const session = await getServerSession(opts.req, opts.res, nextAuthOptions)

return {
session,
}
}

// Avoid exporting the entire t-object
// since it's not very descriptive.
// For instance, the use of a t variable
// is common in i18n libraries.
const t = initTRPC.create()
const t = initTRPC.context<typeof createContext>().create()
// Base router and procedure helpers
export const router = t.router
export const procedure = t.procedure
export type Middleware = Parameters<(typeof t.procedure)['use']>[0]
export const procedure = t.procedure.use(verifyAuth)

function getBaseUrl() {
if (typeof window !== 'undefined')
Expand All @@ -25,6 +39,7 @@ function getBaseUrl() {
// assume localhost
return `http://localhost:${process.env.PORT ?? 3000}`
}

export const trpc = createTRPCNext<AppRouter>({
config(_) {
return {
Expand Down
9 changes: 9 additions & 0 deletions src/types/next-auth.d.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import { DefaultSession } from 'next-auth'

declare module 'next-auth' {
interface Session {
user: {
accessToken: string | undefined
} & DefaultSession['user']
}
}
3 changes: 3 additions & 0 deletions test/octomock/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,9 @@ let mockFunctions = {
saveState: jest.fn(),
getState: jest.fn(),
},
users: {
getAuthenticated: jest.fn(),
},
},
}

Expand Down
63 changes: 63 additions & 0 deletions test/server/auth.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import { healthRouter } from '../../src/server/routers/health'
import { Octomock } from '../octomock'
import { createTestContext } from '../utils/auth'
const om = new Octomock()

jest.mock('../../src/bot/octokit', () => ({
personalOctokit: () => om.getOctokitImplementation(),
}))

describe('Git router', () => {
beforeEach(() => {
om.resetMocks()
jest.resetAllMocks()
})

it('should allow users that are authenticated', async () => {
const caller = healthRouter.createCaller(createTestContext())

om.mockFunctions.rest.users.getAuthenticated.mockResolvedValue({
status: 200,
data: {
login: 'test-user',
},
})

const res = await caller.ping()

expect(res).toEqual('pong')

expect(om.mockFunctions.rest.users.getAuthenticated).toHaveBeenCalledTimes(
1,
)
})

it('should throw on invalid sessions', async () => {
const caller = healthRouter.createCaller(
createTestContext({
user: {
name: 'fake-username',
email: 'fake-email',
image: 'fake-image',
accessToken: 'bad-token',
},
expires: new Date('2030-01-01').toISOString(),
}),
)

om.mockFunctions.rest.users.getAuthenticated.mockResolvedValue({
status: 401,
data: {
message: 'Bad credentials',
},
})

await caller.ping().catch((error) => {
expect(error.code).toContain('UNAUTHORIZED')
})

expect(om.mockFunctions.rest.users.getAuthenticated).toHaveBeenCalledTimes(
1,
)
})
})
11 changes: 8 additions & 3 deletions test/server/git.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@ jest.mock('simple-git', () => {
})

import * as config from '../../src/bot/config'
import * as auth from '../../src/server/lib/auth'
import { gitRouter } from '../../src/server/routers/git'
import { Octomock } from '../octomock'
import { createTestContext } from '../utils/auth'
const om = new Octomock()

jest.mock('../../src/bot/config')
Expand All @@ -24,6 +26,7 @@ jest.mock('../../src/bot/octokit', () => ({
appOctokit: () => om.getOctokitImplementation(),
installationOctokit: () => om.getOctokitImplementation(),
}))
jest.mock('../../src/server/lib/auth')

const fakeForkRepo = {
status: 200,
Expand Down Expand Up @@ -70,14 +73,16 @@ const repoNotFound = {
},
}

jest.spyOn(auth, 'checkGitHubAuth').mockResolvedValue()

describe('Git router', () => {
beforeEach(() => {
om.resetMocks()
jest.resetAllMocks()
})

it('should create a mirror when repo does not exist exist', async () => {
const caller = gitRouter.createCaller({})
const caller = gitRouter.createCaller(createTestContext())

const configSpy = jest.spyOn(config, 'getConfig').mockResolvedValue({
publicOrg: 'github',
Expand Down Expand Up @@ -116,7 +121,7 @@ describe('Git router', () => {
})

it('should throw an error when repo already exists', async () => {
const caller = gitRouter.createCaller({})
const caller = gitRouter.createCaller(createTestContext())

const configSpy = jest.spyOn(config, 'getConfig').mockResolvedValue({
publicOrg: 'github',
Expand Down Expand Up @@ -150,7 +155,7 @@ describe('Git router', () => {
})

it('should cleanup repos when there is an error', async () => {
const caller = gitRouter.createCaller({})
const caller = gitRouter.createCaller(createTestContext())

const configSpy = jest.spyOn(config, 'getConfig').mockResolvedValue({
publicOrg: 'github',
Expand Down
18 changes: 18 additions & 0 deletions test/utils/auth.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import { Session } from 'next-auth'
import { createContext } from '../../src/server/trpc'

export const createTestContext = (
session?: Session,
): Awaited<ReturnType<typeof createContext>> => {
return {
session: session ?? {
user: {
name: 'fake-username',
email: 'fake-email',
image: 'fake-image',
accessToken: 'fake-access-token',
},
expires: new Date('2030-01-01').toISOString(),
},
}
}