diff --git a/AmplifyPlugins/API/Sources/AWSAPIPlugin/AppSyncRealTimeClient/AppSyncRealTimeRequestAuth.swift b/AmplifyPlugins/API/Sources/AWSAPIPlugin/AppSyncRealTimeClient/AppSyncRealTimeRequestAuth.swift index 87e01b1842..f649c3c380 100644 --- a/AmplifyPlugins/API/Sources/AWSAPIPlugin/AppSyncRealTimeClient/AppSyncRealTimeRequestAuth.swift +++ b/AmplifyPlugins/API/Sources/AWSAPIPlugin/AppSyncRealTimeClient/AppSyncRealTimeRequestAuth.swift @@ -9,6 +9,9 @@ import Foundation public enum AppSyncRealTimeRequestAuth { + private static let jsonEncoder = JSONEncoder() + private static let jsonDecoder = JSONDecoder() + case authToken(AuthToken) case apiKey(ApiKey) case iam(IAM) @@ -31,33 +34,10 @@ public enum AppSyncRealTimeRequestAuth { let amzDate: String } - public struct URLQuery { - let header: AppSyncRealTimeRequestAuth - let payload: String - - init(header: AppSyncRealTimeRequestAuth, payload: String = "{}") { - self.header = header - self.payload = payload - } - - func withBaseURL(_ url: URL, encoder: JSONEncoder? = nil) -> URL { - let jsonEncoder: JSONEncoder = encoder ?? JSONEncoder() - guard let headerJsonData = try? jsonEncoder.encode(header) else { - return url - } - - guard var urlComponents = URLComponents(url: url, resolvingAgainstBaseURL: false) - else { - return url - } - - urlComponents.queryItems = [ - URLQueryItem(name: "header", value: headerJsonData.base64EncodedString()), - URLQueryItem(name: "payload", value: try? payload.base64EncodedString()) - ] - - return urlComponents.url ?? url - } + var authHeaders: [String: String] { + (try? Self.jsonEncoder.encode(self)).flatMap { + try? Self.jsonDecoder.decode([String: String].self, from: $0) + } ?? [:] } } diff --git a/AmplifyPlugins/API/Sources/AWSAPIPlugin/Interceptor/SubscriptionInterceptor/APIKeyAuthInterceptor.swift b/AmplifyPlugins/API/Sources/AWSAPIPlugin/Interceptor/SubscriptionInterceptor/APIKeyAuthInterceptor.swift index f52ded490e..0a12d295ef 100644 --- a/AmplifyPlugins/API/Sources/AWSAPIPlugin/Interceptor/SubscriptionInterceptor/APIKeyAuthInterceptor.swift +++ b/AmplifyPlugins/API/Sources/AWSAPIPlugin/Interceptor/SubscriptionInterceptor/APIKeyAuthInterceptor.swift @@ -21,11 +21,12 @@ class APIKeyAuthInterceptor { } extension APIKeyAuthInterceptor: WebSocketInterceptor { - func interceptConnection(url: URL) async -> URL { + + func interceptConnection(request: URLRequest) async -> URLRequest { + guard let url = request.url else { return request } + let authHeader = getAuthHeader(apiKey, AppSyncRealTimeClientFactory.appSyncApiEndpoint(url).host!) - return AppSyncRealTimeRequestAuth.URLQuery( - header: .apiKey(authHeader) - ).withBaseURL(url) + return request.injectAppSyncAuthToWebSocketSubprotocolsHeader(auth: .apiKey(authHeader)) } } diff --git a/AmplifyPlugins/API/Sources/AWSAPIPlugin/Interceptor/SubscriptionInterceptor/AuthTokenInterceptor.swift b/AmplifyPlugins/API/Sources/AWSAPIPlugin/Interceptor/SubscriptionInterceptor/AuthTokenInterceptor.swift index b0f19ffd78..09b110d8ae 100644 --- a/AmplifyPlugins/API/Sources/AWSAPIPlugin/Interceptor/SubscriptionInterceptor/AuthTokenInterceptor.swift +++ b/AmplifyPlugins/API/Sources/AWSAPIPlugin/Interceptor/SubscriptionInterceptor/AuthTokenInterceptor.swift @@ -57,15 +57,16 @@ extension AuthTokenInterceptor: AppSyncRequestInterceptor { } extension AuthTokenInterceptor: WebSocketInterceptor { - func interceptConnection(url: URL) async -> URL { + func interceptConnection(request: URLRequest) async -> URLRequest { + guard let url = request.url else { return request } let authToken = await getAuthToken() - return AppSyncRealTimeRequestAuth.URLQuery( - header: .authToken(.init( + return request.injectAppSyncAuthToWebSocketSubprotocolsHeader( + auth: .authToken(.init( host: AppSyncRealTimeClientFactory.appSyncApiEndpoint(url).host!, authToken: authToken - )) - ).withBaseURL(url) + ) + )) } } diff --git a/AmplifyPlugins/API/Sources/AWSAPIPlugin/Interceptor/SubscriptionInterceptor/IAMAuthInterceptor.swift b/AmplifyPlugins/API/Sources/AWSAPIPlugin/Interceptor/SubscriptionInterceptor/IAMAuthInterceptor.swift index cd023676c7..eaf4be8616 100644 --- a/AmplifyPlugins/API/Sources/AWSAPIPlugin/Interceptor/SubscriptionInterceptor/IAMAuthInterceptor.swift +++ b/AmplifyPlugins/API/Sources/AWSAPIPlugin/Interceptor/SubscriptionInterceptor/IAMAuthInterceptor.swift @@ -24,7 +24,7 @@ class IAMAuthInterceptor { func getAuthHeader( _ endpoint: URL, - with payload: String, + with payload: String? = nil, signer: AWSSignatureV4Signer = AmplifyAWSSignatureV4Signer() ) async -> AppSyncRealTimeRequestAuth.IAM? { guard let host = endpoint.host else { @@ -45,7 +45,7 @@ class IAMAuthInterceptor { .withHeader(name: "content-encoding", value: "amz-1.0") .withHeader(name: URLRequestConstants.Header.contentType, value: "application/json; charset=UTF-8") .withHeader(name: URLRequestConstants.Header.host, value: host) - .withBody(.data(Data(payload.utf8))) + .withBody(.data(payload.map { Data($0.utf8) })) /// 2. The request is SigV4 signed by using all the available headers on the request. By signing the request, the signature is added to /// the request headers as authorization and security token. @@ -88,15 +88,19 @@ class IAMAuthInterceptor { } extension IAMAuthInterceptor: WebSocketInterceptor { - func interceptConnection(url: URL) async -> URL { + + // TODO: (5D) it seems the new auth in header doesn't require payload, needs confirm + func interceptConnection(request: URLRequest) async -> URLRequest { + guard let url = request.url else { return request } let connectUrl = AppSyncRealTimeClientFactory.appSyncApiEndpoint(url).appendingPathComponent("connect") - guard let authHeader = await getAuthHeader(connectUrl, with: "{}") else { - return connectUrl + + var requestCopy = request + requestCopy.url = connectUrl + + guard let authHeader = await getAuthHeader(connectUrl) else { + return requestCopy } - - return AppSyncRealTimeRequestAuth.URLQuery( - header: .iam(authHeader) - ).withBaseURL(url) + return requestCopy.injectAppSyncAuthToWebSocketSubprotocolsHeader(auth: .iam(authHeader)) } } diff --git a/AmplifyPlugins/API/Sources/AWSAPIPlugin/Support/Utils/URLRequest+AppSyncAuth.swift b/AmplifyPlugins/API/Sources/AWSAPIPlugin/Support/Utils/URLRequest+AppSyncAuth.swift new file mode 100644 index 0000000000..36f21b1d2d --- /dev/null +++ b/AmplifyPlugins/API/Sources/AWSAPIPlugin/Support/Utils/URLRequest+AppSyncAuth.swift @@ -0,0 +1,17 @@ +// +// Copyright Amazon.com Inc. or its affiliates. +// All Rights Reserved. +// +// SPDX-License-Identifier: Apache-2.0 +// + + +import Foundation + +extension URLRequest { + func injectAppSyncAuthToWebSocketSubprotocolsHeader(auth: AppSyncRealTimeRequestAuth) -> URLRequest { + var requstCopy = self + auth.authHeaders.forEach { requstCopy.setValue($0.value, forHTTPHeaderField: $0.key) } + return requstCopy + } +} diff --git a/AmplifyPlugins/API/Tests/AWSAPIPluginTests/AppSyncRealTimeClient/AppSyncRealTimeRequestAuthTests.swift b/AmplifyPlugins/API/Tests/AWSAPIPluginTests/AppSyncRealTimeClient/AppSyncRealTimeRequestAuthTests.swift index 6ab7af0692..e9f4061431 100644 --- a/AmplifyPlugins/API/Tests/AWSAPIPluginTests/AppSyncRealTimeClient/AppSyncRealTimeRequestAuthTests.swift +++ b/AmplifyPlugins/API/Tests/AWSAPIPluginTests/AppSyncRealTimeClient/AppSyncRealTimeRequestAuthTests.swift @@ -147,61 +147,6 @@ class AppSyncRealTimeRequestAuthTests: XCTestCase { """.shrink()) } - func testAppSyncRealTimeRequestAuth_URLQueryWithCognitoAuthHeader() { - let expectedURL = """ - https://example.com?\ - header=eyJBdXRob3JpemF0aW9uIjoiNDk4NTljN2MtNzQwNS00ZDU4LWFmZjctNTJiZ\ - TRiNDczNTU3IiwiaG9zdCI6ImV4YW1wbGUuY29tIn0%3D\ - &payload=e30%3D - """ - let encodedURL = AppSyncRealTimeRequestAuth.URLQuery( - header: .authToken(.init( - host: "example.com", - authToken: "49859c7c-7405-4d58-aff7-52be4b473557" - )) - ).withBaseURL(URL(string: "https://example.com")!, encoder: jsonEncoder) - XCTAssertEqual(encodedURL.absoluteString, expectedURL) - } - - func testAppSyncRealTimeRequestAuth_URLQueryWithApiKeyAuthHeader() { - let expectedURL = """ - https://example.com?\ - header=eyJob3N0IjoiZXhhbXBsZS5jb20iLCJ4LWFtei1kYXRlIjoiOWUwZTJkZjktMmVlNy00NjU5L\ - TgzNjItMWM4ODFlMTE4YzlmIiwieC1hcGkta2V5IjoiNjVlMmZhY2EtOGUxZS00ZDM3LThkYzctNjQ0N\ - 2Q5Njk4MjQ3In0%3D\ - &payload=e30%3D - """ - let encodedURL = AppSyncRealTimeRequestAuth.URLQuery( - header: .apiKey(.init( - host: "example.com", - apiKey: "65e2faca-8e1e-4d37-8dc7-6447d9698247", - amzDate: "9e0e2df9-2ee7-4659-8362-1c881e118c9f" - )) - ).withBaseURL(URL(string: "https://example.com")!, encoder: jsonEncoder) - XCTAssertEqual(encodedURL.absoluteString, expectedURL) - } - - func testAppSyncRealTimeRequestAuth_URLQueryWithIAMAuthHeader() { - - let expectedURL = """ - https://example.com?\ - header=eyJhY2NlcHQiOiJhcHBsaWNhdGlvblwvanNvbiwgdGV4dFwvamF2YXNjcmlwdCIsIkF1dGhvcml6YXR\ - pb24iOiJjOWRhZDg5Ny05MGQxLTRhNGMtYTVjOS0yYjM2YTI0NzczNWYiLCJjb250ZW50LWVuY29kaW5nIjoiY\ - W16LTEuMCIsImNvbnRlbnQtdHlwZSI6ImFwcGxpY2F0aW9uXC9qc29uOyBjaGFyc2V0PVVURi04IiwiaG9zdCI\ - 6ImV4YW1wbGUuY29tIiwieC1hbXotZGF0ZSI6IjllMGUyZGY5LTJlZTctNDY1OS04MzYyLTFjODgxZTExOGM5Z\ - iIsIlgtQW16LVNlY3VyaXR5LVRva2VuIjoiZTdlNjI2OWUtZmRhMS00ZGUwLThiZGItYmFhN2I2ZGQwYTBkIn0%3D\ - &payload=e30%3D - """ - let encodedURL = AppSyncRealTimeRequestAuth.URLQuery( - header: .iam(.init( - host: "example.com", - authToken: "c9dad897-90d1-4a4c-a5c9-2b36a247735f", - securityToken: "e7e6269e-fda1-4de0-8bdb-baa7b6dd0a0d", - amzDate: "9e0e2df9-2ee7-4659-8362-1c881e118c9f")) - ).withBaseURL(URL(string: "https://example.com")!, encoder: jsonEncoder) - XCTAssertEqual(encodedURL.absoluteString, expectedURL) - } - private func toJson(_ value: Encodable) -> String? { return try? String(data: jsonEncoder.encode(value), encoding: .utf8) } diff --git a/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Interceptor/SubscriptionInterceptor/APIKeyAuthInterceptorTests.swift b/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Interceptor/SubscriptionInterceptor/APIKeyAuthInterceptorTests.swift index 8c89c0a53a..7c8ebff620 100644 --- a/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Interceptor/SubscriptionInterceptor/APIKeyAuthInterceptorTests.swift +++ b/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Interceptor/SubscriptionInterceptor/APIKeyAuthInterceptorTests.swift @@ -12,20 +12,13 @@ import Amplify class APIKeyAuthInterceptorTests: XCTestCase { - func testInterceptConnection_addApiKeySignatureInURLQuery() async { + func testInterceptConnection_addApiKeyInRequestHeader() async { let apiKey = UUID().uuidString let interceptor = APIKeyAuthInterceptor(apiKey: apiKey) - let resultUrl = await interceptor.interceptConnection(url: URL(string: "https://example.com")!) - guard let components = URLComponents(url: resultUrl, resolvingAgainstBaseURL: false) else { - XCTFail("Failed to decode decorated URL") - return - } - - let header = components.queryItems?.first { $0.name == "header" } - XCTAssertNotNil(header?.value) - let headerData = try! header?.value!.base64DecodedString().data(using: .utf8) - let decodedHeader = try! JSONDecoder().decode(JSONValue.self, from: headerData!) - XCTAssertEqual(decodedHeader["x-api-key"]?.stringValue, apiKey) + let resultUrlRequest = await interceptor.interceptConnection(request: URLRequest(url: URL(string: "https://example.com")!)) + + let header = resultUrlRequest.value(forHTTPHeaderField: "x-api-key") + XCTAssertEqual(header, apiKey) } func testInterceptRequest_appendAuthInfoInPayload() async { diff --git a/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Interceptor/SubscriptionInterceptor/CognitoAuthInterceptorTests.swift b/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Interceptor/SubscriptionInterceptor/CognitoAuthInterceptorTests.swift index 4127f018fd..d0383bff21 100644 --- a/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Interceptor/SubscriptionInterceptor/CognitoAuthInterceptorTests.swift +++ b/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Interceptor/SubscriptionInterceptor/CognitoAuthInterceptorTests.swift @@ -13,56 +13,24 @@ import Amplify class CognitoAuthInterceptorTests: XCTestCase { - func testInterceptConnection_withAuthTokenProvider_appendCorrectAuthHeaderToQuery() async { + func testInterceptConnection_withAuthTokenProvider_appendCorrectAuthHeader() async { let authTokenProvider = MockAuthTokenProvider() let interceptor = AuthTokenInterceptor(authTokenProvider: authTokenProvider) - let decoratedURL = await interceptor.interceptConnection(url: URL(string: "https://example.com")!) - guard let components = URLComponents(url: decoratedURL, resolvingAgainstBaseURL: false) else { - XCTFail("Failed to get url components from decorated URL") - return - } + let decoratedURLRequest = await interceptor.interceptConnection(request: URLRequest(url:URL(string: "https://example.com")!)) - guard let queryHeaderString = - try? components.queryItems?.first(where: { $0.name == "header" })?.value?.base64DecodedString() - else { - XCTFail("Failed to extract header field from query string") - return - } - - guard let queryHeader = try? JSONDecoder().decode(JSONValue.self, from: queryHeaderString.data(using: .utf8)!) - else { - XCTFail("Failed to decode query header to json object") - return - } - XCTAssertEqual(authTokenProvider.authToken, queryHeader.Authorization?.stringValue) - XCTAssertEqual("example.com", queryHeader.host?.stringValue) + XCTAssertEqual(authTokenProvider.authToken, decoratedURLRequest.value(forHTTPHeaderField: "Authorization")) + XCTAssertEqual("example.com", decoratedURLRequest.value(forHTTPHeaderField: "host")) } - func testInterceptConnection_withAuthTokenProviderFailed_appendEmptyAuthHeaderToQuery() async { + func testInterceptConnection_withAuthTokenProviderFailed_appendEmptyAuthHeader() async { let authTokenProvider = MockAuthTokenProviderFailed() let interceptor = AuthTokenInterceptor(authTokenProvider: authTokenProvider) - let decoratedURL = await interceptor.interceptConnection(url: URL(string: "https://example.com")!) - guard let components = URLComponents(url: decoratedURL, resolvingAgainstBaseURL: false) else { - XCTFail("Failed to get url components from decorated URL") - return - } + let decoratedURLRequest = await interceptor.interceptConnection(request: URLRequest(url:URL(string: "https://example.com")!)) - guard let queryHeaderString = - try? components.queryItems?.first(where: { $0.name == "header" })?.value?.base64DecodedString() - else { - XCTFail("Failed to extract header field from query string") - return - } - - guard let queryHeader = try? JSONDecoder().decode(JSONValue.self, from: queryHeaderString.data(using: .utf8)!) - else { - XCTFail("Failed to decode query header to json object") - return - } - XCTAssertEqual("", queryHeader.Authorization?.stringValue) - XCTAssertEqual("example.com", queryHeader.host?.stringValue) + XCTAssertEqual("", decoratedURLRequest.value(forHTTPHeaderField: "Authorization")) + XCTAssertEqual("example.com", decoratedURLRequest.value(forHTTPHeaderField: "host")) } func testInterceptRequest_withAuthTokenProvider_appendCorrectAuthInfoToPayload() async { diff --git a/AmplifyPlugins/Core/AWSPluginsCore/WebSocket/WebSocketClient.swift b/AmplifyPlugins/Core/AWSPluginsCore/WebSocket/WebSocketClient.swift index cc1149ac27..4d66e83c52 100644 --- a/AmplifyPlugins/Core/AWSPluginsCore/WebSocket/WebSocketClient.swift +++ b/AmplifyPlugins/Core/AWSPluginsCore/WebSocket/WebSocketClient.swift @@ -160,6 +160,8 @@ public final actor WebSocketClient: NSObject { var urlRequest = URLRequest(url: decoratedURL) self.handshakeHttpHeaders.forEach { urlRequest.setValue($0.value, forHTTPHeaderField: $0.key) } + urlRequest = await self.interceptor?.interceptConnection(request: urlRequest) ?? urlRequest + let urlSession = URLSession(configuration: .default, delegate: self, delegateQueue: nil) return urlSession.webSocketTask(with: urlRequest) } diff --git a/AmplifyPlugins/Core/AWSPluginsCore/WebSocket/WebSocketInterceptor.swift b/AmplifyPlugins/Core/AWSPluginsCore/WebSocket/WebSocketInterceptor.swift index a53ec3b950..351119ff03 100644 --- a/AmplifyPlugins/Core/AWSPluginsCore/WebSocket/WebSocketInterceptor.swift +++ b/AmplifyPlugins/Core/AWSPluginsCore/WebSocket/WebSocketInterceptor.swift @@ -11,4 +11,18 @@ import Foundation @_spi(WebSocket) public protocol WebSocketInterceptor { func interceptConnection(url: URL) async -> URL + + func interceptConnection(request: URLRequest) async -> URLRequest +} + +public extension WebSocketInterceptor { + + func interceptConnection(url: URL) async -> URL { + return url + } + + func interceptConnection(request: URLRequest) async -> URLRequest { + return request + } + }