Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
tssweeney committed Dec 14, 2024
1 parent 5e7ffb8 commit f01f97a
Show file tree
Hide file tree
Showing 14 changed files with 388 additions and 26 deletions.
12 changes: 6 additions & 6 deletions weave-js/scripts/generate-schemas.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
# Exit on error
set -e

SCHEMA_INPUT_PATH="../weave/trace_server/interface/builtin_object_classes/generated/generated_builtin_object_class_schemas.json"
SCHEMA_OUTPUT_PATH="./src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/generatedBuiltinObjectClasses.zod.ts"
SCHEMA_INPUT_PATH="../weave/trace_server/interface/base_object_classes/generated/generated_base_object_class_schemas.json"
SCHEMA_OUTPUT_PATH="./src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/generatedBaseObjectClasses.zod.ts"

echo "Generating schemas..."

Expand All @@ -13,10 +13,10 @@ yarn quicktype -s schema "$SCHEMA_INPUT_PATH" -o "$SCHEMA_OUTPUT_PATH" --lang ty

# Transform the schema to extract the type map
sed -i.bak '
# Find the GeneratedBuiltinObjectClassesZodSchema definition and capture its contents
/export const GeneratedBuiltinObjectClassesZodSchema = z.object({/,/});/ {
# Find the GeneratedBaseObjectClassesZodSchema definition and capture its contents
/export const GeneratedBaseObjectClassesZodSchema = z.object({/,/});/ {
# Replace the opening line with typeMap declaration
s/export const GeneratedBuiltinObjectClassesZodSchema = z.object({/export const builtinObjectClassRegistry = ({/
s/export const GeneratedBaseObjectClassesZodSchema = z.object({/export const baseObjectClassRegistry = ({/
# Store the pattern
h
# If this is the last line (with closing brace), append the schema definition
Expand All @@ -27,7 +27,7 @@ sed -i.bak '
s/.*//
i\
\
export const GeneratedBuiltinObjectClassesZodSchema = z.object(builtinObjectClassRegistry)
export const GeneratedBaseObjectClassesZodSchema = z.object(baseObjectClassRegistry)
}
}
' "$SCHEMA_OUTPUT_PATH"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import {AnnotationSpec} from '../../pages/wfReactInterface/generatedBuiltinObjectClasses.zod';
import {AnnotationSpec} from '../../pages/wfReactInterface/generatedBaseObjectClasses.zod';
import {Feedback} from '../../pages/wfReactInterface/traceServerClientTypes';

export const HUMAN_ANNOTATION_BASE_TYPE = 'wandb.annotation';
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import {useEffect, useMemo, useState} from 'react';

import {useBaseObjectInstances} from '../../pages/wfReactInterface/objectClassQuery';
import {useBaseObjectInstances} from '../../pages/wfReactInterface/baseObjectClassQuery';
import {
TraceObjQueryReq,
TraceObjSchema,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@ import {CellValue} from '../../../Browse2/CellValue';
import {NotApplicable} from '../../../Browse2/NotApplicable';
import {SmallRef} from '../../../Browse2/SmallRef';
import {StyledDataGrid} from '../../StyledDataGrid'; // Import the StyledDataGrid component
import {WEAVE_REF_SCHEME} from '../wfReactInterface/constants';
import {useWFHooks} from '../wfReactInterface/context';
import {
TraceObjSchemaForBaseObjectClass,
useBaseObjectInstances,
} from '../wfReactInterface/objectClassQuery';
} from '../wfReactInterface/baseObjectClassQuery';
import {WEAVE_REF_SCHEME} from '../wfReactInterface/constants';
import {useWFHooks} from '../wfReactInterface/context';
import {useGetTraceServerClientContext} from '../wfReactInterface/traceServerClientContext';
import {Feedback} from '../wfReactInterface/traceServerClientTypes';
import {projectIdFromParts} from '../wfReactInterface/tsDataModelHooks';
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ import {SimplePageLayout} from '../common/SimplePageLayout';
import {ObjectVersionsTable} from '../ObjectVersionsPage';
import {
useBaseObjectInstances,
useCreateBuiltinObjectInstance,
} from '../wfReactInterface/objectClassQuery';
useCreateBaseObjectInstance,
} from '../wfReactInterface/baseObjectClassQuery';
import {sanitizeObjectId} from '../wfReactInterface/traceServerDirectClient';
import {
convertTraceServerObjectVersionToSchema,
Expand Down Expand Up @@ -162,8 +162,7 @@ const generateLeaderboardId = () => {
};

const useCreateLeaderboard = (entity: string, project: string) => {
const createLeaderboardInstance =
useCreateBuiltinObjectInstance('Leaderboard');
const createLeaderboardInstance = useCreateBaseObjectInstance('Leaderboard');

const createLeaderboard = async () => {
const objectId = sanitizeObjectId(generateLeaderboardId());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ import {LeaderboardObjectVal} from '../../views/Leaderboard/types/leaderboardCon
import {SimplePageLayout} from '../common/SimplePageLayout';
import {
useBaseObjectInstances,
useCreateBuiltinObjectInstance,
} from '../wfReactInterface/objectClassQuery';
useCreateBaseObjectInstance,
} from '../wfReactInterface/baseObjectClassQuery';
import {projectIdFromParts} from '../wfReactInterface/tsDataModelHooks';
import {LeaderboardConfigEditor} from './LeaderboardConfigEditor';

Expand Down Expand Up @@ -131,7 +131,7 @@ const useUpdateLeaderboard = (
project: string,
objectId: string
) => {
const createLeaderboard = useCreateBuiltinObjectInstance('Leaderboard');
const createLeaderboard = useCreateBaseObjectInstance('Leaderboard');

const updateLeaderboard = async (leaderboardVal: LeaderboardObjectVal) => {
return await createLeaderboard({
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import {Box} from '@material-ui/core';
import React, {FC, useCallback, useEffect, useState} from 'react';
import {z} from 'zod';

import {createBuiltinObjectInstance} from '../wfReactInterface/objectClassQuery';
import {createBaseObjectInstance} from '../wfReactInterface/baseObjectClassQuery';
import {TraceServerClient} from '../wfReactInterface/traceServerClient';
import {sanitizeObjectId} from '../wfReactInterface/traceServerDirectClient';
import {projectIdFromParts} from '../wfReactInterface/tsDataModelHooks';
Expand Down Expand Up @@ -86,7 +86,7 @@ export const onAnnotationScorerSave = async (
) => {
const jsonSchemaType = convertTypeToJsonSchemaType(data.Type.type);
const typeExtras = convertTypeExtrasToJsonSchema(data);
return createBuiltinObjectInstance(client, 'AnnotationSpec', {
return createBaseObjectInstance(client, 'AnnotationSpec', {
obj: {
project_id: projectIdFromParts({entity, project}),
object_id: sanitizeObjectId(data.Name),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ import _ from 'lodash';
import React, {FC, useCallback, useState} from 'react';
import {z} from 'zod';

import {LlmJudgeActionSpecSchema} from '../wfReactInterface/builtinObjectClasses.zod';
import {ActionSpecSchema} from '../wfReactInterface/generatedBuiltinObjectClasses.zod';
import {createBuiltinObjectInstance} from '../wfReactInterface/objectClassQuery';
import {LlmJudgeActionSpecSchema} from '../wfReactInterface/baseObjectClasses.zod';
import {createBaseObjectInstance} from '../wfReactInterface/baseObjectClassQuery';
import {ActionSpecSchema} from '../wfReactInterface/generatedBaseObjectClasses.zod';
import {TraceServerClient} from '../wfReactInterface/traceServerClient';
import {projectIdFromParts} from '../wfReactInterface/tsDataModelHooks';
import {AutocompleteWithLabel} from './FormComponents';
Expand Down Expand Up @@ -185,7 +185,7 @@ export const onLLMJudgeScorerSave = async (
config: judgeAction,
});

return createBuiltinObjectInstance(client, 'ActionSpec', {
return createBaseObjectInstance(client, 'ActionSpec', {
obj: {
project_id: projectIdFromParts({entity, project}),
object_id: objectId,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import {expectType} from 'tsd';

import {
useBaseObjectInstances,
useCreateBaseObjectInstance,
} from './baseObjectClassQuery';
import {
TestOnlyExample,
TestOnlyExampleSchema,
} from './generatedBaseObjectClasses.zod';
import {
TraceObjCreateReq,
TraceObjCreateRes,
TraceObjSchema,
} from './traceServerClientTypes';
import {Loadable} from './wfDataModelHooksInterface';

type TypesAreEqual<T, U> = [T] extends [U]
? [U] extends [T]
? true
: false
: false;

describe('Type Tests', () => {
it('useCollectionObjects return type matches expected structure', () => {
type CollectionObjectsReturn = ReturnType<
typeof useBaseObjectInstances<'TestOnlyExample'>
>;

// Define the expected type structure
type ExpectedType = Loadable<
Array<TraceObjSchema<TestOnlyExample, 'TestOnlyExample'>>
>;

// Type assertion tests
type AssertTypesAreEqual = TypesAreEqual<
CollectionObjectsReturn,
ExpectedType
>;
type Assert = AssertTypesAreEqual extends true ? true : never;

// This will fail compilation if the types don't match exactly
const _assert: Assert = true;
expect(_assert).toBe(true);

// Additional runtime sample for documentation
const sampleResult: CollectionObjectsReturn = {
loading: false,
result: [
{
project_id: '',
object_id: '',
created_at: '',
digest: '',
version_index: 0,
is_latest: 0,
kind: 'object',
base_object_class: 'TestOnlyExample',
val: TestOnlyExampleSchema.parse({
name: '',
description: '',
nested_base_model: {
a: 1,
},
nested_base_object: '',
primitive: 1,
}),
},
],
};

expectType<ExpectedType>(sampleResult);
});

it('useCreateCollectionObject return type matches expected structure', () => {
type CreateCollectionObjectReturn = ReturnType<
typeof useCreateBaseObjectInstance<'TestOnlyExample'>
>;

// Define the expected type structure
type ExpectedType = (
req: TraceObjCreateReq<TestOnlyExample>
) => Promise<TraceObjCreateRes>;

// Type assertion tests
type AssertTypesAreEqual = TypesAreEqual<
CreateCollectionObjectReturn,
ExpectedType
>;
type Assert = AssertTypesAreEqual extends true ? true : never;

// This will fail compilation if the types don't match exactly
const _assert: Assert = true;
expect(_assert).toBe(true);

// Additional runtime sample for documentation
const sampleResult: CreateCollectionObjectReturn = async req => {
return {
digest: '',
};
};

expectType<ExpectedType>(sampleResult);
});
});
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
import {useDeepMemo} from '@wandb/weave/hookUtils';
import {useEffect, useRef, useState} from 'react';
import {z} from 'zod';

import {baseObjectClassRegistry} from './generatedBaseObjectClasses.zod';
import {TraceServerClient} from './traceServerClient';
import {useGetTraceServerClientContext} from './traceServerClientContext';
import {
TraceObjCreateReq,
TraceObjCreateRes,
TraceObjQueryReq,
TraceObjSchema,
} from './traceServerClientTypes';
import {Loadable} from './wfDataModelHooksInterface';

type BaseObjectClassRegistry = typeof baseObjectClassRegistry;
type BaseObjectClassRegistryKeys = keyof BaseObjectClassRegistry;
type BaseObjectClassType<C extends BaseObjectClassRegistryKeys> = z.infer<
BaseObjectClassRegistry[C]
>;

export type TraceObjSchemaForBaseObjectClass<
C extends BaseObjectClassRegistryKeys
> = TraceObjSchema<BaseObjectClassType<C>, C>;

export const useBaseObjectInstances = <C extends BaseObjectClassRegistryKeys>(
baseObjectClassName: C,
req: TraceObjQueryReq
): Loadable<Array<TraceObjSchemaForBaseObjectClass<C>>> => {
const [objects, setObjects] = useState<
Array<TraceObjSchemaForBaseObjectClass<C>>
>([]);
const getTsClient = useGetTraceServerClientContext();
const client = getTsClient();
const deepReq = useDeepMemo(req);
const currReq = useRef(deepReq);
const [loading, setLoading] = useState(true);

useEffect(() => {
let isMounted = true;
setLoading(true);
currReq.current = deepReq;
getBaseObjectInstances(client, baseObjectClassName, deepReq).then(
collectionObjects => {
if (isMounted && currReq.current === deepReq) {
setObjects(collectionObjects);
setLoading(false);
}
}
);
return () => {
isMounted = false;
};
}, [client, baseObjectClassName, deepReq]);

return {result: objects, loading};
};

const getBaseObjectInstances = async <C extends BaseObjectClassRegistryKeys>(
client: TraceServerClient,
baseObjectClassName: C,
req: TraceObjQueryReq
): Promise<Array<TraceObjSchema<BaseObjectClassType<C>, C>>> => {
const knownObjectClass = baseObjectClassRegistry[baseObjectClassName];
if (!knownObjectClass) {
console.warn(`Unknown object class: ${baseObjectClassName}`);
return [];
}

const reqWithBaseObjectClass: TraceObjQueryReq = {
...req,
filter: {...req.filter, base_object_classes: [baseObjectClassName]},
};

const objectPromise = client.objsQuery(reqWithBaseObjectClass);

const objects = await objectPromise;

// We would expect that this filtering does not filter anything
// out because the backend enforces the base object class, but this
// is here as a sanity check.
return objects.objs
.map(obj => ({obj, parsed: knownObjectClass.safeParse(obj.val)}))
.filter(({parsed}) => parsed.success)
.filter(({obj}) => obj.base_object_class === baseObjectClassName)
.map(
({obj, parsed}) =>
({...obj, val: parsed.data} as TraceObjSchema<
BaseObjectClassType<C>,
C
>)
);
};

export const useCreateBaseObjectInstance = <
C extends BaseObjectClassRegistryKeys,
T = BaseObjectClassType<C>
>(
baseObjectClassName: C
): ((req: TraceObjCreateReq<T>) => Promise<TraceObjCreateRes>) => {
const getTsClient = useGetTraceServerClientContext();
const client = getTsClient();
return (req: TraceObjCreateReq<T>) =>
createBaseObjectInstance(client, baseObjectClassName, req);
};

export const createBaseObjectInstance = async <
C extends BaseObjectClassRegistryKeys,
T = BaseObjectClassType<C>
>(
client: TraceServerClient,
baseObjectClassName: C,
req: TraceObjCreateReq<T>
): Promise<TraceObjCreateRes> => {
if (
req.obj.set_base_object_class != null &&
req.obj.set_base_object_class !== baseObjectClassName
) {
throw new Error(
`set_base_object_class must match baseObjectClassName: ${baseObjectClassName}`
);
}

const knownBaseObjectClass = baseObjectClassRegistry[baseObjectClassName];
if (!knownBaseObjectClass) {
throw new Error(`Unknown object class: ${baseObjectClassName}`);
}

const verifiedObject = knownBaseObjectClass.safeParse(req.obj.val);

if (!verifiedObject.success) {
throw new Error(
`Invalid object: ${JSON.stringify(verifiedObject.error.errors)}`
);
}

const reqWithBaseObjectClass: TraceObjCreateReq = {
...req,
obj: {
...req.obj,
set_base_object_class: baseObjectClassName,
},
};

return client.objCreate(reqWithBaseObjectClass);
};
Loading

0 comments on commit f01f97a

Please sign in to comment.