diff --git a/AmplifyPlugins/API/Sources/AWSAPIPlugin/Operation/AWSGraphQLSubscriptionTaskRunner.swift b/AmplifyPlugins/API/Sources/AWSAPIPlugin/Operation/AWSGraphQLSubscriptionTaskRunner.swift index 3e70654298..3f3889566a 100644 --- a/AmplifyPlugins/API/Sources/AWSAPIPlugin/Operation/AWSGraphQLSubscriptionTaskRunner.swift +++ b/AmplifyPlugins/API/Sources/AWSAPIPlugin/Operation/AWSGraphQLSubscriptionTaskRunner.swift @@ -378,6 +378,31 @@ fileprivate func toAPIError(_ errors: [Error], type: R.Type) -> AP (hasAuthorizationError ? ": \(APIError.UnauthorizedMessageString)" : "") } +#if swift(<5.8) + if let errors = errors.cast(to: AppSyncRealTimeRequest.Error.self) { + let hasAuthorizationError = errors.contains(where: { $0 == .unauthorized}) + return APIError.operationError( + errorDescription(hasAuthorizationError), + "", + errors.first + ) + } else if let errors = errors.cast(to: GraphQLError.self) { + let hasAuthorizationError = errors.map(\.extensions) + .compactMap { $0.flatMap { $0["errorType"]?.stringValue } } + .contains(where: { AppSyncErrorType($0) == .unauthorized }) + return APIError.operationError( + errorDescription(hasAuthorizationError), + "", + GraphQLResponseError.error(errors) + ) + } else { + return APIError.operationError( + errorDescription(), + "", + errors.first + ) + } +#else switch errors { case let errors as [AppSyncRealTimeRequest.Error]: let hasAuthorizationError = errors.contains(where: { $0 == .unauthorized}) @@ -402,5 +427,5 @@ fileprivate func toAPIError(_ errors: [Error], type: R.Type) -> AP errors.first ) } - +#endif } diff --git a/AmplifyPlugins/API/Sources/AWSAPIPlugin/Support/Utils/Array+Error+TypeCast.swift b/AmplifyPlugins/API/Sources/AWSAPIPlugin/Support/Utils/Array+Error+TypeCast.swift new file mode 100644 index 0000000000..3592791dc2 --- /dev/null +++ b/AmplifyPlugins/API/Sources/AWSAPIPlugin/Support/Utils/Array+Error+TypeCast.swift @@ -0,0 +1,21 @@ +// +// Copyright Amazon.com Inc. or its affiliates. +// All Rights Reserved. +// +// SPDX-License-Identifier: Apache-2.0 +// + + +import Foundation + +@_spi(AmplifyAPI) +extension Array where Element == Error { + func cast(to type: T.Type) -> [T]? { + self.reduce([]) { partialResult, ele in + if let partialResult, let ele = ele as? T { + return partialResult + [ele] + } + return nil + } + } +} diff --git a/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Support/Utils/Array+Error+TypeCastTests.swift b/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Support/Utils/Array+Error+TypeCastTests.swift new file mode 100644 index 0000000000..d1d6861a74 --- /dev/null +++ b/AmplifyPlugins/API/Tests/AWSAPIPluginTests/Support/Utils/Array+Error+TypeCastTests.swift @@ -0,0 +1,62 @@ +// +// Copyright Amazon.com Inc. or its affiliates. +// All Rights Reserved. +// +// SPDX-License-Identifier: Apache-2.0 +// + + +import XCTest +@testable @_spi(AmplifyAPI) import AWSAPIPlugin + +class ArrayWithErrorElementExtensionTests: XCTestCase { + + /** + Given: errors with generic protocol type + When: cast to the correct underlying concrete type + Then: successfully casted to underlying concrete type + */ + func testCast_toCorrectErrorType_returnCastedErrorType() { + let errors: [Error] = [ + Error1(), Error1(), Error1() + ] + + let error1s = errors.cast(to: Error1.self) + XCTAssertNotNil(error1s) + XCTAssertTrue(!error1s!.isEmpty) + XCTAssertEqual(errors.count, error1s!.count) + } + + /** + Given: errors with generic protocol type + When: cast to the wong underlying concrete type + Then: return nil + */ + func testCast_toWrongErrorType_returnNil() { + let errors: [Error] = [ + Error1(), Error1(), Error1() + ] + + let error2s = errors.cast(to: Error2.self) + XCTAssertNil(error2s) + } + + /** + Given: errors with generic protocol type + When: some of the elements failed to cast to the underlying concrete type + Then: return nil + */ + + func testCast_partiallyToWrongErrorType_returnNil() { + let errors: [Error] = [ + Error2(), Error2(), Error1() + ] + + let error2s = errors.cast(to: Error2.self) + XCTAssertNil(error2s) + } + + struct Error1: Error { } + + struct Error2: Error { } +}