Skip to content

Commit

Permalink
Merge pull request #740 from snyk/feat/oauth-jwt-retrieval
Browse files Browse the repository at this point in the history
feat: add jwt retrieval logic
  • Loading branch information
aarlaud authored Apr 16, 2024
2 parents 4e22850 + 3607ee7 commit 02d599e
Show file tree
Hide file tree
Showing 21 changed files with 244 additions and 15 deletions.
4 changes: 4 additions & 0 deletions config.universaltest.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
"BROKER_SERVER_URL": "https://broker2.dev.snyk.io",
"BROKER_HA_MODE_ENABLED": "false",
"BROKER_DISPATCHER_BASE_URL": "https://api.dev.snyk.io"
},
"oauth": {
"clientId": "${CLIENT_ID}",
"clientSecret": "${CLIENT_SECRET}"
}
},
"github": {
Expand Down
46 changes: 46 additions & 0 deletions lib/client/auth/oauth.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import { makeRequestToDownstream } from '../../common/http/request';
import { PostFilterPreparedRequest } from '../../common/relay/prepareRequest';
import { log as logger } from '../../logs/logger';
interface tokenExchangeResponse {
access_token: string;
expires_in: number;
scope: string;
token_type: string;
}

export async function fetchJwt(
apiHostname: string,
clientId: string,
clientSecret: string,
) {
try {
const data = {
grant_type: 'client_credentials',
client_id: clientId,
client_secret: clientSecret,
};
const formData = new URLSearchParams(data);

const request: PostFilterPreparedRequest = {
url: `${apiHostname}/oauth2/token`,
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
method: 'POST',
body: formData.toString(),
};
const oauthResponse = await makeRequestToDownstream(request);
if (oauthResponse.statusCode != 200) {
const errorBody = JSON.parse(oauthResponse.body);
throw new Error(
`${oauthResponse.statusCode}-${errorBody.error}:${errorBody.error_description}`,
);
}
const accessToken = JSON.parse(oauthResponse.body) as tokenExchangeResponse;
const jwt = accessToken.access_token;
const type = accessToken.token_type;
const expiresIn = accessToken.expires_in;

return { expiresIn: expiresIn, authHeader: `${type} ${jwt}` };
} catch (err) {
logger.error({ err }, 'Unable to retrieve JWT');
}
}
30 changes: 26 additions & 4 deletions lib/client/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import { loadAllFilters } from '../common/filter/filtersAsync';
import { ClientOpts, LoadedClientOpts } from '../common/types/options';
import { websocketConnectionSelectorMiddleware } from './routesHandler/websocketConnectionMiddlewares';
import { getClientConfigMetadata } from './utils/configHelpers';
import { fetchJwt } from './auth/oauth';

process.on('uncaughtException', (error) => {
if (error.message == 'read ECONNRESET') {
Expand Down Expand Up @@ -54,6 +55,17 @@ export const main = async (clientOpts: ClientOpts) => {
throw new Error('Unable to load filters');
}

if (
clientOpts.config.brokerClientConfiguration.common.oauth?.clientId &&
clientOpts.config.brokerClientConfiguration.common.oauth?.clientSecret
) {
loadedClientOpts.accessToken = await fetchJwt(
clientOpts.config.API_BASE_URL,
clientOpts.config.brokerClientConfiguration.common.oauth.clientId,
clientOpts.config.brokerClientConfiguration.common.oauth.clientSecret,
);
}

const globalIdentifyingMetadata: IdentifyingMetadata = {
capabilities: ['post-streams'],
clientId: brokerClientId,
Expand All @@ -66,10 +78,20 @@ export const main = async (clientOpts: ClientOpts) => {

let websocketConnections: WebSocketConnection[] = [];
if (clientOpts.config.universalBrokerEnabled) {
websocketConnections = createWebSockets(
loadedClientOpts,
globalIdentifyingMetadata,
);
const integrationsKeys = clientOpts.config.connections
? Object.keys(clientOpts.config.connections)
: [];
if (integrationsKeys.length < 1) {
logger.error(
{},
`No connection found. Please add connections to config.${process.env.SERVICE_ENV}.json.`,
);
} else {
websocketConnections = createWebSockets(
loadedClientOpts,
globalIdentifyingMetadata,
);
}
} else {
websocketConnections.push(
createWebSocket(
Expand Down
62 changes: 52 additions & 10 deletions lib/client/socket.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import { initializeSocketHandlers } from './socketHandlers/init';
import { LoadedClientOpts } from '../common/types/options';
import { translateIntegrationTypeToBrokerIntegrationType } from './utils/integrations';
import { maskToken } from '../common/utils/token';
import { fetchJwt } from './auth/oauth';

export const createWebSockets = (
clientOpts: LoadedClientOpts,
Expand Down Expand Up @@ -90,18 +91,27 @@ export const createWebSocket = (

// Will exponentially back-off from 0.5 seconds to a maximum of 20 minutes
// Retry for a total period of around 4.5 hours
const socketSettings = {
reconnect: {
factor: 1.5,
retries: 30,
max: 20 * 60 * 1000,
},
ping: parseInt(localClientOps.config.socketPingInterval) || 25000,
pong: parseInt(localClientOps.config.socketPongTimeout) || 10000,
timeout: parseInt(localClientOps.config.socketConnectTimeout) || 10000,
};

if (clientOpts.accessToken) {
socketSettings['transport'] = {
extraHeaders: {
Authorization: clientOpts.accessToken?.authHeader,
},
};
}
const websocket: WebSocketConnection = new Socket(
localClientOps.config.brokerServerUrlForSocket,
{
reconnect: {
factor: 1.5,
retries: 30,
max: 20 * 60 * 1000,
},
ping: parseInt(localClientOps.config.socketPingInterval) || 25000,
pong: parseInt(localClientOps.config.socketPongTimeout) || 10000,
timeout: parseInt(localClientOps.config.socketConnectTimeout) || 10000,
},
socketSettings,
);
websocket.socketVersion = 1;
websocket.socketType = 'client';
Expand All @@ -117,6 +127,38 @@ export const createWebSocket = (
websocket.clientConfig = identifyingMetadata.clientConfig;
websocket.role = identifyingMetadata.role;

if (clientOpts.accessToken) {
let timeoutHandlerId;
let timeoutHandler = async () => {};
timeoutHandler = async () => {
logger.debug({}, 'Refreshing oauth access token');
clearTimeout(timeoutHandlerId);
clientOpts.accessToken = await fetchJwt(
clientOpts.config.API_BASE_URL,
clientOpts.config.brokerClientConfiguration.common.oauth!.clientId,
clientOpts.config.brokerClientConfiguration.common.oauth!.clientSecret,
);

websocket.transport.extraHeaders['Authorization'] =
clientOpts.accessToken!.authHeader;
websocket.end();
websocket.open();
timeoutHandlerId = setTimeout(
timeoutHandler,
(clientOpts.accessToken!.expiresIn - 60) * 1000,
);
};

timeoutHandlerId = setTimeout(
timeoutHandler,
(clientOpts.accessToken!.expiresIn - 60) * 1000,
);
}

websocket.on('incoming::error', (e) => {
websocket.emit('error', { type: e.type, description: e.description });
});

logger.info(
{
url: localClientOps.config.brokerServerUrlForSocket,
Expand Down
4 changes: 3 additions & 1 deletion lib/client/types/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,12 @@ export interface WebSocketConnection {
socket: any;
destroy: any;
send: any;
end: any;
open: any;
emit: any;
capabilities?: any;
on: (string, any) => any;
readyState: any;
end: () => any;
}
// export interface WebSocketConnection {
// websocket: Connection;
Expand Down
1 change: 1 addition & 0 deletions lib/client/types/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ interface BrokerClient {

interface BrokerServer {
BROKER_SERVER_URL: string;
BROKER_SERVER_MANDATORY_AUTH_ENABLED?: boolean;
}

/**
Expand Down
8 changes: 8 additions & 0 deletions lib/common/types/options.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,14 @@ export interface ClientOpts {
filters: FiltersType | Map<string, FiltersType>;
serverId?: string;
connections?: Record<string, any>;
oauth?: {
clientId: string;
clientSecret: string;
};
accessToken?: {
authHeader: string;
expiresIn: number;
};
}

export interface ServerOpts {
Expand Down
50 changes: 50 additions & 0 deletions lib/server/socket.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import { SocketHandler } from './types/socket';
import { handleIoError } from './socketHandlers/errorHandler';
import { handleSocketConnection } from './socketHandlers/connectionHandler';
import { initConnectionHandler } from './socketHandlers/initHandlers';
import { maskToken } from '../common/utils/token';
import { log as logger } from '../logs/logger';

const socketConnections = new Map();

Expand All @@ -28,6 +30,54 @@ const socket = ({ server, loadedServerOpts }): SocketHandler => {
};

const websocket = new Primus(server, ioConfig);
websocket.authorize(async (req, done) => {
const maskedToken = maskToken(
req.uri.pathname.replaceAll(/^\/primus\/([^/]+)\//g, '$1').toLowerCase(),
);
const authHeader = req.headers['authorization'];

if (
(!authHeader || !authHeader.startsWith('Bearer')) &&
loadedServerOpts.config.BROKER_SERVER_MANDATORY_AUTH_ENABLED
) {
logger.error({ maskedToken }, 'request missing Authorization header');
done({
statusCode: 401,
authenticate: 'Bearer',
message: 'missing required authorization header',
});
return;
}

const jwt = authHeader
? authHeader.substring(authHeader.indexOf(' ') + 1)
: '';
if (!jwt) logger.debug({}, `TODO: Validate jwt`);
done();
// let oauthResponse = await axiosInstance.request({
// url: 'http://localhost:8080/oauth2/introspect',
// method: 'POST',
// headers: {
// 'Content-Type': 'application/x-www-form-urlencoded',
// },
// auth: {
// username: 'broker-connection-a',
// password: 'secret',
// },
// data: `token=${token}`,
// });

// if (!oauthResponse.data.active) {
// logger.error({maskedToken}, 'JWT is not active (could be expired, malformed, not issued by us, etc)');
// done({
// statusCode: 403,
// message: 'token not active',
// });
// } else {
// req.oauth_data = oauthResponse.data;
// done();
// }
});
websocket.socketType = 'server';
websocket.socketVersion = 1;
websocket.plugin('emitter', Emitter);
Expand Down
4 changes: 4 additions & 0 deletions test/functional/client-universal-server.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ describe('proxy requests originating from behind the broker client', () => {
process.env.SNYK_BROKER_CLIENT_CONFIGURATION__common__default__BROKER_SERVER_URL = `http://localhost:${bs.port}`;
process.env.SNYK_FILTER_RULES_PATHS__github = clientAccept;
process.env.SNYK_FILTER_RULES_PATHS__gitlab = clientAccept;
process.env.CLIENT_ID = 'clienid';
process.env.CLIENT_SECRET = 'clientsecret';

bc = await createUniversalBrokerClient();
});
Expand All @@ -57,6 +59,8 @@ describe('proxy requests originating from behind the broker client', () => {
delete process.env.SNYK_BROKER_SERVER_UNIVERSAL_CONFIG_ENABLED;
delete process.env
.SNYK_BROKER_CLIENT_CONFIGURATION__common__default__BROKER_SERVER_URL;
delete process.env.CLIENT_ID;
delete process.env.CLIENT_SECRET;
});

it('server identifies self to client', async () => {
Expand Down
4 changes: 4 additions & 0 deletions test/functional/healthcheck-universal.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ describe('proxy requests originating from behind the broker client', () => {
tws = await createTestWebServer();
bs = await createBrokerServer({ filters: serverAccept });
process.env.SNYK_BROKER_CLIENT_CONFIGURATION__common__default__BROKER_SERVER_URL = `http://localhost:${bs.port}`;
process.env.CLIENT_ID = 'clienid';
process.env.CLIENT_SECRET = 'clientsecret';
});

afterAll(async () => {
Expand All @@ -31,6 +33,8 @@ describe('proxy requests originating from behind the broker client', () => {
delete process.env.BROKER_SERVER_URL;
delete process.env
.SNYK_BROKER_CLIENT_CONFIGURATION__common__default__BROKER_SERVER_URL;
delete process.env.CLIENT_ID;
delete process.env.CLIENT_SECRET;
});

afterEach(async () => {
Expand Down
4 changes: 4 additions & 0 deletions test/functional/server-client-universal.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ describe('proxy requests originating from behind the broker server', () => {
process.env.SNYK_FILTER_RULES_PATHS__gitlab = clientAccept;
process.env['SNYK_FILTER_RULES_PATHS__azure-repos'] = clientAccept;
process.env['SNYK_FILTER_RULES_PATHS__jira-bearer-auth'] = clientAccept;
process.env.CLIENT_ID = 'clienid';
process.env.CLIENT_SECRET = 'clientsecret';

bc = await createUniversalBrokerClient();
await waitForUniversalBrokerClientsConnection(bs, 2);
Expand All @@ -68,6 +70,8 @@ describe('proxy requests originating from behind the broker server', () => {
delete process.env.SNYK_BROKER_SERVER_UNIVERSAL_CONFIG_ENABLED;
delete process.env
.SNYK_BROKER_CLIENT_CONFIGURATION__common__default__BROKER_SERVER_URL;
delete process.env.CLIENT_ID;
delete process.env.CLIENT_SECRET;
});

it('successfully broker GET', async () => {
Expand Down
Loading

0 comments on commit 02d599e

Please sign in to comment.