Skip to content

Commit

Permalink
first stab at version negotiation
Browse files Browse the repository at this point in the history
  • Loading branch information
Funkatronics committed Oct 2, 2023
1 parent 0fdfdc4 commit b41573d
Show file tree
Hide file tree
Showing 15 changed files with 236 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,25 +14,45 @@
import com.solana.mobilewalletadapter.common.protocol.MobileWalletAdapterSessionCommon;
import com.solana.mobilewalletadapter.common.crypto.ECDSASignatures;

import org.json.JSONException;
import org.json.JSONObject;

import java.io.IOException;
import java.security.InvalidKeyException;
import java.security.KeyPair;
import java.security.NoSuchAlgorithmException;
import java.security.Signature;
import java.security.SignatureException;
import java.security.interfaces.ECPublicKey;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

public class MobileWalletAdapterSession extends MobileWalletAdapterSessionCommon {
private static final String TAG = MobileWalletAdapterSession.class.getSimpleName();

@NonNull
private final KeyPair mAssociationKey;

@NonNull
private final List<Integer> mSupportedProtocolVersions;

private ECPublicKey mCachedEncryptionPublicKey;

@Nullable
private Integer mSelectedProtocolVersion;

public MobileWalletAdapterSession(@NonNull MessageReceiver decryptedPayloadReceiver,
@Nullable StateCallbacks stateCallbacks) {
this(decryptedPayloadReceiver, stateCallbacks, new ArrayList<>());
}

public MobileWalletAdapterSession(@NonNull MessageReceiver decryptedPayloadReceiver,
@Nullable StateCallbacks stateCallbacks,
@NonNull List<Integer> supportedProtocolVersions) {
super(decryptedPayloadReceiver, stateCallbacks);
mAssociationKey = generateECP256KeyPair();
mSupportedProtocolVersions = supportedProtocolVersions;
}

@NonNull
Expand All @@ -41,6 +61,12 @@ protected ECPublicKey getAssociationPublicKey() {
return (ECPublicKey) mAssociationKey.getPublic();
}

@Nullable
@Override
protected Integer getSelectedProtocolVersion() {
return mSelectedProtocolVersion;
}

// N.B. Does not need to be synchronized; it consumes only a final immutable object
@NonNull
public byte[] getEncodedAssociationPublicKey() {
Expand Down Expand Up @@ -91,8 +117,19 @@ protected void handleSessionEstablishmentMessage(@NonNull byte[] payload)
throws SessionMessageException {
Log.v(TAG, "handleSessionEstablishmentMessage");

final ECPublicKey theirPublicKey = parseHelloRsp(payload);
generateSessionECDHSecret(theirPublicKey);
// TODO: this only works if the wallet actually sends the SESSION_PROPS message
// if the wallet is legacy and does not send SESSION_PROPS, we will get stuck

// first the wallet sends HELLO_RSP, so we parse that
if (mCachedEncryptionPublicKey == null) {
mCachedEncryptionPublicKey = parseHelloRsp(payload);
} else {
// then wallet should send SESSION_PROPS
mSelectedProtocolVersion = parseSessionProps(payload);
// now we can move session state to ENCRYPTED_SESSION
generateSessionECDHSecret(mCachedEncryptionPublicKey);
mCachedEncryptionPublicKey = null;
}
}

@NonNull
Expand All @@ -108,4 +145,19 @@ private ECPublicKey parseHelloRsp(@NonNull byte[] message) throws SessionMessage
otherPublicKey.getW().getAffineY());
return otherPublicKey;
}

@NonNull
private Integer parseSessionProps(@NonNull byte[] message) throws SessionMessageException {
final Integer version;
try {
String versionContent = new JSONObject(new String(message)).getString("v");
version = Integer.parseInt(versionContent);
} catch (JSONException e) {
throw new SessionMessageException("Failed to parse SESSION_PROPS", e);
}

Log.v(TAG, "Received session properties: version = " + version);

return version;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
import com.solana.mobilewalletadapter.clientlib.protocol.MobileWalletAdapterSession;
import com.solana.mobilewalletadapter.common.AssociationContract;

import java.util.ArrayList;
import java.util.List;

public class LocalAssociationIntentCreator {

private LocalAssociationIntentCreator() { }
Expand All @@ -25,28 +28,37 @@ private LocalAssociationIntentCreator() { }
public static Intent createAssociationIntent(@Nullable Uri endpointPrefix,
@IntRange(from = 0, to = 65535) int port,
@NonNull MobileWalletAdapterSession session) {
return createAssociationIntent(endpointPrefix, port, session, new ArrayList<>());
}

@NonNull
public static Intent createAssociationIntent(@Nullable Uri endpointPrefix,
@IntRange(from = 0, to = 65535) int port,
@NonNull MobileWalletAdapterSession session,
@NonNull List<Integer> supportedProtocolVersions) {
final byte[] associationPublicKey = session.getEncodedAssociationPublicKey();
final String associationToken = Base64.encodeToString(associationPublicKey,
Base64.URL_SAFE | Base64.NO_PADDING | Base64.NO_WRAP);
return new Intent()
.setAction(Intent.ACTION_VIEW)
.addCategory(Intent.CATEGORY_BROWSABLE)
.setData(createAssociationUri(endpointPrefix, port, associationToken));
.setData(createAssociationUri(endpointPrefix, port, associationToken, supportedProtocolVersions));
}

public static boolean isWalletEndpointAvailable(@NonNull PackageManager pm) {
final Intent intent = new Intent()
.setAction(Intent.ACTION_VIEW)
.addCategory(Intent.CATEGORY_BROWSABLE)
.setData(createAssociationUri(null, 0, ""));
.setData(createAssociationUri(null, 0, "", new ArrayList<>()));
final ResolveInfo resolveInfo = pm.resolveActivity(intent, PackageManager.MATCH_DEFAULT_ONLY);
return (resolveInfo != null);
}

@NonNull
private static Uri createAssociationUri(@Nullable Uri endpointPrefix,
@IntRange(from = 0, to = 65535) int port,
@NonNull String associationToken) {
@NonNull String associationToken,
@NonNull List<Integer> supportedProtocolVersions) {
if (endpointPrefix != null && (!"https".equals(endpointPrefix.getScheme()) || !endpointPrefix.isHierarchical())) {
throw new IllegalArgumentException("Endpoint-specific URI prefix must be absolute with scheme 'https' and hierarchical");
}
Expand All @@ -61,12 +73,18 @@ private static Uri createAssociationUri(@Nullable Uri endpointPrefix,
dataUriBuilder = new Uri.Builder()
.scheme(AssociationContract.SCHEME_MOBILE_WALLET_ADAPTER);
}
return dataUriBuilder
.appendEncodedPath(AssociationContract.LOCAL_PATH_SUFFIX)

dataUriBuilder.appendEncodedPath(AssociationContract.LOCAL_PATH_SUFFIX)
.appendQueryParameter(AssociationContract.PARAMETER_ASSOCIATION_TOKEN,
associationToken)
.appendQueryParameter(AssociationContract.LOCAL_PARAMETER_PORT,
Integer.toString(port))
.build();
Integer.toString(port));

for (int version : supportedProtocolVersions) {
dataUriBuilder.appendQueryParameter(AssociationContract.PARAMETER_PROTOCOL_VERSION,
String.valueOf(version));
}

return dataUriBuilder.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import java.net.URI;
import java.net.URISyntaxException;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
Expand All @@ -37,6 +38,9 @@ public class LocalAssociationScenario extends Scenario {
@NonNull
private final URI mWebSocketUri;

@NonNull
private final List<Integer> mSupportedProtocolVersions;

// All access to these members must be protected by mLock
private final Object mLock = new Object();
private State mState = State.NOT_STARTED;
Expand All @@ -55,9 +59,18 @@ public MobileWalletAdapterSession getSession() {
return mMobileWalletAdapterSession;
}

public List<Integer> getSupportedProtocolVersions() { return mSupportedProtocolVersions; }

public LocalAssociationScenario(@IntRange(from = 0) int clientTimeoutMs) {
this(clientTimeoutMs, new ArrayList<>());
}

public LocalAssociationScenario(@IntRange(from = 0) int clientTimeoutMs,
@NonNull List<Integer> supportedProtocolVersions) {
super(clientTimeoutMs);

mSupportedProtocolVersions = supportedProtocolVersions;

mPort = new Random().nextInt(WebSocketsTransportContract.WEBSOCKETS_LOCAL_PORT_MAX -
WebSocketsTransportContract.WEBSOCKETS_LOCAL_PORT_MIN + 1) +
WebSocketsTransportContract.WEBSOCKETS_LOCAL_PORT_MIN;
Expand All @@ -71,7 +84,8 @@ public LocalAssociationScenario(@IntRange(from = 0) int clientTimeoutMs) {

mMobileWalletAdapterSession = new MobileWalletAdapterSession(
mMobileWalletAdapterClient,
mSessionStateCallbacks);
mSessionStateCallbacks,
mSupportedProtocolVersions);

Log.v(TAG, "Creating local association scenario for " + mWebSocketUri);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ public class AssociationContract {

public static final String PARAMETER_ASSOCIATION_TOKEN = "association";

public static final String PARAMETER_PROTOCOL_VERSION = "v";

public static final String LOCAL_PATH_SUFFIX = "v1/associate/local";
public static final String LOCAL_PARAMETER_PORT = "port"; // type: Int

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import java.security.spec.ECParameterSpec;
import java.security.spec.InvalidParameterSpecException;
import java.util.Arrays;
import java.util.List;

import javax.crypto.BadPaddingException;
import javax.crypto.Cipher;
Expand Down Expand Up @@ -65,6 +66,9 @@ protected MobileWalletAdapterSessionCommon(@NonNull MessageReceiver decryptedPay
@NonNull
protected abstract ECPublicKey getAssociationPublicKey();

@Nullable
protected abstract Integer getSelectedProtocolVersion();

@Override
public synchronized void receiverConnected(@NonNull MessageSender messageSender) {
Log.v(TAG, "receiverConnected");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -329,10 +329,11 @@ object MobileWalletAdapterUseCase {
async {
mobileWalletAdapterClientSem.withPermit {
val contract = intentLauncher.contract as StartMobileWalletAdapterActivity
val localAssociation = LocalAssociationScenario(Scenario.DEFAULT_CLIENT_TIMEOUT_MS)
val localAssociation = LocalAssociationScenario(Scenario.DEFAULT_CLIENT_TIMEOUT_MS, listOf(1))

val associationIntent = LocalAssociationIntentCreator.createAssociationIntent(
uriPrefix, localAssociation.port, localAssociation.session
uriPrefix, localAssociation.port, localAssociation.session,
localAssociation.supportedProtocolVersions
)
try {
contract.waitForActivityResumed() // may throw TimeoutCancellationException
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ class MainActivityTest {
val associationIntent = LocalAssociationIntentCreator.createAssociationIntent(
null,
localAssociation.port,
localAssociation.session
localAssociation.session,
listOf(1)
)

// when
Expand Down Expand Up @@ -111,7 +112,8 @@ class MainActivityTest {
val associationIntent = LocalAssociationIntentCreator.createAssociationIntent(
null,
localAssociation.port,
localAssociation.session
localAssociation.session,
listOf(1)
)

// when
Expand Down Expand Up @@ -149,7 +151,8 @@ class MainActivityTest {
val associationIntent = LocalAssociationIntentCreator.createAssociationIntent(
null,
localAssociation.port,
localAssociation.session
localAssociation.session,
listOf(1)
)

// when
Expand Down Expand Up @@ -180,7 +183,8 @@ class MainActivityTest {
val associationIntent = LocalAssociationIntentCreator.createAssociationIntent(
null,
localAssociation.port,
localAssociation.session
localAssociation.session,
listOf(1)
)

TestScopeLowPowerMode = false
Expand All @@ -202,7 +206,8 @@ class MainActivityTest {
val associationIntent = LocalAssociationIntentCreator.createAssociationIntent(
null,
localAssociation.port,
localAssociation.session
localAssociation.session,
listOf(1)
)

TestScopeLowPowerMode = true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,24 @@
import com.solana.mobilewalletadapter.walletlib.protocol.MobileWalletAdapterConfig;
import com.solana.mobilewalletadapter.walletlib.scenario.Scenario;

import java.util.ArrayList;
import java.util.List;

public abstract class AssociationUri {
@NonNull
public final Uri uri;

@NonNull
public final byte[] associationPublicKey;

@NonNull
public final List<Integer> supportedProtocolVersions;

protected AssociationUri(@NonNull Uri uri) {
this.uri = uri;
validate(uri);
associationPublicKey = parseAssociationToken(uri);
supportedProtocolVersions = parseSupportedProtocolVersions(uri);
}

private static void validate(@NonNull Uri uri) {
Expand All @@ -52,6 +59,22 @@ private static byte[] parseAssociationToken(@NonNull Uri uri) {
return Base64.decode(associationToken, Base64.URL_SAFE);
}

@NonNull
private static List<Integer> parseSupportedProtocolVersions(@NonNull Uri uri) {
final List<Integer> supportedVersions = new ArrayList<>();

for (String supportVersionStr : uri.getQueryParameters(
AssociationContract.PARAMETER_PROTOCOL_VERSION)) {
try {
supportedVersions.add(Integer.parseInt(supportVersionStr, 10));
} catch (NumberFormatException e) {
throw new IllegalArgumentException("port parameter must be a number", e);
}
}

return supportedVersions;
}

@Nullable
public static AssociationUri parse(@NonNull Uri uri) {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ public LocalWebSocketServerScenario createScenario(@NonNull Context context,
@NonNull Scenario.Callbacks callbacks) {
if (callbacks instanceof LocalScenario.Callbacks) {
return new LocalWebSocketServerScenario(context, mobileWalletAdapterConfig,
authIssuerConfig, (LocalScenario.Callbacks) callbacks, associationPublicKey, port);
authIssuerConfig, (LocalScenario.Callbacks) callbacks, associationPublicKey, port, supportedProtocolVersions);
} else {
throw new IllegalArgumentException("callbacks must implement " + LocalScenario.Callbacks.class.getName());
}
Expand Down
Loading

0 comments on commit b41573d

Please sign in to comment.