Skip to content

Commit

Permalink
feat(credential-providers): source accountId from credential providers (
Browse files Browse the repository at this point in the history
#6019)

* feat(codegen): source accountId from STS

* feat(credential-provider-sso): source accountId from SSO

* feat(credential-provider-ini): source accountId from static credentials

* feat(credential-provider-process): source accountId from process JSON/profile

* feat(credential-provider-env): source accountId from env

* fix(codegen): error safety for accountId sourcing from STS

* chore(codegen): run codegen for sts client

* test(codegen): add accId test for RoleAssumerWebIdentity

* test(client-sts): copy test file manually

* fix(credential-provider-sso): accountId resolution fix

* test(credentials): bug fixes for static and sts creds tests

* fix(credential-providers): undefined safety for returning creds object

* chore(codegen): add interface and refactor for AssumedRoleUser and arn parsing
  • Loading branch information
siddsriv authored Jul 24, 2024
1 parent f4400fe commit 83cd253
Show file tree
Hide file tree
Showing 15 changed files with 261 additions and 42 deletions.
41 changes: 37 additions & 4 deletions clients/client-sts/src/defaultStsRoleAssumers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,31 @@ export type RoleAssumer = (

const ASSUME_ROLE_DEFAULT_REGION = "us-east-1";

interface AssumedRoleUser {
/**
* The ARN of the temporary security credentials that are returned from the AssumeRole action.
*/
Arn?: string;

/**
* A unique identifier that contains the role ID and the role session name of the role that is being assumed.
*/
AssumedRoleId?: string;
}

/**
* @internal
*/
const getAccountIdFromAssumedRoleUser = (assumedRoleUser?: AssumedRoleUser) => {
if (typeof assumedRoleUser?.Arn === "string") {
const arnComponents = assumedRoleUser.Arn.split(":");
if (arnComponents.length > 4 && arnComponents[4] !== "") {
return arnComponents[4];
}
}
return undefined;
};

/**
* @internal
*
Expand Down Expand Up @@ -84,17 +109,21 @@ export const getDefaultRoleAssumer = (
logger: logger as any,
});
}
const { Credentials } = await stsClient.send(new AssumeRoleCommand(params));
const { Credentials, AssumedRoleUser } = await stsClient.send(new AssumeRoleCommand(params));
if (!Credentials || !Credentials.AccessKeyId || !Credentials.SecretAccessKey) {
throw new Error(`Invalid response from STS.assumeRole call with role ${params.RoleArn}`);
}

const accountId = getAccountIdFromAssumedRoleUser(AssumedRoleUser);

return {
accessKeyId: Credentials.AccessKeyId,
secretAccessKey: Credentials.SecretAccessKey,
sessionToken: Credentials.SessionToken,
expiration: Credentials.Expiration,
// TODO(credentialScope): access normally when shape is updated.
credentialScope: (Credentials as any).CredentialScope,
...((Credentials as any).CredentialScope && { credentialScope: (Credentials as any).CredentialScope }),
...(accountId && { accountId }),
};
};
};
Expand Down Expand Up @@ -134,17 +163,21 @@ export const getDefaultRoleAssumerWithWebIdentity = (
logger: logger as any,
});
}
const { Credentials } = await stsClient.send(new AssumeRoleWithWebIdentityCommand(params));
const { Credentials, AssumedRoleUser } = await stsClient.send(new AssumeRoleWithWebIdentityCommand(params));
if (!Credentials || !Credentials.AccessKeyId || !Credentials.SecretAccessKey) {
throw new Error(`Invalid response from STS.assumeRoleWithWebIdentity call with role ${params.RoleArn}`);
}

const accountId = getAccountIdFromAssumedRoleUser(AssumedRoleUser);

return {
accessKeyId: Credentials.AccessKeyId,
secretAccessKey: Credentials.SecretAccessKey,
sessionToken: Credentials.SessionToken,
expiration: Credentials.Expiration,
// TODO(credentialScope): access normally when shape is updated.
credentialScope: (Credentials as any).CredentialScope,
...((Credentials as any).CredentialScope && { credentialScope: (Credentials as any).CredentialScope }),
...(accountId && { accountId }),
};
};
};
Expand Down
26 changes: 26 additions & 0 deletions clients/client-sts/test/defaultRoleAssumers.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,17 @@ describe("getDefaultRoleAssumer", () => {
);
});

it("should return accountId in the credentials", async () => {
const roleAssumer = getDefaultRoleAssumer();
const params: AssumeRoleCommandInput = {
RoleArn: "arn:aws:foo",
RoleSessionName: "session",
};
const sourceCred = { accessKeyId: "key", secretAccessKey: "secrete" };
const assumedRole = await roleAssumer(sourceCred, params);
expect(assumedRole.accountId).toEqual("123");
});

it("should use the STS client config", async () => {
const logger = console;
const region = "some-region";
Expand Down Expand Up @@ -169,6 +180,10 @@ describe("getDefaultRoleAssumer", () => {
describe("getDefaultRoleAssumerWithWebIdentity", () => {
const assumeRoleResponse = `<Response xmlns="https://sts.amazonaws.com/doc/2011-06-15/">
<AssumeRoleWithWebIdentityResult>
<AssumedRoleUser>
<AssumedRoleId>AROAZOX2IL27GNRBJHWC2:session</AssumedRoleId>
<Arn>arn:aws:sts::123456789012:assumed-role/assume-role-test/session</Arn>
</AssumedRoleUser>
<Credentials>
<AccessKeyId>key</AccessKeyId>
<SecretAccessKey>secrete</SecretAccessKey>
Expand Down Expand Up @@ -209,6 +224,17 @@ describe("getDefaultRoleAssumerWithWebIdentity", () => {
});
});

it("should return accountId in the credentials", async () => {
const roleAssumerWithWebIdentity = getDefaultRoleAssumerWithWebIdentity();
const params: AssumeRoleWithWebIdentityCommandInput = {
RoleArn: "arn:aws:foo",
RoleSessionName: "session",
WebIdentityToken: "token",
};
const assumedRole = await roleAssumerWithWebIdentity(params);
expect(assumedRole.accountId).toEqual("123456789012");
});

it("should use the STS client middleware", async () => {
const customMiddlewareFunction = jest.fn();
const roleAssumerWithWebIdentity = getDefaultRoleAssumerWithWebIdentity({}, [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,17 @@ describe("getDefaultRoleAssumer", () => {
);
});

it("should return accountId in the credentials", async () => {
const roleAssumer = getDefaultRoleAssumer();
const params: AssumeRoleCommandInput = {
RoleArn: "arn:aws:foo",
RoleSessionName: "session",
};
const sourceCred = { accessKeyId: "key", secretAccessKey: "secrete" };
const assumedRole = await roleAssumer(sourceCred, params);
expect(assumedRole.accountId).toEqual("123");
});

it("should use the STS client config", async () => {
const logger = console;
const region = "some-region";
Expand Down Expand Up @@ -167,6 +178,10 @@ describe("getDefaultRoleAssumer", () => {
describe("getDefaultRoleAssumerWithWebIdentity", () => {
const assumeRoleResponse = `<Response xmlns="https://sts.amazonaws.com/doc/2011-06-15/">
<AssumeRoleWithWebIdentityResult>
<AssumedRoleUser>
<AssumedRoleId>AROAZOX2IL27GNRBJHWC2:session</AssumedRoleId>
<Arn>arn:aws:sts::123456789012:assumed-role/assume-role-test/session</Arn>
</AssumedRoleUser>
<Credentials>
<AccessKeyId>key</AccessKeyId>
<SecretAccessKey>secrete</SecretAccessKey>
Expand Down Expand Up @@ -207,6 +222,17 @@ describe("getDefaultRoleAssumerWithWebIdentity", () => {
});
});

it("should return accountId in the credentials", async () => {
const roleAssumerWithWebIdentity = getDefaultRoleAssumerWithWebIdentity();
const params: AssumeRoleWithWebIdentityCommandInput = {
RoleArn: "arn:aws:foo",
RoleSessionName: "session",
WebIdentityToken: "token",
};
const assumedRole = await roleAssumerWithWebIdentity(params);
expect(assumedRole.accountId).toEqual("123456789012");
});

it("should use the STS client middleware", async () => {
const customMiddlewareFunction = jest.fn();
const roleAssumerWithWebIdentity = getDefaultRoleAssumerWithWebIdentity({}, [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,31 @@ export type RoleAssumer = (

const ASSUME_ROLE_DEFAULT_REGION = "us-east-1";

interface AssumedRoleUser {
/**
* The ARN of the temporary security credentials that are returned from the AssumeRole action.
*/
Arn?: string;

/**
* A unique identifier that contains the role ID and the role session name of the role that is being assumed.
*/
AssumedRoleId?: string;
}

/**
* @internal
*/
const getAccountIdFromAssumedRoleUser = (assumedRoleUser?: AssumedRoleUser) => {
if (typeof assumedRoleUser?.Arn === "string") {
const arnComponents = assumedRoleUser.Arn.split(":");
if (arnComponents.length > 4 && arnComponents[4] !== "") {
return arnComponents[4];
}
}
return undefined;
};

/**
* @internal
*
Expand Down Expand Up @@ -81,17 +106,21 @@ export const getDefaultRoleAssumer = (
logger: logger as any,
});
}
const { Credentials } = await stsClient.send(new AssumeRoleCommand(params));
const { Credentials, AssumedRoleUser } = await stsClient.send(new AssumeRoleCommand(params));
if (!Credentials || !Credentials.AccessKeyId || !Credentials.SecretAccessKey) {
throw new Error(`Invalid response from STS.assumeRole call with role ${params.RoleArn}`);
}

const accountId = getAccountIdFromAssumedRoleUser(AssumedRoleUser);

return {
accessKeyId: Credentials.AccessKeyId,
secretAccessKey: Credentials.SecretAccessKey,
sessionToken: Credentials.SessionToken,
expiration: Credentials.Expiration,
// TODO(credentialScope): access normally when shape is updated.
credentialScope: (Credentials as any).CredentialScope,
...((Credentials as any).CredentialScope && { credentialScope: (Credentials as any).CredentialScope }),
...(accountId && { accountId }),
};
};
};
Expand Down Expand Up @@ -131,17 +160,21 @@ export const getDefaultRoleAssumerWithWebIdentity = (
logger: logger as any,
});
}
const { Credentials } = await stsClient.send(new AssumeRoleWithWebIdentityCommand(params));
const { Credentials, AssumedRoleUser } = await stsClient.send(new AssumeRoleWithWebIdentityCommand(params));
if (!Credentials || !Credentials.AccessKeyId || !Credentials.SecretAccessKey) {
throw new Error(`Invalid response from STS.assumeRoleWithWebIdentity call with role ${params.RoleArn}`);
}

const accountId = getAccountIdFromAssumedRoleUser(AssumedRoleUser);

return {
accessKeyId: Credentials.AccessKeyId,
secretAccessKey: Credentials.SecretAccessKey,
sessionToken: Credentials.SessionToken,
expiration: Credentials.Expiration,
// TODO(credentialScope): access normally when shape is updated.
credentialScope: (Credentials as any).CredentialScope,
...((Credentials as any).CredentialScope && { credentialScope: (Credentials as any).CredentialScope }),
...(accountId && { accountId }),
};
};
};
Expand Down
20 changes: 18 additions & 2 deletions packages/credential-provider-env/src/fromEnv.spec.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import { CredentialsProviderError } from "@smithy/property-provider";

import { ENV_EXPIRATION, ENV_KEY, ENV_SECRET, ENV_SESSION, fromEnv } from "./fromEnv";
import { ENV_ACCOUNT_ID, ENV_EXPIRATION, ENV_KEY, ENV_SECRET, ENV_SESSION, fromEnv } from "./fromEnv";

describe(fromEnv.name, () => {
const ORIGINAL_ENV = process.env;
const mockAccessKeyId = "mockAccessKeyId";
const mockSecretAccessKey = "mockSecretAccessKey";
const mockSessionToken = "mockSessionToken";
const mockExpiration = new Date().toISOString();
const mockAccountId = "123456789012";

beforeEach(() => {
process.env = {
Expand All @@ -16,6 +17,7 @@ describe(fromEnv.name, () => {
[ENV_SECRET]: mockSecretAccessKey,
[ENV_SESSION]: mockSessionToken,
[ENV_EXPIRATION]: mockExpiration,
[ENV_ACCOUNT_ID]: mockAccountId,
};
});

Expand All @@ -30,19 +32,33 @@ describe(fromEnv.name, () => {
secretAccessKey: mockSecretAccessKey,
sessionToken: mockSessionToken,
expiration: new Date(mockExpiration),
accountId: mockAccountId,
});
});

it("can create credentials without a session token or expiration", async () => {
it("can create credentials without a session token, accountId, or expiration", async () => {
delete process.env[ENV_SESSION];
delete process.env[ENV_EXPIRATION];
delete process.env[ENV_ACCOUNT_ID];
const receivedCreds = await fromEnv()();
expect(receivedCreds).toStrictEqual({
accessKeyId: mockAccessKeyId,
secretAccessKey: mockSecretAccessKey,
});
});

it("should include accountId when it is provided in environment variables", async () => {
process.env[ENV_ACCOUNT_ID] = mockAccountId;
const receivedCreds = await fromEnv()();
expect(receivedCreds).toHaveProperty("accountId", mockAccountId);
});

it("should not include accountId when it is not provided in environment variables", async () => {
delete process.env[ENV_ACCOUNT_ID]; // Ensure accountId is not set
const receivedCreds = await fromEnv()();
expect(receivedCreds).not.toHaveProperty("accountId");
});

it.each([ENV_KEY, ENV_SECRET])("throws if env['%s'] is not found", async (key) => {
delete process.env[key];
const expectedError = new CredentialsProviderError("Unable to find environment variable credentials.");
Expand Down
6 changes: 6 additions & 0 deletions packages/credential-provider-env/src/fromEnv.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ export const ENV_EXPIRATION = "AWS_CREDENTIAL_EXPIRATION";
* @internal
*/
export const ENV_CREDENTIAL_SCOPE = "AWS_CREDENTIAL_SCOPE";
/**
* @internal
*/
export const ENV_ACCOUNT_ID = "AWS_ACCOUNT_ID";

/**
* @internal
Expand All @@ -41,6 +45,7 @@ export const fromEnv =
const sessionToken: string | undefined = process.env[ENV_SESSION];
const expiry: string | undefined = process.env[ENV_EXPIRATION];
const credentialScope: string | undefined = process.env[ENV_CREDENTIAL_SCOPE];
const accountId: string | undefined = process.env[ENV_ACCOUNT_ID];

if (accessKeyId && secretAccessKey) {
return {
Expand All @@ -49,6 +54,7 @@ export const fromEnv =
...(sessionToken && { sessionToken }),
...(expiry && { expiration: new Date(expiry) }),
...(credentialScope && { credentialScope }),
...(accountId && { accountId }),
};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ const getMockStaticCredsProfile = () => ({
aws_secret_access_key: "mock_aws_secret_access_key",
aws_session_token: "mock_aws_session_token",
aws_credential_scope: "mock_aws_credential_scope",
aws_account_id: "mock_aws_account_id",
});

describe(isStaticCredsProfile.name, () => {
Expand Down Expand Up @@ -32,6 +33,12 @@ describe(isStaticCredsProfile.name, () => {
});
});

it.each(["aws_account_id"])("value at '%s' is not of type string | undefined", (key) => {
[true, null, 1, NaN, {}].forEach((value) => {
expect(isStaticCredsProfile({ ...getMockStaticCredsProfile(), [key]: value })).toEqual(false);
});
});

it("returns true for StaticCredentialsProfile", () => {
expect(isStaticCredsProfile(getMockStaticCredsProfile())).toEqual(true);
});
Expand All @@ -46,6 +53,7 @@ describe(resolveStaticCredentials.name, () => {
secretAccessKey: mockProfile.aws_secret_access_key,
sessionToken: mockProfile.aws_session_token,
credentialScope: mockProfile.aws_credential_scope,
accountId: mockProfile.aws_account_id,
});
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ export interface StaticCredsProfile extends Profile {
aws_secret_access_key: string;
aws_session_token?: string;
aws_credential_scope?: string;
aws_account_id?: string;
}

/**
Expand All @@ -20,7 +21,8 @@ export const isStaticCredsProfile = (arg: any): arg is StaticCredsProfile =>
typeof arg === "object" &&
typeof arg.aws_access_key_id === "string" &&
typeof arg.aws_secret_access_key === "string" &&
["undefined", "string"].indexOf(typeof arg.aws_session_token) > -1;
["undefined", "string"].indexOf(typeof arg.aws_session_token) > -1 &&
["undefined", "string"].indexOf(typeof arg.aws_account_id) > -1;

/**
* @internal
Expand All @@ -34,6 +36,7 @@ export const resolveStaticCredentials = (
accessKeyId: profile.aws_access_key_id,
secretAccessKey: profile.aws_secret_access_key,
sessionToken: profile.aws_session_token,
credentialScope: profile.aws_credential_scope,
...(profile.aws_credential_scope && { credentialScope: profile.aws_credential_scope }),
...(profile.aws_account_id && { accountId: profile.aws_account_id }),
});
};
Loading

0 comments on commit 83cd253

Please sign in to comment.