Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
tssweeney committed Dec 10, 2024
1 parent 8d0f3f7 commit ef84eb6
Show file tree
Hide file tree
Showing 16 changed files with 116 additions and 86 deletions.
8 changes: 4 additions & 4 deletions dev_docs/BaseObjectClasses.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ curl -X POST 'https://trace.wandb.ai/obj/create' \
"project_id": "user/project",
"object_id": "my_config",
"val": {...},
"set_base_object_class": "MyConfig"
"set_leaf_object_class": "MyConfig"
}
}'

Expand Down Expand Up @@ -162,7 +162,7 @@ Run `make synchronize-base-object-schemas` to ensure the frontend TypeScript typ
4. Now, each use case uses different parts:
1. `Python Writing`. Users can directly import these classes and use them as normal Pydantic models, which get published with `weave.publish`. The python client correct builds the requisite payload.
2. `Python Reading`. Users can `weave.ref().get()` and the weave python SDK will return the instance with the correct type. Note: we do some special handling such that the returned object is not a WeaveObject, but literally the exact pydantic class.
3. `HTTP Writing`. In cases where the client/user does not want to add the special type information, users can publish base objects by setting the `set_base_object_class` setting on `POST obj/create` to the name of the class. The weave server will validate the object against the schema, update the metadata fields, and store the object.
3. `HTTP Writing`. In cases where the client/user does not want to add the special type information, users can publish objects by setting the `set_leaf_object_class` setting on `POST obj/create` to the name of the class. The weave server will validate the object against the schema, update the metadata fields, and store the object.
4. `HTTP Reading`. When querying for objects, the server will return the object with the correct type if the `base_object_class` metadata field is set.
5. `Frontend`. The frontend will read the zod schema from `weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/generatedBaseObjectClasses.zod.ts` and use that to provide compile time type safety when using `useBaseObjectInstances` and runtime type safety when using `useCreateBaseObjectInstance`.
* Note: it is critical that all techniques produce the same digest for the same data - which is tested in the tests. This way versions are not thrashed by different clients/users.
Expand All @@ -185,7 +185,7 @@ graph TD
subgraph "Trace Server"
subgraph "HTTP API"
R --> |validates using| HW["POST obj/create<br>set_base_object_class"]
R --> |validates using| HW["POST obj/create<br>set_leaf_object_class"]
HW --> DB[(Weave Object Store)]
HR["POST objs/query<br>base_object_classes"] --> |Filters base_object_class| DB
end
Expand All @@ -203,7 +203,7 @@ graph TD
Z --> |import| UBI["useBaseObjectInstances"]
Z --> |import| UCI["useCreateBaseObjectInstance"]
UBI --> |Filters base_object_class| HR
UCI --> |set_base_object_class| HW
UCI --> |set_leaf_object_class| HW
UI[React UI] --> UBI
UI --> UCI
end
Expand Down
14 changes: 7 additions & 7 deletions tests/trace/test_base_object_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def test_interface_creation(client):
"project_id": client._project_id(),
"object_id": nested_obj_id,
"val": nested_obj.model_dump(),
"set_base_object_class": "TestOnlyNestedBaseObject",
"set_leaf_object_class": "TestOnlyNestedBaseObject",
}
}
)
Expand All @@ -164,7 +164,7 @@ def test_interface_creation(client):
"project_id": client._project_id(),
"object_id": top_level_obj_id,
"val": top_obj.model_dump(),
"set_base_object_class": "TestOnlyExample",
"set_leaf_object_class": "TestOnlyExample",
}
}
)
Expand Down Expand Up @@ -271,7 +271,7 @@ def test_digest_equality(client):
"project_id": client._project_id(),
"object_id": nested_obj_id,
"val": nested_obj.model_dump(),
"set_base_object_class": "TestOnlyNestedBaseObject",
"set_leaf_object_class": "TestOnlyNestedBaseObject",
}
}
)
Expand Down Expand Up @@ -300,7 +300,7 @@ def test_digest_equality(client):
"project_id": client._project_id(),
"object_id": top_level_obj_id,
"val": top_obj.model_dump(),
"set_base_object_class": "TestOnlyExample",
"set_leaf_object_class": "TestOnlyExample",
}
}
)
Expand All @@ -322,7 +322,7 @@ def test_schema_validation(client):
"object_id": "nested_obj",
# Incorrect schema, should raise!
"val": {"a": 2},
"set_base_object_class": "TestOnlyNestedBaseObject",
"set_leaf_object_class": "TestOnlyNestedBaseObject",
}
}
)
Expand All @@ -340,7 +340,7 @@ def test_schema_validation(client):
"_class_name": "TestOnlyNestedBaseObject",
"_bases": ["BaseObject", "BaseModel"],
},
"set_base_object_class": "TestOnlyNestedBaseObject",
"set_leaf_object_class": "TestOnlyNestedBaseObject",
}
}
)
Expand All @@ -359,7 +359,7 @@ def test_schema_validation(client):
"_class_name": "TestOnlyNestedBaseObject",
"_bases": ["BaseObject", "BaseModel"],
},
"set_base_object_class": "TestOnlyExample",
"set_leaf_object_class": "TestOnlyExample",
}
}
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import {NotApplicable} from '../../../Browse2/NotApplicable';
import {SmallRef} from '../../../Browse2/SmallRef';
import {StyledDataGrid} from '../../StyledDataGrid'; // Import the StyledDataGrid component
import {
TraceObjSchemaForBaseObjectClass,
TraceObjSchemaForObjectClass,
useBaseObjectInstances,
} from '../wfReactInterface/baseObjectClassQuery';
import {WEAVE_REF_SCHEME} from '../wfReactInterface/constants';
Expand Down Expand Up @@ -61,7 +61,7 @@ const useRunnableFeedbacksForCall = (call: CallSchema) => {

const useRunnableFeedbackTypeToLatestActionRef = (
call: CallSchema,
actionSpecs: Array<TraceObjSchemaForBaseObjectClass<'ActionSpec'>>
actionSpecs: Array<TraceObjSchemaForObjectClass<'ActionSpec'>>
): Record<string, string> => {
return useMemo(() => {
return _.fromPairs(
Expand Down Expand Up @@ -92,7 +92,7 @@ type GroupedRowType = {
};

const useTableRowsForRunnableFeedbacks = (
actionSpecs: Array<TraceObjSchemaForBaseObjectClass<'ActionSpec'>>,
actionSpecs: Array<TraceObjSchemaForObjectClass<'ActionSpec'>>,
runnableFeedbacks: Feedback[],
runnableFeedbackTypeToLatestActionRef: Record<string, string>
): GroupedRowType[] => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import {SimplePageLayout} from '../common/SimplePageLayout';
import {ObjectVersionsTable} from '../ObjectVersionsPage';
import {
useBaseObjectInstances,
useCreateBaseObjectInstance,
useCreateLeafObjectInstance,
} from '../wfReactInterface/baseObjectClassQuery';
import {sanitizeObjectId} from '../wfReactInterface/traceServerDirectClient';
import {
Expand Down Expand Up @@ -162,7 +162,7 @@ const generateLeaderboardId = () => {
};

const useCreateLeaderboard = (entity: string, project: string) => {
const createLeaderboardInstance = useCreateBaseObjectInstance('Leaderboard');
const createLeaderboardInstance = useCreateLeafObjectInstance('Leaderboard');

const createLeaderboard = async () => {
const objectId = sanitizeObjectId(generateLeaderboardId());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import {LeaderboardObjectVal} from '../../views/Leaderboard/types/leaderboardCon
import {SimplePageLayout} from '../common/SimplePageLayout';
import {
useBaseObjectInstances,
useCreateBaseObjectInstance,
useCreateLeafObjectInstance,
} from '../wfReactInterface/baseObjectClassQuery';
import {projectIdFromParts} from '../wfReactInterface/tsDataModelHooks';
import {LeaderboardConfigEditor} from './LeaderboardConfigEditor';
Expand Down Expand Up @@ -131,7 +131,7 @@ const useUpdateLeaderboard = (
project: string,
objectId: string
) => {
const createLeaderboard = useCreateBaseObjectInstance('Leaderboard');
const createLeaderboard = useCreateLeafObjectInstance('Leaderboard');

const updateLeaderboard = async (leaderboardVal: LeaderboardObjectVal) => {
return await createLeaderboard({
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import {Box} from '@material-ui/core';
import React, {FC, useCallback, useState} from 'react';
import {z} from 'zod';

import {createBaseObjectInstance} from '../wfReactInterface/baseObjectClassQuery';
import {createLeafObjectInstance} from '../wfReactInterface/baseObjectClassQuery';
import {TraceServerClient} from '../wfReactInterface/traceServerClient';
import {sanitizeObjectId} from '../wfReactInterface/traceServerDirectClient';
import {projectIdFromParts} from '../wfReactInterface/tsDataModelHooks';
Expand Down Expand Up @@ -83,7 +83,7 @@ export const onAnnotationScorerSave = async (
) => {
const jsonSchemaType = convertTypeToJsonSchemaType(data.Type.type);
const typeExtras = convertTypeExtrasToJsonSchema(data);
return createBaseObjectInstance(client, 'AnnotationSpec', {
return createLeafObjectInstance(client, 'AnnotationSpec', {
obj: {
project_id: projectIdFromParts({entity, project}),
object_id: sanitizeObjectId(data.Name),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import React, {FC, useCallback, useState} from 'react';
import {z} from 'zod';

import {LlmJudgeActionSpecSchema} from '../wfReactInterface/baseObjectClasses.zod';
import {createBaseObjectInstance} from '../wfReactInterface/baseObjectClassQuery';
import {createLeafObjectInstance} from '../wfReactInterface/baseObjectClassQuery';
import {ActionSpecSchema} from '../wfReactInterface/generatedBaseObjectClasses.zod';
import {TraceServerClient} from '../wfReactInterface/traceServerClient';
import {projectIdFromParts} from '../wfReactInterface/tsDataModelHooks';
Expand Down Expand Up @@ -185,7 +185,7 @@ export const onLLMJudgeScorerSave = async (
config: judgeAction,
});

return createBaseObjectInstance(client, 'ActionSpec', {
return createLeafObjectInstance(client, 'ActionSpec', {
obj: {
project_id: projectIdFromParts({entity, project}),
object_id: objectId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import {expectType} from 'tsd';

import {
useBaseObjectInstances,
useCreateBaseObjectInstance,
useCreateLeafObjectInstance,
} from './baseObjectClassQuery';
import {
TestOnlyExample,
Expand Down Expand Up @@ -74,7 +74,7 @@ describe('Type Tests', () => {

it('useCreateCollectionObject return type matches expected structure', () => {
type CreateCollectionObjectReturn = ReturnType<
typeof useCreateBaseObjectInstance<'TestOnlyExample'>
typeof useCreateLeafObjectInstance<'TestOnlyExample'>
>;

// Define the expected type structure
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,22 @@ import {
} from './traceServerClientTypes';
import {Loadable} from './wfDataModelHooksInterface';

type BaseObjectClassRegistry = typeof baseObjectClassRegistry;
type BaseObjectClassRegistryKeys = keyof BaseObjectClassRegistry;
type BaseObjectClassType<C extends BaseObjectClassRegistryKeys> = z.infer<
BaseObjectClassRegistry[C]
type ObjectClassRegistry = typeof baseObjectClassRegistry; // TODO: Add more here - not just bases!
type ObjectClassRegistryKeys = keyof ObjectClassRegistry;
type ObjectClassType<C extends ObjectClassRegistryKeys> = z.infer<
ObjectClassRegistry[C]
>;

export type TraceObjSchemaForBaseObjectClass<
C extends BaseObjectClassRegistryKeys
> = TraceObjSchema<BaseObjectClassType<C>, C>;
export type TraceObjSchemaForObjectClass<
C extends ObjectClassRegistryKeys
> = TraceObjSchema<ObjectClassType<C>, C>;

export const useBaseObjectInstances = <C extends BaseObjectClassRegistryKeys>(
export const useBaseObjectInstances = <C extends ObjectClassRegistryKeys>(
baseObjectClassName: C,
req: TraceObjQueryReq
): Loadable<Array<TraceObjSchemaForBaseObjectClass<C>>> => {
): Loadable<Array<TraceObjSchemaForObjectClass<C>>> => {
const [objects, setObjects] = useState<
Array<TraceObjSchemaForBaseObjectClass<C>>
Array<TraceObjSchemaForObjectClass<C>>
>([]);
const getTsClient = useGetTraceServerClientContext();
const client = getTsClient();
Expand Down Expand Up @@ -56,11 +56,11 @@ export const useBaseObjectInstances = <C extends BaseObjectClassRegistryKeys>(
return {result: objects, loading};
};

const getBaseObjectInstances = async <C extends BaseObjectClassRegistryKeys>(
const getBaseObjectInstances = async <C extends ObjectClassRegistryKeys>(
client: TraceServerClient,
baseObjectClassName: C,
req: TraceObjQueryReq
): Promise<Array<TraceObjSchema<BaseObjectClassType<C>, C>>> => {
): Promise<Array<TraceObjSchema<ObjectClassType<C>, C>>> => {
const knownObjectClass = baseObjectClassRegistry[baseObjectClassName];
if (!knownObjectClass) {
console.warn(`Unknown object class: ${baseObjectClassName}`);
Expand All @@ -86,61 +86,61 @@ const getBaseObjectInstances = async <C extends BaseObjectClassRegistryKeys>(
.map(
({obj, parsed}) =>
({...obj, val: parsed.data} as TraceObjSchema<
BaseObjectClassType<C>,
ObjectClassType<C>,
C
>)
);
};

export const useCreateBaseObjectInstance = <
C extends BaseObjectClassRegistryKeys,
T = BaseObjectClassType<C>
export const useCreateLeafObjectInstance = <
C extends ObjectClassRegistryKeys,
T = ObjectClassType<C>
>(
baseObjectClassName: C
leafObjectClassName: C
): ((req: TraceObjCreateReq<T>) => Promise<TraceObjCreateRes>) => {
const getTsClient = useGetTraceServerClientContext();
const client = getTsClient();
return (req: TraceObjCreateReq<T>) =>
createBaseObjectInstance(client, baseObjectClassName, req);
createLeafObjectInstance(client, leafObjectClassName, req);
};

export const createBaseObjectInstance = async <
C extends BaseObjectClassRegistryKeys,
T = BaseObjectClassType<C>
export const createLeafObjectInstance = async <
C extends ObjectClassRegistryKeys,
T = ObjectClassType<C>
>(
client: TraceServerClient,
baseObjectClassName: C,
leafObjectClassName: C,
req: TraceObjCreateReq<T>
): Promise<TraceObjCreateRes> => {
if (
req.obj.set_base_object_class != null &&
req.obj.set_base_object_class !== baseObjectClassName
req.obj.set_leaf_object_class != null &&
req.obj.set_leaf_object_class !== leafObjectClassName
) {
throw new Error(
`set_base_object_class must match baseObjectClassName: ${baseObjectClassName}`
`set_leaf_object_class must match leafObjectClassName: ${leafObjectClassName}`
);
}

const knownBaseObjectClass = baseObjectClassRegistry[baseObjectClassName];
if (!knownBaseObjectClass) {
throw new Error(`Unknown object class: ${baseObjectClassName}`);
const knownObjectClass = baseObjectClassRegistry[leafObjectClassName];
if (!knownObjectClass) {
throw new Error(`Unknown object class: ${leafObjectClassName}`);
}

const verifiedObject = knownBaseObjectClass.safeParse(req.obj.val);
const verifiedObject = knownObjectClass.safeParse(req.obj.val);

if (!verifiedObject.success) {
throw new Error(
`Invalid object: ${JSON.stringify(verifiedObject.error.errors)}`
);
}

const reqWithBaseObjectClass: TraceObjCreateReq = {
const reqWithLeafObjectClass: TraceObjCreateReq = {
...req,
obj: {
...req.obj,
set_base_object_class: baseObjectClassName,
set_leaf_object_class: leafObjectClassName,
},
};

return client.objCreate(reqWithBaseObjectClass);
return client.objCreate(reqWithLeafObjectClass);
};
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ export type TraceObjCreateReq<T extends any = any> = {
project_id: string;
object_id: string;
val: T;
set_base_object_class?: string;
set_leaf_object_class?: string;
};
};

Expand Down
Loading

0 comments on commit ef84eb6

Please sign in to comment.