From 50735d59188c093de293836329bed474cd4c815b Mon Sep 17 00:00:00 2001 From: Michael Absolon Date: Tue, 19 Nov 2024 20:10:35 +0100 Subject: [PATCH] =?UTF-8?q?fix(sdk):=20Properly=20disconnect=20auto-create?= =?UTF-8?q?d=20API=20client=20=E2=9C=85?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../playground/src/components/XcmTransfer.tsx | 14 +- packages/sdk/src/api/IPolkadotApi.ts | 3 + .../assets/asset-claim/assetClaim.test.ts | 3 +- .../pallets/assets/asset-claim/assetClaim.ts | 34 ++-- .../assets/balance/getAssetBalance.test.ts | 21 ++- .../pallets/assets/balance/getAssetBalance.ts | 22 ++- .../assets/balance/getBalanceForeign.test.ts | 3 +- .../assets/balance/getBalanceForeign.ts | 22 ++- .../assets/balance/getBalanceNative.test.ts | 9 +- .../assets/balance/getBalanceNative.ts | 13 +- .../assets/getExistentialDeposit.test.ts | 12 +- .../pallets/assets/getExistentialDeposit.ts | 4 +- .../assets/getOriginFeeDetails.test.ts | 10 +- .../src/pallets/assets/getOriginFeeDetails.ts | 40 ++-- .../transfer-info/getTransferInfo.test.ts | 24 +-- .../assets/transfer-info/getTransferInfo.ts | 136 +++++++------- .../keepAlive/checkKeepAlive.test.ts | 4 +- .../xcmPallet/keepAlive/checkKeepAlive.ts | 5 + .../src/pallets/xcmPallet/transfer.test.ts | 3 +- .../sdk/src/pallets/xcmPallet/transfer.ts | 171 ++++++++++-------- packages/sdk/src/papi/PapiApi.test.ts | 37 ++++ packages/sdk/src/papi/PapiApi.ts | 26 +++ packages/sdk/src/pjs/PolkadotJsApi.test.ts | 41 ++++- packages/sdk/src/pjs/PolkadotJsApi.ts | 25 +++ packages/sdk/src/utils/isPjsClient.test.ts | 38 ++++ packages/sdk/src/utils/isPjsClient.ts | 8 + 26 files changed, 498 insertions(+), 230 deletions(-) create mode 100644 packages/sdk/src/utils/isPjsClient.test.ts create mode 100644 packages/sdk/src/utils/isPjsClient.ts diff --git a/apps/playground/src/components/XcmTransfer.tsx b/apps/playground/src/components/XcmTransfer.tsx index 977735e1..1164ba63 100644 --- a/apps/playground/src/components/XcmTransfer.tsx +++ b/apps/playground/src/components/XcmTransfer.tsx @@ -104,14 +104,14 @@ const XcmTransfer = () => { const signer = await getSigner(); - try { - const Sdk = - apiType === "PAPI" - ? await import("@paraspell/sdk/papi") - : await import("@paraspell/sdk"); + const Sdk = + apiType === "PAPI" + ? await import("@paraspell/sdk/papi") + : await import("@paraspell/sdk"); - const api = await Sdk.createApiInstanceForNode(from); + const api = await Sdk.createApiInstanceForNode(from); + try { let tx: Extrinsic | TPapiTransaction; if (useApi) { tx = await getTxFromApi( @@ -178,6 +178,8 @@ const XcmTransfer = () => { } } finally { setLoading(false); + if ("disconnect" in api) await api.disconnect(); + else api.destroy(); } }; diff --git a/packages/sdk/src/api/IPolkadotApi.ts b/packages/sdk/src/api/IPolkadotApi.ts index 426889e0..b29bd442 100644 --- a/packages/sdk/src/api/IPolkadotApi.ts +++ b/packages/sdk/src/api/IPolkadotApi.ts @@ -26,4 +26,7 @@ export interface IPolkadotApi { getFromStorage(key: string): Promise clone(): IPolkadotApi createApiForNode(node: TNodeWithRelayChains): Promise> + setDisconnectAllowed(allowed: boolean): void + getDisconnectAllowed(): boolean + disconnect(): Promise } diff --git a/packages/sdk/src/pallets/assets/asset-claim/assetClaim.test.ts b/packages/sdk/src/pallets/assets/asset-claim/assetClaim.test.ts index 70177ac1..418043b0 100644 --- a/packages/sdk/src/pallets/assets/asset-claim/assetClaim.test.ts +++ b/packages/sdk/src/pallets/assets/asset-claim/assetClaim.test.ts @@ -21,7 +21,8 @@ vi.mock('./buildClaimAssetsInput', () => ({ describe('claimAssets', () => { const apiMock = { init: vi.fn(), - callTxMethod: vi.fn() + callTxMethod: vi.fn(), + disconnect: vi.fn() } as unknown as IPolkadotApi const nodeMock = 'Acala' diff --git a/packages/sdk/src/pallets/assets/asset-claim/assetClaim.ts b/packages/sdk/src/pallets/assets/asset-claim/assetClaim.ts index 45d3354a..13517f1f 100644 --- a/packages/sdk/src/pallets/assets/asset-claim/assetClaim.ts +++ b/packages/sdk/src/pallets/assets/asset-claim/assetClaim.ts @@ -2,6 +2,7 @@ import type { TPallet } from '../../../types' import { type TSerializedApiCall } from '../../../types' import { type TAssetClaimOptions } from '../../../types/TAssetClaim' import { isRelayChain } from '../../../utils' +import { isPjsClient } from '../../../utils/isPjsClient' import { buildClaimAssetsInput } from './buildClaimAssetsInput' export const claimAssets = async ( @@ -11,23 +12,28 @@ export const claimAssets = async ( await api.init(node) - const args = buildClaimAssetsInput(options) + try { + const args = buildClaimAssetsInput(options) - const module: TPallet = isRelayChain(node) ? 'XcmPallet' : 'PolkadotXcm' + const module: TPallet = isRelayChain(node) ? 'XcmPallet' : 'PolkadotXcm' - const call = { - module, - section: 'claim_assets', - parameters: args - } + const call = { + module, + section: 'claim_assets', + parameters: args + } - if (serializedApiCallEnabled === true) { - return { - ...call, - // Keep compatible with the old SerializedCall type - parameters: Object.values(args) + if (serializedApiCallEnabled === true) { + return { + ...call, + parameters: Object.values(args) + } } - } - return api.callTxMethod(call) + return api.callTxMethod(call) + } finally { + if (isPjsClient(api)) { + await api.disconnect() + } + } } diff --git a/packages/sdk/src/pallets/assets/balance/getAssetBalance.test.ts b/packages/sdk/src/pallets/assets/balance/getAssetBalance.test.ts index 65d46813..8e1b5f3f 100644 --- a/packages/sdk/src/pallets/assets/balance/getAssetBalance.test.ts +++ b/packages/sdk/src/pallets/assets/balance/getAssetBalance.test.ts @@ -3,8 +3,8 @@ import { createApiInstanceForNode } from '../../../utils' import { getNativeAssetSymbol } from '../assets' import { getAssetBalance } from './getAssetBalance' import type { ApiPromise } from '@polkadot/api' -import { getBalanceNative } from './getBalanceNative' -import { getBalanceForeign } from './getBalanceForeign' +import { getBalanceNativeInternal } from './getBalanceNative' +import { getBalanceForeignInternal } from './getBalanceForeign' import type { IPolkadotApi } from '../../../api/IPolkadotApi' import type { Extrinsic } from '../../../pjs/types' @@ -17,11 +17,11 @@ vi.mock('../assets', () => ({ })) vi.mock('./getBalanceNative', () => ({ - getBalanceNative: vi.fn() + getBalanceNativeInternal: vi.fn() })) vi.mock('./getBalanceForeign', () => ({ - getBalanceForeign: vi.fn() + getBalanceForeignInternal: vi.fn() })) describe('getAssetBalance', () => { @@ -29,7 +29,8 @@ describe('getAssetBalance', () => { beforeEach(() => { apiMock = { - init: vi.fn() + init: vi.fn(), + disconnect: vi.fn() } as unknown as IPolkadotApi vi.mocked(createApiInstanceForNode).mockResolvedValue(apiMock) }) @@ -39,11 +40,11 @@ describe('getAssetBalance', () => { const node = 'Polkadot' const currency = { symbol: 'DOT' } vi.mocked(getNativeAssetSymbol).mockReturnValue('DOT') - vi.mocked(getBalanceNative).mockResolvedValue(BigInt(1000)) + vi.mocked(getBalanceNativeInternal).mockResolvedValue(BigInt(1000)) const result = await getAssetBalance({ api: apiMock, address: account, node, currency }) expect(result).toEqual(BigInt(1000)) - expect(getBalanceNative).toHaveBeenCalledWith({ address: account, node, api: apiMock }) + expect(getBalanceNativeInternal).toHaveBeenCalledWith({ address: account, node, api: apiMock }) }) it('returns the foreign asset balance when the currency symbol does not match the native symbol', async () => { @@ -51,11 +52,11 @@ describe('getAssetBalance', () => { const node = 'Kusama' const currency = { symbol: 'KSM' } vi.mocked(getNativeAssetSymbol).mockReturnValue('DOT') - vi.mocked(getBalanceForeign).mockResolvedValue(BigInt(200)) + vi.mocked(getBalanceForeignInternal).mockResolvedValue(BigInt(200)) const result = await getAssetBalance({ api: apiMock, address: account, node, currency }) expect(result).toEqual(BigInt(200)) - expect(getBalanceForeign).toHaveBeenCalledWith({ + expect(getBalanceForeignInternal).toHaveBeenCalledWith({ address: account, node, currency, @@ -68,7 +69,7 @@ describe('getAssetBalance', () => { const node = 'Kusama' const currency = { symbol: 'XYZ' } vi.mocked(getNativeAssetSymbol).mockReturnValue('DOT') - vi.mocked(getBalanceForeign).mockResolvedValue(BigInt(0)) + vi.mocked(getBalanceForeignInternal).mockResolvedValue(BigInt(0)) const result = await getAssetBalance({ api: apiMock, address: account, node, currency }) expect(result).toEqual(BigInt(0)) diff --git a/packages/sdk/src/pallets/assets/balance/getAssetBalance.ts b/packages/sdk/src/pallets/assets/balance/getAssetBalance.ts index fb0bb734..caaa5d1e 100644 --- a/packages/sdk/src/pallets/assets/balance/getAssetBalance.ts +++ b/packages/sdk/src/pallets/assets/balance/getAssetBalance.ts @@ -1,10 +1,10 @@ import type { TNodePolkadotKusama } from '../../../types' import { getNativeAssetSymbol } from '../assets' -import { getBalanceNative } from './getBalanceNative' -import { getBalanceForeign } from './getBalanceForeign' +import { getBalanceNativeInternal } from './getBalanceNative' +import { getBalanceForeignInternal } from './getBalanceForeign' import type { TGetAssetBalanceOptions } from '../../../types/TBalance' -export const getAssetBalance = async ({ +export const getAssetBalanceInternal = async ({ address, node, currency, @@ -14,16 +14,28 @@ export const getAssetBalance = async ({ const isNativeSymbol = 'symbol' in currency ? getNativeAssetSymbol(node) === currency.symbol : false + return isNativeSymbol - ? await getBalanceNative({ + ? await getBalanceNativeInternal({ address, node, api }) - : ((await getBalanceForeign({ + : ((await getBalanceForeignInternal({ address, node: node as TNodePolkadotKusama, api, currency })) ?? BigInt(0)) } + +export const getAssetBalance = async ( + options: TGetAssetBalanceOptions +): Promise => { + const { api } = options + try { + return await getAssetBalanceInternal(options) + } finally { + await api.disconnect() + } +} diff --git a/packages/sdk/src/pallets/assets/balance/getBalanceForeign.test.ts b/packages/sdk/src/pallets/assets/balance/getBalanceForeign.test.ts index 73cd65c8..cb0b1c7a 100644 --- a/packages/sdk/src/pallets/assets/balance/getBalanceForeign.test.ts +++ b/packages/sdk/src/pallets/assets/balance/getBalanceForeign.test.ts @@ -30,7 +30,8 @@ describe('getBalanceForeign', () => { const mockApi = { init: vi.fn(), getBalanceForeignXTokens: vi.fn(), - getBalanceForeign: vi.fn() + getBalanceForeign: vi.fn(), + disconnect: vi.fn() } as unknown as IPolkadotApi beforeEach(() => { diff --git a/packages/sdk/src/pallets/assets/balance/getBalanceForeign.ts b/packages/sdk/src/pallets/assets/balance/getBalanceForeign.ts index fb6aef44..42d48105 100644 --- a/packages/sdk/src/pallets/assets/balance/getBalanceForeign.ts +++ b/packages/sdk/src/pallets/assets/balance/getBalanceForeign.ts @@ -1,11 +1,11 @@ import { getDefaultPallet } from '../../pallets' import { getAssetBySymbolOrId } from '../getAssetBySymbolOrId' import { getBalanceForeignPolkadotXcm } from './getBalanceForeignPolkadotXcm' +import { getBalanceForeignXTokens } from './getBalanceForeignXTokens' import type { TGetBalanceForeignOptions } from '../../../types/TBalance' import { InvalidCurrencyError } from '../../../errors' -import { getBalanceForeignXTokens } from './getBalanceForeignXTokens' -export const getBalanceForeign = async ({ +export const getBalanceForeignInternal = async ({ address, node, currency, @@ -21,10 +21,24 @@ export const getBalanceForeign = async ({ throw new InvalidCurrencyError(`Asset ${JSON.stringify(currency)} not found on ${node}`) } - if (getDefaultPallet(node) === 'XTokens') { + const defaultPallet = getDefaultPallet(node) + + if (defaultPallet === 'XTokens') { return await getBalanceForeignXTokens(api, node, address, asset) - } else if (getDefaultPallet(node) === 'PolkadotXcm') { + } else if (defaultPallet === 'PolkadotXcm') { return await getBalanceForeignPolkadotXcm(api, node, address, asset) } + throw new Error('Unsupported pallet') } + +export const getBalanceForeign = async ( + options: TGetBalanceForeignOptions +): Promise => { + const { api } = options + try { + return await getBalanceForeignInternal(options) + } finally { + await api.disconnect() + } +} diff --git a/packages/sdk/src/pallets/assets/balance/getBalanceNative.test.ts b/packages/sdk/src/pallets/assets/balance/getBalanceNative.test.ts index cbe1f7ab..886f74b6 100644 --- a/packages/sdk/src/pallets/assets/balance/getBalanceNative.test.ts +++ b/packages/sdk/src/pallets/assets/balance/getBalanceNative.test.ts @@ -11,16 +11,15 @@ vi.mock('../../../utils', () => ({ describe('getBalanceNative', () => { const apiMock = { init: vi.fn(), - getBalanceNative: vi.fn() + getBalanceNative: vi.fn(), + disconnect: vi.fn() } as unknown as IPolkadotApi it('returns the correct balance when API is provided', async () => { const address = '0x123' const node = 'Polkadot' - const apiMock = { - init: vi.fn(), - getBalanceNative: vi.fn().mockResolvedValue(BigInt(1000)) - } as unknown as IPolkadotApi + + vi.spyOn(apiMock, 'getBalanceNative').mockResolvedValue(BigInt(1000)) const balance = await getBalanceNative({ address, diff --git a/packages/sdk/src/pallets/assets/balance/getBalanceNative.ts b/packages/sdk/src/pallets/assets/balance/getBalanceNative.ts index e9104ced..8a07657e 100644 --- a/packages/sdk/src/pallets/assets/balance/getBalanceNative.ts +++ b/packages/sdk/src/pallets/assets/balance/getBalanceNative.ts @@ -1,6 +1,6 @@ import type { TGetBalanceNativeOptions } from '../../../types/TBalance' -export const getBalanceNative = async ({ +export const getBalanceNativeInternal = async ({ address, node, api @@ -8,3 +8,14 @@ export const getBalanceNative = async ({ await api.init(node) return await api.getBalanceNative(address) } + +export const getBalanceNative = async ( + options: TGetBalanceNativeOptions +): Promise => { + const { api } = options + try { + return await getBalanceNativeInternal(options) + } finally { + await api.disconnect() + } +} diff --git a/packages/sdk/src/pallets/assets/getExistentialDeposit.test.ts b/packages/sdk/src/pallets/assets/getExistentialDeposit.test.ts index 5370e5a5..922fd4d7 100644 --- a/packages/sdk/src/pallets/assets/getExistentialDeposit.test.ts +++ b/packages/sdk/src/pallets/assets/getExistentialDeposit.test.ts @@ -5,18 +5,20 @@ import { getMaxNativeTransferableAmount } from './getExistentialDeposit' import * as edsMapJson from '../../maps/existential-deposits.json' -import { getBalanceNative } from './balance/getBalanceNative' +import { getBalanceNativeInternal } from './balance/getBalanceNative' import type { TNodeDotKsmWithRelayChains } from '../../types' import type { IPolkadotApi } from '../../api/IPolkadotApi' import type { ApiPromise } from '@polkadot/api' import type { Extrinsic } from '../../pjs/types' vi.mock('./balance/getBalanceNative', () => ({ - getBalanceNative: vi.fn() + getBalanceNativeInternal: vi.fn() })) describe('Existential Deposit and Transferable Amounts', () => { - const apiMock = {} as unknown as IPolkadotApi + const apiMock = { + disconnect: vi.fn() + } as unknown as IPolkadotApi const mockPalletsMap = edsMapJson as { [key: string]: string } const mockNode: TNodeDotKsmWithRelayChains = 'Polkadot' const mockAddress = '1A1zP1eP5QGefi2DMPTfTL5SLmv7DivfNa' @@ -36,7 +38,7 @@ describe('Existential Deposit and Transferable Amounts', () => { it('should return the correct maximum native transferable amount', async () => { const mockBalance = BigInt(1000000000000) - vi.mocked(getBalanceNative).mockResolvedValue(mockBalance) + vi.mocked(getBalanceNativeInternal).mockResolvedValue(mockBalance) const ed = getExistentialDeposit(mockNode) const expectedMaxTransferableAmount = mockBalance - ed - ed / BigInt(10) @@ -50,7 +52,7 @@ describe('Existential Deposit and Transferable Amounts', () => { it('should return 0 for maximum native transferable amount if balance is too low', async () => { const mockBalance = BigInt(5000) - vi.mocked(getBalanceNative).mockResolvedValue(mockBalance) + vi.mocked(getBalanceNativeInternal).mockResolvedValue(mockBalance) const result = await getMaxNativeTransferableAmount(apiMock, mockAddress, mockNode) diff --git a/packages/sdk/src/pallets/assets/getExistentialDeposit.ts b/packages/sdk/src/pallets/assets/getExistentialDeposit.ts index 56fa900a..d2adb0cd 100644 --- a/packages/sdk/src/pallets/assets/getExistentialDeposit.ts +++ b/packages/sdk/src/pallets/assets/getExistentialDeposit.ts @@ -1,6 +1,6 @@ import { type TNodeDotKsmWithRelayChains, type TEdJsonMap } from '../../types' import * as edsMapJson from '../../maps/existential-deposits.json' assert { type: 'json' } -import { getBalanceNative } from './balance/getBalanceNative' +import { getBalanceNativeInternal } from './balance/getBalanceNative' import type { IPolkadotApi } from '../../api/IPolkadotApi' const palletsMap = edsMapJson as TEdJsonMap @@ -19,7 +19,7 @@ export const getMaxNativeTransferableAmount = async ( node: TNodeDotKsmWithRelayChains ): Promise => { const ed = getExistentialDeposit(node) - const nativeBalance = await getBalanceNative({ + const nativeBalance = await getBalanceNativeInternal({ address, node, api diff --git a/packages/sdk/src/pallets/assets/getOriginFeeDetails.test.ts b/packages/sdk/src/pallets/assets/getOriginFeeDetails.test.ts index ed7dceef..8547cfcc 100644 --- a/packages/sdk/src/pallets/assets/getOriginFeeDetails.test.ts +++ b/packages/sdk/src/pallets/assets/getOriginFeeDetails.test.ts @@ -12,7 +12,9 @@ import type { IPolkadotApi } from '../../api/IPolkadotApi' import type { Extrinsic } from '../../pjs/types' const apiMock = { - calculateTransactionFee: vi.fn().mockResolvedValue(BigInt('1000000000')) + init: vi.fn(), + calculateTransactionFee: vi.fn().mockResolvedValue(BigInt('1000000000')), + disconnect: vi.fn() } as unknown as IPolkadotApi describe('getOriginFeeDetails', () => { @@ -27,7 +29,7 @@ describe('getOriginFeeDetails', () => { const minTransferableAmount = BigInt('1000000000000') const xcmFee = '1000000000' - vi.spyOn(balanceModule, 'getBalanceNative').mockResolvedValue(nativeBalance) + vi.spyOn(balanceModule, 'getBalanceNativeInternal').mockResolvedValue(nativeBalance) vi.spyOn(depositModule, 'getMinNativeTransferableAmount').mockReturnValue(minTransferableAmount) vi.spyOn(utilsModule, 'createApiInstanceForNode').mockResolvedValue({} as ApiPromise) @@ -83,7 +85,7 @@ describe('getOriginFeeDetails', () => { const minTransferableAmount = BigInt('1000000000000') const xcmFee = '1000000000' - vi.spyOn(balanceModule, 'getBalanceNative').mockResolvedValue(nativeBalance) + vi.spyOn(balanceModule, 'getBalanceNativeInternal').mockResolvedValue(nativeBalance) vi.spyOn(depositModule, 'getMinNativeTransferableAmount').mockReturnValue(minTransferableAmount) vi.spyOn(utilsModule, 'createApiInstanceForNode').mockResolvedValue({} as ApiPromise) @@ -139,7 +141,7 @@ describe('getOriginFeeDetails', () => { const minTransferableAmount = BigInt('1000000000000') const xcmFee = '1000000000' - vi.spyOn(balanceModule, 'getBalanceNative').mockResolvedValue(nativeBalance) + vi.spyOn(balanceModule, 'getBalanceNativeInternal').mockResolvedValue(nativeBalance) vi.spyOn(depositModule, 'getMinNativeTransferableAmount').mockReturnValue(minTransferableAmount) vi.spyOn(utilsModule, 'createApiInstanceForNode').mockResolvedValue({} as ApiPromise) diff --git a/packages/sdk/src/pallets/assets/getOriginFeeDetails.ts b/packages/sdk/src/pallets/assets/getOriginFeeDetails.ts index 7bfa905f..f8be5298 100644 --- a/packages/sdk/src/pallets/assets/getOriginFeeDetails.ts +++ b/packages/sdk/src/pallets/assets/getOriginFeeDetails.ts @@ -1,6 +1,6 @@ import type { TCurrencyCore, TNodePolkadotKusama, TOriginFeeDetails } from '../../types' import { type TNodeDotKsmWithRelayChains } from '../../types' -import { getBalanceNative } from './balance/getBalanceNative' +import { getBalanceNativeInternal } from './balance/getBalanceNative' import { getMinNativeTransferableAmount } from './getExistentialDeposit' import { isRelayChain } from '../../utils' import { Builder } from '../../builder' @@ -8,7 +8,7 @@ import type { IPolkadotApi } from '../../api/IPolkadotApi' import type { TGetOriginFeeDetailsOptions } from '../../types/TBalance' const createTx = async ( - originApi: IPolkadotApi, + api: IPolkadotApi, address: string, amount: string, currency: TCurrencyCore, @@ -16,19 +16,19 @@ const createTx = async ( destNode: TNodeDotKsmWithRelayChains ): Promise => { if (isRelayChain(originNode)) { - return await Builder(originApi) + return await Builder(api) .to(destNode as TNodePolkadotKusama) .amount(amount) .address(address) .build() } else if (isRelayChain(destNode)) { - return await Builder(originApi) + return await Builder(api) .from(originNode as TNodePolkadotKusama) .amount(amount) .address(address) .build() } else { - return await Builder(originApi) + return await Builder(api) .from(originNode as TNodePolkadotKusama) .to(destNode as TNodePolkadotKusama) .currency(currency) @@ -38,7 +38,7 @@ const createTx = async ( } } -export const getOriginFeeDetails = async ({ +export const getOriginFeeDetailsInternal = async ({ api, account, accountDestination, @@ -48,20 +48,20 @@ export const getOriginFeeDetails = async ({ destination, feeMarginPercentage = 10 }: TGetOriginFeeDetailsOptions): Promise => { - const nativeBalance = await getBalanceNative({ - address: account, - node: origin, - api - }) - - const minTransferableAmount = getMinNativeTransferableAmount(origin) + await api.init(origin) const tx = await createTx(api, accountDestination, amount, currency, origin, destination) const xcmFee = await api.calculateTransactionFee(tx, account) - const xcmFeeWithMargin = xcmFee + xcmFee / BigInt(feeMarginPercentage) + const nativeBalance = await getBalanceNativeInternal({ + address: account, + node: origin, + api + }) + + const minTransferableAmount = getMinNativeTransferableAmount(origin) const sufficientForXCM = nativeBalance - minTransferableAmount - xcmFeeWithMargin > 0 return { @@ -69,3 +69,15 @@ export const getOriginFeeDetails = async ({ xcmFee } } + +export const getOriginFeeDetails = async ( + options: TGetOriginFeeDetailsOptions +): Promise => { + const { api } = options + + try { + return await getOriginFeeDetailsInternal(options) + } finally { + await api.disconnect() + } +} diff --git a/packages/sdk/src/pallets/assets/transfer-info/getTransferInfo.test.ts b/packages/sdk/src/pallets/assets/transfer-info/getTransferInfo.test.ts index 1d92bf87..d151cbed 100644 --- a/packages/sdk/src/pallets/assets/transfer-info/getTransferInfo.test.ts +++ b/packages/sdk/src/pallets/assets/transfer-info/getTransferInfo.test.ts @@ -1,10 +1,10 @@ import { describe, it, expect, beforeEach, vi } from 'vitest' import { createApiInstanceForNode, determineRelayChainSymbol } from '../../../utils' import { getTransferInfo } from './getTransferInfo' -import { getBalanceNative } from '../balance/getBalanceNative' -import { getOriginFeeDetails } from '../getOriginFeeDetails' +import { getBalanceNativeInternal } from '../balance/getBalanceNative' +import { getOriginFeeDetailsInternal } from '../getOriginFeeDetails' import { getAssetBySymbolOrId } from '../getAssetBySymbolOrId' -import { getAssetBalance } from '../balance/getAssetBalance' +import { getAssetBalanceInternal } from '../balance/getAssetBalance' import { getExistentialDeposit, getMaxNativeTransferableAmount, @@ -37,19 +37,21 @@ vi.mock('../getAssetBySymbolOrId', () => ({ })) vi.mock('../balance/getAssetBalance', () => ({ - getAssetBalance: vi.fn() + getAssetBalanceInternal: vi.fn() })) vi.mock('../balance/getBalanceNative', () => ({ - getBalanceNative: vi.fn() + getBalanceNativeInternal: vi.fn() })) vi.mock('../getOriginFeeDetails', () => ({ - getOriginFeeDetails: vi.fn() + getOriginFeeDetailsInternal: vi.fn() })) const apiMock = { - init: vi.fn() + init: vi.fn(), + disconnect: vi.fn(), + setDisconnectAllowed: vi.fn() } as unknown as IPolkadotApi describe('getTransferInfo', () => { @@ -62,13 +64,13 @@ describe('getTransferInfo', () => { beforeEach(() => { vi.mocked(createApiInstanceForNode).mockResolvedValue({} as ApiPromise) - vi.mocked(getBalanceNative).mockResolvedValue(BigInt(5000)) - vi.mocked(getOriginFeeDetails).mockResolvedValue({ + vi.mocked(getBalanceNativeInternal).mockResolvedValue(BigInt(5000)) + vi.mocked(getOriginFeeDetailsInternal).mockResolvedValue({ xcmFee: BigInt(100), sufficientForXCM: true }) vi.mocked(getAssetBySymbolOrId).mockReturnValue({ symbol: 'DOT', assetId: '1' }) - vi.mocked(getAssetBalance).mockResolvedValue(BigInt(2000)) + vi.mocked(getAssetBalanceInternal).mockResolvedValue(BigInt(2000)) vi.mocked(getExistentialDeposit).mockReturnValue(BigInt('100')) vi.mocked(getMinNativeTransferableAmount).mockReturnValue(BigInt('10')) vi.mocked(getMaxNativeTransferableAmount).mockResolvedValue(BigInt(4000)) @@ -115,7 +117,7 @@ describe('getTransferInfo', () => { }) it('handles errors during API interactions', async () => { - vi.mocked(getBalanceNative).mockRejectedValue(new Error('API failure')) + vi.mocked(getBalanceNativeInternal).mockRejectedValue(new Error('API failure')) await expect( getTransferInfo({ api: apiMock, diff --git a/packages/sdk/src/pallets/assets/transfer-info/getTransferInfo.ts b/packages/sdk/src/pallets/assets/transfer-info/getTransferInfo.ts index 08920375..6248759f 100644 --- a/packages/sdk/src/pallets/assets/transfer-info/getTransferInfo.ts +++ b/packages/sdk/src/pallets/assets/transfer-info/getTransferInfo.ts @@ -2,15 +2,15 @@ import { InvalidCurrencyError } from '../../../errors' import type { TGetTransferInfoOptions, TTransferInfo } from '../../../types/TTransferInfo' import { determineRelayChainSymbol } from '../../../utils' import { getNativeAssetSymbol } from '../assets' -import { getAssetBalance } from '../balance/getAssetBalance' -import { getBalanceNative } from '../balance/getBalanceNative' +import { getAssetBalanceInternal } from '../balance/getAssetBalance' +import { getBalanceNativeInternal } from '../balance/getBalanceNative' import { getAssetBySymbolOrId } from '../getAssetBySymbolOrId' import { getExistentialDeposit, getMaxNativeTransferableAmount, getMinNativeTransferableAmount } from '../getExistentialDeposit' -import { getOriginFeeDetails } from '../getOriginFeeDetails' +import { getOriginFeeDetailsInternal } from '../getOriginFeeDetails' export const getTransferInfo = async ({ origin, @@ -19,74 +19,78 @@ export const getTransferInfo = async ({ accountDestination, currency, amount, - api: originApi + api }: TGetTransferInfoOptions): Promise => { - await originApi.init(origin) - const originBalance = await getBalanceNative({ - address: accountOrigin, - node: origin, - api: originApi - }) - const xcmFeeDetails = await getOriginFeeDetails({ - origin, - destination, - currency, - amount, - account: accountOrigin, - accountDestination, - api: originApi - }) + await api.init(origin) + api.setDisconnectAllowed(false) - const expectedBalanceAfterXCMDelivery = originBalance - xcmFeeDetails.xcmFee + try { + const originBalance = await getBalanceNativeInternal({ + address: accountOrigin, + node: origin, + api + }) - const asset = - getAssetBySymbolOrId(origin, currency, destination) ?? - (origin === 'AssetHubPolkadot' ? getAssetBySymbolOrId('Ethereum', currency, null) : null) - - if (!asset) { - throw new InvalidCurrencyError(`Asset ${JSON.stringify(currency)} not found on ${origin}`) - } - - return { - chain: { + const xcmFeeDetails = await getOriginFeeDetailsInternal({ origin, destination, - ecosystem: determineRelayChainSymbol(origin) - }, - currencyBalanceOrigin: { - balance: await getAssetBalance({ - api: originApi, - address: accountOrigin, - node: origin, - currency - }), - currency: asset?.symbol ?? '' - }, - originFeeBalance: { - balance: await getBalanceNative({ - address: accountOrigin, - node: origin, - api: originApi - }), - expectedBalanceAfterXCMFee: expectedBalanceAfterXCMDelivery, - xcmFee: xcmFeeDetails, - existentialDeposit: BigInt(getExistentialDeposit(origin) ?? 0), - asset: getNativeAssetSymbol(origin), - minNativeTransferableAmount: getMinNativeTransferableAmount(origin), - maxNativeTransferableAmount: await getMaxNativeTransferableAmount( - originApi, - accountOrigin, - origin - ) - }, - destinationFeeBalance: { - balance: await getBalanceNative({ - address: accountDestination, - node: destination, - api: originApi - }), - currency: getNativeAssetSymbol(destination), - existentialDeposit: getExistentialDeposit(destination) + currency, + amount, + account: accountOrigin, + accountDestination, + api + }) + + const expectedBalanceAfterXCMDelivery = originBalance - xcmFeeDetails.xcmFee + + const asset = + getAssetBySymbolOrId(origin, currency, destination) ?? + (origin === 'AssetHubPolkadot' ? getAssetBySymbolOrId('Ethereum', currency, null) : null) + + if (!asset) { + throw new InvalidCurrencyError(`Asset ${JSON.stringify(currency)} not found on ${origin}`) + } + + return { + chain: { + origin, + destination, + ecosystem: determineRelayChainSymbol(origin) + }, + currencyBalanceOrigin: { + balance: await getAssetBalanceInternal({ + api, + address: accountOrigin, + node: origin, + currency + }), + currency: asset?.symbol ?? '' + }, + originFeeBalance: { + balance: originBalance, + expectedBalanceAfterXCMFee: expectedBalanceAfterXCMDelivery, + xcmFee: xcmFeeDetails, + existentialDeposit: BigInt(getExistentialDeposit(origin) ?? 0), + asset: getNativeAssetSymbol(origin), + minNativeTransferableAmount: getMinNativeTransferableAmount(origin), + maxNativeTransferableAmount: await getMaxNativeTransferableAmount( + api, + accountOrigin, + origin + ) + }, + destinationFeeBalance: { + balance: await getBalanceNativeInternal({ + address: accountDestination, + node: destination, + api + }), + currency: getNativeAssetSymbol(destination), + existentialDeposit: getExistentialDeposit(destination) + } } + } finally { + api.setDisconnectAllowed(true) + await api.disconnect() } } diff --git a/packages/sdk/src/pallets/xcmPallet/keepAlive/checkKeepAlive.test.ts b/packages/sdk/src/pallets/xcmPallet/keepAlive/checkKeepAlive.test.ts index e023f28b..88e384ff 100644 --- a/packages/sdk/src/pallets/xcmPallet/keepAlive/checkKeepAlive.test.ts +++ b/packages/sdk/src/pallets/xcmPallet/keepAlive/checkKeepAlive.test.ts @@ -20,7 +20,9 @@ describe('checkKeepAlive', () => { const mockApi = { getApi: vi.fn().mockReturnValue({}), getBalanceNative: vi.fn().mockReturnValue(BigInt(1000000)), - calculateTransactionFee: vi.fn().mockResolvedValue(BigInt(100)) + calculateTransactionFee: vi.fn().mockResolvedValue(BigInt(100)), + getDisconnectAllowed: vi.fn().mockReturnValue(true), + setDisconnectAllowed: vi.fn() } as unknown as IPolkadotApi beforeEach(() => { diff --git a/packages/sdk/src/pallets/xcmPallet/keepAlive/checkKeepAlive.ts b/packages/sdk/src/pallets/xcmPallet/keepAlive/checkKeepAlive.ts index 26f2131c..ad821079 100644 --- a/packages/sdk/src/pallets/xcmPallet/keepAlive/checkKeepAlive.ts +++ b/packages/sdk/src/pallets/xcmPallet/keepAlive/checkKeepAlive.ts @@ -45,6 +45,9 @@ export const checkKeepAlive = async ({ originNode ?? determineRelayChain(destNode as TNodePolkadotKusama) ) + const oldDisconnectAllowed = originApi.getDisconnectAllowed() + originApi.setDisconnectAllowed(false) + const tx = await createTx( originApi, destApi, @@ -55,6 +58,8 @@ export const checkKeepAlive = async ({ destNode ) + originApi.setDisconnectAllowed(oldDisconnectAllowed) + if (tx === null) { throw new KeepAliveError('Transaction for XCM fee calculation could not be created.') } diff --git a/packages/sdk/src/pallets/xcmPallet/transfer.test.ts b/packages/sdk/src/pallets/xcmPallet/transfer.test.ts index 093b6d86..5d0936a9 100644 --- a/packages/sdk/src/pallets/xcmPallet/transfer.test.ts +++ b/packages/sdk/src/pallets/xcmPallet/transfer.test.ts @@ -47,7 +47,8 @@ const mockApi = { getApi: vi.fn(), setApi: vi.fn(), init: vi.fn(), - callTxMethod: vi.fn() + callTxMethod: vi.fn(), + disconnect: vi.fn() } as unknown as IPolkadotApi describe('send', () => { diff --git a/packages/sdk/src/pallets/xcmPallet/transfer.ts b/packages/sdk/src/pallets/xcmPallet/transfer.ts index 527604a8..8090b8f1 100644 --- a/packages/sdk/src/pallets/xcmPallet/transfer.ts +++ b/packages/sdk/src/pallets/xcmPallet/transfer.ts @@ -17,6 +17,7 @@ import { getNativeAssets, getRelayChainSymbol, hasSupportForAsset } from '../ass import { getNode, determineRelayChain } from '../../utils' import { isSymbolSpecifier } from '../../utils/assets/isSymbolSpecifier' import { isOverrideMultiLocationSpecifier } from '../../utils/multiLocation/isOverrideMultiLocationSpecifier' +import { isPjsClient } from '../../utils/isPjsClient' const sendCommon = async ( options: TSendOptions, @@ -178,56 +179,64 @@ const sendCommon = async ( await api.init(origin) - const amountStr = amount?.toString() - - if ('multilocation' in currency || 'multiasset' in currency) { - console.warn('Keep alive check is not supported when using MultiLocation as currency.') - } else if (typeof address === 'object') { - console.warn('Keep alive check is not supported when using MultiLocation as address.') - } else if (typeof destination === 'object') { - console.warn('Keep alive check is not supported when using MultiLocation as destination.') - } else if (destination === 'Ethereum') { - console.warn('Keep alive check is not supported when using Ethereum as origin or destination.') - } else if (!asset) { - console.warn('Keep alive check is not supported when asset check is disabled.') - } else { - await checkKeepAlive({ - originApi: api, - address, + try { + const amountStr = amount?.toString() + + if ('multilocation' in currency || 'multiasset' in currency) { + console.warn('Keep alive check is not supported when using MultiLocation as currency.') + } else if (typeof address === 'object') { + console.warn('Keep alive check is not supported when using MultiLocation as address.') + } else if (typeof destination === 'object') { + console.warn('Keep alive check is not supported when using MultiLocation as destination.') + } else if (destination === 'Ethereum') { + console.warn( + 'Keep alive check is not supported when using Ethereum as origin or destination.' + ) + } else if (!asset) { + console.warn('Keep alive check is not supported when asset check is disabled.') + } else { + await checkKeepAlive({ + originApi: api, + address, + amount: amountStr ?? '', + originNode: origin, + destApi: destApiForKeepAlive, + asset, + destNode: destination + }) + } + + // In case asset check is disabled, we create asset object from currency symbol + const resolvedAsset = + asset ?? + ({ + symbol: 'symbol' in currency ? currency.symbol : undefined + } as TNativeAsset) + + return originNode.transfer({ + api, + asset: resolvedAsset, amount: amountStr ?? '', - originNode: origin, - destApi: destApiForKeepAlive, - asset, - destNode: destination + address, + destination, + paraIdTo, + overridedCurrencyMultiLocation: + 'multilocation' in currency && isOverrideMultiLocationSpecifier(currency.multilocation) + ? currency.multilocation.value + : 'multiasset' in currency + ? currency.multiasset + : undefined, + feeAsset, + version, + destApiForKeepAlive, + serializedApiCallEnabled, + ahAddress }) + } finally { + if (isPjsClient(api)) { + await api.disconnect() + } } - - // In case asset check is disabled, we create asset object from currency symbol - const resolvedAsset = - asset ?? - ({ - symbol: 'symbol' in currency ? currency.symbol : undefined - } as TNativeAsset) - - return originNode.transfer({ - api, - asset: resolvedAsset, - amount: amountStr ?? '', - address, - destination, - paraIdTo, - overridedCurrencyMultiLocation: - 'multilocation' in currency && isOverrideMultiLocationSpecifier(currency.multilocation) - ? currency.multilocation.value - : 'multiasset' in currency - ? currency.multiasset - : undefined, - feeAsset, - version, - destApiForKeepAlive, - serializedApiCallEnabled, - ahAddress - }) } export const sendSerializedApiCall = async ( @@ -252,44 +261,50 @@ export const transferRelayToParaCommon = async ( await api.init(determineRelayChain(destination as TNode)) - const amountStr = amount.toString() + try { + const amountStr = amount.toString() + + if (isMultiLocationDestination) { + console.warn('Keep alive check is not supported when using MultiLocation as destination.') + } else if (isAddressMultiLocation) { + console.warn('Keep alive check is not supported when using MultiLocation as address.') + } else { + await checkKeepAlive({ + originApi: api, + address, + amount: amountStr, + destApi: destApiForKeepAlive, + asset: { symbol: getRelayChainSymbol(destination) }, + destNode: destination + }) + } - if (isMultiLocationDestination) { - console.warn('Keep alive check is not supported when using MultiLocation as destination.') - } else if (isAddressMultiLocation) { - console.warn('Keep alive check is not supported when using MultiLocation as address.') - } else { - await checkKeepAlive({ - originApi: api, + const serializedApiCall = getNode( + isMultiLocationDestination ? resolveTNodeFromMultiLocation(destination) : destination + ).transferRelayToPara({ + api, + destination, address, amount: amountStr, - destApi: destApiForKeepAlive, - asset: { symbol: getRelayChainSymbol(destination) }, - destNode: destination + paraIdTo, + destApiForKeepAlive, + version }) - } - const serializedApiCall = getNode( - isMultiLocationDestination ? resolveTNodeFromMultiLocation(destination) : destination - ).transferRelayToPara({ - api, - destination, - address, - amount: amountStr, - paraIdTo, - destApiForKeepAlive, - version - }) - - if (serializedApiCallEnabled) { - // Keep compatibility with old serialized call type - return { - ...serializedApiCall, - parameters: Object.values(serializedApiCall.parameters) + if (serializedApiCallEnabled) { + // Keep compatibility with old serialized call type + return { + ...serializedApiCall, + parameters: Object.values(serializedApiCall.parameters) + } } - } - return api.callTxMethod(serializedApiCall) + return api.callTxMethod(serializedApiCall) + } finally { + if (isPjsClient(api)) { + await api.disconnect() + } + } } export const transferRelayToPara = async ( diff --git a/packages/sdk/src/papi/PapiApi.test.ts b/packages/sdk/src/papi/PapiApi.test.ts index 52400f3a..60e7d001 100644 --- a/packages/sdk/src/papi/PapiApi.test.ts +++ b/packages/sdk/src/papi/PapiApi.test.ts @@ -46,6 +46,7 @@ describe('PapiApi', () => { mockPolkadotClient = { _request: vi.fn(), + destroy: vi.fn(), getUnsafeApi: vi.fn().mockReturnValue({ tx: { XcmPallet: { @@ -129,6 +130,7 @@ describe('PapiApi', () => { }) it('should create api instance when _api is undefined', async () => { + const papiApi = new PapiApi() papiApi.setApi(undefined) const mockCreateApiInstanceForNode = vi .spyOn(utils, 'createApiInstanceForNode') @@ -444,4 +446,39 @@ describe('PapiApi', () => { expect(apiInstance).toBeDefined() }) }) + + describe('disconnect', () => { + it('should disconnect the api when _api is a string', async () => { + const mockDisconnect = vi.spyOn(mockPolkadotClient, 'destroy').mockResolvedValue() + + papiApi.setApi('api') + await papiApi.disconnect() + + expect(mockDisconnect).toHaveBeenCalled() + + mockDisconnect.mockRestore() + }) + + it('should disconnect the api when _api is not provided', async () => { + const mockDisconnect = vi.spyOn(mockPolkadotClient, 'destroy').mockResolvedValue() + + papiApi.setApi(undefined) + await papiApi.disconnect() + + expect(mockDisconnect).toHaveBeenCalled() + + mockDisconnect.mockRestore() + }) + + it('should not disconnect the api when _api is provided', async () => { + const mockDisconnect = vi.spyOn(mockPolkadotClient, 'destroy').mockResolvedValue() + + papiApi.setApi(mockPolkadotClient) + await papiApi.disconnect() + + expect(mockDisconnect).not.toHaveBeenCalled() + + mockDisconnect.mockRestore() + }) + }) }) diff --git a/packages/sdk/src/papi/PapiApi.ts b/packages/sdk/src/papi/PapiApi.ts index fb498781..bb4fe70d 100644 --- a/packages/sdk/src/papi/PapiApi.ts +++ b/packages/sdk/src/papi/PapiApi.ts @@ -38,6 +38,8 @@ const unsupportedNodes = [ class PapiApi implements IPolkadotApi { private _api?: TPapiApiOrUrl private api: TPapiApi + private initialized = false + private disconnectAllowed = true setApi(api?: TPapiApiOrUrl): void { this._api = api @@ -48,6 +50,10 @@ class PapiApi implements IPolkadotApi { } async init(node: TNodeDotKsmWithRelayChains): Promise { + if (this.initialized) { + return + } + if (unsupportedNodes.includes(node)) { throw new NodeNotSupportedError(`The node ${node} is not yet supported by the Polkadot API.`) } @@ -57,6 +63,8 @@ class PapiApi implements IPolkadotApi { this.api = this._api ?? (await createApiInstanceForNode(this, node)) } + + this.initialized = true } async createApiInstance(wsUrl: string): Promise { @@ -175,6 +183,24 @@ class PapiApi implements IPolkadotApi { await api.init(node) return api } + + setDisconnectAllowed(allowed: boolean): void { + this.disconnectAllowed = allowed + } + + getDisconnectAllowed(): boolean { + return this.disconnectAllowed + } + + disconnect(): Promise { + if (!this.disconnectAllowed) return Promise.resolve() + + // Disconnect api only if it was created automatically + if (typeof this._api === 'string' || this._api === undefined) { + this.api.destroy() + } + return Promise.resolve() + } } export default PapiApi diff --git a/packages/sdk/src/pjs/PolkadotJsApi.test.ts b/packages/sdk/src/pjs/PolkadotJsApi.test.ts index 74c2a562..8c892c7c 100644 --- a/packages/sdk/src/pjs/PolkadotJsApi.test.ts +++ b/packages/sdk/src/pjs/PolkadotJsApi.test.ts @@ -52,7 +52,8 @@ describe('PolkadotJsApi', () => { entries: vi.fn() }) } - } + }, + disconnect: vi.fn() } as unknown as TPjsApi polkadotApi.setApi(mockApiPromise) await polkadotApi.init('Acala') @@ -60,6 +61,7 @@ describe('PolkadotJsApi', () => { describe('setApi and getApi', () => { it('should set and get the api', async () => { + const polkadotApi = new PolkadotJsApi() const newApi = {} as TPjsApi polkadotApi.setApi(newApi) await polkadotApi.init('Acala') @@ -70,6 +72,7 @@ describe('PolkadotJsApi', () => { describe('init', () => { it('should set api to _api when _api is defined', async () => { + const polkadotApi = new PolkadotJsApi() const mockApi = {} as TPjsApi polkadotApi.setApi(mockApi) await polkadotApi.init('Acala') @@ -77,6 +80,7 @@ describe('PolkadotJsApi', () => { }) it('should create api instance when _api is undefined', async () => { + const polkadotApi = new PolkadotJsApi() polkadotApi.setApi(undefined) const mockCreateApiInstanceForNode = vi .spyOn(utils, 'createApiInstanceForNode') @@ -470,4 +474,39 @@ describe('PolkadotJsApi', () => { mockCreateApiInstanceForNode.mockRestore() }) }) + + describe('disconnect', () => { + it('should disconnect the api when _api is a string', async () => { + const mockDisconnect = vi.spyOn(mockApiPromise, 'disconnect').mockResolvedValue() + + polkadotApi.setApi('api') + await polkadotApi.disconnect() + + expect(mockDisconnect).toHaveBeenCalled() + + mockDisconnect.mockRestore() + }) + + it('should disconnect the api when _api is not provided', async () => { + const mockDisconnect = vi.spyOn(mockApiPromise, 'disconnect').mockResolvedValue() + + polkadotApi.setApi(undefined) + await polkadotApi.disconnect() + + expect(mockDisconnect).toHaveBeenCalled() + + mockDisconnect.mockRestore() + }) + + it('should not disconnect the api when _api is provided', async () => { + const mockDisconnect = vi.spyOn(mockApiPromise, 'disconnect').mockResolvedValue() + + polkadotApi.setApi(mockApiPromise) + await polkadotApi.disconnect() + + expect(mockDisconnect).not.toHaveBeenCalled() + + mockDisconnect.mockRestore() + }) + }) }) diff --git a/packages/sdk/src/pjs/PolkadotJsApi.ts b/packages/sdk/src/pjs/PolkadotJsApi.ts index 4bd167d8..0be2946e 100644 --- a/packages/sdk/src/pjs/PolkadotJsApi.ts +++ b/packages/sdk/src/pjs/PolkadotJsApi.ts @@ -26,6 +26,8 @@ const snakeToCamel = (str: string) => class PolkadotJsApi implements IPolkadotApi { private _api?: TPjsApiOrUrl private api: TPjsApi + private initialized = false + private disconnectAllowed = true setApi(api?: TPjsApiOrUrl): void { this._api = api @@ -36,11 +38,17 @@ class PolkadotJsApi implements IPolkadotApi { } async init(node: TNodeDotKsmWithRelayChains): Promise { + if (this.initialized) { + return + } + if (typeof this._api === 'string') { this.api = await this.createApiInstance(this._api) } else { this.api = this._api ?? (await createApiInstanceForNode(this, node)) } + + this.initialized = true } async createApiInstance(wsUrl: string): Promise { @@ -160,6 +168,23 @@ class PolkadotJsApi implements IPolkadotApi { await api.init(node) return api } + + setDisconnectAllowed(allowed: boolean): void { + this.disconnectAllowed = allowed + } + + getDisconnectAllowed(): boolean { + return this.disconnectAllowed + } + + async disconnect(): Promise { + if (!this.disconnectAllowed) return + + // Disconnect api only if it was created automatically + if (typeof this._api === 'string' || this._api === undefined) { + await this.api.disconnect() + } + } } export default PolkadotJsApi diff --git a/packages/sdk/src/utils/isPjsClient.test.ts b/packages/sdk/src/utils/isPjsClient.test.ts new file mode 100644 index 00000000..942e1825 --- /dev/null +++ b/packages/sdk/src/utils/isPjsClient.test.ts @@ -0,0 +1,38 @@ +import { describe, it, expect } from 'vitest' +import { isPjsClient } from './isPjsClient' + +describe('isPjsClient', () => { + it('should return true for an object with a disconnect function', () => { + const mockApi = { + disconnect: async () => {} + } + expect(isPjsClient(mockApi)).toBe(true) + }) + + it('should return false for an object without a disconnect property', () => { + const mockApi = { + someOtherProperty: 'value' + } + expect(isPjsClient(mockApi)).toBe(false) + }) + + it('should return false for an object where disconnect is not a function', () => { + const mockApi = { + disconnect: 'not a function' + } + expect(isPjsClient(mockApi)).toBe(false) + }) + + it('should return false for null', () => { + expect(isPjsClient(null)).toBe(false) + }) + + it('should return false for a non-object type (e.g., number)', () => { + expect(isPjsClient(42)).toBe(false) + }) + + it('should return false for an object where disconnect is missing', () => { + const mockApi = {} + expect(isPjsClient(mockApi)).toBe(false) + }) +}) diff --git a/packages/sdk/src/utils/isPjsClient.ts b/packages/sdk/src/utils/isPjsClient.ts new file mode 100644 index 00000000..335893da --- /dev/null +++ b/packages/sdk/src/utils/isPjsClient.ts @@ -0,0 +1,8 @@ +export const isPjsClient = (api: unknown): api is { disconnect: () => Promise } => { + return ( + typeof api === 'object' && + api !== null && + 'disconnect' in api && + typeof api.disconnect === 'function' + ) +}