diff --git a/packages/starknet-snap/src/state/transaction-state-manager.test.ts b/packages/starknet-snap/src/state/transaction-state-manager.test.ts index aee2264c..bc0c90e9 100644 --- a/packages/starknet-snap/src/state/transaction-state-manager.test.ts +++ b/packages/starknet-snap/src/state/transaction-state-manager.test.ts @@ -6,6 +6,7 @@ import { } from 'starknet'; import { generateTransactions } from '../__tests__/helper'; +import type { V2Transaction } from '../types/snapState'; import { TransactionDataVersion } from '../types/snapState'; import { PRELOADED_TOKENS } from '../utils/constants'; import { mockAcccounts, mockState } from './__tests__/helper'; @@ -137,7 +138,7 @@ describe('TransactionStateManager', () => { const chainId = constants.StarknetChainId.SN_SEPOLIA; const { txns: [legacyData, ...newData], - getDataSpy, + state, } = await prepareMockData(chainId); const legacyTxn = { @@ -155,7 +156,7 @@ describe('TransactionStateManager', () => { failureReason: legacyData.failureReason, }; // simulate the data source return the legacy data and new data - getDataSpy.mockResolvedValue(newData.concat([legacyTxn])); + state.transactions = newData.concat([legacyTxn]); const stateManager = new TransactionStateManager(); @@ -170,16 +171,23 @@ describe('TransactionStateManager', () => { const { txns, stateManager } = await prepareFindTransctions(); const tokenAddress1 = PRELOADED_TOKENS.map((token) => token.address)[0]; const tokenAddress2 = PRELOADED_TOKENS.map((token) => token.address)[2]; + const contractAddress = [ + tokenAddress1.toLowerCase(), + tokenAddress2.toLowerCase(), + ]; + const contractAddressSet = new Set(contractAddress); const result = await stateManager.findTransactions({ - contractAddress: [tokenAddress1, tokenAddress2], + contractAddress, }); expect(result).toStrictEqual( txns.filter( - (txn) => - txn.contractAddress === tokenAddress1 || - txn.contractAddress === tokenAddress2, + (txn: V2Transaction) => + txn.accountCalls && + Object.keys(txn.accountCalls).some((contract) => + contractAddressSet.has(contract.toLowerCase()), + ), ), ); }); @@ -246,8 +254,9 @@ describe('TransactionStateManager', () => { TransactionExecutionStatus.REJECTED, ]; const contractAddressCond = [ - PRELOADED_TOKENS.map((token) => token.address)[0], + PRELOADED_TOKENS.map((token) => token.address.toLowerCase())[0], ]; + const contractAddressSet = new Set(contractAddressCond); const timestampCond = txns[5].timestamp * 1000; const chainIdCond = [ txns[0].chainId as unknown as constants.StarknetChainId, @@ -263,7 +272,7 @@ describe('TransactionStateManager', () => { }); expect(result).toStrictEqual( - txns.filter((txn) => { + txns.filter((txn: V2Transaction) => { return ( (finalityStatusCond.includes( txn.finalityStatus as unknown as TransactionFinalityStatus, @@ -272,7 +281,10 @@ describe('TransactionStateManager', () => { txn.executionStatus as unknown as TransactionExecutionStatus, )) && txn.timestamp >= txns[5].timestamp && - contractAddressCond.includes(txn.contractAddress) && + txn.accountCalls && + Object.keys(txn.accountCalls).some((contract) => + contractAddressSet.has(contract.toLowerCase()), + ) && chainIdCond.includes( txn.chainId as unknown as constants.StarknetChainId, ) && diff --git a/packages/starknet-snap/src/state/transaction-state-manager.ts b/packages/starknet-snap/src/state/transaction-state-manager.ts index 597a98e1..0973ce82 100644 --- a/packages/starknet-snap/src/state/transaction-state-manager.ts +++ b/packages/starknet-snap/src/state/transaction-state-manager.ts @@ -16,7 +16,6 @@ import { ChainIdFilter as BaseChainIdFilter, StringFllter, Filter, - MultiFilter, } from './filter'; import { StateManager, StateManagerError } from './state-manager'; @@ -27,21 +26,17 @@ export class ChainIdFilter implements ITxFilter {} export class ContractAddressFilter - extends MultiFilter + extends StringFllter implements ITxFilter { - protected _prepareSearch(search: string[]): void { - this.search = new Set(search?.map((val) => val)); - } - protected _apply(data: Transaction): boolean { const txn = data as V2Transaction; const { accountCalls } = txn; if (!accountCalls) { return false; } - for (const contract in this.search) { - if (accountCalls[contract]) { + for (const contract in accountCalls) { + if (this.search.has(contract.toLowerCase())) { return true; } }