Skip to content

Commit

Permalink
fix(auth): Add userAttributes to confirmSignIn call (#2640)
Browse files Browse the repository at this point in the history
  • Loading branch information
tylerjroach authored Nov 21, 2023
1 parent 644be2d commit 3f2ea63
Show file tree
Hide file tree
Showing 8 changed files with 183 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -785,12 +785,15 @@ internal class RealAWSCognitoAuthPlugin(
},
{
val awsCognitoConfirmSignInOptions = options as? AWSCognitoAuthConfirmSignInOptions
val metadata = awsCognitoConfirmSignInOptions?.metadata ?: emptyMap()
val userAttributes = awsCognitoConfirmSignInOptions?.userAttributes ?: emptyList()
when (signInState) {
is SignInState.ResolvingChallenge -> {
val event = SignInChallengeEvent(
SignInChallengeEvent.EventType.VerifyChallengeAnswer(
challengeResponse,
awsCognitoConfirmSignInOptions?.metadata ?: mapOf()
metadata,
userAttributes
)
)
authStateMachine.send(event)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package com.amplifyframework.auth.cognito.actions
import aws.sdk.kotlin.services.cognitoidentityprovider.model.ChallengeNameType
import aws.sdk.kotlin.services.cognitoidentityprovider.model.ResourceNotFoundException
import aws.sdk.kotlin.services.cognitoidentityprovider.respondToAuthChallenge
import com.amplifyframework.auth.AuthUserAttribute
import com.amplifyframework.auth.cognito.AuthEnvironment
import com.amplifyframework.auth.cognito.helpers.AuthHelper
import com.amplifyframework.auth.cognito.helpers.SignInChallengeHelper
Expand All @@ -32,9 +33,11 @@ import com.amplifyframework.statemachine.codegen.events.SignInChallengeEvent
internal object SignInChallengeCognitoActions : SignInChallengeActions {
private const val KEY_SECRET_HASH = "SECRET_HASH"
private const val KEY_USERNAME = "USERNAME"
private const val KEY_PREFIX_USER_ATTRIBUTE = "userAttributes."
override fun verifyChallengeAuthAction(
answer: String,
metadata: Map<String, String>,
attributes: List<AuthUserAttribute>,
challenge: AuthChallenge
): Action = Action<AuthEnvironment>("VerifySignInChallenge") { id, dispatcher ->
logger.verbose("$id Starting execution")
Expand All @@ -50,6 +53,12 @@ internal object SignInChallengeCognitoActions : SignInChallengeActions {
challengeResponses[responseKey] = answer
}

challengeResponses.putAll(
attributes.map {
Pair("${KEY_PREFIX_USER_ATTRIBUTE}${it.key.keyString}", it.value)
}
)

val secretHash = AuthHelper.getSecretHash(
username,
configuration.userPool?.appClient,
Expand Down Expand Up @@ -90,6 +99,7 @@ internal object SignInChallengeCognitoActions : SignInChallengeActions {
SignInChallengeEvent.EventType.RetryVerifyChallengeAnswer(
answer,
metadata,
attributes,
challenge
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@

package com.amplifyframework.statemachine.codegen.actions

import com.amplifyframework.auth.AuthUserAttribute
import com.amplifyframework.statemachine.Action
import com.amplifyframework.statemachine.codegen.data.AuthChallenge

internal interface SignInChallengeActions {
fun verifyChallengeAuthAction(
answer: String,
metadata: Map<String, String>,
userAttributes: List<AuthUserAttribute>,
challenge: AuthChallenge
): Action
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,24 @@

package com.amplifyframework.statemachine.codegen.events

import com.amplifyframework.auth.AuthUserAttribute
import com.amplifyframework.statemachine.StateMachineEvent
import com.amplifyframework.statemachine.codegen.data.AuthChallenge
import java.util.Date

internal class SignInChallengeEvent(val eventType: EventType, override val time: Date? = null) : StateMachineEvent {
sealed class EventType {
data class WaitForAnswer(val challenge: AuthChallenge, val hasNewResponse: Boolean = false) : EventType()
data class VerifyChallengeAnswer(val answer: String, val metadata: Map<String, String>) : EventType()
data class VerifyChallengeAnswer(
val answer: String,
val metadata: Map<String, String>,
val userAttributes: List<AuthUserAttribute>
) : EventType()

data class RetryVerifyChallengeAnswer(
val answer: String,
val metadata: Map<String, String>,
val userAttributes: List<AuthUserAttribute>,
val authChallenge: AuthChallenge
) : EventType()
data class FinalizeSignIn(val accessToken: String) : EventType()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,10 @@ internal sealed class SignInChallengeState : State {
is WaitingForAnswer -> when (challengeEvent) {
is SignInChallengeEvent.EventType.VerifyChallengeAnswer -> {
val action = challengeActions.verifyChallengeAuthAction(
challengeEvent.answer, challengeEvent.metadata, oldState.challenge
challengeEvent.answer,
challengeEvent.metadata,
challengeEvent.userAttributes,
oldState.challenge
)
StateResolution(Verifying(oldState.challenge.challengeName), listOf(action))
}
Expand All @@ -78,7 +81,10 @@ internal sealed class SignInChallengeState : State {
}
is SignInChallengeEvent.EventType.RetryVerifyChallengeAnswer -> {
val action = challengeActions.verifyChallengeAuthAction(
challengeEvent.answer, challengeEvent.metadata, challengeEvent.authChallenge
challengeEvent.answer,
challengeEvent.metadata,
challengeEvent.userAttributes,
challengeEvent.authChallenge,
)
StateResolution(Verifying(challengeEvent.authChallenge.challengeName), listOf(action))
}
Expand All @@ -92,7 +98,10 @@ internal sealed class SignInChallengeState : State {
when (challengeEvent) {
is SignInChallengeEvent.EventType.VerifyChallengeAnswer -> {
val action = challengeActions.verifyChallengeAuthAction(
challengeEvent.answer, challengeEvent.metadata, oldState.challenge
challengeEvent.answer,
challengeEvent.metadata,
challengeEvent.userAttributes,
oldState.challenge,
)
StateResolution(Verifying(oldState.challenge.challengeName), listOf(action))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,7 @@ open class StateTransitionTestBase {

Mockito.`when`(
mockSignInChallengeActions.verifyChallengeAuthAction(
MockitoHelper.anyObject(),
MockitoHelper.anyObject(),
MockitoHelper.anyObject(),
MockitoHelper.anyObject()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -387,12 +387,13 @@ class StateTransitionTests : StateTransitionTestBase() {
SignInChallengeEvent(
SignInChallengeEvent.EventType.RetryVerifyChallengeAnswer(
"test",
mapOf(),
emptyMap(),
emptyList(),
AuthChallenge(
ChallengeNameType.CustomChallenge.toString(),
"Test",
"session_mock_value",
mapOf()
emptyMap(),
)
)
)
Expand All @@ -401,7 +402,8 @@ class StateTransitionTests : StateTransitionTestBase() {
SignInChallengeEvent(
SignInChallengeEvent.EventType.VerifyChallengeAnswer(
"test",
mapOf()
emptyMap(),
emptyList()
)
)
)
Expand Down Expand Up @@ -481,7 +483,7 @@ class StateTransitionTests : StateTransitionTestBase() {
challengeState?.apply {
stateMachine.send(
SignInChallengeEvent(
SignInChallengeEvent.EventType.VerifyChallengeAnswer("test", mapOf())
SignInChallengeEvent.EventType.VerifyChallengeAnswer("test", emptyMap(), emptyList())
)
)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
/*
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://aws.amazon.com/apache2.0
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/
package com.amplifyframework.auth.cognito.actions

import androidx.test.core.app.ApplicationProvider
import aws.sdk.kotlin.services.cognitoidentityprovider.CognitoIdentityProviderClient
import aws.sdk.kotlin.services.cognitoidentityprovider.model.RespondToAuthChallengeRequest
import com.amplifyframework.auth.AuthUserAttribute
import com.amplifyframework.auth.AuthUserAttributeKey
import com.amplifyframework.auth.cognito.AWSCognitoAuthService
import com.amplifyframework.auth.cognito.AuthEnvironment
import com.amplifyframework.auth.cognito.StoreClientBehavior
import com.amplifyframework.logging.Logger
import com.amplifyframework.statemachine.EventDispatcher
import com.amplifyframework.statemachine.StateMachineEvent
import com.amplifyframework.statemachine.codegen.data.AmplifyCredential
import com.amplifyframework.statemachine.codegen.data.AuthChallenge
import com.amplifyframework.statemachine.codegen.data.AuthConfiguration
import com.amplifyframework.statemachine.codegen.data.CredentialType
import com.amplifyframework.statemachine.codegen.data.UserPoolConfiguration
import io.mockk.coEvery
import io.mockk.every
import io.mockk.mockk
import io.mockk.slot
import junit.framework.TestCase.assertTrue
import kotlin.test.assertEquals
import kotlinx.coroutines.test.runTest
import org.junit.Before
import org.junit.Test
import org.junit.runner.RunWith
import org.robolectric.RobolectricTestRunner

@RunWith(RobolectricTestRunner::class)
class SignInChallengeCognitoActionsTest {

private val pool = mockk<UserPoolConfiguration> {
every { appClient } returns "client"
every { appClientSecret } returns null
every { pinpointAppId } returns null
}
private val configuration = mockk<AuthConfiguration> {
every { userPool } returns pool
}
private val cognitoAuthService = mockk<AWSCognitoAuthService>()
private val credentialStoreClient = mockk<StoreClientBehavior> {
coEvery { loadCredentials(CredentialType.ASF) } returns AmplifyCredential.ASFDevice("asf_id")
}
private val logger = mockk<Logger>(relaxed = true)
private val cognitoIdentityProviderClientMock = mockk<CognitoIdentityProviderClient>()

private val capturedEvent = slot<StateMachineEvent>()
private val dispatcher = mockk<EventDispatcher> {
every { send(capture(capturedEvent)) }.answers { }
}

private lateinit var authEnvironment: AuthEnvironment

@Before
fun setup() {
every { cognitoAuthService.cognitoIdentityProviderClient }.answers { cognitoIdentityProviderClientMock }
authEnvironment = AuthEnvironment(
ApplicationProvider.getApplicationContext(),
configuration,
cognitoAuthService,
credentialStoreClient,
null,
null,
logger
)
}

@Test
fun `very auth challenge without user attributes`() = runTest {
val expectedChallengeResponses = mapOf(
"USERNAME" to "testUser"
)
val capturedRequest = slot<RespondToAuthChallengeRequest>()
coEvery {
cognitoIdentityProviderClientMock.respondToAuthChallenge(capture(capturedRequest))
}.answers {
mockk()
}

SignInChallengeCognitoActions.verifyChallengeAuthAction(
"myAnswer",
emptyMap(),
emptyList(),
AuthChallenge(
"CONFIRM_SIGN_IN_WITH_NEW_PASSWORD",
username = "testUser",
session = null,
parameters = null
)
).execute(dispatcher, authEnvironment)

assertTrue(capturedRequest.isCaptured)
assertEquals(expectedChallengeResponses, capturedRequest.captured.challengeResponses)
}

@Test
fun `user attributes are added to auth challenge`() = runTest {
val providedUserAttributes = listOf(AuthUserAttribute(AuthUserAttributeKey.phoneNumber(), "+15555555555"))
val expectedChallengeResponses = mapOf(
"USERNAME" to "testUser",
"userAttributes.phone_number" to "+15555555555"
)
val capturedRequest = slot<RespondToAuthChallengeRequest>()
coEvery {
cognitoIdentityProviderClientMock.respondToAuthChallenge(capture(capturedRequest))
}.answers {
mockk()
}

SignInChallengeCognitoActions.verifyChallengeAuthAction(
"myAnswer",
emptyMap(),
providedUserAttributes,
AuthChallenge(
"CONFIRM_SIGN_IN_WITH_NEW_PASSWORD",
username = "testUser",
session = null,
parameters = null
)
).execute(dispatcher, authEnvironment)

assertTrue(capturedRequest.isCaptured)
assertEquals(expectedChallengeResponses, capturedRequest.captured.challengeResponses)
}
}

0 comments on commit 3f2ea63

Please sign in to comment.