parseRef(props.tableRefUri) as WeaveObjectRef,
+ [props.tableRefUri]
+ );
+
// Determines if the table itself is truncated
const isTruncated = useMemo(() => {
return (fetchQuery.result ?? []).length > MAX_ROWS;
@@ -96,16 +106,19 @@ export const WeaveCHTable: FC<{
);
return (
-
+
+
+
);
};
@@ -133,7 +146,7 @@ export const DataTableView: FC<{
if (val == null) {
return {};
} else if (typeof val === 'object' && !Array.isArray(val)) {
- return flattenObject(val);
+ return flattenObjectPreservingWeaveTypes(val);
}
return {'': val};
});
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/ObjectViewer.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/ObjectViewer.tsx
index 8868f8a1d7d..24e10a3dd09 100644
--- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/ObjectViewer.tsx
+++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/ObjectViewer.tsx
@@ -21,6 +21,7 @@ import {LoadingDots} from '../../../../../LoadingDots';
import {Browse2OpDefCode} from '../../../Browse2/Browse2OpDefCode';
import {parseRefMaybe} from '../../../Browse2/SmallRef';
import {StyledDataGrid} from '../../StyledDataGrid';
+import {isCustomWeaveTypePayload} from '../../typeViews/customWeaveType.types';
import {isRef} from '../common/util';
import {
LIST_INDEX_EDGE_NAME,
@@ -163,6 +164,37 @@ export const ObjectViewer = ({
}
> = [];
traverse(resolvedData, context => {
+ // Ops should be migrated to the generic CustomWeaveType pattern, but for
+ // now they are custom handled.
+ const isOpPayload = context.value?.weave_type?.type === 'Op';
+
+ if (isCustomWeaveTypePayload(context.value) && !isOpPayload) {
+ /**
+ * This block adds an "empty" key that is used to render the custom
+ * weave type. In the event that a custom type has both properties AND
+ * custom views, then we might need to extend / modify this part.
+ */
+ const refBackingData = context.value?._ref;
+ let depth = context.depth;
+ let path = context.path;
+ if (refBackingData) {
+ contexts.push({
+ ...context,
+ isExpandableRef: true,
+ });
+ depth += 1;
+ path = context.path.plus('');
+ }
+ contexts.push({
+ depth,
+ isLeaf: true,
+ path,
+ value: context.value,
+ valueType: context.valueType,
+ });
+ return 'skip';
+ }
+
if (context.depth !== 0) {
const contextTail = context.path.tail();
const isNullDescription =
@@ -207,7 +239,8 @@ export const ObjectViewer = ({
if (USE_TABLE_FOR_ARRAYS && context.valueType === 'array') {
return 'skip';
}
- if (context.value?._ref && context.value?.weave_type?.type === 'Op') {
+ if (context.value?._ref && isOpPayload) {
+ // This should be moved to the CustomWeaveType pattern.
contexts.push({
depth: context.depth + 1,
isLeaf: true,
@@ -377,11 +410,15 @@ export const ObjectViewer = ({
isRef(params.model.value) &&
(parseRefMaybe(params.model.value) as any).weaveKind === 'table';
const {isCode} = params.model;
+ const isCustomWeaveType = isCustomWeaveTypePayload(
+ params.model.value
+ );
if (
isNonRefString ||
(isArray && USE_TABLE_FOR_ARRAYS) ||
isTableRef ||
- isCode
+ isCode ||
+ isCustomWeaveType
) {
return 'auto';
}
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/ObjectViewerSection.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/ObjectViewerSection.tsx
index d5176339465..6a72ee6e6d1 100644
--- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/ObjectViewerSection.tsx
+++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/ObjectViewerSection.tsx
@@ -14,6 +14,8 @@ import {isWeaveObjectRef, parseRef} from '../../../../../../react';
import {Alert} from '../../../../../Alert';
import {Button} from '../../../../../Button';
import {CodeEditor} from '../../../../../CodeEditor';
+import {isCustomWeaveTypePayload} from '../../typeViews/customWeaveType.types';
+import {CustomWeaveTypeDispatcher} from '../../typeViews/CustomWeaveTypeDispatcher';
import {isRef} from '../common/util';
import {OBJECT_ATTR_EDGE_NAME} from '../wfReactInterface/constants';
import {WeaveCHTable, WeaveCHTableSourceRefContext} from './DataTableView';
@@ -119,7 +121,7 @@ const ObjectViewerSectionNonEmpty = ({
);
}
return null;
- }, [apiRef, mode, data, expandedIds]);
+ }, [mode, apiRef, data, expandedIds]);
const setTreeExpanded = useCallback(
(setIsExpanded: boolean) => {
@@ -215,9 +217,20 @@ export const ObjectViewerSection = ({
noHide,
isExpanded,
}: ObjectViewerSectionProps) => {
- const numKeys = Object.keys(data).length;
const currentRef = useContext(WeaveCHTableSourceRefContext);
+ if (isCustomWeaveTypePayload(data)) {
+ return (
+ <>
+
+ {title}
+
+
+ >
+ );
+ }
+
+ const numKeys = Object.keys(data).length;
if (numKeys === 0) {
return (
<>
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/ValueView.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/ValueView.tsx
index 1536ff6e9a1..581aba7a729 100644
--- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/ValueView.tsx
+++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/ValueView.tsx
@@ -1,7 +1,9 @@
import React, {useMemo} from 'react';
-import {parseRef} from '../../../../../../react';
+import {isWeaveObjectRef, parseRef} from '../../../../../../react';
import {parseRefMaybe, SmallRef} from '../../../Browse2/SmallRef';
+import {isCustomWeaveTypePayload} from '../../typeViews/customWeaveType.types';
+import {CustomWeaveTypeDispatcher} from '../../typeViews/CustomWeaveTypeDispatcher';
import {isRef} from '../common/util';
import {
DataTableView,
@@ -77,5 +79,36 @@ export const ValueView = ({data, isExpanded}: ValueViewProps) => {
return {JSON.stringify(data.value)}
;
}
+ if (data.valueType === 'object') {
+ if (isCustomWeaveTypePayload(data.value)) {
+ // This is a little ugly, but essentially if the data is coming from an
+ // expanded ref, then we want to use that ref to get the entity and project.
+ // Else we just use the current entity and project.
+ let entityForWeaveType: string | undefined;
+ let projectForWeaveType: string | undefined;
+
+ if (valueIsExpandedRef(data)) {
+ const parsedRef = parseRef((data.value as any)._ref);
+ if (isWeaveObjectRef(parsedRef)) {
+ entityForWeaveType = parsedRef.entityName;
+ projectForWeaveType = parsedRef.projectName;
+ }
+ }
+
+ // If we have have a custom view for this weave type, use it.
+ return (
+
+ );
+ }
+ }
+
return {data.value.toString()}
;
};
+
+const valueIsExpandedRef = (data: ValueData) => {
+ return data.value != null && (data.value as any)._ref != null;
+};
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/callsTableColumns.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/callsTableColumns.tsx
index 7d8ba65cb86..c0007df46dd 100644
--- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/callsTableColumns.tsx
+++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/callsTableColumns.tsx
@@ -284,6 +284,10 @@ function buildCallsTableColumns(
const {cols: newCols, groupingModel} = buildDynamicColumns(
filteredDynamicColumnNames,
+ row => {
+ const [rowEntity, rowProject] = row.project_id.split('/');
+ return {entity: rowEntity, project: rowProject};
+ },
(row, key) => (row as any)[key],
key => expandedRefCols.has(key),
key => columnsWithRefs.has(key),
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ExampleCompareSection/exampleCompareSectionUtil.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ExampleCompareSection/exampleCompareSectionUtil.ts
index 2ab5664005e..59570f4c14e 100644
--- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ExampleCompareSection/exampleCompareSectionUtil.ts
+++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ExampleCompareSection/exampleCompareSectionUtil.ts
@@ -1,7 +1,7 @@
import _ from 'lodash';
import {useMemo} from 'react';
-import {flattenObject} from '../../../../../Browse2/browse2Util';
+import {flattenObjectPreservingWeaveTypes} from '../../../../../Browse2/browse2Util';
import {
buildCompositeMetricsMap,
CompositeScoreMetrics,
@@ -138,8 +138,10 @@ export const useFilteredAggregateRows = (state: EvaluationComparisonState) => {
evaluationCallId: predictAndScoreRes.evaluationCallId,
inputDigest: datasetRow.digest,
inputRef: predictAndScoreRes.exampleRef,
- input: flattenObject({input: datasetRow.val}),
- output: flattenObject({output}),
+ input: flattenObjectPreservingWeaveTypes({
+ input: datasetRow.val,
+ }),
+ output: flattenObjectPreservingWeaveTypes({output}),
scores: Object.fromEntries(
[...Object.entries(state.data.scoreMetrics)].map(
([scoreKey, scoreVal]) => {
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectVersionPage.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectVersionPage.tsx
index 92e51f0e57d..f8c85adeae7 100644
--- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectVersionPage.tsx
+++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectVersionPage.tsx
@@ -4,6 +4,8 @@ import React, {useMemo} from 'react';
import {maybePluralizeWord} from '../../../../../core/util/string';
import {LoadingDots} from '../../../../LoadingDots';
+import {NotFoundPanel} from '../NotFoundPanel';
+import {CustomWeaveTypeProjectContext} from '../typeViews/CustomWeaveTypeDispatcher';
import {WeaveCHTableSourceRefContext} from './CallPage/DataTableView';
import {ObjectViewerSection} from './CallPage/ObjectViewerSection';
import {WFHighLevelCallFilter} from './CallsPage/callsTableFilter';
@@ -58,7 +60,7 @@ export const ObjectVersionPage: React.FC<{
if (objectVersion.loading) {
return ;
} else if (objectVersion.result == null) {
- return Object not found
;
+ return ;
}
return (
@@ -207,7 +209,7 @@ const ObjectVersionPageInner: React.FC<{
{
label: 'Values',
content: (
-
+
) : (
-
+
+
+
)}
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectVersionsPage.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectVersionsPage.tsx
index aa6a9161a25..93af4fcacc5 100644
--- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectVersionsPage.tsx
+++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectVersionsPage.tsx
@@ -239,9 +239,16 @@ const ObjectVersionsTable: React.FC<{
});
});
- const {cols: newCols, groupingModel} =
- buildDynamicColumns(dynamicFields, (row, key) => {
- const obj: ObjectVersionSchema = (row as any).obj;
+ const {cols: newCols, groupingModel} = buildDynamicColumns<{
+ obj: ObjectVersionSchema;
+ }>(
+ dynamicFields,
+ row => ({
+ entity: row.obj.entity,
+ project: row.obj.project,
+ }),
+ (row, key) => {
+ const obj: ObjectVersionSchema = row.obj;
const res = obj.val?.[key];
if (isTableRef(res)) {
// This whole block is a hack to make the table ref clickable. This
@@ -258,7 +265,8 @@ const ObjectVersionsTable: React.FC<{
return makeRefExpandedPayload(targetRefUri, res);
}
return res;
- });
+ }
+ );
cols.push(...newCols);
groups = groupingModel;
}
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/OpVersionPage.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/OpVersionPage.tsx
index 030b8980675..8b76964845a 100644
--- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/OpVersionPage.tsx
+++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/OpVersionPage.tsx
@@ -1,6 +1,7 @@
import React, {useMemo} from 'react';
import {LoadingDots} from '../../../../LoadingDots';
+import {NotFoundPanel} from '../NotFoundPanel';
import {OpCodeViewer} from '../OpCodeViewer';
import {
CallsLink,
@@ -35,7 +36,7 @@ export const OpVersionPage: React.FC<{
if (opVersion.loading) {
return ;
} else if (opVersion.result == null) {
- return Op version not found
;
+ return ;
}
return ;
};
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/tabularListViews/columnBuilder.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/tabularListViews/columnBuilder.tsx
index 9ff6ae4a191..1e81bbd9643 100644
--- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/tabularListViews/columnBuilder.tsx
+++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/tabularListViews/columnBuilder.tsx
@@ -9,13 +9,15 @@ import React from 'react';
import {isWeaveObjectRef, parseRef} from '../../../../../../../react';
import {ErrorBoundary} from '../../../../../../ErrorBoundary';
-import {flattenObject} from '../../../../Browse2/browse2Util';
+import {flattenObjectPreservingWeaveTypes} from '../../../../Browse2/browse2Util';
import {CellValue} from '../../../../Browse2/CellValue';
import {CollapseHeader} from '../../../../Browse2/CollapseGroupHeader';
import {ExpandHeader} from '../../../../Browse2/ExpandHeader';
import {NotApplicable} from '../../../../Browse2/NotApplicable';
import {SmallRef} from '../../../../Browse2/SmallRef';
import {CellFilterWrapper} from '../../../filters/CellFilterWrapper';
+import {isCustomWeaveTypePayload} from '../../../typeViews/customWeaveType.types';
+import {CustomWeaveTypeProjectContext} from '../../../typeViews/CustomWeaveTypeDispatcher';
import {
OBJECT_ATTR_EDGE_NAME,
WEAVE_PRIVATE_PREFIX,
@@ -60,7 +62,15 @@ export function prepareFlattenedDataForTable(
): Array {
return data.map(r => {
// First, flatten the inner object
- let flattened = flattenObject(r ?? {});
+ let flattened = flattenObjectPreservingWeaveTypes(r ?? {});
+
+ // In the rare case that we have custom objects in the root (this only occurs if you directly)
+ // publish a custom object. Then we want to instead nest it under an empty key!
+ if (isCustomWeaveTypePayload(flattened)) {
+ flattened = {
+ ' ': flattened,
+ };
+ }
flattened = replaceTableRefsInFlattenedData(flattened);
@@ -182,6 +192,7 @@ const isExpandedRefWithValueAsTableRef = (
export const buildDynamicColumns = (
filteredDynamicColumnNames: string[],
+ entityProjectFromRow: (row: T) => {entity: string; project: string},
valueForKey: (row: T, key: string) => any,
columnIsExpanded?: (col: string) => boolean,
columnCanBeExpanded?: (col: string) => boolean,
@@ -269,6 +280,7 @@ export const buildDynamicColumns = (
return val;
},
renderCell: cellParams => {
+ const {entity, project} = entityProjectFromRow(cellParams.row);
const val = valueForKey(cellParams.row, key);
if (val === undefined) {
return (
@@ -287,7 +299,12 @@ export const buildDynamicColumns = (
onAddFilter={onAddFilter}
field={key}
operation={null}
- value={val}>
+ value={val}
+ style={{
+ width: '100%',
+ height: '100%',
+ alignContent: 'center',
+ }}>
{/* In the future, we may want to move this isExpandedRefWithValueAsTableRef condition
into `CellValue`. However, at the moment, `ExpandedRefWithValueAsTableRef` is a
Table-specific data structure and we might not want to leak that into the
@@ -295,7 +312,10 @@ export const buildDynamicColumns = (
{isExpandedRefWithValueAsTableRef(val) ? (
) : (
-
+
+
+
)}
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/tsDataModelHooks.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/tsDataModelHooks.ts
index 93f239b2d23..b659a35b845 100644
--- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/tsDataModelHooks.ts
+++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/tsDataModelHooks.ts
@@ -669,6 +669,13 @@ const useOpVersion = (
};
}
+ if (opVersionRes.obj == null) {
+ return {
+ loading: false,
+ result: null,
+ };
+ }
+
const returnedResult = convertTraceServerObjectVersionToOpSchema(
opVersionRes.obj
);
@@ -812,6 +819,13 @@ const useObjectVersion = (
};
}
+ if (objectVersionRes.obj == null) {
+ return {
+ loading: false,
+ result: null,
+ };
+ }
+
const returnedResult: ObjectVersionSchema =
convertTraceServerObjectVersionToSchema(objectVersionRes.obj);
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/typeViews/CustomWeaveTypeDispatcher.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/typeViews/CustomWeaveTypeDispatcher.tsx
new file mode 100644
index 00000000000..cc8389fecb1
--- /dev/null
+++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/typeViews/CustomWeaveTypeDispatcher.tsx
@@ -0,0 +1,80 @@
+import React from 'react';
+
+import {CustomWeaveTypePayload} from './customWeaveType.types';
+import {PILImageImage} from './PIL.Image.Image/PILImageImage';
+
+type CustomWeaveTypeDispatcherProps = {
+ data: CustomWeaveTypePayload;
+ // Entity and Project can be optionally provided as props, but if they are not
+ // provided, they must be provided in context. Failure to provide them will
+ // result in a console warning and a fallback to a default component.
+ //
+ // This pattern is used because in many cases we are rendering data from
+ // hierarchical data structures, and we want to avoid passing entity and project
+ // down through the tree.
+ entity?: string;
+ project?: string;
+};
+
+const customWeaveTypeRegistry: {
+ [typeId: string]: {
+ component: React.FC<{
+ entity: string;
+ project: string;
+ data: any; // I wish this could be typed more specifically
+ }>;
+ };
+} = {
+ 'PIL.Image.Image': {
+ component: PILImageImage,
+ },
+};
+
+/**
+ * This context is used to provide the entity and project to the
+ * CustomWeaveTypeDispatcher. Importantly, what this does is allows the
+ * developer to inject an entity/project context around some component tree, and
+ * then any CustomWeaveTypeDispatchers within that tree will be assumed to be
+ * within that entity/project context. This is far cleaner than passing
+ * entity/project down through the tree. We just have to remember in the future
+ * case when we support multiple entities/projects in the same tree, we will
+ * need to update this context if you end up traversing into a different
+ * entity/project. This should already be accounted for in all the current
+ * use-cases.
+ */
+export const CustomWeaveTypeProjectContext = React.createContext<{
+ entity: string;
+ project: string;
+} | null>(null);
+
+/**
+ * This is the primary entry-point for dispatching custom weave types. Currently
+ * we just have 1, but as we add more, we might want to add a more robust
+ * "registry"
+ */
+export const CustomWeaveTypeDispatcher: React.FC<
+ CustomWeaveTypeDispatcherProps
+> = ({data, entity, project}) => {
+ const projectContext = React.useContext(CustomWeaveTypeProjectContext);
+ const typeId = data.weave_type.type;
+ const comp = customWeaveTypeRegistry[typeId]?.component;
+ const defaultReturn = Custom Weave Type: {data.weave_type.type} ;
+
+ if (comp) {
+ const applicableEntity = entity || projectContext?.entity;
+ const applicableProject = project || projectContext?.project;
+ if (applicableEntity == null || applicableProject == null) {
+ console.warn(
+ 'CustomWeaveTypeDispatch: entity and project must be provided in context or as props'
+ );
+ return defaultReturn;
+ }
+ return React.createElement(comp, {
+ entity: applicableEntity,
+ project: applicableProject,
+ data,
+ });
+ }
+
+ return defaultReturn;
+};
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/typeViews/PIL.Image.Image/PILImageImage.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/typeViews/PIL.Image.Image/PILImageImage.tsx
new file mode 100644
index 00000000000..fb07477497f
--- /dev/null
+++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/typeViews/PIL.Image.Image/PILImageImage.tsx
@@ -0,0 +1,53 @@
+import React from 'react';
+
+import {LoadingDots} from '../../../../../LoadingDots';
+import {useWFHooks} from '../../pages/wfReactInterface/context';
+import {CustomWeaveTypePayload} from '../customWeaveType.types';
+
+type PILImageImageTypePayload = CustomWeaveTypePayload<
+ 'PIL.Image.Image',
+ {'image.png': string}
+>;
+
+export const isPILImageImageType = (
+ data: CustomWeaveTypePayload
+): data is PILImageImageTypePayload => {
+ return data.weave_type.type === 'PIL.Image.Image';
+};
+
+export const PILImageImage: React.FC<{
+ entity: string;
+ project: string;
+ data: PILImageImageTypePayload;
+}> = props => {
+ const {useFileContent} = useWFHooks();
+ const imageBinary = useFileContent(
+ props.entity,
+ props.project,
+ props.data.files['image.png']
+ );
+
+ if (imageBinary.loading) {
+ return ;
+ } else if (imageBinary.result == null) {
+ return ;
+ }
+
+ const arrayBuffer = imageBinary.result as any as ArrayBuffer;
+ const blob = new Blob([arrayBuffer], {type: 'image/png'});
+ const url = URL.createObjectURL(blob);
+
+ // TODO: It would be nice to have a more general image render - similar to the
+ // ValueViewImage that does things like light box, general scaling,
+ // downloading, etc..
+ return (
+
+ );
+};
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/typeViews/customWeaveType.types.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/typeViews/customWeaveType.types.ts
new file mode 100644
index 00000000000..4677e58c407
--- /dev/null
+++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/typeViews/customWeaveType.types.ts
@@ -0,0 +1,48 @@
+export type CustomWeaveTypePayload<
+ T extends string = string,
+ FP extends {[filename: string]: string} = {[filename: string]: string}
+> = {
+ _type: 'CustomWeaveType';
+ weave_type: {
+ type: T;
+ };
+ files: FP;
+ load_op?: string | CustomWeaveTypePayload<'Op', {'obj.py': string}>;
+} & {[extra: string]: any};
+
+export const isCustomWeaveTypePayload = (
+ data: any
+): data is CustomWeaveTypePayload => {
+ if (typeof data !== 'object' || data == null) {
+ return false;
+ }
+ if (data._type !== 'CustomWeaveType') {
+ return false;
+ }
+ if (
+ typeof data.weave_type !== 'object' ||
+ data.weave_type == null ||
+ typeof data.weave_type.type !== 'string'
+ ) {
+ return false;
+ }
+ if (typeof data.files !== 'object' || data.files == null) {
+ return false;
+ }
+ if (data.weave_type.type === 'Op') {
+ if (data.load_op != null) {
+ return false;
+ }
+ } else {
+ if (data.load_op == null) {
+ return false;
+ }
+ if (
+ typeof data.load_op !== 'string' &&
+ !isCustomWeaveTypePayload(data.load_op)
+ ) {
+ return false;
+ }
+ }
+ return true;
+};
diff --git a/weave/flow/model.py b/weave/flow/model.py
index dc211902ba9..0cd23eabd54 100644
--- a/weave/flow/model.py
+++ b/weave/flow/model.py
@@ -2,6 +2,11 @@
from weave.flow.obj import Object
+INFER_METHOD_NAMES = {"predict", "infer", "forward", "invoke"}
+
+
+class MissingInferenceMethodError(Exception): ...
+
class Model(Object):
"""
@@ -32,20 +37,18 @@ def predict(self, input_data: str) -> dict:
# TODO: should be infer: Callable
def get_infer_method(self) -> Callable:
- for infer_method_names in ("predict", "infer", "forward"):
- infer_method = getattr(self, infer_method_names, None)
- if infer_method:
+ for name in INFER_METHOD_NAMES:
+ if infer_method := getattr(self, name, None):
return infer_method
- raise ValueError(
- f"Model {self} does not have a predict, infer, or forward method."
+ raise MissingInferenceMethodError(
+ f"Missing a method with name in ({INFER_METHOD_NAMES})"
)
def get_infer_method(model: Model) -> Callable:
- for infer_method_names in ("predict", "infer", "forward"):
- infer_method = getattr(model, infer_method_names, None)
- if infer_method:
+ for name in INFER_METHOD_NAMES:
+ if (infer_method := getattr(model, name, None)) is not None:
return infer_method
- raise ValueError(
- f"Model {model} does not have a predict, infer, or forward method."
+ raise MissingInferenceMethodError(
+ f"Missing a method with name in ({INFER_METHOD_NAMES})"
)
diff --git a/weave/frontend/index.html b/weave/frontend/index.html
index 74808c202a6..c96a45a9ae8 100644
--- a/weave/frontend/index.html
+++ b/weave/frontend/index.html
@@ -91,7 +91,7 @@
-
+
diff --git a/weave/frontend/sha1.txt b/weave/frontend/sha1.txt
index 01a425480b3..7a517474529 100644
--- a/weave/frontend/sha1.txt
+++ b/weave/frontend/sha1.txt
@@ -1 +1 @@
-cd3b2e94bf9dc8702f53efb3844074379cd3b951
+18eebed493dc14f0fbaa3ab62505c6bdfd42ad6f
diff --git a/weave/init_message.py b/weave/init_message.py
index ffb6da21cf6..34f107139dd 100644
--- a/weave/init_message.py
+++ b/weave/init_message.py
@@ -44,10 +44,17 @@ def _print_version_check() -> None:
if use_message:
print(use_message)
- orig_module = wandb._wandb_module
- wandb._wandb_module = "weave"
- weave_messages = wandb.sdk.internal.update.check_available(weave.__version__)
- wandb._wandb_module = orig_module
+ weave_messages = None
+ if hasattr(weave, "_wandb_module"):
+ try:
+ orig_module = wandb._wandb_module # type: ignore
+ wandb._wandb_module = "weave" # type: ignore
+ weave_messages = wandb.sdk.internal.update.check_available(
+ weave.__version__
+ )
+ wandb._wandb_module = orig_module # type: ignore
+ except Exception:
+ weave_messages = None
if weave_messages:
use_message = (
diff --git a/weave/legacy/wandb_interface/project_creator.py b/weave/legacy/wandb_interface/project_creator.py
index 617fe09a167..c12fe52e9a5 100644
--- a/weave/legacy/wandb_interface/project_creator.py
+++ b/weave/legacy/wandb_interface/project_creator.py
@@ -39,12 +39,19 @@ def wandb_logging_disabled() -> typing.Iterator[None]:
wandb.termerror = original_termerror
-def ensure_project_exists(entity_name: str, project_name: str) -> None:
+def ensure_project_exists(entity_name: str, project_name: str) -> typing.Dict[str, str]:
with wandb_logging_disabled():
return _ensure_project_exists(entity_name, project_name)
-def _ensure_project_exists(entity_name: str, project_name: str) -> None:
+def _ensure_project_exists(
+ entity_name: str, project_name: str
+) -> typing.Dict[str, str]:
+ """
+ Ensures that a W&B project exists by trying to access it, returns the project_name,
+ which is not guaranteed to be the same if the provided project_name contains invalid
+ characters. Adheres to trace_server_interface.EnsureProjectExistsRes
+ """
wandb_logging_disabled()
api = InternalApi({"entity": entity_name, "project": project_name})
# Since `UpsertProject` will fail if the user does not have permission to create a project
@@ -72,4 +79,4 @@ def _ensure_project_exists(entity_name: str, project_name: str) -> None:
raise UnableToCreateProject(
f"Failed to create project {entity_name}/{project_name}"
)
- return
+ return {"project_name": project["name"]}
diff --git a/weave/tests/test_client_trace.py b/weave/tests/test_client_trace.py
index ad0ddba7f87..c319c66cc12 100644
--- a/weave/tests/test_client_trace.py
+++ b/weave/tests/test_client_trace.py
@@ -2144,6 +2144,44 @@ def calculate(a: int, b: int) -> int:
assert i == len(calls.calls)
+def test_call_query_stream_columns(client):
+ @weave.op
+ def calculate(a: int, b: int) -> int:
+ return {"result": {"a + b": a + b}, "not result": 123}
+
+ for i in range(2):
+ calculate(i, i * i)
+
+ calls = client.server.calls_query_stream(
+ tsi.CallsQueryReq(
+ project_id=client._project_id(),
+ columns=["id", "inputs"],
+ )
+ )
+ calls = list(calls)
+ assert len(calls) == 2
+ assert len(calls[0].inputs) == 2
+
+ # NO output returned because not required and not requested
+ assert calls[0].output is None
+ assert calls[0].ended_at is None
+ assert calls[0].attributes == {}
+ assert calls[0].inputs == {"a": 0, "b": 0}
+
+ # now explicitly get output
+ calls = client.server.calls_query_stream(
+ tsi.CallsQueryReq(
+ project_id=client._project_id(),
+ columns=["id", "inputs", "output.result"],
+ )
+ )
+ calls = list(calls)
+ assert len(calls) == 2
+ assert calls[0].output["result"]["a + b"] == 0
+ assert calls[0].attributes == {}
+ assert calls[0].inputs == {"a": 0, "b": 0}
+
+
@pytest.mark.skip("Not implemented: filter / sort through refs")
def test_sort_and_filter_through_refs(client):
@weave.op()
diff --git a/weave/tests/test_op.py b/weave/tests/test_op.py
index a0c2d3240a2..bfa2271e8e5 100644
--- a/weave/tests/test_op.py
+++ b/weave/tests/test_op.py
@@ -249,3 +249,28 @@ def my_op(self, a: int) -> str: # type: ignore[empty-body]
"a": types.Int(),
}
assert SomeWeaveObj.my_op.concrete_output_type == types.String()
+
+
+def test_op_internal_tracing_enabled(client):
+ # This test verifies the behavior of `_tracing_enabled` which
+ # is not a user-facing API and is used internally to toggle
+ # tracing on and off.
+ @weave.op
+ def my_op():
+ return "hello"
+
+ my_op() # <-- this call will be traced
+
+ assert len(list(my_op.calls())) == 1
+
+ my_op._tracing_enabled = False
+
+ my_op() # <-- this call will not be traced
+
+ assert len(list(my_op.calls())) == 1
+
+ my_op._tracing_enabled = True
+
+ my_op() # <-- this call will be traced
+
+ assert len(list(my_op.calls())) == 2
diff --git a/weave/tests/test_weave_client.py b/weave/tests/test_weave_client.py
index 6fc102ce400..7a065f243d2 100644
--- a/weave/tests/test_weave_client.py
+++ b/weave/tests/test_weave_client.py
@@ -1324,3 +1324,13 @@ def test_summary_tokens_cost_sqlite(client):
assert noCostCallSummary is None
assert withCostCallSummary is None
+
+
+def test_ref_in_dict(client):
+ ref = client._save_object({"a": 5}, "d1")
+
+ # Put a ref directly in a dict.
+ ref2 = client._save_object({"b": ref}, "d2")
+
+ obj = weave.ref(ref2.uri()).get()
+ assert obj["b"] == {"a": 5}
diff --git a/weave/trace/custom_objs.py b/weave/trace/custom_objs.py
index 7f8b7215c32..6a8c0022d3e 100644
--- a/weave/trace/custom_objs.py
+++ b/weave/trace/custom_objs.py
@@ -156,6 +156,9 @@ def decode_custom_obj(
raise ValueError(f"No serializer found for {weave_type}")
load_instance_op = serializer.load
+ # Disables tracing so that calls to loading data itself don't get traced
+ load_instance_op._tracing_enabled = False # type: ignore
+
art = MemTraceFilesArtifact(
encoded_path_contents,
metadata={},
diff --git a/weave/trace/op.py b/weave/trace/op.py
index b784c2f0a92..d99ac6200b4 100644
--- a/weave/trace/op.py
+++ b/weave/trace/op.py
@@ -125,6 +125,15 @@ class Op(Protocol):
__call__: Callable[..., Any]
__self__: Any
+ # `_tracing_enabled` is a runtime-only flag that can be used to disable
+ # call tracing for an op. It is not persisted as a property of the op, but is
+ # respected by the current execution context. It is an underscore property
+ # because it is not intended to be used by users directly, but rather assists
+ # with internal Weave behavior. If we find a need to expose this to users, we
+ # should consider a more user-friendly API (perhaps a setter/getter) & whether
+ # it disables child ops as well.
+ _tracing_enabled: bool
+
def _set_on_output_handler(func: Op, on_output: OnOutputHandlerType) -> None:
if func._on_output_handler is not None:
@@ -337,6 +346,8 @@ async def wrapper(*args: Any, **kwargs: Any) -> Any:
return await func(*args, **kwargs)
if weave_client_context.get_weave_client() is None:
return await func(*args, **kwargs)
+ if not wrapper._tracing_enabled: # type: ignore
+ return await func(*args, **kwargs)
call = _create_call(wrapper, *args, **kwargs) # type: ignore
res, _ = await _execute_call(wrapper, call, *args, **kwargs) # type: ignore
return res
@@ -348,6 +359,8 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
return func(*args, **kwargs)
if weave_client_context.get_weave_client() is None:
return func(*args, **kwargs)
+ if not wrapper._tracing_enabled: # type: ignore
+ return func(*args, **kwargs)
call = _create_call(wrapper, *args, **kwargs) # type: ignore
res, _ = _execute_call(wrapper, call, *args, **kwargs) # type: ignore
return res
@@ -375,6 +388,8 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
wrapper._set_on_output_handler = partial(_set_on_output_handler, wrapper) # type: ignore
wrapper._on_output_handler = None # type: ignore
+ wrapper._tracing_enabled = True # type: ignore
+
return cast(Op, wrapper)
return create_wrapper(func)
diff --git a/weave/trace/serializer.py b/weave/trace/serializer.py
index 7d5bf90fc93..334811111d8 100644
--- a/weave/trace/serializer.py
+++ b/weave/trace/serializer.py
@@ -52,7 +52,7 @@ def id(self) -> str:
# "Op" in the database.
if ser_id.endswith(".Op"):
return "Op"
- return self.target_class.__name__
+ return ser_id
SERIALIZERS = []
diff --git a/weave/trace/vals.py b/weave/trace/vals.py
index 5deb9c4bf4e..8933dcc201b 100644
--- a/weave/trace/vals.py
+++ b/weave/trace/vals.py
@@ -292,7 +292,13 @@ def _remote_iter(self) -> Generator[dict, None, None]:
for item in response.rows:
new_ref = self.ref.with_item(item.digest) if self.ref else None
- yield make_trace_obj(item.val, new_ref, self.server, self.root)
+ res = from_json(
+ item.val,
+ self.table_ref.entity + "/" + self.table_ref.project,
+ self.server,
+ )
+ res = make_trace_obj(res, new_ref, self.server, self.root)
+ yield res
if len(response.rows) < page_size:
break
diff --git a/weave/trace_server/clickhouse_schema.py b/weave/trace_server/clickhouse_schema.py
index 1cead54c4a8..f41ba84cca3 100644
--- a/weave/trace_server/clickhouse_schema.py
+++ b/weave/trace_server/clickhouse_schema.py
@@ -105,8 +105,11 @@ class SelectableCHCallSchema(BaseModel):
ended_at: typing.Optional[datetime.datetime] = None
exception: typing.Optional[str] = None
- attributes_dump: str
- inputs_dump: str
+ # attributes and inputs are required on call schema, but can be
+ # optionally selected when querying
+ attributes_dump: typing.Optional[str] = None
+ inputs_dump: typing.Optional[str] = None
+
output_dump: typing.Optional[str] = None
summary_dump: typing.Optional[str] = None
diff --git a/weave/trace_server/clickhouse_trace_server_batched.py b/weave/trace_server/clickhouse_trace_server_batched.py
index 9cab74f6895..57745ddeea6 100644
--- a/weave/trace_server/clickhouse_trace_server_batched.py
+++ b/weave/trace_server/clickhouse_trace_server_batched.py
@@ -106,10 +106,7 @@ class NotFoundError(Exception):
all_call_select_columns = list(SelectableCHCallSchema.model_fields.keys())
all_call_json_columns = ("inputs", "output", "attributes", "summary")
-
-
-# Let's just make everything required for now ... can optimize when we implement column selection
-required_call_columns = list(set(all_call_select_columns) - set([]))
+required_call_columns = ["id", "project_id", "trace_id", "op_name", "started_at"]
# Columns in the calls_merged table with special aggregation functions:
@@ -254,17 +251,20 @@ def calls_query_stream(
cq = CallsQuery(
project_id=req.project_id, include_costs=req.include_costs or False
)
-
- # TODO (Perf): By allowing a sub-selection of columns
- # we will gain increased performance by not having to
- # fetch all columns from the database. Currently all use
- # cases call for every column to be fetched, so we have not
- # implemented this yet.
columns = all_call_select_columns
+ if req.columns:
+ # Set columns to user-requested columns, w/ required columns
+ # These are all formatted by the CallsQuery, which prevents injection
+ # and other attack vectors.
+ columns = list(set(required_call_columns + req.columns))
+ # TODO: add support for json extract fields
+ # Split out any nested column requests
+ columns = [col.split(".")[0] for col in columns]
+
# We put summary_dump last so that when we compute the costs and summary its in the right place
if req.include_costs:
columns = [
- *[col for col in all_call_select_columns if col != "summary_dump"],
+ *[col for col in columns if col != "summary_dump"],
"summary_dump",
]
for col in columns:
@@ -291,9 +291,10 @@ def calls_query_stream(
pb.get_params(),
)
+ select_columns = [c.field for c in cq.select_fields]
for row in raw_res:
yield tsi.CallSchema.model_validate(
- _ch_call_dict_to_call_schema_dict(dict(zip(columns, row)))
+ _ch_call_dict_to_call_schema_dict(dict(zip(select_columns, row)))
)
def calls_delete(self, req: tsi.CallsDeleteReq) -> tsi.CallsDeleteRes:
@@ -1383,8 +1384,8 @@ def _ch_call_to_call_schema(ch_call: SelectableCHCallSchema) -> tsi.CallSchema:
op_name=ch_call.op_name,
started_at=_ensure_datetimes_have_tz(ch_call.started_at),
ended_at=_ensure_datetimes_have_tz(ch_call.ended_at),
- attributes=_dict_dump_to_dict(ch_call.attributes_dump),
- inputs=_dict_dump_to_dict(ch_call.inputs_dump),
+ attributes=_dict_dump_to_dict(ch_call.attributes_dump or "{}"),
+ inputs=_dict_dump_to_dict(ch_call.inputs_dump or "{}"),
output=_nullable_any_dump_to_any(ch_call.output_dump),
summary=_nullable_dict_dump_to_dict(ch_call.summary_dump),
exception=ch_call.exception,
@@ -1404,8 +1405,8 @@ def _ch_call_dict_to_call_schema_dict(ch_call_dict: typing.Dict) -> typing.Dict:
op_name=ch_call_dict.get("op_name"),
started_at=_ensure_datetimes_have_tz(ch_call_dict.get("started_at")),
ended_at=_ensure_datetimes_have_tz(ch_call_dict.get("ended_at")),
- attributes=_dict_dump_to_dict(ch_call_dict["attributes_dump"]),
- inputs=_dict_dump_to_dict(ch_call_dict["inputs_dump"]),
+ attributes=_dict_dump_to_dict(ch_call_dict.get("attributes_dump", "{}")),
+ inputs=_dict_dump_to_dict(ch_call_dict.get("inputs_dump", "{}")),
output=_nullable_any_dump_to_any(ch_call_dict.get("output_dump")),
summary=_nullable_dict_dump_to_dict(ch_call_dict.get("summary_dump")),
exception=ch_call_dict.get("exception"),
diff --git a/weave/trace_server/external_to_internal_trace_server_adapter.py b/weave/trace_server/external_to_internal_trace_server_adapter.py
index c75929f2f6f..372aa48d1b3 100644
--- a/weave/trace_server/external_to_internal_trace_server_adapter.py
+++ b/weave/trace_server/external_to_internal_trace_server_adapter.py
@@ -94,8 +94,10 @@ def cached_int_to_ext_project_id(project_id: str) -> typing.Optional[str]:
yield universal_int_to_ext_ref_converter(item, cached_int_to_ext_project_id)
# Standard API Below:
- def ensure_project_exists(self, entity: str, project: str) -> None:
- self._internal_trace_server.ensure_project_exists(entity, project)
+ def ensure_project_exists(
+ self, entity: str, project: str
+ ) -> tsi.EnsureProjectExistsRes:
+ return self._internal_trace_server.ensure_project_exists(entity, project)
def call_start(self, req: tsi.CallStartReq) -> tsi.CallStartRes:
req.start.project_id = self._idc.ext_to_int_project_id(req.start.project_id)
diff --git a/weave/trace_server/remote_http_trace_server.py b/weave/trace_server/remote_http_trace_server.py
index 45617f1a0c6..145d2f50906 100644
--- a/weave/trace_server/remote_http_trace_server.py
+++ b/weave/trace_server/remote_http_trace_server.py
@@ -108,10 +108,14 @@ def __init__(
self._auth: t.Optional[t.Tuple[str, str]] = None
self.remote_request_bytes_limit = remote_request_bytes_limit
- def ensure_project_exists(self, entity: str, project: str) -> None:
+ def ensure_project_exists(
+ self, entity: str, project: str
+ ) -> tsi.EnsureProjectExistsRes:
# TODO: This should happen in the wandb backend, not here, and it's slow
# (hundreds of ms)
- project_creator.ensure_project_exists(entity, project)
+ return tsi.EnsureProjectExistsRes.model_validate(
+ project_creator.ensure_project_exists(entity, project)
+ )
@classmethod
def from_env(cls, should_batch: bool = False) -> "RemoteHTTPTraceServer":
diff --git a/weave/trace_server/sqlite_trace_server.py b/weave/trace_server/sqlite_trace_server.py
index 2402067b062..82a512bcfc7 100644
--- a/weave/trace_server/sqlite_trace_server.py
+++ b/weave/trace_server/sqlite_trace_server.py
@@ -379,7 +379,17 @@ def process_operand(operand: tsi_query.Operand) -> str:
conds.append(filter_cond)
- query = f"SELECT * FROM calls WHERE deleted_at IS NULL AND project_id = '{req.project_id}'"
+ required_columns = ["id", "trace_id", "project_id", "op_name", "started_at"]
+ select_columns = list(tsi.CallSchema.model_fields.keys())
+ if req.columns:
+ # TODO(gst): allow json fields to be selected
+ simple_columns = [x.split(".")[0] for x in req.columns]
+ select_columns = [x for x in simple_columns if x in select_columns]
+ # add required columns, preserving requested column order
+ select_columns += [
+ rcol for rcol in required_columns if rcol not in select_columns
+ ]
+ query = f"SELECT {', '.join(select_columns)} FROM calls WHERE deleted_at IS NULL AND project_id = '{req.project_id}'"
conditions_part = " AND ".join(conds)
@@ -431,29 +441,29 @@ def process_operand(operand: tsi_query.Operand) -> str:
cursor.execute(query)
query_result = cursor.fetchall()
- return tsi.CallsQueryRes(
- calls=[
- tsi.CallSchema(
- project_id=row[0],
- id=row[1],
- trace_id=row[2],
- parent_id=row[3],
- op_name=row[4],
- started_at=row[5],
- ended_at=row[6],
- exception=row[7],
- attributes=json.loads(row[8]),
- inputs=json.loads(row[9]),
- output=None if row[11] is None else json.loads(row[11]),
- output_refs=None if row[12] is None else json.loads(row[12]),
- summary=json.loads(row[13]) if row[13] else None,
- wb_user_id=row[14],
- wb_run_id=row[15],
- display_name=row[17] if row[17] != "" else None,
- )
- for row in query_result
- ]
- )
+ calls = []
+ for row in query_result:
+ call_dict = {k: v for k, v in zip(select_columns, row)}
+ # convert json dump fields into json
+ for json_field in ["attributes", "summary", "inputs", "output"]:
+ if call_dict.get(json_field):
+ call_dict[json_field] = json.loads(call_dict[json_field])
+ # convert empty string display_names to None
+ if "display_name" in call_dict and call_dict["display_name"] == "":
+ call_dict["display_name"] = None
+ # fill in missing required fields with defaults
+ for col, mfield in tsi.CallSchema.model_fields.items():
+ if mfield.is_required() and col not in call_dict:
+ if isinstance(mfield.annotation, str):
+ call_dict[col] = ""
+ elif isinstance(
+ mfield.annotation, (datetime.datetime, datetime.date)
+ ):
+ raise ValueError(f"Field '{col}' is required for selection")
+ else:
+ call_dict[col] = {}
+ calls.append(tsi.CallSchema(**call_dict))
+ return tsi.CallsQueryRes(calls=calls)
def calls_query_stream(self, req: tsi.CallsQueryReq) -> Iterator[tsi.CallSchema]:
return iter(self.calls_query(req).calls)
diff --git a/weave/trace_server/trace_server_interface.py b/weave/trace_server/trace_server_interface.py
index 64df67db54d..a4ad1625c9b 100644
--- a/weave/trace_server/trace_server_interface.py
+++ b/weave/trace_server/trace_server_interface.py
@@ -1,6 +1,5 @@
-import abc
import datetime
-import typing
+from typing import Any, Dict, Iterator, List, Literal, Optional, Protocol, Union
from pydantic import BaseModel, ConfigDict, Field, field_serializer
from typing_extensions import TypedDict
@@ -21,42 +20,42 @@ class ExtraKeysTypedDict(TypedDict):
class LLMUsageSchema(TypedDict, total=False):
- prompt_tokens: typing.Optional[int]
- input_tokens: typing.Optional[int]
- completion_tokens: typing.Optional[int]
- output_tokens: typing.Optional[int]
- requests: typing.Optional[int]
- total_tokens: typing.Optional[int]
+ prompt_tokens: Optional[int]
+ input_tokens: Optional[int]
+ completion_tokens: Optional[int]
+ output_tokens: Optional[int]
+ requests: Optional[int]
+ total_tokens: Optional[int]
class LLMCostSchema(LLMUsageSchema):
- prompt_tokens_cost: typing.Optional[float]
- completion_tokens_cost: typing.Optional[float]
- prompt_token_cost: typing.Optional[float]
- completion_token_cost: typing.Optional[float]
- prompt_token_cost_unit: typing.Optional[str]
- completion_token_cost_unit: typing.Optional[str]
- effective_date: typing.Optional[str]
- provider_id: typing.Optional[str]
- pricing_level: typing.Optional[str]
- pricing_level_id: typing.Optional[str]
- created_at: typing.Optional[str]
- created_by: typing.Optional[str]
+ prompt_tokens_cost: Optional[float]
+ completion_tokens_cost: Optional[float]
+ prompt_token_cost: Optional[float]
+ completion_token_cost: Optional[float]
+ prompt_token_cost_unit: Optional[str]
+ completion_token_cost_unit: Optional[str]
+ effective_date: Optional[str]
+ provider_id: Optional[str]
+ pricing_level: Optional[str]
+ pricing_level_id: Optional[str]
+ created_at: Optional[str]
+ created_by: Optional[str]
class WeaveSummarySchema(ExtraKeysTypedDict, total=False):
- status: typing.Optional[typing.Literal["success", "error", "running"]]
- nice_trace_name: typing.Optional[str]
- latency: typing.Optional[int]
- costs: typing.Optional[typing.Dict[str, LLMCostSchema]]
+ status: Optional[Literal["success", "error", "running"]]
+ nice_trace_name: Optional[str]
+ latency: Optional[int]
+ costs: Optional[Dict[str, LLMCostSchema]]
class SummaryInsertMap(ExtraKeysTypedDict, total=False):
- usage: typing.Dict[str, LLMUsageSchema]
+ usage: Dict[str, LLMUsageSchema]
class SummaryMap(SummaryInsertMap, total=False):
- weave: typing.Optional[WeaveSummarySchema]
+ weave: Optional[WeaveSummarySchema]
class CallSchema(BaseModel):
@@ -66,43 +65,41 @@ class CallSchema(BaseModel):
# Name of the calling function (op)
op_name: str
# Optional display name of the call
- display_name: typing.Optional[str] = None
+ display_name: Optional[str] = None
- ## Trace ID
+ # Trace ID
trace_id: str
- ## Parent ID is optional because the call may be a root
- parent_id: typing.Optional[str] = None
+ # Parent ID is optional because the call may be a root
+ parent_id: Optional[str] = None
- ## Start time is required
+ # Start time is required
started_at: datetime.datetime
- ## Attributes: properties of the call
- attributes: typing.Dict[str, typing.Any]
+ # Attributes: properties of the call
+ attributes: Dict[str, Any]
- ## Inputs
- inputs: typing.Dict[str, typing.Any]
+ # Inputs
+ inputs: Dict[str, Any]
- ## End time is required if finished
- ended_at: typing.Optional[datetime.datetime] = None
+ # End time is required if finished
+ ended_at: Optional[datetime.datetime] = None
- ## Exception is present if the call failed
- exception: typing.Optional[str] = None
+ # Exception is present if the call failed
+ exception: Optional[str] = None
- ## Outputs
- output: typing.Optional[typing.Any] = None
+ # Outputs
+ output: Optional[Any] = None
- ## Summary: a summary of the call
- summary: typing.Optional[SummaryMap] = None
+ # Summary: a summary of the call
+ summary: Optional[SummaryMap] = None
# WB Metadata
- wb_user_id: typing.Optional[str] = None
- wb_run_id: typing.Optional[str] = None
+ wb_user_id: Optional[str] = None
+ wb_run_id: Optional[str] = None
- deleted_at: typing.Optional[datetime.datetime] = None
+ deleted_at: Optional[datetime.datetime] = None
@field_serializer("attributes", "summary", when_used="unless-none")
- def serialize_typed_dicts(
- self, v: typing.Dict[str, typing.Any]
- ) -> typing.Dict[str, typing.Any]:
+ def serialize_typed_dicts(self, v: Dict[str, Any]) -> Dict[str, Any]:
return dict(v)
@@ -111,51 +108,49 @@ def serialize_typed_dicts(
# - trace_id is not required (will be generated)
class StartedCallSchemaForInsert(BaseModel):
project_id: str
- id: typing.Optional[str] = None # Will be generated if not provided
+ id: Optional[str] = None # Will be generated if not provided
# Name of the calling function (op)
op_name: str
# Optional display name of the call
- display_name: typing.Optional[str] = None
+ display_name: Optional[str] = None
- ## Trace ID
- trace_id: typing.Optional[str] = None # Will be generated if not provided
- ## Parent ID is optional because the call may be a root
- parent_id: typing.Optional[str] = None
+ # Trace ID
+ trace_id: Optional[str] = None # Will be generated if not provided
+ # Parent ID is optional because the call may be a root
+ parent_id: Optional[str] = None
- ## Start time is required
+ # Start time is required
started_at: datetime.datetime
- ## Attributes: properties of the call
- attributes: typing.Dict[str, typing.Any]
+ # Attributes: properties of the call
+ attributes: Dict[str, Any]
- ## Inputs
- inputs: typing.Dict[str, typing.Any]
+ # Inputs
+ inputs: Dict[str, Any]
# WB Metadata
- wb_user_id: typing.Optional[str] = Field(None, description=WB_USER_ID_DESCRIPTION)
- wb_run_id: typing.Optional[str] = None
+ wb_user_id: Optional[str] = Field(None, description=WB_USER_ID_DESCRIPTION)
+ wb_run_id: Optional[str] = None
class EndedCallSchemaForInsert(BaseModel):
project_id: str
id: str
- ## End time is required
+ # End time is required
ended_at: datetime.datetime
- ## Exception is present if the call failed
- exception: typing.Optional[str] = None
+ # Exception is present if the call failed
+ exception: Optional[str] = None
- ## Outputs
- output: typing.Optional[typing.Any] = None
+ # Outputs
+ output: Optional[Any] = None
- ## Summary: a summary of the call
+ # Summary: a summary of the call
summary: SummaryInsertMap
@field_serializer("summary")
- def serialize_typed_dicts(
- self, v: typing.Dict[str, typing.Any]
- ) -> typing.Dict[str, typing.Any]:
+ def serialize_typed_dicts(self, v: Dict[str, Any]) -> Dict[str, Any]:
return dict(v)
@@ -163,24 +158,24 @@ class ObjSchema(BaseModel):
project_id: str
object_id: str
created_at: datetime.datetime
- deleted_at: typing.Optional[datetime.datetime] = None
+ deleted_at: Optional[datetime.datetime] = None
digest: str
version_index: int
is_latest: int
kind: str
- base_object_class: typing.Optional[str]
- val: typing.Any
+ base_object_class: Optional[str]
+ val: Any
class ObjSchemaForInsert(BaseModel):
project_id: str
object_id: str
- val: typing.Any
+ val: Any
class TableSchemaForInsert(BaseModel):
project_id: str
- rows: list[dict[str, typing.Any]]
+ rows: list[dict[str, Any]]
class CallStartReq(BaseModel):
@@ -203,19 +198,19 @@ class CallEndRes(BaseModel):
class CallReadReq(BaseModel):
project_id: str
id: str
- include_costs: typing.Optional[bool] = False
+ include_costs: Optional[bool] = False
class CallReadRes(BaseModel):
- call: typing.Optional[CallSchema]
+ call: Optional[CallSchema]
class CallsDeleteReq(BaseModel):
project_id: str
- call_ids: typing.List[str]
+ call_ids: List[str]
# wb_user_id is automatically populated by the server
- wb_user_id: typing.Optional[str] = Field(None, description=WB_USER_ID_DESCRIPTION)
+ wb_user_id: Optional[str] = Field(None, description=WB_USER_ID_DESCRIPTION)
class CallsDeleteRes(BaseModel):
@@ -223,15 +218,15 @@ class CallsDeleteRes(BaseModel):
class CallsFilter(BaseModel):
- op_names: typing.Optional[typing.List[str]] = None
- input_refs: typing.Optional[typing.List[str]] = None
- output_refs: typing.Optional[typing.List[str]] = None
- parent_ids: typing.Optional[typing.List[str]] = None
- trace_ids: typing.Optional[typing.List[str]] = None
- call_ids: typing.Optional[typing.List[str]] = None
- trace_roots_only: typing.Optional[bool] = None
- wb_user_ids: typing.Optional[typing.List[str]] = None
- wb_run_ids: typing.Optional[typing.List[str]] = None
+ op_names: Optional[List[str]] = None
+ input_refs: Optional[List[str]] = None
+ output_refs: Optional[List[str]] = None
+ parent_ids: Optional[List[str]] = None
+ trace_ids: Optional[List[str]] = None
+ call_ids: Optional[List[str]] = None
+ trace_roots_only: Optional[bool] = None
+ wb_user_ids: Optional[List[str]] = None
+ wb_run_ids: Optional[List[str]] = None
class SortBy(BaseModel):
@@ -240,32 +235,32 @@ class SortBy(BaseModel):
# dot-separated.
field: str # Consider changing this to _FieldSelect
# Direction should be either 'asc' or 'desc'
- direction: typing.Literal["asc", "desc"]
+ direction: Literal["asc", "desc"]
class CallsQueryReq(BaseModel):
project_id: str
- filter: typing.Optional[CallsFilter] = None
- limit: typing.Optional[int] = None
- offset: typing.Optional[int] = None
+ filter: Optional[CallsFilter] = None
+ limit: Optional[int] = None
+ offset: Optional[int] = None
# Sort by multiple fields
- sort_by: typing.Optional[typing.List[SortBy]] = None
- query: typing.Optional[Query] = None
- include_costs: typing.Optional[bool] = False
+ sort_by: Optional[List[SortBy]] = None
+ query: Optional[Query] = None
+ include_costs: Optional[bool] = False
# TODO: type this with call schema columns, following the same rules as
# SortBy and thus GetFieldOperator.get_field_ (without direction)
- columns: typing.Optional[typing.List[str]] = None
+ columns: Optional[List[str]] = None
class CallsQueryRes(BaseModel):
- calls: typing.List[CallSchema]
+ calls: List[CallSchema]
class CallsQueryStatsReq(BaseModel):
project_id: str
- filter: typing.Optional[CallsFilter] = None
- query: typing.Optional[Query] = None
+ filter: Optional[CallsFilter] = None
+ query: Optional[Query] = None
class CallsQueryStatsRes(BaseModel):
@@ -278,10 +273,10 @@ class CallUpdateReq(BaseModel):
call_id: str
# optional update fields
- display_name: typing.Optional[str] = None
+ display_name: Optional[str] = None
# wb_user_id is automatically populated by the server
- wb_user_id: typing.Optional[str] = Field(None, description=WB_USER_ID_DESCRIPTION)
+ wb_user_id: Optional[str] = Field(None, description=WB_USER_ID_DESCRIPTION)
class CallUpdateRes(BaseModel):
@@ -307,17 +302,17 @@ class OpReadRes(BaseModel):
class OpVersionFilter(BaseModel):
- op_names: typing.Optional[typing.List[str]] = None
- latest_only: typing.Optional[bool] = None
+ op_names: Optional[List[str]] = None
+ latest_only: Optional[bool] = None
class OpQueryReq(BaseModel):
project_id: str
- filter: typing.Optional[OpVersionFilter] = None
+ filter: Optional[OpVersionFilter] = None
class OpQueryRes(BaseModel):
- op_objs: typing.List[ObjSchema]
+ op_objs: List[ObjSchema]
class ObjCreateReq(BaseModel):
@@ -339,19 +334,19 @@ class ObjReadRes(BaseModel):
class ObjectVersionFilter(BaseModel):
- base_object_classes: typing.Optional[typing.List[str]] = None
- object_ids: typing.Optional[typing.List[str]] = None
- is_op: typing.Optional[bool] = None
- latest_only: typing.Optional[bool] = None
+ base_object_classes: Optional[List[str]] = None
+ object_ids: Optional[List[str]] = None
+ is_op: Optional[bool] = None
+ latest_only: Optional[bool] = None
class ObjQueryReq(BaseModel):
project_id: str
- filter: typing.Optional[ObjectVersionFilter] = None
+ filter: Optional[ObjectVersionFilter] = None
class ObjQueryRes(BaseModel):
- objs: typing.List[ObjSchema]
+ objs: List[ObjSchema]
class TableCreateReq(BaseModel):
@@ -410,7 +405,7 @@ class Table[OPERATION]Spec(BaseModel):
class TableAppendSpecPayload(BaseModel):
- row: dict[str, typing.Any]
+ row: dict[str, Any]
class TableAppendSpec(BaseModel):
@@ -427,14 +422,14 @@ class TablePopSpec(BaseModel):
class TableInsertSpecPayload(BaseModel):
index: int
- row: dict[str, typing.Any]
+ row: dict[str, Any]
class TableInsertSpec(BaseModel):
insert: TableInsertSpecPayload
-TableUpdateSpec = typing.Union[TableAppendSpec, TablePopSpec, TableInsertSpec]
+TableUpdateSpec = Union[TableAppendSpec, TablePopSpec, TableInsertSpec]
class TableUpdateReq(BaseModel):
@@ -449,7 +444,7 @@ class TableUpdateRes(BaseModel):
class TableRowSchema(BaseModel):
digest: str
- val: typing.Any
+ val: Any
class TableCreateRes(BaseModel):
@@ -457,27 +452,27 @@ class TableCreateRes(BaseModel):
class TableRowFilter(BaseModel):
- row_digests: typing.Optional[typing.List[str]] = None
+ row_digests: Optional[List[str]] = None
class TableQueryReq(BaseModel):
project_id: str
digest: str
- filter: typing.Optional[TableRowFilter] = None
- limit: typing.Optional[int] = None
- offset: typing.Optional[int] = None
+ filter: Optional[TableRowFilter] = None
+ limit: Optional[int] = None
+ offset: Optional[int] = None
class TableQueryRes(BaseModel):
- rows: typing.List[TableRowSchema]
+ rows: List[TableRowSchema]
class RefsReadBatchReq(BaseModel):
- refs: typing.List[str]
+ refs: List[str]
class RefsReadBatchRes(BaseModel):
- vals: typing.List[typing.Any]
+ vals: List[Any]
class FeedbackPayloadReactionReq(BaseModel):
@@ -491,9 +486,9 @@ class FeedbackPayloadNoteReq(BaseModel):
class FeedbackCreateReq(BaseModel):
project_id: str = Field(examples=["entity/project"])
weave_ref: str = Field(examples=["weave:///entity/project/object/name:digest"])
- creator: typing.Optional[str] = Field(default=None, examples=["Jane Smith"])
+ creator: Optional[str] = Field(default=None, examples=["Jane Smith"])
feedback_type: str = Field(examples=["custom"])
- payload: typing.Dict[str, typing.Any] = Field(
+ payload: Dict[str, Any] = Field(
examples=[
{
"key": "value",
@@ -502,7 +497,7 @@ class FeedbackCreateReq(BaseModel):
)
# wb_user_id is automatically populated by the server
- wb_user_id: typing.Optional[str] = Field(None, description=WB_USER_ID_DESCRIPTION)
+ wb_user_id: Optional[str] = Field(None, description=WB_USER_ID_DESCRIPTION)
# The response provides the additional fields needed to convert a request
@@ -511,7 +506,7 @@ class FeedbackCreateRes(BaseModel):
id: str
created_at: datetime.datetime
wb_user_id: str
- payload: typing.Dict[str, typing.Any] # If not empty, replace payload
+ payload: Dict[str, Any] # If not empty, replace payload
class Feedback(FeedbackCreateReq):
@@ -521,20 +516,20 @@ class Feedback(FeedbackCreateReq):
class FeedbackQueryReq(BaseModel):
project_id: str = Field(examples=["entity/project"])
- fields: typing.Optional[list[str]] = Field(
+ fields: Optional[list[str]] = Field(
default=None, examples=[["id", "feedback_type", "payload.note"]]
)
- query: typing.Optional[Query] = None
+ query: Optional[Query] = None
# TODO: I think I would prefer to call this order_by to match SQL, but this is what calls API uses
# TODO: Might be nice to have shortcut for single field and implied ASC direction
- sort_by: typing.Optional[typing.List[SortBy]] = None
- limit: typing.Optional[int] = Field(default=None, examples=[10])
- offset: typing.Optional[int] = Field(default=None, examples=[0])
+ sort_by: Optional[List[SortBy]] = None
+ limit: Optional[int] = Field(default=None, examples=[10])
+ offset: Optional[int] = Field(default=None, examples=[0])
class FeedbackQueryRes(BaseModel):
# Note: this is not a list of Feedback because user can request any fields.
- result: list[dict[str, typing.Any]]
+ result: list[dict[str, Any]]
class FeedbackPurgeReq(BaseModel):
@@ -565,114 +560,49 @@ class FileContentReadRes(BaseModel):
content: bytes
-class TraceServerInterface:
- def ensure_project_exists(self, entity: str, project: str) -> None:
- pass
+class EnsureProjectExistsRes(BaseModel):
+ project_name: str
- # Call API
- @abc.abstractmethod
- def call_start(self, req: CallStartReq) -> CallStartRes:
- raise NotImplementedError()
-
- @abc.abstractmethod
- def call_end(self, req: CallEndReq) -> CallEndRes:
- raise NotImplementedError()
-
- @abc.abstractmethod
- def call_read(self, req: CallReadReq) -> CallReadRes:
- raise NotImplementedError()
-
- @abc.abstractmethod
- def calls_query(self, req: CallsQueryReq) -> CallsQueryRes:
- raise NotImplementedError()
-
- @abc.abstractmethod
- def calls_query_stream(self, req: CallsQueryReq) -> typing.Iterator[CallSchema]:
- raise NotImplementedError()
- @abc.abstractmethod
- def calls_delete(self, req: CallsDeleteReq) -> CallsDeleteRes:
- raise NotImplementedError()
+class TraceServerInterface(Protocol):
+ def ensure_project_exists(
+ self, entity: str, project: str
+ ) -> EnsureProjectExistsRes:
+ return EnsureProjectExistsRes(project_name=project)
- @abc.abstractmethod
- def calls_query_stats(self, req: CallsQueryStatsReq) -> CallsQueryStatsRes:
- raise NotImplementedError()
-
- @abc.abstractmethod
- def call_update(self, req: CallUpdateReq) -> CallUpdateRes:
- raise NotImplementedError()
+ # Call API
+ def call_start(self, req: CallStartReq) -> CallStartRes: ...
+ def call_end(self, req: CallEndReq) -> CallEndRes: ...
+ def call_read(self, req: CallReadReq) -> CallReadRes: ...
+ def calls_query(self, req: CallsQueryReq) -> CallsQueryRes: ...
+ def calls_query_stream(self, req: CallsQueryReq) -> Iterator[CallSchema]: ...
+ def calls_delete(self, req: CallsDeleteReq) -> CallsDeleteRes: ...
+ def calls_query_stats(self, req: CallsQueryStatsReq) -> CallsQueryStatsRes: ...
+ def call_update(self, req: CallUpdateReq) -> CallUpdateRes: ...
# Op API
- @abc.abstractmethod
- def op_create(self, req: OpCreateReq) -> OpCreateRes:
- raise NotImplementedError()
-
- @abc.abstractmethod
- def op_read(self, req: OpReadReq) -> OpReadRes:
- raise NotImplementedError()
-
- @abc.abstractmethod
- def ops_query(self, req: OpQueryReq) -> OpQueryRes:
- raise NotImplementedError()
+ def op_create(self, req: OpCreateReq) -> OpCreateRes: ...
+ def op_read(self, req: OpReadReq) -> OpReadRes: ...
+ def ops_query(self, req: OpQueryReq) -> OpQueryRes: ...
# Obj API
- @abc.abstractmethod
- def obj_create(self, req: ObjCreateReq) -> ObjCreateRes:
- raise NotImplementedError()
-
- @abc.abstractmethod
- def obj_read(self, req: ObjReadReq) -> ObjReadRes:
- raise NotImplementedError()
-
- @abc.abstractmethod
- def objs_query(self, req: ObjQueryReq) -> ObjQueryRes:
- raise NotImplementedError()
-
- @abc.abstractmethod
- def table_create(self, req: TableCreateReq) -> TableCreateRes:
- raise NotImplementedError()
-
- @abc.abstractmethod
- def table_update(self, req: TableUpdateReq) -> TableUpdateRes:
- raise NotImplementedError()
-
- @abc.abstractmethod
- def table_query(self, req: TableQueryReq) -> TableQueryRes:
- raise NotImplementedError()
-
- @abc.abstractmethod
- def refs_read_batch(self, req: RefsReadBatchReq) -> RefsReadBatchRes:
- raise NotImplementedError()
-
- @abc.abstractmethod
- def file_create(self, req: FileCreateReq) -> FileCreateRes:
- raise NotImplementedError()
-
- @abc.abstractmethod
- def file_content_read(self, req: FileContentReadReq) -> FileContentReadRes:
- raise NotImplementedError()
-
- @abc.abstractmethod
- def feedback_create(self, req: FeedbackCreateReq) -> FeedbackCreateRes:
- raise NotImplementedError()
-
- @abc.abstractmethod
- def feedback_query(self, req: FeedbackQueryReq) -> FeedbackQueryRes:
- raise NotImplementedError()
-
- @abc.abstractmethod
- def feedback_purge(self, req: FeedbackPurgeReq) -> FeedbackPurgeRes:
- raise NotImplementedError()
+ def obj_create(self, req: ObjCreateReq) -> ObjCreateRes: ...
+ def obj_read(self, req: ObjReadReq) -> ObjReadRes: ...
+ def objs_query(self, req: ObjQueryReq) -> ObjQueryRes: ...
+ def table_create(self, req: TableCreateReq) -> TableCreateRes: ...
+ def table_update(self, req: TableUpdateReq) -> TableUpdateRes: ...
+ def table_query(self, req: TableQueryReq) -> TableQueryRes: ...
+ def refs_read_batch(self, req: RefsReadBatchReq) -> RefsReadBatchRes: ...
+ def file_create(self, req: FileCreateReq) -> FileCreateRes: ...
+ def file_content_read(self, req: FileContentReadReq) -> FileContentReadRes: ...
+ def feedback_create(self, req: FeedbackCreateReq) -> FeedbackCreateRes: ...
+ def feedback_query(self, req: FeedbackQueryReq) -> FeedbackQueryRes: ...
+ def feedback_purge(self, req: FeedbackPurgeReq) -> FeedbackPurgeRes: ...
# These symbols are used in the WB Trace Server and it is not safe
# to remove them, else it will break the server. Once the server
# is updated to use the new symbols, these can be removed.
-#
-# Remove once https://github.com/wandb/core/pull/22040 lands
-CallsDeleteReqForInsert = CallsDeleteReq
-CallUpdateReqForInsert = CallUpdateReq
-FeedbackCreateReqForInsert = FeedbackCreateReq
# Legacy Names (i think these might be used in a few growth examples, so keeping
# around until we clean those up of them)
diff --git a/weave/type_serializers/Image/__init__.py b/weave/type_serializers/Image/__init__.py
new file mode 100644
index 00000000000..e69de29bb2d
diff --git a/weave/type_serializers/Image/image.py b/weave/type_serializers/Image/image.py
new file mode 100644
index 00000000000..c9b9402a4de
--- /dev/null
+++ b/weave/type_serializers/Image/image.py
@@ -0,0 +1,44 @@
+"""Defines the custom Image weave type."""
+
+from weave.trace import serializer
+from weave.trace.custom_objs import MemTraceFilesArtifact
+
+dependencies_met = False
+
+try:
+ from PIL import Image
+
+ dependencies_met = True
+except ImportError:
+ pass
+
+
+def save(obj: "Image.Image", artifact: MemTraceFilesArtifact, name: str) -> None:
+ # Note: I am purposely ignoring the `name` here and hard-coding the filename to "image.png".
+ # There is an extensive internal discussion here:
+ # https://weightsandbiases.slack.com/archives/C03BSTEBD7F/p1723670081582949
+ #
+ # In summary, there is an outstanding design decision to be made about how to handle the
+ # `name` parameter. One school of thought is that using the `name` parameter allows multiple
+ # object to use the same artifact more cleanly. However, another school of thought is that
+ # the payload should not be dependent on an external name - resulting in different payloads
+ # for the same logical object.
+ #
+ # Using `image.png` is fine for now since we don't have any cases of multiple objects
+ # using the same artifact. Moreover, since we package the deserialization logic with the
+ # object payload, we can always change the serialization logic later without breaking
+ # existing payloads.
+ with artifact.new_file("image.png", binary=True) as f:
+ obj.save(f, format="png") # type: ignore
+
+
+def load(artifact: MemTraceFilesArtifact, name: str) -> "Image.Image":
+ # Note: I am purposely ignoring the `name` here and hard-coding the filename. See comment
+ # on save.
+ path = artifact.path("image.png")
+ return Image.open(path)
+
+
+def register() -> None:
+ if dependencies_met:
+ serializer.register_serializer(Image.Image, save, load)
diff --git a/weave/type_serializers/Image/image_test.py b/weave/type_serializers/Image/image_test.py
new file mode 100644
index 00000000000..cf42c07c4d6
--- /dev/null
+++ b/weave/type_serializers/Image/image_test.py
@@ -0,0 +1,102 @@
+from PIL import Image
+
+import weave
+from weave.weave_client import WeaveClient, get_ref
+
+"""When testing types, it is important to test:
+Objects:
+1. Publishing Directly
+2. Publishing as a property
+3. Using as a cell in a table
+
+Calls:
+4. Using as inputs, output, and output component (raw)
+5. Using as inputs, output, and output component (refs)
+
+"""
+
+
+def test_image_publish(client: WeaveClient) -> None:
+ img = Image.new("RGB", (512, 512), "purple")
+ weave.publish(img)
+
+ ref = get_ref(img)
+
+ assert ref is not None
+ gotten_img = weave.ref(ref.uri()).get()
+ assert img.tobytes() == gotten_img.tobytes()
+
+
+class ImageWrapper(weave.Object):
+ img: Image.Image
+
+
+def test_image_as_property(client: WeaveClient) -> None:
+ img = Image.new("RGB", (512, 512), "purple")
+ img_wrapper = ImageWrapper(img=img)
+ assert img_wrapper.img == img
+
+ weave.publish(img_wrapper)
+
+ ref = get_ref(img_wrapper)
+ assert ref is not None
+
+ gotten_img_wrapper = weave.ref(ref.uri()).get()
+ assert gotten_img_wrapper.img.tobytes() == img.tobytes()
+
+
+def test_image_as_dataset_cell(client: WeaveClient) -> None:
+ img = Image.new("RGB", (512, 512), "purple")
+ dataset = weave.Dataset(rows=[{"img": img}])
+ assert dataset.rows[0]["img"] == img
+
+ weave.publish(dataset)
+
+ ref = get_ref(dataset)
+ assert ref is not None
+
+ gotten_dataset = weave.ref(ref.uri()).get()
+ assert gotten_dataset.rows[0]["img"].tobytes() == img.tobytes()
+
+
+@weave.op
+def image_as_solo_output(publish_first: bool) -> Image.Image:
+ img = Image.new("RGB", (512, 512), "purple")
+ if publish_first:
+ weave.publish(img)
+ return img
+
+
+@weave.op
+def image_as_input_and_output_part(in_img: Image.Image) -> dict:
+ return {"out_img": in_img}
+
+
+def test_image_as_call_io(client: WeaveClient) -> None:
+ non_published_img = image_as_solo_output(publish_first=False)
+ img_dict = image_as_input_and_output_part(non_published_img)
+
+ exp_bytes = non_published_img.tobytes()
+ assert img_dict["out_img"].tobytes() == exp_bytes
+
+ image_as_solo_output_call = image_as_solo_output.calls()[0]
+ image_as_input_and_output_part_call = image_as_input_and_output_part.calls()[0]
+
+ assert image_as_solo_output_call.output.tobytes() == exp_bytes
+ assert image_as_input_and_output_part_call.inputs["in_img"].tobytes() == exp_bytes
+ assert image_as_input_and_output_part_call.output["out_img"].tobytes() == exp_bytes
+
+
+def test_image_as_call_io_refs(client: WeaveClient) -> None:
+ non_published_img = image_as_solo_output(publish_first=True)
+ img_dict = image_as_input_and_output_part(non_published_img)
+
+ exp_bytes = non_published_img.tobytes()
+ assert img_dict["out_img"].tobytes() == exp_bytes
+
+ image_as_solo_output_call = image_as_solo_output.calls()[0]
+ image_as_input_and_output_part_call = image_as_input_and_output_part.calls()[0]
+
+ assert image_as_solo_output_call.output.tobytes() == exp_bytes
+ assert image_as_input_and_output_part_call.inputs["in_img"].tobytes() == exp_bytes
+ assert image_as_input_and_output_part_call.output["out_img"].tobytes() == exp_bytes
diff --git a/weave/type_serializers/__init__.py b/weave/type_serializers/__init__.py
new file mode 100644
index 00000000000..396af8f791e
--- /dev/null
+++ b/weave/type_serializers/__init__.py
@@ -0,0 +1,3 @@
+from .Image import image
+
+image.register()
diff --git a/weave/weave_client.py b/weave/weave_client.py
index 88e205f605a..d65275180d9 100644
--- a/weave/weave_client.py
+++ b/weave/weave_client.py
@@ -23,6 +23,7 @@
from weave.trace.op import op as op_deco
from weave.trace.refs import CallRef, ObjectRef, OpRef, Ref, TableRef
from weave.trace.serialize import from_json, isinstance_namedtuple, to_json
+from weave.trace.serializer import get_serializer_for_obj
from weave.trace.vals import WeaveObject, WeaveTable, make_trace_obj
from weave.trace_server.ids import generate_id
from weave.trace_server.trace_server_interface import (
@@ -101,6 +102,8 @@ def _get_direct_ref(obj: Any) -> Optional[Ref]:
def map_to_refs(obj: Any) -> Any:
+ if isinstance(obj, Ref):
+ return obj
if ref := _get_direct_ref(obj):
return ref
@@ -288,7 +291,7 @@ def make_client_call(
parent_id=server_call.parent_id,
id=server_call.id,
inputs=from_json(server_call.inputs, server_call.project_id, server),
- output=output,
+ output=from_json(output, server_call.project_id, server),
summary=dict(server_call.summary) if server_call.summary is not None else None,
display_name=server_call.display_name,
attributes=server_call.attributes,
@@ -380,7 +383,9 @@ def __init__(
self.ensure_project_exists = ensure_project_exists
if ensure_project_exists:
- self.server.ensure_project_exists(entity, project)
+ resp = self.server.ensure_project_exists(entity, project)
+ # Set Client project name with updated project name
+ self.project = resp.project_name
################ High Level Convenience Methods ################
@@ -727,10 +732,18 @@ def _project_id(self) -> str:
@trace_sentry.global_trace_sentry.watch()
def _save_object(self, val: Any, name: str, branch: str = "latest") -> ObjectRef:
self._save_nested_objects(val, name=name)
+
+ # typically, this condition would belong inside of the
+ # `_save_nested_objects` switch. However, we don't want to recursively
+ # publish all custom objects. Instead we only want to do this at the
+ # top-most level if requested
+ if get_serializer_for_obj(val) is not None:
+ self._save_and_attach_ref(val)
+
return self._save_object_basic(val, name, branch)
def _save_object_basic(
- self, val: Any, name: str, branch: str = "latest"
+ self, val: Any, name: Optional[str] = None, branch: str = "latest"
) -> ObjectRef:
# The WeaveTable case is special because object saving happens inside
# _save_object_nested and it has a special table_ref -- skip it here.
@@ -743,6 +756,14 @@ def _save_object_basic(
return val
json_val = to_json(val, self._project_id(), self.server)
+ if name is None:
+ if json_val.get("_type") == "CustomWeaveType":
+ custom_name = json_val.get("weave_type", {}).get("type")
+ name = custom_name
+
+ if name is None:
+ raise ValueError("Name must be provided for object saving")
+
response = self.server.obj_create(
ObjCreateReq(
obj=ObjSchemaForInsert(
@@ -778,7 +799,7 @@ def _save_nested_objects(self, obj: Any, name: Optional[str] = None) -> Any:
self._save_nested_objects(v)
ref = self._save_object_basic(obj_rec, name or get_obj_name(obj_rec))
obj.__dict__["ref"] = ref
- elif dataclasses.is_dataclass(obj):
+ elif dataclasses.is_dataclass(obj) and not isinstance(obj, Ref):
obj_rec = dataclass_object_record(obj)
for v in obj_rec.__dict__.values():
self._save_nested_objects(v)
@@ -808,11 +829,10 @@ def _save_nested_objects(self, obj: Any, name: Optional[str] = None) -> Any:
@trace_sentry.global_trace_sentry.watch()
def _save_table(self, table: Table) -> TableRef:
+ rows = to_json(table.rows, self._project_id(), self.server)
response = self.server.table_create(
TableCreateReq(
- table=TableSchemaForInsert(
- project_id=self._project_id(), rows=table.rows
- )
+ table=TableSchemaForInsert(project_id=self._project_id(), rows=rows)
)
)
return TableRef(
@@ -846,8 +866,16 @@ def _objects(self, filter: Optional[ObjectVersionFilter] = None) -> list[ObjSche
def _save_op(self, op: Op, name: Optional[str] = None) -> Ref:
if op.ref is not None:
return op.ref
+
if name is None:
name = op.name
+
+ return self._save_and_attach_ref(op, name)
+
+ def _save_and_attach_ref(self, op: Any, name: Optional[str] = None) -> Ref:
+ if (ref := getattr(op, "ref", None)) is not None:
+ return ref
+
op_def_ref = self._save_object_basic(op, name)
# setattr(op, "ref", op_def_ref) fails here
diff --git a/weave/weave_init.py b/weave/weave_init.py
index b1744ff0ce1..d2af6bf5eb7 100644
--- a/weave/weave_init.py
+++ b/weave/weave_init.py
@@ -104,6 +104,8 @@ def init_weave(
client = weave_client.WeaveClient(
entity_name, project_name, remote_server, ensure_project_exists
)
+ # If the project name was formatted by init, update the project name
+ project_name = client.project
_current_inited_client = InitializedClient(client)
# entity_name, project_name = get_entity_project_from_project_name(project_name)