diff --git a/.changeset/sour-monkeys-type.md b/.changeset/sour-monkeys-type.md new file mode 100644 index 000000000..212c34683 --- /dev/null +++ b/.changeset/sour-monkeys-type.md @@ -0,0 +1,5 @@ +--- +'@metaplex-foundation/kinobi': minor +--- + +Create PDA Seed nodes diff --git a/src/nodes/AccountNode.ts b/src/nodes/AccountNode.ts index 1d58de4f3..d685075be 100644 --- a/src/nodes/AccountNode.ts +++ b/src/nodes/AccountNode.ts @@ -1,7 +1,6 @@ import type { IdlAccount } from '../idl'; import { AccountDiscriminator, - AccountSeed, InvalidKinobiTreeError, MainCaseString, PartialExcept, @@ -9,6 +8,12 @@ import { } from '../shared'; import { AccountDataNode, accountDataNode } from './AccountDataNode'; import type { Node } from './Node'; +import { + PdaSeedNode, + constantPdaSeedNode, + programIdPdaSeedNode, + variablePdaSeedNode, +} from './pdaSeedNodes'; import { remainderSizeNode } from './sizeNodes'; import { bytesTypeNode } from './typeNodes/BytesTypeNode'; import { stringTypeNode } from './typeNodes/StringTypeNode'; @@ -28,7 +33,7 @@ export type AccountNode = { readonly docs: string[]; readonly internal: boolean; readonly size?: number | null; - readonly seeds: AccountSeed[]; + readonly seeds: PdaSeedNode[]; readonly discriminator?: AccountDiscriminator; }; @@ -62,7 +67,7 @@ export function accountNodeFromIdl(idl: Partial): AccountNode { const idlStruct = idl.type ?? { kind: 'struct', fields: [] }; const struct = createTypeNodeFromIdl(idlStruct); assertStructTypeNode(struct); - const seeds = (idl.seeds ?? []).map((seed): AccountSeed => { + const seeds = (idl.seeds ?? []).map((seed): PdaSeedNode => { if (seed.kind === 'constant') { const value = (() => { if (typeof seed.value === 'string') return stringValueNode(seed.value); @@ -77,17 +82,16 @@ export function accountNodeFromIdl(idl: Partial): AccountNode { } else { type = createTypeNodeFromIdl(seed.type); } - return { ...seed, type, value }; + return constantPdaSeedNode(type, value); } if (seed.kind === 'variable') { - return { - ...seed, - name: mainCase(seed.name), - type: createTypeNodeFromIdl(seed.type), - docs: seed.description ? [seed.description] : [], - }; + return variablePdaSeedNode( + seed.name, + createTypeNodeFromIdl(seed.type), + seed.description ? [seed.description] : [] + ); } - return { kind: 'programId' }; + return programIdPdaSeedNode(); }); return accountNode({ name, diff --git a/src/nodes/Node.ts b/src/nodes/Node.ts index 9035f41c5..c9d3787a0 100644 --- a/src/nodes/Node.ts +++ b/src/nodes/Node.ts @@ -8,6 +8,7 @@ import type { InstructionExtraArgsNode } from './InstructionExtraArgsNode'; import type { InstructionNode } from './InstructionNode'; import type { ProgramNode } from './ProgramNode'; import type { RootNode } from './RootNode'; +import { REGISTERED_PDA_SEED_NODES } from './pdaSeedNodes'; import { REGISTERED_SIZE_NODES } from './sizeNodes'; import { REGISTERED_TYPE_NODES } from './typeNodes'; import { REGISTERED_VALUE_NODES } from './valueNodes'; @@ -25,6 +26,7 @@ const REGISTERED_NODES = { definedTypeNode: {} as DefinedTypeNode, // Groups. + ...REGISTERED_PDA_SEED_NODES, ...REGISTERED_SIZE_NODES, ...REGISTERED_TYPE_NODES, ...REGISTERED_VALUE_NODES, diff --git a/src/nodes/index.ts b/src/nodes/index.ts index 51dd31d8e..da53c713a 100644 --- a/src/nodes/index.ts +++ b/src/nodes/index.ts @@ -10,6 +10,7 @@ export * from './Node'; export * from './ProgramNode'; export * from './RootNode'; +export * from './pdaSeedNodes'; export * from './sizeNodes'; export * from './typeNodes'; export * from './valueNodes'; diff --git a/src/nodes/pdaSeedNodes/ConstantPdaSeedNode.ts b/src/nodes/pdaSeedNodes/ConstantPdaSeedNode.ts new file mode 100644 index 000000000..245f8e60a --- /dev/null +++ b/src/nodes/pdaSeedNodes/ConstantPdaSeedNode.ts @@ -0,0 +1,43 @@ +import { Node } from '../Node'; +import { remainderSizeNode } from '../sizeNodes'; +import { TypeNode, stringTypeNode } from '../typeNodes'; +import { ValueNode, stringValueNode } from '../valueNodes'; + +export type ConstantPdaSeedNode = { + readonly kind: 'constantPdaSeedNode'; + readonly type: TypeNode; + readonly value: ValueNode; +}; + +export function constantPdaSeedNode( + type: TypeNode, + value: ValueNode +): ConstantPdaSeedNode { + return { kind: 'constantPdaSeedNode', type, value }; +} + +export function constantPdaSeedNodeFromString( + value: string +): ConstantPdaSeedNode { + return { + kind: 'constantPdaSeedNode', + type: stringTypeNode({ size: remainderSizeNode() }), + value: stringValueNode(value), + }; +} + +export function isConstantPdaSeedNode( + node: Node | null +): node is ConstantPdaSeedNode { + return !!node && node.kind === 'constantPdaSeedNode'; +} + +export function assertConstantPdaSeedNode( + node: Node | null +): asserts node is ConstantPdaSeedNode { + if (!isConstantPdaSeedNode(node)) { + throw new Error( + `Expected constantPdaSeedNode, got ${node?.kind ?? 'null'}.` + ); + } +} diff --git a/src/nodes/pdaSeedNodes/PdaSeedNode.ts b/src/nodes/pdaSeedNodes/PdaSeedNode.ts new file mode 100644 index 000000000..6d93695d8 --- /dev/null +++ b/src/nodes/pdaSeedNodes/PdaSeedNode.ts @@ -0,0 +1,32 @@ +import { Node } from '../Node'; +import type { ConstantPdaSeedNode } from './ConstantPdaSeedNode'; +import type { ProgramIdPdaSeedNode } from './ProgramIdPdaSeedNode'; +import type { VariablePdaSeedNode } from './VariablePdaSeedNode'; + +export const REGISTERED_PDA_SEED_NODES = { + constantPdaSeedNode: {} as ConstantPdaSeedNode, + programIdPdaSeedNode: {} as ProgramIdPdaSeedNode, + variablePdaSeedNode: {} as VariablePdaSeedNode, +}; + +export const REGISTERED_PDA_SEED_NODE_KEYS = Object.keys( + REGISTERED_PDA_SEED_NODES +) as (keyof typeof REGISTERED_PDA_SEED_NODES)[]; + +export type RegisteredPdaSeedNodes = typeof REGISTERED_PDA_SEED_NODES; + +export type PdaSeedNode = RegisteredPdaSeedNodes[keyof RegisteredPdaSeedNodes]; + +export function isPdaSeedNode(node: Node | null): node is PdaSeedNode { + return ( + !!node && (REGISTERED_PDA_SEED_NODE_KEYS as string[]).includes(node.kind) + ); +} + +export function assertPdaSeedNode( + node: Node | null +): asserts node is PdaSeedNode { + if (!isPdaSeedNode(node)) { + throw new Error(`Expected typeNode, got ${node?.kind ?? 'null'}.`); + } +} diff --git a/src/nodes/pdaSeedNodes/ProgramIdPdaSeedNode.ts b/src/nodes/pdaSeedNodes/ProgramIdPdaSeedNode.ts new file mode 100644 index 000000000..e2fa745a1 --- /dev/null +++ b/src/nodes/pdaSeedNodes/ProgramIdPdaSeedNode.ts @@ -0,0 +1,25 @@ +import { Node } from '../Node'; + +export type ProgramIdPdaSeedNode = { + readonly kind: 'programIdPdaSeedNode'; +}; + +export function programIdPdaSeedNode(): ProgramIdPdaSeedNode { + return { kind: 'programIdPdaSeedNode' }; +} + +export function isProgramIdPdaSeedNode( + node: Node | null +): node is ProgramIdPdaSeedNode { + return !!node && node.kind === 'programIdPdaSeedNode'; +} + +export function assertProgramIdPdaSeedNode( + node: Node | null +): asserts node is ProgramIdPdaSeedNode { + if (!isProgramIdPdaSeedNode(node)) { + throw new Error( + `Expected programIdPdaSeedNode, got ${node?.kind ?? 'null'}.` + ); + } +} diff --git a/src/nodes/pdaSeedNodes/VariablePdaSeedNode.ts b/src/nodes/pdaSeedNodes/VariablePdaSeedNode.ts new file mode 100644 index 000000000..1426f0239 --- /dev/null +++ b/src/nodes/pdaSeedNodes/VariablePdaSeedNode.ts @@ -0,0 +1,39 @@ +import { MainCaseString, mainCase } from '../../shared'; +import { Node } from '../Node'; +import { TypeNode } from '../typeNodes'; + +export type VariablePdaSeedNode = { + readonly kind: 'variablePdaSeedNode'; + readonly name: MainCaseString; + readonly type: TypeNode; + readonly docs: string[]; +}; + +export function variablePdaSeedNode( + name: string, + type: TypeNode, + docs: string | string[] = [] +): VariablePdaSeedNode { + return { + kind: 'variablePdaSeedNode', + name: mainCase(name), + type, + docs: Array.isArray(docs) ? docs : [docs], + }; +} + +export function isVariablePdaSeedNode( + node: Node | null +): node is VariablePdaSeedNode { + return !!node && node.kind === 'variablePdaSeedNode'; +} + +export function assertVariablePdaSeedNode( + node: Node | null +): asserts node is VariablePdaSeedNode { + if (!isVariablePdaSeedNode(node)) { + throw new Error( + `Expected variablePdaSeedNode, got ${node?.kind ?? 'null'}.` + ); + } +} diff --git a/src/nodes/pdaSeedNodes/index.ts b/src/nodes/pdaSeedNodes/index.ts new file mode 100644 index 000000000..bfa30da61 --- /dev/null +++ b/src/nodes/pdaSeedNodes/index.ts @@ -0,0 +1,4 @@ +export * from './ConstantPdaSeedNode'; +export * from './PdaSeedNode'; +export * from './ProgramIdPdaSeedNode'; +export * from './VariablePdaSeedNode'; diff --git a/src/renderers/js-experimental/fragments/accountPdaHelpers.njk b/src/renderers/js-experimental/fragments/accountPdaHelpers.njk index d0aa89350..cadaad6ae 100644 --- a/src/renderers/js-experimental/fragments/accountPdaHelpers.njk +++ b/src/renderers/js-experimental/fragments/accountPdaHelpers.njk @@ -3,7 +3,7 @@ {% if hasVariableSeeds %} export type {{ accountSeedsType }} = { {% for seed in seeds %} - {% if seed.kind === 'variable' %} + {% if seed.kind === 'variablePdaSeedNode' %} {{ macros.docblock(seed.docs) }} {{ seed.name | camelCase }}: {{ seed.typeManifest.looseType.render }}; {% endif %} @@ -20,9 +20,9 @@ export async function {{ findPdaFunction }}( const { programAddress = '{{ program.publicKey }}' as Address<'{{ program.publicKey }}'> } = config; return getProgramDerivedAddress({ programAddress, seeds: [ {% for seed in seeds %} - {% if seed.kind === 'programId' %} + {% if seed.kind === 'programIdPdaSeedNode' %} getAddressEncoder().encode(programAddress), - {% elif seed.kind === 'constant' %} + {% elif seed.kind === 'constantPdaSeedNode' %} {{ seed.typeManifest.encoder.render }}.encode({{ seed.value.render }}), {% else %} {{ seed.typeManifest.encoder.render }}.encode(seeds.{{ seed.name | camelCase }}), diff --git a/src/renderers/js-experimental/fragments/accountPdaHelpers.ts b/src/renderers/js-experimental/fragments/accountPdaHelpers.ts index f66ce1fa5..81365b08e 100644 --- a/src/renderers/js-experimental/fragments/accountPdaHelpers.ts +++ b/src/renderers/js-experimental/fragments/accountPdaHelpers.ts @@ -1,4 +1,10 @@ -import { AccountNode, ProgramNode, RegisteredTypeNodes } from '../../../nodes'; +import { + AccountNode, + ProgramNode, + RegisteredTypeNodes, + isConstantPdaSeedNode, + isVariablePdaSeedNode, +} from '../../../nodes'; import { Visitor, visit } from '../../../visitors'; import { ImportMap } from '../ImportMap'; import { TypeManifest } from '../TypeManifest'; @@ -20,7 +26,7 @@ export function getAccountPdaHelpersFragment(scope: { // Seeds. const imports = new ImportMap(); const seeds = accountNode.seeds.map((seed) => { - if (seed.kind === 'constant') { + if (isConstantPdaSeedNode(seed)) { const seedManifest = visit(seed.type, typeManifestVisitor); imports.mergeWith(seedManifest.encoder); const seedValue = seed.value; @@ -29,7 +35,7 @@ export function getAccountPdaHelpersFragment(scope: { imports.mergeWith(valueManifest.imports); return { ...seed, typeManifest: seedManifest }; } - if (seed.kind === 'variable') { + if (isVariablePdaSeedNode(seed)) { const seedManifest = visit(seed.type, typeManifestVisitor); imports.mergeWith(seedManifest.looseType, seedManifest.encoder); return { ...seed, typeManifest: seedManifest }; @@ -38,7 +44,7 @@ export function getAccountPdaHelpersFragment(scope: { return seed; }); const hasVariableSeeds = - accountNode.seeds.filter((seed) => seed.kind === 'variable').length > 0; + accountNode.seeds.filter(isVariablePdaSeedNode).length > 0; return fragmentFromTemplate('accountPdaHelpers.njk', { accountType: nameApi.accountType(accountNode.name), diff --git a/src/renderers/js/getRenderMapVisitor.ts b/src/renderers/js/getRenderMapVisitor.ts index bc5dea250..8f6b1c02a 100644 --- a/src/renderers/js/getRenderMapVisitor.ts +++ b/src/renderers/js/getRenderMapVisitor.ts @@ -8,8 +8,10 @@ import { getAllDefinedTypes, getAllInstructionsWithSubs, InstructionNode, + isConstantPdaSeedNode, isDataEnum, isEnumTypeNode, + isVariablePdaSeedNode, ProgramNode, } from '../../nodes'; import { @@ -328,7 +330,7 @@ export function getRenderMapVisitor( // Seeds. const seeds = node.seeds.map((seed) => { - if (seed.kind === 'constant') { + if (isConstantPdaSeedNode(seed)) { const seedManifest = visit(seed.type, typeManifestVisitor); imports.mergeWith(seedManifest.serializerImports); const seedValue = seed.value; @@ -337,7 +339,7 @@ export function getRenderMapVisitor( imports.mergeWith(valueManifest.imports); return { ...seed, typeManifest: seedManifest }; } - if (seed.kind === 'variable') { + if (isVariablePdaSeedNode(seed)) { const seedManifest = visit(seed.type, typeManifestVisitor); imports.mergeWith( seedManifest.looseImports, @@ -354,7 +356,7 @@ export function getRenderMapVisitor( imports.add('umi', ['Pda']); } const hasVariableSeeds = - node.seeds.filter((seed) => seed.kind === 'variable').length > 0; + node.seeds.filter(isVariablePdaSeedNode).length > 0; return new RenderMap().add( `accounts/${camelCase(node.name)}.ts`, diff --git a/src/renderers/js/templates/accountsPage.njk b/src/renderers/js/templates/accountsPage.njk index 4ead9ace8..904be8935 100644 --- a/src/renderers/js/templates/accountsPage.njk +++ b/src/renderers/js/templates/accountsPage.njk @@ -89,7 +89,7 @@ export function find{{ account.name | pascalCase }}Pda( {% if hasVariableSeeds %} seeds: { {% for seed in seeds %} - {% if seed.kind === 'variable' %} + {% if seed.kind === 'variablePdaSeedNode' %} {{ macros.docblock(seed.docs) }} {{ seed.name | camelCase }}: {{ seed.typeManifest.looseType }}; {% endif %} @@ -100,9 +100,9 @@ export function find{{ account.name | pascalCase }}Pda( const programId = context.programs.getPublicKey('{{ program.name | camelCase }}', '{{ program.publicKey }}'); return context.eddsa.findPda(programId, [ {% for seed in seeds %} - {% if seed.kind === 'programId' %} + {% if seed.kind === 'programIdPdaSeedNode' %} publicKeySerializer().serialize(programId), - {% elif seed.kind === 'constant' %} + {% elif seed.kind === 'constantPdaSeedNode' %} {{ seed.typeManifest.serializer }}.serialize({{ seed.value.render }}), {% else %} {{ seed.typeManifest.serializer }}.serialize(seeds.{{ seed.name | camelCase }}), diff --git a/src/renderers/rust/getRenderMapVisitor.ts b/src/renderers/rust/getRenderMapVisitor.ts index 737a0992b..a48b25d6d 100644 --- a/src/renderers/rust/getRenderMapVisitor.ts +++ b/src/renderers/rust/getRenderMapVisitor.ts @@ -5,7 +5,9 @@ import { getAllAccounts, getAllDefinedTypes, getAllInstructionsWithSubs, + isConstantPdaSeedNode, isOptionTypeNode, + isVariablePdaSeedNode, } from '../../nodes'; import { ImportFrom, @@ -135,7 +137,7 @@ export function getRenderMapVisitor(options: GetRustRenderMapOptions = {}) { // Seeds. const seedsImports = new RustImportMap(); const seeds = node.seeds.map((seed) => { - if (seed.kind === 'constant') { + if (isConstantPdaSeedNode(seed)) { const seedManifest = visit(seed.type, typeManifestVisitor); const seedValue = seed.value; const valueManifest = renderValueNode(seedValue, true); @@ -143,7 +145,7 @@ export function getRenderMapVisitor(options: GetRustRenderMapOptions = {}) { seedsImports.mergeWith(valueManifest.imports); return { ...seed, typeManifest: seedManifest }; } - if (seed.kind === 'variable') { + if (isVariablePdaSeedNode(seed)) { const seedManifest = visit(seed.type, typeManifestVisitor); seedsImports.mergeWith(seedManifest.imports); return { ...seed, typeManifest: seedManifest }; @@ -151,7 +153,7 @@ export function getRenderMapVisitor(options: GetRustRenderMapOptions = {}) { return seed; }); const hasVariableSeeds = - node.seeds.filter((seed) => seed.kind === 'variable').length > 0; + node.seeds.filter(isVariablePdaSeedNode).length > 0; const { imports } = typeManifest; diff --git a/src/renderers/rust/templates/accountsPage.njk b/src/renderers/rust/templates/accountsPage.njk index 0232c8580..6ec3a889c 100644 --- a/src/renderers/rust/templates/accountsPage.njk +++ b/src/renderers/rust/templates/accountsPage.njk @@ -21,7 +21,7 @@ impl {{ account.name | pascalCase }} { pub fn create_pda( {% if hasVariableSeeds %} {% for seed in seeds %} - {% if seed.kind === 'variable' %} + {% if seed.kind === 'variablePdaSeedNode' %} {{ seed.name | snakeCase }}: {{ seed.typeManifest.type }}, {% endif %} {% endfor %} @@ -31,11 +31,11 @@ impl {{ account.name | pascalCase }} { solana_program::pubkey::Pubkey::create_program_address( &[ {% for seed in seeds %} - {% if seed.kind === 'programId' %} + {% if seed.kind === 'programIdPdaSeedNode' %} crate::{{ program.name | snakeCase | upper }}_ID.as_ref(), - {% elif seed.kind === 'constant' %} + {% elif seed.kind === 'constantPdaSeedNode' %} {{ seed.value.render }}.as_bytes(), - {% elif seed.kind == 'variable' and seed.type.kind == 'publicKeyTypeNode' %} + {% elif seed.kind == 'variablePdaSeedNode' and seed.type.kind == 'publicKeyTypeNode' %} {{ seed.name | snakeCase }}.as_ref(), {% else %} {{ seed.name | snakeCase }}.to_string().as_ref(), @@ -50,9 +50,9 @@ impl {{ account.name | pascalCase }} { pub fn find_pda( {% if hasVariableSeeds %} {% for seed in seeds %} - {% if seed.kind == 'variable' and seed.type.kind == 'publicKeyTypeNode' %} + {% if seed.kind == 'variablePdaSeedNode' and seed.type.kind == 'publicKeyTypeNode' %} {{ seed.name | snakeCase }}: &{{ seed.typeManifest.type }}, - {% elif seed.kind === 'variable' %} + {% elif seed.kind === 'variablePdaSeedNode' %} {{ seed.name | snakeCase }}: {{ seed.typeManifest.type }}, {% endif %} {% endfor %} @@ -61,11 +61,11 @@ impl {{ account.name | pascalCase }} { solana_program::pubkey::Pubkey::find_program_address( &[ {% for seed in seeds %} - {% if seed.kind === 'programId' %} + {% if seed.kind === 'programIdPdaSeedNode' %} crate::{{ program.name | snakeCase | upper }}_ID.as_ref(), - {% elif seed.kind === 'constant' %} + {% elif seed.kind === 'constantPdaSeedNode' %} {{ seed.value.render }}.as_bytes(), - {% elif seed.kind == 'variable' and seed.type.kind == 'publicKeyTypeNode' %} + {% elif seed.kind == 'variablePdaSeedNode' and seed.type.kind == 'publicKeyTypeNode' %} {{ seed.name | snakeCase }}.as_ref(), {% else %} {{ seed.name | snakeCase }}.to_string().as_ref(), diff --git a/src/shared/AccountSeed.ts b/src/shared/AccountSeed.ts deleted file mode 100644 index 906622bed..000000000 --- a/src/shared/AccountSeed.ts +++ /dev/null @@ -1,49 +0,0 @@ -import { - TypeNode, - ValueNode, - publicKeyTypeNode, - remainderSizeNode, - stringTypeNode, - stringValueNode, -} from '../nodes'; -import { MainCaseString, mainCase } from './utils'; - -export type AccountSeed = - | { kind: 'programId' } - | { kind: 'constant'; type: TypeNode; value: ValueNode } - | { kind: 'variable'; name: MainCaseString; type: TypeNode; docs: string[] }; - -export const programSeed = (): AccountSeed => ({ kind: 'programId' }); - -export const constantSeed = ( - type: TypeNode, - value: ValueNode -): AccountSeed => ({ kind: 'constant', type, value }); - -export const stringConstantSeed = (value: string): AccountSeed => - constantSeed( - stringTypeNode({ size: remainderSizeNode() }), - stringValueNode(value) - ); - -export const variableSeed = ( - name: string, - type: TypeNode, - docs: string | string[] = [] -): AccountSeed => ({ - kind: 'variable', - name: mainCase(name), - type, - docs: Array.isArray(docs) ? docs : [docs], -}); - -export const publicKeySeed = ( - name: string, - docs: string | string[] = [] -): AccountSeed => variableSeed(name, publicKeyTypeNode(), docs); - -export const stringSeed = ( - name: string, - docs: string | string[] = [] -): AccountSeed => - variableSeed(name, stringTypeNode({ size: remainderSizeNode() }), docs); diff --git a/src/shared/InstructionDefault.ts b/src/shared/InstructionDefault.ts index 0392a6bdf..10182bc45 100644 --- a/src/shared/InstructionDefault.ts +++ b/src/shared/InstructionDefault.ts @@ -1,4 +1,9 @@ -import { AccountNode, ValueNode, isPublicKeyTypeNode } from '../nodes'; +import { + AccountNode, + ValueNode, + isPublicKeyTypeNode, + isVariablePdaSeedNode, +} from '../nodes'; import { ImportFrom } from './ImportFrom'; import { MainCaseString, mainCase } from './utils'; @@ -188,7 +193,7 @@ export const getDefaultSeedsFromAccount = ( node: AccountNode ): Record => node.seeds.reduce((acc, seed) => { - if (seed.kind !== 'variable') return acc; + if (!isVariablePdaSeedNode(seed)) return acc; if (isPublicKeyTypeNode(seed.type)) { acc[seed.name] = { kind: 'account', name: seed.name }; } else { diff --git a/src/shared/index.ts b/src/shared/index.ts index 369d09e18..9972e14a3 100644 --- a/src/shared/index.ts +++ b/src/shared/index.ts @@ -1,5 +1,4 @@ export * from './AccountDiscriminator'; -export * from './AccountSeed'; export * from './AnchorDiscriminator'; export * from './BytesCreatedOnChain'; export * from './GpaField'; diff --git a/src/visitors/identityVisitor.ts b/src/visitors/identityVisitor.ts index fd833830d..aa13cfb5f 100644 --- a/src/visitors/identityVisitor.ts +++ b/src/visitors/identityVisitor.ts @@ -18,6 +18,7 @@ import { assertInstructionNode, assertLinkTypeNode, assertNumberTypeNode, + assertPdaSeedNode, assertProgramNode, assertSizeNode, assertStructFieldTypeNode, @@ -27,6 +28,7 @@ import { assertValueNode, booleanTypeNode, bytesTypeNode, + constantPdaSeedNode, dateTimeTypeNode, definedTypeNode, enumStructVariantTypeNode, @@ -55,8 +57,8 @@ import { structValueNode, tupleTypeNode, tupleValueNode, + variablePdaSeedNode, } from '../nodes'; -import { AccountSeed } from '../shared'; import { staticVisitor } from './staticVisitor'; import { Visitor, visit as baseVisit } from './visitor'; @@ -111,14 +113,8 @@ export function identityVisitor< if (data === null) return null; assertAccountDataNode(data); const seeds = node.seeds - .map((seed) => { - if (seed.kind !== 'variable') return seed; - const newType = visit(this)(seed.type); - if (newType === null) return null; - assertTypeNode(newType); - return { ...seed, type: newType }; - }) - .filter((s): s is AccountSeed => s !== null); + .map((type) => visit(this)(type)) + .filter(removeNullAndAssertNodeFilter(assertPdaSeedNode)); return accountNode({ ...node, data, seeds }); }; } @@ -447,5 +443,26 @@ export function identityVisitor< }; } + if (castedNodeKeys.includes('constantPdaSeedNode')) { + visitor.visitConstantPdaSeed = function visitConstantPdaSeed(node) { + const type = visit(this)(node.type); + if (type === null) return null; + assertTypeNode(type); + const value = visit(this)(node.value); + if (value === null) return null; + assertValueNode(value); + return constantPdaSeedNode(type, value); + }; + } + + if (castedNodeKeys.includes('variablePdaSeedNode')) { + visitor.visitVariablePdaSeed = function visitVariablePdaSeed(node) { + const type = visit(this)(node.type); + if (type === null) return null; + assertTypeNode(type); + return variablePdaSeedNode(node.name, type, node.docs); + }; + } + return visitor as Visitor; } diff --git a/src/visitors/mergeVisitor.ts b/src/visitors/mergeVisitor.ts index d198ac3bd..f1856bdae 100644 --- a/src/visitors/mergeVisitor.ts +++ b/src/visitors/mergeVisitor.ts @@ -38,10 +38,7 @@ export function mergeVisitor< visitor.visitAccount = function visitAccount(node) { return merge(node, [ ...visit(this)(node.data), - ...node.seeds.flatMap((seed) => { - if (seed.kind !== 'variable') return []; - return visit(this)(seed.type); - }), + ...node.seeds.flatMap(visit(this)), ]); }; } @@ -261,5 +258,20 @@ export function mergeVisitor< }; } + if (castedNodeKeys.includes('constantPdaSeedNode')) { + visitor.visitConstantPdaSeed = function visitConstantPdaSeed(node) { + return merge(node, [ + ...visit(this)(node.type), + ...visit(this)(node.value), + ]); + }; + } + + if (castedNodeKeys.includes('variablePdaSeedNode')) { + visitor.visitVariablePdaSeed = function visitVariablePdaSeed(node) { + return merge(node, visit(this)(node.type)); + }; + } + return visitor as Visitor; } diff --git a/test/testFile.cjs b/test/testFile.cjs index f34c63fa7..56a65f540 100644 --- a/test/testFile.cjs +++ b/test/testFile.cjs @@ -1,4 +1,7 @@ const k = require('../dist/cjs/index.js'); +const { + publicKeyTypeNode, +} = require('../dist/cjs/nodes/typeNodes/PublicKeyTypeNode.js'); const kinobi = k.createFromIdls([ __dirname + '/spl_system.json', @@ -20,9 +23,9 @@ kinobi.update( Metadata: { size: 679 }, MasterEditionV1: { seeds: [ - k.stringConstantSeed('metadata'), - k.programSeed(), - k.variableSeed( + k.constantPdaSeedNodeFromString('metadata'), + k.programIdPdaSeedNode(), + k.variablePdaSeedNode( 'delegateRole', k.linkTypeNode('delegateRole'), 'The role of the delegate' @@ -32,18 +35,22 @@ kinobi.update( MasterEditionV2: { size: 282, seeds: [ - k.stringConstantSeed('metadata'), - k.programSeed(), - k.publicKeySeed('mint', 'The address of the mint account'), - k.stringConstantSeed('edition'), + k.constantPdaSeedNodeFromString('metadata'), + k.programIdPdaSeedNode(), + k.variablePdaSeedNode( + 'mint', + publicKeyTypeNode(), + 'The address of the mint account' + ), + k.constantPdaSeedNodeFromString('edition'), ], }, delegateRecord: { size: 282, seeds: [ - k.stringConstantSeed('delegate_record'), - k.programSeed(), - k.variableSeed( + k.constantPdaSeedNodeFromString('delegate_record'), + k.programIdPdaSeedNode(), + k.variablePdaSeedNode( 'role', k.linkTypeNode('delegateRole'), 'The delegate role' @@ -51,7 +58,10 @@ kinobi.update( ], }, FrequencyAccount: { - seeds: [k.stringConstantSeed('frequency_pda'), k.programSeed()], + seeds: [ + k.constantPdaSeedNodeFromString('frequency_pda'), + k.programIdPdaSeedNode(), + ], }, }) );