Skip to content

Commit

Permalink
fix(auth): Fix isSignedIn states (#2830)
Browse files Browse the repository at this point in the history
  • Loading branch information
tylerjroach authored Jun 14, 2024
1 parent 37fc35e commit f2fd6f7
Show file tree
Hide file tree
Showing 18 changed files with 506 additions and 85 deletions.
1 change: 1 addition & 0 deletions aws-auth-cognito/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ dependencies {
testImplementation(libs.test.androidx.core)
testImplementation(libs.test.kotlin.reflection)
testImplementation(libs.test.kotest.assertions)
testImplementation(libs.test.kotest.assertions.json)

androidTestImplementation(libs.gson)
//noinspection GradleDependency
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,21 +65,39 @@ internal fun AmplifyCredential.isValid(): Boolean {
}

internal fun AmplifyCredential.getCognitoSession(
exception: AuthException = SignedOutException()
exception: AuthException? = null
): AWSAuthSessionBehavior<AWSCognitoUserPoolTokens> {
fun getCredentialsResult(awsCredentials: CognitoCredentials): AuthSessionResult<AWSCredentials> =
with(awsCredentials) {

fun getCredentialsResult(
awsCredentials: CognitoCredentials,
exception: AuthException?
): AuthSessionResult<AWSCredentials> {
if (exception != null && exception !is SignedOutException) {
return AuthSessionResult.failure(exception)
}

return with(awsCredentials) {
AWSCredentials.createAWSCredentials(accessKeyId, secretAccessKey, sessionToken, expiration)
}?.let {
AuthSessionResult.success(it)
} ?: AuthSessionResult.failure(UnknownException("Failed to fetch AWS credentials."))
}

fun getIdentityIdResult(identityId: String): AuthSessionResult<String> {
return if (identityId.isNotEmpty()) AuthSessionResult.success(identityId)
else AuthSessionResult.failure(UnknownException("Failed to fetch identity id."))
fun getIdentityIdResult(identityId: String, exception: AuthException?): AuthSessionResult<String> {
return if (exception != null && exception !is SignedOutException) {
AuthSessionResult.failure(exception)
} else if (identityId.isNotEmpty()) {
AuthSessionResult.success(identityId)
} else {
AuthSessionResult.failure(UnknownException("Failed to fetch identity id."))
}
}

fun getUserSubResult(userPoolTokens: CognitoUserPoolTokens?): AuthSessionResult<String> {
fun getUserSubResult(userPoolTokens: CognitoUserPoolTokens?, exception: AuthException?): AuthSessionResult<String> {
if (exception != null && exception !is SignedOutException) {
return AuthSessionResult.failure(exception)
}

return try {
AuthSessionResult.success(userPoolTokens?.accessToken?.let(SessionHelper::getUserSub))
} catch (e: Exception) {
Expand All @@ -88,8 +106,13 @@ internal fun AmplifyCredential.getCognitoSession(
}

fun getUserPoolTokensResult(
cognitoUserPoolTokens: CognitoUserPoolTokens
cognitoUserPoolTokens: CognitoUserPoolTokens,
exception: AuthException?
): AuthSessionResult<AWSCognitoUserPoolTokens> {
if (exception != null && exception !is SignedOutException) {
return AuthSessionResult.failure(exception)
}

return AuthSessionResult.success(
AWSCognitoUserPoolTokens(
accessToken = cognitoUserPoolTokens.accessToken,
Expand All @@ -113,20 +136,20 @@ internal fun AmplifyCredential.getCognitoSession(
"Cognito Identity not configured. Please check amplifyconfiguration.json file."
)
),
userSubResult = getUserSubResult(signedInData.cognitoUserPoolTokens),
userPoolTokensResult = getUserPoolTokensResult(signedInData.cognitoUserPoolTokens)
userSubResult = getUserSubResult(signedInData.cognitoUserPoolTokens, exception),
userPoolTokensResult = getUserPoolTokensResult(signedInData.cognitoUserPoolTokens, exception)
)
is AmplifyCredential.UserAndIdentityPool -> AWSCognitoAuthSession(
true,
identityIdResult = getIdentityIdResult(identityId),
awsCredentialsResult = getCredentialsResult(credentials),
userSubResult = getUserSubResult(signedInData.cognitoUserPoolTokens),
userPoolTokensResult = getUserPoolTokensResult(signedInData.cognitoUserPoolTokens)
identityIdResult = getIdentityIdResult(identityId, exception),
awsCredentialsResult = getCredentialsResult(credentials, exception),
userSubResult = getUserSubResult(signedInData.cognitoUserPoolTokens, exception),
userPoolTokensResult = getUserPoolTokensResult(signedInData.cognitoUserPoolTokens, exception)
)
is AmplifyCredential.IdentityPool -> AWSCognitoAuthSession(
false,
identityIdResult = getIdentityIdResult(identityId),
awsCredentialsResult = getCredentialsResult(credentials),
identityIdResult = getIdentityIdResult(identityId, exception),
awsCredentialsResult = getCredentialsResult(credentials, exception),
userSubResult = AuthSessionResult.failure(SignedOutException()),
userPoolTokensResult = AuthSessionResult.failure(SignedOutException())
)
Expand All @@ -137,8 +160,8 @@ internal fun AmplifyCredential.getCognitoSession(
)
AWSCognitoAuthSession(
true,
identityIdResult = getIdentityIdResult(identityId),
awsCredentialsResult = getCredentialsResult(credentials),
identityIdResult = getIdentityIdResult(identityId, exception),
awsCredentialsResult = getCredentialsResult(credentials, exception),
userSubResult = AuthSessionResult.failure(userPoolException),
userPoolTokensResult = AuthSessionResult.failure(userPoolException)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1127,7 +1127,7 @@ internal class RealAWSCognitoAuthPlugin(
when (val authZState = authState.authZState) {
is AuthorizationState.Configured -> {
authStateMachine.send(AuthorizationEvent(AuthorizationEvent.EventType.FetchUnAuthSession))
_fetchAuthSession(onSuccess, onError)
_fetchAuthSession(onSuccess)
}
is AuthorizationState.SessionEstablished -> {
val credential = authZState.amplifyCredential
Expand All @@ -1147,7 +1147,7 @@ internal class RealAWSCognitoAuthPlugin(
AuthorizationEvent(AuthorizationEvent.EventType.RefreshSession(credential))
)
}
_fetchAuthSession(onSuccess, onError)
_fetchAuthSession(onSuccess)
} else {
onSuccess.accept(credential.getCognitoSession())
}
Expand All @@ -1171,7 +1171,7 @@ internal class RealAWSCognitoAuthPlugin(
AuthorizationEvent(AuthorizationEvent.EventType.RefreshSession(amplifyCredential))
)
}
_fetchAuthSession(onSuccess, onError)
_fetchAuthSession(onSuccess)
} else {
onError.accept(InvalidStateException())
}
Expand All @@ -1182,8 +1182,7 @@ internal class RealAWSCognitoAuthPlugin(
}

private fun _fetchAuthSession(
onSuccess: Consumer<AuthSession>,
onError: Consumer<AuthException>
onSuccess: Consumer<AuthSession>
) {
val token = StateChangeListenerToken()
authStateMachine.listen(
Expand All @@ -1198,23 +1197,23 @@ internal class RealAWSCognitoAuthPlugin(
authStateMachine.cancel(token)
when (val error = authZState.exception) {
is SessionError -> {
when (error.exception) {
when (val innerException = error.exception) {
is SignedOutException -> {
onSuccess.accept(error.amplifyCredential.getCognitoSession(error.exception))
onSuccess.accept(error.amplifyCredential.getCognitoSession(innerException))
}
is SessionExpiredException -> {
onSuccess.accept(AmplifyCredential.Empty.getCognitoSession(error.exception))
onSuccess.accept(error.amplifyCredential.getCognitoSession(innerException))
sendHubEvent(AuthChannelEventName.SESSION_EXPIRED.toString())
}
is ServiceException -> {
onSuccess.accept(AmplifyCredential.Empty.getCognitoSession(error.exception))
onSuccess.accept(error.amplifyCredential.getCognitoSession(innerException))
}
is NotAuthorizedException -> {
onSuccess.accept(AmplifyCredential.Empty.getCognitoSession(error.exception))
onSuccess.accept(error.amplifyCredential.getCognitoSession(innerException))
}
else -> {
val errorResult = UnknownException("Fetch auth session failed.", error)
onSuccess.accept(AmplifyCredential.Empty.getCognitoSession(errorResult))
val errorResult = UnknownException("Fetch auth session failed.", innerException)
onSuccess.accept(error.amplifyCredential.getCognitoSession(errorResult))
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,11 +136,7 @@ internal object FetchAuthSessionCognitoActions : FetchAuthSessionActions {
)
AuthorizationEvent(AuthorizationEvent.EventType.ThrowError(exception))
} catch (e: Exception) {
val exception = SignedOutException(
recoverySuggestion = SignedOutException.RECOVERY_SUGGESTION_GUEST_ACCESS_POSSIBLE,
cause = e
)
AuthorizationEvent(AuthorizationEvent.EventType.ThrowError(exception))
AuthorizationEvent(AuthorizationEvent.EventType.ThrowError(e))
}
logger.verbose("$id Sending event ${evt.type}")
dispatcher.send(evt)
Expand Down Expand Up @@ -173,11 +169,7 @@ internal object FetchAuthSessionCognitoActions : FetchAuthSessionActions {
)
AuthorizationEvent(AuthorizationEvent.EventType.ThrowError(exception))
} catch (e: Exception) {
val exception = SignedOutException(
recoverySuggestion = SignedOutException.RECOVERY_SUGGESTION_GUEST_ACCESS_POSSIBLE,
cause = e
)
AuthorizationEvent(AuthorizationEvent.EventType.ThrowError(exception))
AuthorizationEvent(AuthorizationEvent.EventType.ThrowError(e))
}
logger.verbose("$id Sending event ${evt.type}")
dispatcher.send(evt)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ import com.amplifyframework.statemachine.codegen.data.DeviceMetadata
import com.amplifyframework.statemachine.codegen.states.AuthState
import featureTest.utilities.CognitoMockFactory
import featureTest.utilities.CognitoRequestFactory
import featureTest.utilities.TimeZoneRule
import featureTest.utilities.apiExecutor
import io.kotest.assertions.json.shouldEqualJson
import io.mockk.clearAllMocks
import io.mockk.coEvery
import io.mockk.coVerify
Expand All @@ -43,6 +45,7 @@ import io.mockk.mockkObject
import io.mockk.mockkStatic
import io.mockk.slot
import java.io.File
import java.util.TimeZone
import java.util.concurrent.CountDownLatch
import java.util.concurrent.TimeUnit
import kotlin.reflect.full.callSuspend
Expand All @@ -56,13 +59,16 @@ import kotlinx.serialization.json.Json
import org.json.JSONObject
import org.junit.After
import org.junit.Before
import org.junit.Rule
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.Parameterized

@RunWith(Parameterized::class)
class AWSCognitoAuthPluginFeatureTest(private val testCase: FeatureTestCase) {

@Rule @JvmField val timeZoneRule = TimeZoneRule(TimeZone.getTimeZone("US/Pacific"))

lateinit var feature: FeatureTestCase
private var apiExecutionResult: Any? = null

Expand Down Expand Up @@ -114,6 +120,7 @@ class AWSCognitoAuthPluginFeatureTest(private val testCase: FeatureTestCase) {

@Before
fun setUp() {
// set timezone to be same as generated json from JsonGenerator
Dispatchers.setMain(mainThreadSurrogate)
feature = testCase
sut.realPlugin = readConfiguration(feature.preConditions.`amplify-configuration`)
Expand Down Expand Up @@ -189,9 +196,9 @@ class AWSCognitoAuthPluginFeatureTest(private val testCase: FeatureTestCase) {
is Cognito -> verifyCognito(validation)

is ExpectationShapes.Amplify -> {
val expectedResponse = validation.response

assertEquals(expectedResponse, apiExecutionResult.toJsonElement())
val expected = validation.response.toString()
val actual = apiExecutionResult.toJsonElement().toString()
actual shouldEqualJson expected
}
is ExpectationShapes.State -> {
val getStateLatch = CountDownLatch(1)
Expand All @@ -214,7 +221,7 @@ class AWSCognitoAuthPluginFeatureTest(private val testCase: FeatureTestCase) {

coVerify {
when (validation) {
is CognitoIdentity -> mockCognitoIdClient to mockCognitoIPClient::class
is CognitoIdentity -> mockCognitoIdClient to mockCognitoIdClient::class
is CognitoIdentityProvider -> mockCognitoIPClient to mockCognitoIPClient::class
}.apply {
second.declaredFunctions.first {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,22 @@ import com.amplifyframework.auth.cognito.featuretest.generators.testcasegenerato
import com.amplifyframework.auth.cognito.featuretest.generators.testcasegenerators.SignOutTestCaseGenerator
import com.amplifyframework.auth.cognito.featuretest.generators.testcasegenerators.SignUpTestCaseGenerator
import com.amplifyframework.statemachine.codegen.states.AuthState
import org.junit.Ignore
import org.junit.Test

interface SerializableProvider {
val serializables: List<Any>

// used to reset any global state changes during generation
fun tearDown() {
// default no op
}
}

/**
* Top level generator for generating Json and writing to the destination directory
*/
object JsonGenerator {
class JsonGenerator {
private val providers: List<SerializableProvider> = listOf(
AuthStateJsonGenerator,
ResetPasswordTestCaseGenerator,
Expand All @@ -53,7 +60,26 @@ object JsonGenerator {
FetchUserAttributesTestCaseGenerator,
)

@Ignore("Uncomment and run to clean feature test directory")
@Test
fun clean() {
cleanDirectory()
}

@Ignore("Uncomment and run to clean feature test directory as well as generate json for feature tests.")
@Test
fun cleanAndGenerate() {
cleanDirectory()
generateJson()
}

@Ignore("Uncomment and run to generate json for feature tests.")
@Test
fun generate() {
generateJson()
}

private fun generateJson() {
providers.forEach { provider ->
provider.serializables.forEach {
when (it) {
Expand All @@ -64,11 +90,7 @@ object JsonGenerator {
}
}
}
provider.tearDown()
}
}
}

fun main() {
// cleanDirectory()
JsonGenerator.generate()
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ import kotlinx.serialization.json.JsonNull
import kotlinx.serialization.json.JsonObject
import kotlinx.serialization.json.JsonPrimitive

const val basePath = "aws-auth-cognito/src/test/resources/feature-test"
const val basePath = "src/test/resources/feature-test"

fun writeFile(json: String, dirName: String, fileName: String) {
val directory = File("$basePath/$dirName")
Expand Down
Loading

0 comments on commit f2fd6f7

Please sign in to comment.