Skip to content

Commit

Permalink
Rename ByteDiscriminatorNode to ConstantDiscriminatorNode and use Con…
Browse files Browse the repository at this point in the history
…stantValueNode (#203)
  • Loading branch information
lorisleiva authored Apr 11, 2024
1 parent a56e919 commit 484da01
Show file tree
Hide file tree
Showing 15 changed files with 150 additions and 100 deletions.
5 changes: 5 additions & 0 deletions .changeset/quiet-peaches-move.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@metaplex-foundation/kinobi": minor
---

Rename ByteDiscriminatorNode to ConstantDiscriminatorNode and use ConstantValueNode
26 changes: 0 additions & 26 deletions src/nodes/discriminatorNodes/ByteDiscriminatorNode.ts

This file was deleted.

20 changes: 20 additions & 0 deletions src/nodes/discriminatorNodes/ConstantDiscriminatorNode.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import { ConstantValueNode } from '../valueNodes';

export interface ConstantDiscriminatorNode<
TConstant extends ConstantValueNode = ConstantValueNode,
> {
readonly kind: 'constantDiscriminatorNode';

// Children.
readonly constant: TConstant;

// Data.
readonly offset: number;
}

export function constantDiscriminatorNode<TConstant extends ConstantValueNode>(
constant: TConstant,
offset: number = 0
): ConstantDiscriminatorNode {
return { kind: 'constantDiscriminatorNode', constant, offset };
}
6 changes: 3 additions & 3 deletions src/nodes/discriminatorNodes/DiscriminatorNode.ts
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import type { ByteDiscriminatorNode } from './ByteDiscriminatorNode';
import type { ConstantDiscriminatorNode } from './ConstantDiscriminatorNode';
import type { FieldDiscriminatorNode } from './FieldDiscriminatorNode';
import type { SizeDiscriminatorNode } from './SizeDiscriminatorNode';

// Discriminator Node Registration.
export type RegisteredDiscriminatorNode =
| ByteDiscriminatorNode
| ConstantDiscriminatorNode
| FieldDiscriminatorNode
| SizeDiscriminatorNode;
export const REGISTERED_DISCRIMINATOR_NODE_KINDS = [
'byteDiscriminatorNode',
'constantDiscriminatorNode',
'fieldDiscriminatorNode',
'sizeDiscriminatorNode',
] satisfies readonly RegisteredDiscriminatorNode['kind'][];
Expand Down
2 changes: 1 addition & 1 deletion src/nodes/discriminatorNodes/index.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
export * from './ByteDiscriminatorNode';
export * from './ConstantDiscriminatorNode';
export * from './DiscriminatorNode';
export * from './FieldDiscriminatorNode';
export * from './SizeDiscriminatorNode';
64 changes: 36 additions & 28 deletions src/renderers/js-experimental/fragments/discriminatorCondition.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import { getBase64Decoder } from '@solana/codecs-strings';
import {
constantDiscriminatorNode,
constantValueNode,
constantValueNodeFromBytes,
isNode,
type ByteDiscriminatorNode,
isNodeFilter,
type ConstantDiscriminatorNode,
type DiscriminatorNode,
type FieldDiscriminatorNode,
type ProgramNode,
type SizeDiscriminatorNode,
type StructTypeNode,
isNodeFilter,
byteDiscriminatorNode,
} from '../../../nodes';
import { InvalidKinobiTreeError } from '../../../shared';
import { visit } from '../../../visitors';
Expand Down Expand Up @@ -41,10 +44,10 @@ export function getDiscriminatorConditionFragment(
return mergeFragments(
scope.discriminators.flatMap((discriminator) => {
if (isNode(discriminator, 'sizeDiscriminatorNode')) {
return [getSizeConditionFragment(discriminator, scope.dataName)];
return [getSizeConditionFragment(discriminator, scope)];
}
if (isNode(discriminator, 'byteDiscriminatorNode')) {
return [getByteConditionFragment(discriminator, scope.dataName)];
if (isNode(discriminator, 'constantDiscriminatorNode')) {
return [getByteConditionFragment(discriminator, scope)];
}
if (isNode(discriminator, 'fieldDiscriminatorNode')) {
return [getFieldConditionFragment(discriminator, scope)];
Expand All @@ -57,19 +60,25 @@ export function getDiscriminatorConditionFragment(

function getSizeConditionFragment(
discriminator: SizeDiscriminatorNode,
dataName: string
scope: Pick<GlobalFragmentScope, 'typeManifestVisitor'> & {
dataName: string;
}
): Fragment {
const { dataName } = scope;
return fragment(`${dataName}.length === ${discriminator.size}`);
}

function getByteConditionFragment(
discriminator: ByteDiscriminatorNode,
dataName: string
discriminator: ConstantDiscriminatorNode,
scope: Pick<GlobalFragmentScope, 'typeManifestVisitor'> & {
dataName: string;
}
): Fragment {
const bytes = discriminator.bytes.join(', ');
return fragment(
`memcmp(${dataName}, new Uint8Array([${bytes}]), ${discriminator.offset})`
).addImports('shared', 'memcmp');
const { dataName, typeManifestVisitor } = scope;
const constant = visit(discriminator.constant, typeManifestVisitor).value;
return constant
.mapRender((r) => `memcmp(${dataName}, ${r}, ${discriminator.offset})`)
.addImports('shared', 'memcmp');
}

function getFieldConditionFragment(
Expand All @@ -87,7 +96,7 @@ function getFieldConditionFragment(
}

// This handles the case where a field uses an u8 array to represent its discriminator.
// In this case, we can simplify the generated code by delegating to a byteDiscriminatorNode.
// In this case, we can simplify the generated code by delegating to a constantDiscriminatorNode.
if (
isNode(field.type, 'arrayTypeNode') &&
isNode(field.type.item, 'numberTypeNode') &&
Expand All @@ -96,24 +105,23 @@ function getFieldConditionFragment(
isNode(field.defaultValue, 'arrayValueNode') &&
field.defaultValue.items.every(isNodeFilter('numberValueNode'))
) {
const base64Bytes = getBase64Decoder().decode(
new Uint8Array(field.defaultValue.items.map((node) => node.number))
);
return getByteConditionFragment(
byteDiscriminatorNode(
field.defaultValue.items.map((node) => node.number),
constantDiscriminatorNode(
constantValueNodeFromBytes('base64', base64Bytes),
discriminator.offset
),
scope.dataName
scope
);
}

return mergeFragments(
[
visit(field.type, scope.typeManifestVisitor).encoder,
visit(field.defaultValue, scope.typeManifestVisitor).value,
],
([encoderFunction, value]) => `${encoderFunction}.encode(${value})`
)
.mapRender(
(r) => `memcmp(${scope.dataName}, ${r}, ${discriminator.offset})`
)
.addImports('shared', 'memcmp');
return getByteConditionFragment(
constantDiscriminatorNode(
constantValueNode(field.type, field.defaultValue),
discriminator.offset
),
scope
);
}
20 changes: 14 additions & 6 deletions src/renderers/js-experimental/getTypeManifestVisitor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -771,12 +771,20 @@ export function getTypeManifestVisitor(input: {
},

visitConstantValue(node, { self }) {
const manifest = typeManifest();
manifest.value = mergeFragments(
[visit(node.type, self).encoder, visit(node.value, self).value],
([encoderFunction, value]) => `${encoderFunction}.encode(${value})`
);
return manifest;
if (
isNode(node.type, 'bytesTypeNode') &&
isNode(node.value, 'bytesValueNode')
) {
return visit(node.value, self);
}
return {
...typeManifest(),
value: mergeFragments(
[visit(node.type, self).encoder, visit(node.value, self).value],
([encoderFunction, value]) =>
`${encoderFunction}.encode(${value})`
),
};
},

visitEnumValue(node, { self }) {
Expand Down
2 changes: 1 addition & 1 deletion src/renderers/js/getRenderMapVisitor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ export function getRenderMapVisitor(
// Discriminator.
const discriminator =
(node.discriminators ?? []).find(
(d) => !isNode(d, 'byteDiscriminatorNode')
(d) => !isNode(d, 'constantDiscriminatorNode')
) ?? null;
let resolvedDiscriminator:
| SizeDiscriminatorNode
Expand Down
10 changes: 8 additions & 2 deletions src/renderers/js/getTypeManifestVisitor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -775,11 +775,17 @@ export function getTypeManifestVisitor(input: {
},

visitConstantValue(node, { self }) {
if (
isNode(node.type, 'bytesTypeNode') &&
isNode(node.value, 'bytesValueNode')
) {
return visit(node.value, self);
}
const imports = new JavaScriptImportMap();
const type = visit(node.type, self);
imports.mergeWith(type.serializerImports);
const value = visit(node.value, self);
imports.mergeWith(value.valueImports);
const type = visit(node.type, self);
imports.mergeWith(type.serializerImports);
return {
...typeManifest(),
value: `${type.serializer}.serialize(${value.value})`,
Expand Down
13 changes: 2 additions & 11 deletions src/visitors/getDebugStringVisitor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -111,17 +111,8 @@ function getNodeDetails(node: Node): string[] {
node.name,
...(node.importFrom ? [`from:${node.importFrom}`] : []),
];
case 'byteDiscriminatorNode':
return [
...(node.bytes.length > 0
? [
`0x${node.bytes
.map((byte) => byte.toString(16).padStart(2, '0'))
.join('')}`,
]
: []),
...(node.offset > 0 ? [`offset:${node.offset}`] : []),
];
case 'constantDiscriminatorNode':
return [...(node.offset > 0 ? [`offset:${node.offset}`] : [])];
case 'fieldDiscriminatorNode':
return [node.name, ...(node.offset > 0 ? [`offset:${node.offset}`] : [])];
case 'sizeDiscriminatorNode':
Expand Down
12 changes: 12 additions & 0 deletions src/visitors/identityVisitor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import {
assertIsNode,
booleanTypeNode,
conditionalValueNode,
constantDiscriminatorNode,
constantPdaSeedNode,
constantValueNode,
dateTimeTypeNode,
Expand Down Expand Up @@ -642,5 +643,16 @@ export function identityVisitor<TNodeKind extends NodeKind = NodeKind>(
};
}

if (castedNodeKeys.includes('constantDiscriminatorNode')) {
visitor.visitConstantDiscriminator = function visitConstantDiscriminator(
node
) {
const constant = visit(this)(node.constant);
if (constant === null) return null;
assertIsNode(constant, 'constantValueNode');
return constantDiscriminatorNode(constant, node.offset);
};
}

return visitor as Visitor<Node, TNodeKind>;
}
8 changes: 8 additions & 0 deletions src/visitors/mergeVisitor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -356,5 +356,13 @@ export function mergeVisitor<TReturn, TNodeKind extends NodeKind = NodeKind>(
};
}

if (castedNodeKeys.includes('constantDiscriminatorNode')) {
visitor.visitConstantDiscriminator = function visitConstantDiscriminator(
node
) {
return merge(node, visit(this)(node.constant));
};
}

return visitor as Visitor<TReturn, TNodeKind>;
}
13 changes: 10 additions & 3 deletions test/renderers/js-experimental/programsPage.test.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import test from 'ava';
import {
accountNode,
byteDiscriminatorNode,
constantDiscriminatorNode,
constantValueNodeFromBytes,
fieldDiscriminatorNode,
instructionArgumentNode,
instructionNode,
Expand Down Expand Up @@ -77,7 +78,10 @@ test('it renders an function that identifies accounts in a program', (t) => {
name: 'token',
discriminators: [
sizeDiscriminatorNode(72),
byteDiscriminatorNode([1, 2, 3], 4),
constantDiscriminatorNode(
constantValueNodeFromBytes('base16', '010203'),
4
),
],
}),
// No discriminator.
Expand Down Expand Up @@ -149,7 +153,10 @@ test('it renders an function that identifies instructions in a program', (t) =>
name: 'transferTokens',
discriminators: [
sizeDiscriminatorNode(72),
byteDiscriminatorNode([1, 2, 3], 4),
constantDiscriminatorNode(
constantValueNodeFromBytes('base16', '010203'),
4
),
],
}),
// No discriminator.
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import test from 'ava';
import {
constantDiscriminatorNode,
constantValueNodeFromBytes,
} from '../../../../src';
import {
deleteNodesVisitorMacro,
getDebugStringVisitorMacro,
identityVisitorMacro,
mergeVisitorMacro,
} from '../_setup';

const node = constantDiscriminatorNode(
constantValueNodeFromBytes('base16', '01020304'),
42
);

test(mergeVisitorMacro, node, 4);
test(identityVisitorMacro, node);
test(deleteNodesVisitorMacro, node, '[constantDiscriminatorNode]', null);
test(deleteNodesVisitorMacro, node, '[constantValueNode]', null);
test(
getDebugStringVisitorMacro,
node,
`
constantDiscriminatorNode [offset:42]
| constantValueNode
| | bytesTypeNode
| | bytesValueNode [base16.01020304]`
);

0 comments on commit 484da01

Please sign in to comment.