Skip to content

Commit

Permalink
Merge pull request #130 from MetaMask/session-persistence
Browse files Browse the repository at this point in the history
feat: Session persistence v2
  • Loading branch information
elefantel authored Jul 18, 2024
2 parents 1c91e8b + 27942f7 commit aad8a58
Show file tree
Hide file tree
Showing 8 changed files with 134 additions and 40 deletions.
2 changes: 1 addition & 1 deletion app/src/main/java/com/metamask/dapp/AppModule.kt
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ internal object AppModule {

@Provides // Add SDKOptions(infuraAPIKey="supply_your_key_here") to Ethereum constructor for read-only calls
fun provideEthereum(@ApplicationContext context: Context, dappMetadata: DappMetadata, logger: Logger): Ethereum {
return Ethereum(context, dappMetadata, SDKOptions(infuraAPIKey = "#####"), logger)
return Ethereum(context, dappMetadata, null, logger)
}

@Provides
Expand Down
10 changes: 8 additions & 2 deletions app/src/main/java/com/metamask/dapp/Setup.kt
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,13 @@ fun Setup(ethereumViewModel: EthereumFlowViewModel, screenViewModel: ScreenViewM
var isConnecting by remember { mutableStateOf(false) }
var isConnectSigning by remember { mutableStateOf(false) }
var connectResult by remember { mutableStateOf<Result>(Result.Success.Item("")) }
var signMessage by remember { mutableStateOf("") }
var account by remember { mutableStateOf(ethereumState.selectedAddress) }

LaunchedEffect(ethereumState.selectedAddress) {
if (ethereumState.selectedAddress.isNotEmpty()) {
screenViewModel.setScreen(ACTIONS)
}
}

// Connect
LaunchedEffect(isConnecting) {
Expand All @@ -38,7 +44,7 @@ fun Setup(ethereumViewModel: EthereumFlowViewModel, screenViewModel: ScreenViewM
}
}

NavHost(navController = navController, startDestination = CONNECT.name) {
NavHost(navController = navController, startDestination = if (account.isNotEmpty()) { DappScreen.ACTIONS.name } else { DappScreen.CONNECT.name }) {
composable(CONNECT.name) {
ConnectScreen(
ethereumState = ethereumState,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,14 @@ import kotlinx.serialization.Serializable
import org.json.JSONObject
import java.lang.ref.WeakReference

internal class CommunicationClient(context: Context, callback: EthereumEventCallback?, private val logger: Logger = DefaultLogger) {
class CommunicationClient(
context: Context,
callback: EthereumEventCallback?,
private val sessionManager: SessionManager,
private val keyExchange: KeyExchange,
private val logger: Logger = DefaultLogger) {

var sessionId: String = ""
private val keyExchange: KeyExchange = KeyExchange()

var dappMetadata: DappMetadata? = null
var isServiceConnected = false
Expand All @@ -35,20 +39,23 @@ internal class CommunicationClient(context: Context, callback: EthereumEventCall
private var submittedRequests: MutableMap<String, SubmittedRequest> = mutableMapOf()
private var queuedRequests: MutableMap<String, SubmittedRequest> = mutableMapOf()

private var sessionManager: SessionManager

private var isMetaMaskReady = false
private var sentOriginatorInfo = false
private var requestedBindService = false

var hasSubmittedRequests: Boolean = submittedRequests.isEmpty()
var hasRequestJobs: Boolean = requestJobs.isEmpty()
var hasQueuedRequests: Boolean = queuedRequests.isEmpty()

var enableDebug: Boolean = false
set(value) {
field = value
tracker.enableDebug = value
}

init {
sessionManager = SessionManager(KeyStorage(context))
sessionId = sessionManager.sessionId
// in case not yet initialised in SessionManager
sessionManager.onInitialized = {
sessionId = sessionManager.sessionId
}
Expand Down Expand Up @@ -142,7 +149,7 @@ internal class CommunicationClient(context: Context, callback: EthereumEventCall
sentOriginatorInfo = false
}

private fun handleMessage(message: String) {
fun handleMessage(message: String) {
val jsonString = keyExchange.decrypt(message)
val json = JSONObject(jsonString)

Expand Down Expand Up @@ -187,7 +194,7 @@ internal class CommunicationClient(context: Context, callback: EthereumEventCall
}
}

private fun resumeRequestJobs() {
fun resumeRequestJobs() {
logger.log("CommunicationClient:: Resuming jobs")

while (requestJobs.isNotEmpty()) {
Expand All @@ -196,12 +203,12 @@ internal class CommunicationClient(context: Context, callback: EthereumEventCall
}
}

private fun queueRequestJob(job: () -> Unit) {
fun queueRequestJob(job: () -> Unit) {
requestJobs.add(job)
logger.log("CommunicationClient:: Queued job")
}

private fun clearPendingRequests() {
fun clearPendingRequests() {
queuedRequests = mutableMapOf()
requestJobs = mutableListOf()
submittedRequests = mutableMapOf()
Expand Down Expand Up @@ -311,7 +318,7 @@ internal class CommunicationClient(context: Context, callback: EthereumEventCall
}
}

private fun handleError(error: String, id: String): Boolean {
fun handleError(error: String, id: String): Boolean {
if (error.isEmpty()) {
return false
}
Expand All @@ -329,7 +336,7 @@ internal class CommunicationClient(context: Context, callback: EthereumEventCall
return true
}

private fun completeRequest(id: String, result: Result) {
fun completeRequest(id: String, result: Result) {
if (queuedRequests[id] != null) {
queuedRequests[id]?.callback?.invoke(result)
queuedRequests.remove(id)
Expand All @@ -338,13 +345,13 @@ internal class CommunicationClient(context: Context, callback: EthereumEventCall
submittedRequests.remove(id)
}

private fun handleEvent(event: JSONObject) {
fun handleEvent(event: JSONObject) {
when (event.optString("method")) {
EthereumMethod.METAMASK_ACCOUNTS_CHANGED.value -> {
val accountsJson = event.optString("params")
val accounts: List<String> = Gson().fromJson(accountsJson, object : TypeToken<List<String>>() {}.type)
accounts.getOrNull(0)?.let { account ->
logger.error("CommunicationClient:: Event Updated to account $account")
logger.log("CommunicationClient:: Event Updated to account $account")
updateAccount(account)
}
}
Expand All @@ -362,17 +369,17 @@ internal class CommunicationClient(context: Context, callback: EthereumEventCall
}
}

private fun updateAccount(account: String) {
fun updateAccount(account: String) {
val callback = ethereumEventCallbackRef.get()
callback?.updateAccount(account)
}

private fun updateChainId(chainId: String) {
fun updateChainId(chainId: String) {
val callback = ethereumEventCallbackRef.get()
callback?.updateChainId(chainId)
}

private fun handleKeyExchange(message: String) {
fun handleKeyExchange(message: String) {
val json = JSONObject(message)

val keyExchangeStep = json.optString(KeyExchange.TYPE, KeyExchangeMessageType.KEY_HANDSHAKE_SYN.name)
Expand All @@ -396,7 +403,7 @@ internal class CommunicationClient(context: Context, callback: EthereumEventCall
}
}

private fun sendMessage(message: String) {
fun sendMessage(message: String) {
val bundle = Bundle().apply {
putString(MESSAGE, message)
}
Expand Down Expand Up @@ -439,7 +446,7 @@ internal class CommunicationClient(context: Context, callback: EthereumEventCall
}
}

private fun processRequest(request: RpcRequest, callback: (Result) -> Unit) {
fun processRequest(request: RpcRequest, callback: (Result) -> Unit) {
logger.log("CommunicationClient:: sending request $request")
if (queuedRequests[request.id] != null) {
queuedRequests.remove(request.id)
Expand All @@ -455,7 +462,7 @@ internal class CommunicationClient(context: Context, callback: EthereumEventCall
sendMessage(messageJson)
}

private fun sendOriginatorInfo() {
fun sendOriginatorInfo() {
if (sentOriginatorInfo) { return }
sentOriginatorInfo = true

Expand All @@ -479,7 +486,7 @@ internal class CommunicationClient(context: Context, callback: EthereumEventCall
sendMessage(messageJson)
}

private fun isQA(): Boolean {
fun isQA(): Boolean {
if (Build.VERSION.SDK_INT < 33 ) { // i.e Build.VERSION_CODES.TIRAMISU
return false
}
Expand All @@ -494,7 +501,7 @@ internal class CommunicationClient(context: Context, callback: EthereumEventCall
}
}

private fun bindService() {
fun bindService() {
logger.log("CommunicationClient:: Binding service")
requestedBindService = true

Expand Down Expand Up @@ -538,7 +545,7 @@ internal class CommunicationClient(context: Context, callback: EthereumEventCall
sendKeyExchangeMesage(keyExchange.toString())
}

private fun sendKeyExchangeMesage(message: String) {
fun sendKeyExchangeMesage(message: String) {
val bundle = Bundle().apply {
putString(KEY_EXCHANGE, message)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package io.metamask.androidsdk

import android.content.Context

class CommunicationClientModule(private val context: Context): CommunicationClientModuleInterface {
override fun provideKeyStorage(): KeyStorage {
return KeyStorage(context)
}

override fun provideSessionManager(keyStorage: SecureStorage): SessionManager {
return SessionManager(keyStorage)
}

override fun provideKeyExchange(): KeyExchange {
return KeyExchange()
}

override fun provideLogger(): Logger {
return DefaultLogger
}

override fun provideCommunicationClient(callback: EthereumEventCallback?): CommunicationClient {
val keyStorage = provideKeyStorage()
val sessionManager = provideSessionManager(keyStorage)
val keyExchange = provideKeyExchange()
val logger = provideLogger()
return CommunicationClient(context, callback, sessionManager, keyExchange, logger)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package io.metamask.androidsdk

interface CommunicationClientModuleInterface {
fun provideKeyStorage(): SecureStorage
fun provideSessionManager(keyStorage: SecureStorage): SessionManager
fun provideKeyExchange(): KeyExchange
fun provideLogger(): Logger
fun provideCommunicationClient(callback: EthereumEventCallback?): CommunicationClient
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,31 @@ import android.content.Intent
import android.net.Uri
import androidx.lifecycle.LiveData
import androidx.lifecycle.MutableLiveData
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.SupervisorJob
import kotlinx.coroutines.launch
import java.lang.ref.WeakReference

private const val METAMASK_DEEPLINK = "https://metamask.app.link"
private const val METAMASK_BIND_DEEPLINK = "$METAMASK_DEEPLINK/bind"
private const val DEFAULT_SESSION_DURATION: Long = 30 * 24 * 3600 // 30 days default

class Ethereum (
private val context: Context,
private val dappMetadata: DappMetadata,
sdkOptions: SDKOptions? = null,
private val logger: Logger = DefaultLogger
private val logger: Logger = DefaultLogger,
private val communicationClientModule: CommunicationClientModule = CommunicationClientModule(context)
): EthereumEventCallback {
private var connectRequestSent = false

private val communicationClient: CommunicationClient? by lazy {
CommunicationClient(context, null)
communicationClientModule.provideCommunicationClient(this)
}

private val storage = communicationClientModule.provideKeyStorage()
private val coroutineScope = CoroutineScope(SupervisorJob() + Dispatchers.Main)

private val infuraProvider: InfuraProvider? = sdkOptions?.let {
if (it.infuraAPIKey.isNotEmpty()) {
InfuraProvider(it.infuraAPIKey)
Expand Down Expand Up @@ -51,13 +59,31 @@ class Ethereum (

init {
updateSessionDuration()
initializeEthereumState()
}

private fun initializeEthereumState() {
coroutineScope.launch(Dispatchers.IO) {
try {
val account = storage.getValue(key = SessionManager.SESSION_ACCOUNT_KEY, file = SessionManager.SESSION_CONFIG_FILE)
val chainId = storage.getValue(key = SessionManager.SESSION_CHAIN_ID_KEY, file = SessionManager.SESSION_CONFIG_FILE)
_ethereumState.postValue(
currentEthereumState.copy(
selectedAddress = account ?: "",
chainId = chainId ?: ""
)
)
} catch (e: Exception) {
logger.error(e.localizedMessage)
}
}
}

fun enableDebug(enable: Boolean) = apply {
this.enableDebug = enable
}

private var sessionDuration: Long = DEFAULT_SESSION_DURATION
private var sessionDuration: Long = SessionManager.DEFAULT_SESSION_DURATION

override fun updateAccount(account: String) {
logger.log("Ethereum:: Selected account changed: $account")
Expand All @@ -67,6 +93,9 @@ class Ethereum (
sessionId = communicationClient?.sessionId ?: ""
)
)
if (account.isNotEmpty()) {
storage.putValue(account, key = SessionManager.SESSION_ACCOUNT_KEY, SessionManager.SESSION_CONFIG_FILE)
}
}

override fun updateChainId(newChainId: String) {
Expand All @@ -77,17 +106,21 @@ class Ethereum (
sessionId = communicationClient?.sessionId ?: ""
)
)
if (newChainId.isNotEmpty()) {
storage.putValue(newChainId, key = SessionManager.SESSION_CHAIN_ID_KEY, SessionManager.SESSION_CONFIG_FILE)
}
}

// Set session duration in seconds
fun updateSessionDuration(duration: Long = DEFAULT_SESSION_DURATION) = apply {
fun updateSessionDuration(duration: Long = SessionManager.DEFAULT_SESSION_DURATION) = apply {
sessionDuration = duration
communicationClient?.updateSessionDuration(duration)
}

// Clear persisted session. Subsequent MetaMask connection request will need approval
fun clearSession() {
disconnect(true)
storage.clear(SessionManager.SESSION_CONFIG_FILE)
}

fun connect(callback: ((Result) -> Unit)? = null) {
Expand Down Expand Up @@ -308,7 +341,7 @@ class Ethereum (
fun sendRequest(request: RpcRequest, callback: ((Result) -> Unit)? = null) {
logger.log("Ethereum:: Sending request $request")

if (!connectRequestSent) {
if (!connectRequestSent && selectedAddress.isEmpty()) {
requestAccounts {
sendRequest(request, callback)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,10 @@ constructor(
ethereumRequest(method = EthereumMethod.SWITCH_ETHEREUM_CHAIN, params = listOf(mapOf("chainId" to targetChainId)))

override fun disconnect(clearSession: Boolean) {
ethereum.disconnect(clearSession)
if (clearSession) {
ethereum.clearSession()
} else {
ethereum.disconnect()
}
}
}
Loading

0 comments on commit aad8a58

Please sign in to comment.