Skip to content

Commit

Permalink
fix(DataStore): retry MutationEvents on signed-out and token expired …
Browse files Browse the repository at this point in the history
…errors (#3487)

* fix(DataStore): retry MutationEvents on signed-out and token expired errors

* address PR comments
  • Loading branch information
lawmicha authored Jan 30, 2024
1 parent a512b2e commit f64c471
Show file tree
Hide file tree
Showing 6 changed files with 216 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,20 @@ struct AWSAPIEndpointInterceptors {

var postludeInterceptors: [URLRequestInterceptor] = []

/// Validates whether the access token has expired. A best-effort attempt is made,
/// and it returns `false` if the expiration cannot be determined.
var expiryValidator: ((String) -> Bool) {
{ token in
guard let authService,
let claims = try? authService.getTokenClaims(tokenString: token).get(),
let tokenExpiration = claims["exp"]?.doubleValue else {
return false
}
let currentTime = Date().timeIntervalSince1970
return currentTime > tokenExpiration
}
}

init(endpointName: APIEndpointName,
apiAuthProviderFactory: APIAuthProviderFactory,
authService: AWSAuthServiceBehavior? = nil) {
Expand Down Expand Up @@ -71,7 +85,8 @@ struct AWSAPIEndpointInterceptors {
"")
}
let provider = BasicUserPoolTokenProvider(authService: authService)
let interceptor = AuthTokenURLRequestInterceptor(authTokenProvider: provider)
let interceptor = AuthTokenURLRequestInterceptor(authTokenProvider: provider,
isTokenExpired: expiryValidator)
preludeInterceptors.append(interceptor)
case .openIDConnect:
guard let oidcAuthProvider = apiAuthProviderFactory.oidcAuthProvider() else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@ struct AuthTokenURLRequestInterceptor: URLRequestInterceptor {

private let userAgent = AmplifyAWSServiceConfiguration.userAgentLib
let authTokenProvider: AuthTokenProvider
let isTokenExpired: ((String) -> Bool)?

init(authTokenProvider: AuthTokenProvider) {
init(authTokenProvider: AuthTokenProvider,
isTokenExpired: ((String) -> Bool)? = nil) {
self.authTokenProvider = authTokenProvider
self.isTokenExpired = isTokenExpired
}

func intercept(_ request: URLRequest) async throws -> URLRequest {
Expand All @@ -41,6 +44,14 @@ struct AuthTokenURLRequestInterceptor: URLRequestInterceptor {
} catch {
throw APIError.operationError("Failed to retrieve authorization token.", "", error)
}

if isTokenExpired?(token) ?? false {
// If the access token has expired, we send back the underlying "AuthError.sessionExpired" error.
// Without a more specific AuthError case like "tokenExpired", this is the closest representation.
throw APIError.operationError("Auth Token Provider returned a expired token.",
"Please call `Amplify.Auth.fetchAuthSession()` or sign in again.",
AuthError.sessionExpired("", "", nil))
}

mutableRequest.setValue(token, forHTTPHeaderField: "authorization")
return mutableRequest as URLRequest
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,33 @@ class AWSAPIEndpointInterceptorsTests: XCTestCase {
XCTAssertNotNil(interceptorConfig.postludeInterceptors[0] as? IAMURLRequestInterceptor)
}

func testExpiryValidator_Valid() {
let validToken = Date().timeIntervalSince1970 + 1
let authService = MockAWSAuthService()
authService.tokenClaims = ["exp": validToken as AnyObject]
let interceptorConfig = createAPIInterceptorConfig(authService: authService)

let result = interceptorConfig.expiryValidator("")
XCTAssertFalse(result)
}

func testExpiryValidator_Expired() {
let expiredToken = Date().timeIntervalSince1970 - 1
let authService = MockAWSAuthService()
authService.tokenClaims = ["exp": expiredToken as AnyObject]
let interceptorConfig = createAPIInterceptorConfig(authService: authService)

let result = interceptorConfig.expiryValidator("")
XCTAssertTrue(result)
}

// MARK: - Test Helpers

func createAPIInterceptorConfig() -> AWSAPIEndpointInterceptors {
func createAPIInterceptorConfig(authService: AWSAuthServiceBehavior = MockAWSAuthService()) -> AWSAPIEndpointInterceptors {
return AWSAPIEndpointInterceptors(
endpointName: endpointName,
apiAuthProviderFactory: APIAuthProviderFactory(),
authService: MockAWSAuthService())
authService: authService)
}

struct CustomInterceptor: URLRequestInterceptor {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,28 @@ class AuthTokenURLRequestInterceptorTests: XCTestCase {
XCTAssertNotNil(headers[URLRequestConstants.Header.xAmzDate])
XCTAssertNotNil(headers[URLRequestConstants.Header.userAgent])
}

func testAuthTokenInterceptor_ThrowsInvalid() async throws {
let mockTokenProvider = MockTokenProvider()
let interceptor = AuthTokenURLRequestInterceptor(authTokenProvider: mockTokenProvider,
isTokenExpired: { _ in return true })
let request = RESTOperationRequestUtils.constructURLRequest(
with: URL(string: "http://anapiendpoint.ca")!,
operationType: .get,
requestPayload: nil
)

do {
_ = try await interceptor.intercept(request).allHTTPHeaderFields
} catch {
guard case .operationError(let description, _, let underlyingError) = error as? APIError,
let authError = underlyingError as? AuthError,
case .sessionExpired = authError else {
XCTFail("Should be API.operationError with underlying AuthError.sessionExpired")
return
}
}
}
}

// MARK: - Mocks
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ class SyncMutationToCloudOperation: AsynchronousOperation {
}

/// - Warning: Must be invoked from a locking context
private func getRetryAdviceIfRetryable(error: APIError) -> RequestRetryAdvice {
func getRetryAdviceIfRetryable(error: APIError) -> RequestRetryAdvice {
var advice = RequestRetryAdvice(shouldRetry: false, retryInterval: DispatchTimeInterval.never)

switch error {
Expand All @@ -288,23 +288,25 @@ class SyncMutationToCloudOperation: AsynchronousOperation {
httpURLResponse: nil,
attemptNumber: currentAttemptNumber)

// we can't unify the following two cases as they have different associated values.
// we can't unify the following two cases (case 1 and case 2) as they have different associated values.
// should retry with a different authType if server returned "Unauthorized Error"
case .httpStatusError(_, let httpURLResponse) where httpURLResponse.statusCode == 401:
case .httpStatusError(_, let httpURLResponse) where httpURLResponse.statusCode == 401: // case 1
advice = shouldRetryWithDifferentAuthType()
// should retry with a different authType if request failed locally with an AuthError
case .operationError(_, _, let error) where (error as? AuthError) != nil:

// Not all AuthError's are unauthorized errors. If `AuthError.sessionExpired` then
// the request never made it to the server. We should keep trying until the user is signed in.
// Otherwise we may be making the wrong determination to remove this mutation event.
if case .sessionExpired = error as? AuthError {
// Use `userAuthenticationRequired` to ensure advice to retry is true.
advice = requestRetryablePolicy.retryRequestAdvice(urlError: URLError(.userAuthenticationRequired),
httpURLResponse: nil,
attemptNumber: currentAttemptNumber)
} else {
advice = shouldRetryWithDifferentAuthType()
case .operationError(_, _, let error): // case 2
if let authError = error as? AuthError { // case 2
// Not all AuthError's are unauthorized errors. If `AuthError.sessionExpired` or `.signedOut` then
// the request never made it to the server. We should keep trying until the user is signed in.
// Otherwise we may be making the wrong determination to remove this mutation event.
switch authError {
case .sessionExpired, .signedOut:
// use `userAuthenticationRequired` to ensure advice to retry is true.
advice = requestRetryablePolicy.retryRequestAdvice(urlError: URLError(.userAuthenticationRequired),
httpURLResponse: nil,
attemptNumber: currentAttemptNumber)
default:
// should retry with a different authType if request failed locally with any other AuthError
advice = shouldRetryWithDifferentAuthType()
}
}
case .httpStatusError(_, let httpURLResponse):
advice = requestRetryablePolicy.retryRequestAdvice(urlError: nil,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,132 @@ class SyncMutationToCloudOperationTests: XCTestCase {
operation.cancel()
await fulfillment(of: [expectMutationRequestFailed], timeout: defaultAsyncWaitTimeout)
}

// MARK: - GetRetryAdviceIfRetryableTests

func testGetRetryAdvice_NetworkError_RetryTrue() async throws {
let operation = await SyncMutationToCloudOperation(
mutationEvent: try createMutationEvent(),
getLatestSyncMetadata: { nil },
api: mockAPIPlugin,
authModeStrategy: AWSDefaultAuthModeStrategy(),
networkReachabilityPublisher: publisher,
currentAttemptNumber: 1,
completion: { _ in }
)

let error = APIError.networkError("", nil, URLError(.userAuthenticationRequired))
let advice = operation.getRetryAdviceIfRetryable(error: error)
XCTAssertTrue(advice.shouldRetry)
}

func testGetRetryAdvice_HTTPStatusError401WithMultiAuth_RetryTrue() async throws {
let operation = await SyncMutationToCloudOperation(
mutationEvent: try createMutationEvent(),
getLatestSyncMetadata: { nil },
api: mockAPIPlugin,
authModeStrategy: MockMultiAuthModeStrategy(),
networkReachabilityPublisher: publisher,
currentAttemptNumber: 1,
completion: { _ in }
)
let response = HTTPURLResponse(url: URL(string: "http://localhost")!,
statusCode: 401,
httpVersion: nil,
headerFields: nil)!
let error = APIError.httpStatusError(401, response)
let advice = operation.getRetryAdviceIfRetryable(error: error)
XCTAssertTrue(advice.shouldRetry)
}

func testGetRetryAdvice_OperationErrorAuthErrorWithMultiAuth_RetryTrue() async throws {
let operation = await SyncMutationToCloudOperation(
mutationEvent: try createMutationEvent(),
getLatestSyncMetadata: { nil },
api: mockAPIPlugin,
authModeStrategy: MockMultiAuthModeStrategy(),
networkReachabilityPublisher: publisher,
currentAttemptNumber: 1,
completion: { _ in }
)

let authError = AuthError.notAuthorized("", "", nil)
let error = APIError.operationError("", "", authError)
let advice = operation.getRetryAdviceIfRetryable(error: error)
XCTAssertTrue(advice.shouldRetry)
}

func testGetRetryAdvice_OperationErrorAuthErrorWithSingleAuth_RetryFalse() async throws {
let operation = await SyncMutationToCloudOperation(
mutationEvent: try createMutationEvent(),
getLatestSyncMetadata: { nil },
api: mockAPIPlugin,
authModeStrategy: AWSDefaultAuthModeStrategy(),
networkReachabilityPublisher: publisher,
currentAttemptNumber: 1,
completion: { _ in }
)

let authError = AuthError.notAuthorized("", "", nil)
let error = APIError.operationError("", "", authError)
let advice = operation.getRetryAdviceIfRetryable(error: error)
XCTAssertFalse(advice.shouldRetry)
}

func testGetRetryAdvice_OperationErrorAuthErrorSessionExpired_RetryTrue() async throws {
let operation = await SyncMutationToCloudOperation(
mutationEvent: try createMutationEvent(),
getLatestSyncMetadata: { nil },
api: mockAPIPlugin,
authModeStrategy: AWSDefaultAuthModeStrategy(),
networkReachabilityPublisher: publisher,
currentAttemptNumber: 1,
completion: { _ in }
)

let authError = AuthError.sessionExpired("", "", nil)
let error = APIError.operationError("", "", authError)
let advice = operation.getRetryAdviceIfRetryable(error: error)
XCTAssertTrue(advice.shouldRetry)
}

func testGetRetryAdvice_OperationErrorAuthErrorSignedOut_RetryTrue() async throws {
let operation = await SyncMutationToCloudOperation(
mutationEvent: try createMutationEvent(),
getLatestSyncMetadata: { nil },
api: mockAPIPlugin,
authModeStrategy: AWSDefaultAuthModeStrategy(),
networkReachabilityPublisher: publisher,
currentAttemptNumber: 1,
completion: { _ in }
)

let authError = AuthError.signedOut("", "", nil)
let error = APIError.operationError("", "", authError)
let advice = operation.getRetryAdviceIfRetryable(error: error)
XCTAssertTrue(advice.shouldRetry)
}

private func createMutationEvent() throws -> MutationEvent {
let post1 = Post(title: "post1", content: "content1", createdAt: .now())
return try MutationEvent(model: post1, modelSchema: post1.schema, mutationType: .create)
}

}

public class MockMultiAuthModeStrategy: AuthModeStrategy {
public weak var authDelegate: AuthModeStrategyDelegate?
required public init() {}

public func authTypesFor(schema: ModelSchema,
operation: ModelOperation) -> AWSAuthorizationTypeIterator {
return AWSAuthorizationTypeIterator(withValues: [.amazonCognitoUserPools, .apiKey])
}

public func authTypesFor(schema: ModelSchema,
operations: [ModelOperation]) -> AWSAuthorizationTypeIterator {
return AWSAuthorizationTypeIterator(withValues: [.amazonCognitoUserPools, .apiKey])
}
}

extension SyncMutationToCloudOperationTests {
Expand Down

0 comments on commit f64c471

Please sign in to comment.