diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3.tsx index be81c42d08b5..f6c0d2ad16c4 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3.tsx @@ -284,7 +284,7 @@ const Browse3Mounted: FC<{ const MainPeekingLayout: FC = () => { const {baseRouter} = useWeaveflowRouteContext(); - const params = useParams(); + const params = useParamsDecoded(); const baseRouterProjectRoot = baseRouter.projectUrl(':entity', ':project'); const generalProjectRoot = browse2Context.projectUrl(':entity', ':project'); const query = useURLSearchParamsDict(); @@ -406,7 +406,7 @@ const MainPeekingLayout: FC = () => { }; const ProjectRedirect: FC = () => { - const {entity, project} = useParams(); + const {entity, project} = useParamsDecoded(); const {baseRouter} = useWeaveflowRouteContext(); const url = baseRouter.tracesUIUrl(entity, project); return ; @@ -497,7 +497,7 @@ const Browse3ProjectRoot: FC<{ // TODO(tim/weaveflow_improved_nav): Generalize this const ObjectVersionRoutePageBinding = () => { - const params = useParams(); + const params = useParamsDecoded(); const query = useURLSearchParamsDict(); const history = useHistory(); @@ -538,7 +538,7 @@ const ObjectVersionRoutePageBinding = () => { // TODO(tim/weaveflow_improved_nav): Generalize this const OpVersionRoutePageBinding = () => { - const params = useParams(); + const params = useParamsDecoded(); const history = useHistory(); const routerContext = useWeaveflowCurrentRouteContext(); useEffect(() => { @@ -573,7 +573,7 @@ const useCallPeekRedirect = () => { // This is a "hack" since the client doesn't have all the info // needed to make a correct peek URL. This allows the client to request // such a view and we can redirect to the correct URL. - const params = useParams(); + const params = useParamsDecoded(); const {baseRouter} = useWeaveflowRouteContext(); const history = useHistory(); const {useCall} = useWFHooks(); @@ -618,10 +618,23 @@ const useCallPeekRedirect = () => { ]); }; +const useParamsDecoded = () => { + // Handle the case where entity/project (old) have spaces + const params = useParams(); + return useMemo(() => { + return Object.fromEntries( + Object.entries(params).map(([key, value]) => [ + key, + decodeURIComponent(value), + ]) + ); + }, [params]); +}; + // TODO(tim/weaveflow_improved_nav): Generalize this const CallPageBinding = () => { useCallPeekRedirect(); - const params = useParams(); + const params = useParamsDecoded(); const query = useURLSearchParamsDict(); return ( @@ -636,7 +649,7 @@ const CallPageBinding = () => { // TODO(tim/weaveflow_improved_nav): Generalize this const CallsPageBinding = () => { - const {entity, project, tab} = useParams(); + const {entity, project, tab} = useParamsDecoded(); const query = useURLSearchParamsDict(); const initialFilter = useMemo(() => { if (tab === 'evaluations') { @@ -769,7 +782,7 @@ const CallsPageBinding = () => { // TODO(tim/weaveflow_improved_nav): Generalize this const ObjectVersionsPageBinding = () => { - const {entity, project, tab} = useParams(); + const {entity, project, tab} = useParamsDecoded(); const query = useURLSearchParamsDict(); const filters: WFHighLevelObjectVersionFilter = useMemo(() => { let queryFilter: WFHighLevelObjectVersionFilter = {}; @@ -815,7 +828,7 @@ const ObjectVersionsPageBinding = () => { // TODO(tim/weaveflow_improved_nav): Generalize this const OpVersionsPageBinding = () => { - const params = useParams(); + const params = useParamsDecoded(); const query = useURLSearchParamsDict(); const filters = useMemo(() => { @@ -851,7 +864,7 @@ const OpVersionsPageBinding = () => { // TODO(tim/weaveflow_improved_nav): Generalize this const BoardPageBinding = () => { - const params = useParams(); + const params = useParamsDecoded(); return ( { // TODO(tim/weaveflow_improved_nav): Generalize this const ObjectPageBinding = () => { - const params = useParams(); + const params = useParamsDecoded(); return ( { }; const OpPageBinding = () => { - const params = useParams(); + const params = useParamsDecoded(); return ( { }; const CompareEvaluationsBinding = () => { - const {entity, project} = useParams(); + const {entity, project} = useParamsDecoded(); const query = useURLSearchParamsDict(); const evaluationCallIds = useMemo(() => { return JSON.parse(query.evaluationCallIds); @@ -902,19 +915,19 @@ const CompareEvaluationsBinding = () => { }; const OpsPageBinding = () => { - const params = useParams(); + const params = useParamsDecoded(); return ; }; const BoardsPageBinding = () => { - const params = useParams(); + const params = useParamsDecoded(); return ; }; const TablesPageBinding = () => { - const params = useParams(); + const params = useParamsDecoded(); return ; }; @@ -934,7 +947,7 @@ const AppBarLink = (props: ComponentProps) => ( ); const Browse3Breadcrumbs: FC = props => { - const params = useParams(); + const params = useParamsDecoded(); const query = useURLSearchParamsDict(); const filePathParts = query.path?.split('/') ?? []; const refFields = query.extra?.split('/') ?? []; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/ObjectViewer.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/ObjectViewer.tsx index 0e3005b1051f..d36678bfbd56 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/ObjectViewer.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/ObjectViewer.tsx @@ -467,7 +467,7 @@ const buildBaseRef = ( const parts = path.toPath().slice(-startIndex); parts.forEach(part => { if (typeof part === 'string') { - baseRef += '/' + OBJECT_ATTR_EDGE_NAME + '/' + part; + baseRef += '/' + OBJECT_ATTR_EDGE_NAME + '/' + encodeURIComponent(part); } else if (typeof part === 'number') { baseRef += '/' + LIST_INDEX_EDGE_NAME + '/' + part.toString(); } else { diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/tabularListViews/columnBuilder.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/tabularListViews/columnBuilder.tsx index 1e81bbd96436..73f26cf297b4 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/tabularListViews/columnBuilder.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/tabularListViews/columnBuilder.tsx @@ -160,7 +160,11 @@ function replaceTableRefsInFlattenedData(flattened: Record) { if (parentRef) { const newVal: ExpandedRefWithValueAsTableRef = makeRefExpandedPayload( - parentRef + '/' + OBJECT_ATTR_EDGE_NAME + '/' + attr, + parentRef + + '/' + + OBJECT_ATTR_EDGE_NAME + + '/' + + encodeURIComponent(attr), val ); return [key, newVal]; diff --git a/weave-js/src/react.test.ts b/weave-js/src/react.test.ts index a20e61692fa0..02877f1ba574 100644 --- a/weave-js/src/react.test.ts +++ b/weave-js/src/react.test.ts @@ -34,7 +34,7 @@ describe('parseRef', () => { }); it('parses a ref with spaces in entity', () => { const parsed = parseRef( - 'weave:///Entity%20Name/project/object/artifact-name:artifactversion' + 'weave:///Entity Name/project/object/artifact-name:artifactversion' ); expect(parsed).toEqual({ artifactName: 'artifact-name', @@ -102,20 +102,6 @@ describe('parseRef', () => { weaveKind: 'object', }); }); - it('parses a ref with escaped spaces in name and projectName', () => { - const parsed = parseRef( - 'weave:///entity/project%20with%20spaces/object/artifact%20name%20with%20spaces:artifactversion' - ); - expect(parsed).toEqual({ - artifactName: 'artifact name with spaces', - artifactRefExtra: '', - artifactVersion: 'artifactversion', - entityName: 'entity', - projectName: 'project with spaces', - scheme: 'weave', - weaveKind: 'object', - }); - }); }); it('parses a weave table ref', () => { const parsed = parseRef( diff --git a/weave-js/src/react.tsx b/weave-js/src/react.tsx index d585b54421f6..d11c404d8cd5 100644 --- a/weave-js/src/react.tsx +++ b/weave-js/src/react.tsx @@ -47,6 +47,7 @@ import { useState, } from 'react'; +import {WEAVE_REF_PREFIX} from './components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/constants'; import {PanelCompContext} from './components/Panel2/PanelComp'; import {usePanelContext} from './components/Panel2/PanelContext'; import {toWeaveType} from './components/Panel2/toWeaveType'; @@ -64,7 +65,6 @@ import { getChainRootVar, isConstructor, } from './core/mutate'; -import {trimStartChar} from './core/util/string'; import {UseNodeValueServerExecutionError} from './errors'; import {useDeepMemo} from './hookUtils'; import {consoleLog} from './util'; @@ -505,6 +505,7 @@ export const isWeaveObjectRef = (ref: ObjectRef): ref is WeaveObjectRef => { // Unfortunately many teams have been created that violate this. const PATTERN_ENTITY = '([^/]+)'; const PATTERN_PROJECT = '([^\\#?%:]{1,128})'; // Project name +const PATTERN_REF_EXTRA = '([a-zA-Z0-9_.~/%-]*)'; // Optional ref extra (valid chars are result of python urllib.parse.quote and javascript encodeURIComponent) const RE_WEAVE_OBJECT_REF_PATHNAME = new RegExp( [ '^', // Start of the string @@ -518,7 +519,7 @@ const RE_WEAVE_OBJECT_REF_PATHNAME = new RegExp( ':', '([*]|[a-zA-Z0-9]+)', // Artifact version, allowing '*' for any version '/?', // Ref extra portion is optional - '([a-zA-Z0-9_/]*)', // Optional ref extra + PATTERN_REF_EXTRA, // Optional ref extra '$', // End of the string ].join('') ); @@ -531,7 +532,7 @@ const RE_WEAVE_TABLE_REF_PATHNAME = new RegExp( '/table/', '([a-f0-9]+)', // Digest '/?', // Ref extra portion is optional - '([a-zA-Z0-9_/]*)', // Optional ref extra + PATTERN_REF_EXTRA, // Optional ref extra '$', // End of the string ].join('') ); @@ -544,7 +545,7 @@ const RE_WEAVE_CALL_REF_PATHNAME = new RegExp( '/call/', '([0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12})', // Call UUID '/?', // Ref extra portion is optional - '([a-zA-Z0-9_/]*)', // Optional ref extra + PATTERN_REF_EXTRA, // Optional ref extra '$', // End of the string ].join('') ); @@ -561,7 +562,7 @@ export const parseRef = (ref: string): ObjectRef => { } else if (isLocalArtifact) { splitLimit = 2; } else if (isWeaveRef) { - splitLimit = 4; + return parseWeaveRef(ref); } else { throw new Error(`Unknown protocol: ${url.protocol}`); } @@ -598,58 +599,58 @@ export const parseRef = (ref: string): ObjectRef => { artifactPath, }; } + throw new Error(`Unknown protocol: ${url.protocol}`); +}; - if (isWeaveRef) { - const trimmed = trimStartChar(decodedUri, '/'); - const tableMatch = trimmed.match(RE_WEAVE_TABLE_REF_PATHNAME); - if (tableMatch !== null) { - const [entity, project, digest] = tableMatch.slice(1); - return { - scheme: 'weave', - entityName: entity, - projectName: project, - weaveKind: 'table' as WeaveKind, - artifactName: '', - artifactVersion: digest, - artifactRefExtra: '', - }; - } - const callMatch = trimmed.match(RE_WEAVE_CALL_REF_PATHNAME); - if (callMatch !== null) { - const [entity, project, callId] = callMatch.slice(1); - return { - scheme: 'weave', - entityName: entity, - projectName: project, - weaveKind: 'call' as WeaveKind, - artifactName: callId, - artifactVersion: '', - artifactRefExtra: '', - }; - } - const match = trimmed.match(RE_WEAVE_OBJECT_REF_PATHNAME); - if (match === null) { - throw new Error('Invalid weave ref uri: ' + ref); - } - const [ - entityName, - projectName, - weaveKind, - artifactName, - artifactVersion, - artifactRefExtra, - ] = match.slice(1); +const parseWeaveRef = (ref: string): WeaveObjectRef => { + const trimmed = ref.slice(WEAVE_REF_PREFIX.length); + const tableMatch = trimmed.match(RE_WEAVE_TABLE_REF_PATHNAME); + if (tableMatch !== null) { + const [entity, project, digest] = tableMatch.slice(1); return { scheme: 'weave', - entityName, - projectName, - weaveKind: weaveKind as WeaveKind, - artifactName, - artifactVersion, - artifactRefExtra: artifactRefExtra ?? '', + entityName: entity, + projectName: project, + weaveKind: 'table' as WeaveKind, + artifactName: '', + artifactVersion: digest, + artifactRefExtra: '', }; } - throw new Error(`Unknown protocol: ${url.protocol}`); + const callMatch = trimmed.match(RE_WEAVE_CALL_REF_PATHNAME); + if (callMatch !== null) { + const [entity, project, callId] = callMatch.slice(1); + return { + scheme: 'weave', + entityName: entity, + projectName: project, + weaveKind: 'call' as WeaveKind, + artifactName: callId, + artifactVersion: '', + artifactRefExtra: '', + }; + } + const match = trimmed.match(RE_WEAVE_OBJECT_REF_PATHNAME); + if (match === null) { + throw new Error('Invalid weave ref uri: ' + ref); + } + const [ + entityName, + projectName, + weaveKind, + artifactName, + artifactVersion, + artifactRefExtra, + ] = match.slice(1); + return { + scheme: 'weave', + entityName, + projectName, + weaveKind: weaveKind as WeaveKind, + artifactName, + artifactVersion, + artifactRefExtra: artifactRefExtra ?? '', + }; }; export const objectRefWithExtra = ( diff --git a/weave/integrations/langchain/langchain_test.py b/weave/integrations/langchain/langchain_test.py index 61decdb18193..03956bf3348c 100644 --- a/weave/integrations/langchain/langchain_test.py +++ b/weave/integrations/langchain/langchain_test.py @@ -75,12 +75,12 @@ def test_simple_chain_invoke( ) prompt = PromptTemplate.from_template("1 + {number} = ") long_str = ( - "really_massive_name_that_is_longer_than_max_characters_which_would_be_crazy_" + "really_massive_name_that_is_longer_than_max_characters_which_would_be_crazy" ) name = long_str + long_str prompt.name = name - exp_name = "really_massive_name_that_is_longer_than_max_characte_9ad6_t_is_longer_than_max_characters_which_would_be_crazy_" + exp_name = "really_massive_name_that_is_longer_than_max_characte_ff6e_at_is_longer_than_max_characters_which_would_be_crazy" llm_chain = prompt | llm _ = llm_chain.invoke({"number": 2}) diff --git a/weave/tests/trace/test_client_trace.py b/weave/tests/trace/test_client_trace.py index d20ad76a8a9c..c2846ed126c8 100644 --- a/weave/tests/trace/test_client_trace.py +++ b/weave/tests/trace/test_client_trace.py @@ -16,8 +16,10 @@ from weave import Thread, ThreadPoolExecutor from weave.trace import weave_client from weave.trace.vals import MissingSelfInstanceError +from weave.trace.weave_client import sanitize_object_name from weave.trace_server import trace_server_interface as tsi from weave.trace_server.ids import generate_id +from weave.trace_server.refs_internal import extra_value_quoter from weave.trace_server.sqlite_trace_server import SqliteTraceServer from weave.trace_server.trace_server_interface_util import ( TRACE_REF_SCHEME, @@ -2566,3 +2568,129 @@ def return_nested_object(nested_obj: NestedObject): ) call_result = list(res)[0] assert call_result.output == nested_ref.uri() + + +class Custom(weave.Object): + val: dict + + +def test_object_with_disallowed_keys(client): + name = "thing % with / disallowed : keys" + obj = Custom(name=name, val={"1": 1}) + + weave.publish(obj) + + # we sanitize the name + assert obj.ref.name == "thing-with-disallowed-keys" + + create_req = tsi.ObjCreateReq.model_validate( + dict( + obj=dict( + project_id=client._project_id(), + object_id=name, + val={"1": 1}, + ) + ) + ) + with pytest.raises(Exception): + client.server.obj_create(create_req) + + +CHAR_LIMIT = 128 + + +def test_object_with_char_limit(client): + name = "l" * CHAR_LIMIT + obj = Custom(name=name, val={"1": 1}) + + weave.publish(obj) + + # we sanitize the name + assert obj.ref.name == name + + create_req = tsi.ObjCreateReq.model_validate( + dict( + obj=dict( + project_id=client._project_id(), + object_id=name, + val={"1": 1}, + ) + ) + ) + client.server.obj_create(create_req) + + +def test_object_with_char_over_limit(client): + name = "l" * (CHAR_LIMIT + 1) + obj = Custom(name=name, val={"1": 1}) + + weave.publish(obj) + + # we sanitize the name + assert obj.ref.name == name[:-1] + + create_req = tsi.ObjCreateReq.model_validate( + dict( + obj=dict( + project_id=client._project_id(), + object_id=name, + val={"1": 1}, + ) + ) + ) + with pytest.raises(Exception): + client.server.obj_create(create_req) + + +chars = "+_(){}|\"'<>!@$^&*#:,.[]-=;~`" + + +def test_objects_and_keys_with_special_characters(client): + # make sure to include ":", "/" which are URI-related + + name_with_special_characters = "n-a_m.e: /" + chars + "100" + dict_payload = {name_with_special_characters: "hello world"} + + obj = Custom(name=name_with_special_characters, val=dict_payload) + + weave.publish(obj) + assert obj.ref is not None + + entity, project = client._project_id().split("/") + project_id = f"{entity}/{project}" + ref_base = f"weave:///{project_id}" + exp_name = sanitize_object_name(name_with_special_characters) + assert exp_name == "n-a_m.e-100" + exp_key = extra_value_quoter(name_with_special_characters) + assert ( + exp_key + == "n-a_m.e%3A%20%2F%2B_%28%29%7B%7D%7C%22%27%3C%3E%21%40%24%5E%26%2A%23%3A%2C.%5B%5D-%3D%3B~%60100" + ) + exp_digest = "O66Mk7g91rlAUtcGYOFR1Y2Wk94YyPXJy2UEAzDQcYM" + + exp_obj_ref = f"{ref_base}/object/{exp_name}:{exp_digest}" + assert obj.ref.uri() == exp_obj_ref + + @weave.op + def test(obj: Custom): + return obj.val[name_with_special_characters] + + test.name = name_with_special_characters + + res = test(obj) + + exp_res_ref = f"{exp_obj_ref}/attr/val/key/{exp_key}" + found_ref = res.ref.uri() + assert res == "hello world" + assert found_ref == exp_res_ref + + gotten_res = weave.ref(found_ref).get() + assert gotten_res == "hello world" + + exp_op_digest = "xEPCVKKjDWxKzqaCxxU09jD82FGGf5WcNy2fC9VUF3M" + exp_op_ref = f"{ref_base}/op/{exp_name}:{exp_op_digest}" + + found_ref = test.ref.uri() + assert found_ref == exp_op_ref + gotten_fn = weave.ref(found_ref).get() + assert gotten_fn(obj) == "hello world" diff --git a/weave/tests/trace/test_trace_server.py b/weave/tests/trace/test_trace_server.py index cab5bb56efee..e02899370a56 100644 --- a/weave/tests/trace/test_trace_server.py +++ b/weave/tests/trace/test_trace_server.py @@ -25,10 +25,8 @@ def test_save_object(client): def test_robust_to_url_sensitive_chars(client): - entity = "entity" - project = "project" - project_id = f"{entity}/{project}" - object_id = "mali:cious/obj%ect" + project_id = client._project_id() + object_id = "mali_cious-obj.ect" bad_key = "mali:cious/ke%y" bad_val = {bad_key: "hello world"} @@ -53,22 +51,13 @@ def test_robust_to_url_sensitive_chars(client): assert read_res.obj.val == bad_val # Object ID that contains reserved characters should be rejected. - with pytest.raises(Exception): - read_res = client.server.refs_read_batch( - tsi.RefsReadBatchReq( - refs=[f"weave:///{project_id}/object/{object_id}:{create_res.digest}"] - ) - ) - encoded_object_id = urllib.parse.quote_plus(object_id) - assert encoded_object_id == "mali%3Acious%2Fobj%25ect" read_res = client.server.refs_read_batch( tsi.RefsReadBatchReq( - refs=[ - f"weave:///{project_id}/object/{encoded_object_id}:{create_res.digest}" - ] + refs=[f"weave:///{project_id}/object/{object_id}:{create_res.digest}"] ) ) + assert read_res.vals[0] == bad_val # Key that contains reserved characters should be rejected. @@ -76,7 +65,7 @@ def test_robust_to_url_sensitive_chars(client): read_res = client.server.refs_read_batch( tsi.RefsReadBatchReq( refs=[ - f"weave:///{project_id}/object/{encoded_object_id}:{create_res.digest}/key/{bad_key}" + f"weave:///{project_id}/object/{object_id}:{create_res.digest}/key/{bad_key}" ] ) ) @@ -86,7 +75,7 @@ def test_robust_to_url_sensitive_chars(client): read_res = client.server.refs_read_batch( tsi.RefsReadBatchReq( refs=[ - f"weave:///{project_id}/object/{encoded_object_id}:{create_res.digest}/key/{encoded_bad_key}" + f"weave:///{project_id}/object/{object_id}:{create_res.digest}/key/{encoded_bad_key}" ] ) ) diff --git a/weave/trace/refs.py b/weave/trace/refs.py index cfffe99436c1..16fb47d8f21b 100644 --- a/weave/trace/refs.py +++ b/weave/trace/refs.py @@ -1,7 +1,8 @@ import dataclasses +import urllib from typing import Any, Union -from ..trace_server import refs_internal +from ..trace_server import refs_internal, validation DICT_KEY_EDGE_NAME = refs_internal.DICT_KEY_EDGE_NAME LIST_INDEX_EDGE_NAME = refs_internal.LIST_INDEX_EDGE_NAME @@ -53,10 +54,16 @@ class ObjectRef(RefWithExtra): digest: str extra: tuple[str, ...] = () + def __post_init__(self) -> None: + refs_internal.validate_no_slashes(self.digest, "digest") + refs_internal.validate_no_colons(self.digest, "digest") + refs_internal.validate_extra(list(self.extra)) + validation.object_id_validator(self.name) + def uri(self) -> str: u = f"weave:///{self.entity}/{self.project}/object/{self.name}:{self.digest}" if self.extra: - u += "/" + "/".join(self.extra) + u += "/" + "/".join(refs_internal.extra_value_quoter(e) for e in self.extra) return u def get(self) -> Any: @@ -106,7 +113,7 @@ class OpRef(ObjectRef): def uri(self) -> str: u = f"weave:///{self.entity}/{self.project}/op/{self.name}:{self.digest}" if self.extra: - u += "/" + "/".join(self.extra) + u += "/" + "/".join(refs_internal.extra_value_quoter(e) for e in self.extra) return u @@ -120,7 +127,7 @@ class CallRef(RefWithExtra): def uri(self) -> str: u = f"weave:///{self.entity}/{self.project}/call/{self.id}" if self.extra: - u += "/" + "/".join(self.extra) + u += "/" + "/".join(refs_internal.extra_value_quoter(e) for e in self.extra) return u @@ -138,27 +145,18 @@ def parse_uri(uri: str) -> AnyRef: remaining = tuple(parts[3:]) if kind == "table": return TableRef(entity=entity, project=project, digest=remaining[0]) - elif kind == "call": - return CallRef( - entity=entity, project=project, id=remaining[0], extra=remaining[1:] - ) + extra = tuple(urllib.parse.unquote(r) for r in remaining[1:]) + if kind == "call": + return CallRef(entity=entity, project=project, id=remaining[0], extra=extra) elif kind == "object": name, version = remaining[0].split(":") return ObjectRef( - entity=entity, - project=project, - name=name, - digest=version, - extra=remaining[1:], + entity=entity, project=project, name=name, digest=version, extra=extra ) elif kind == "op": name, version = remaining[0].split(":") return OpRef( - entity=entity, - project=project, - name=name, - digest=version, - extra=remaining[1:], + entity=entity, project=project, name=name, digest=version, extra=extra ) else: raise ValueError(f"Unknown ref kind: {kind}") diff --git a/weave/trace/weave_client.py b/weave/trace/weave_client.py index ce2ea38eea1c..5872afbf85c2 100644 --- a/weave/trace/weave_client.py +++ b/weave/trace/weave_client.py @@ -1,6 +1,7 @@ import dataclasses import datetime import platform +import re import sys import typing from functools import lru_cache @@ -765,6 +766,8 @@ def _save_object_basic( if name is None: raise ValueError("Name must be provided for object saving") + name = sanitize_object_name(name) + response = self.server.obj_create( ObjCreateReq( obj=ObjSchemaForInsert( @@ -1002,4 +1005,16 @@ def redact_sensitive_keys(obj: typing.Any) -> typing.Any: return obj +def sanitize_object_name(name: str) -> str: + # Replaces any non-alphanumeric characters with a single dash and removes + # any leading or trailing dashes. This is more restrictive than the DB + # constraints and can be relaxed if needed. + res = re.sub(r"([._-]{2,})+", "-", re.sub(r"[^\w._]+", "-", name)).strip("-_") + if not res: + raise ValueError(f"Invalid object name: {name}") + if len(res) > 128: + res = res[:128] + return res + + __docspec__ = [WeaveClient, Call, CallsIter] diff --git a/weave/trace_server/refs_internal.py b/weave/trace_server/refs_internal.py index a11e8bdf80ca..ecf2a5fc75c5 100644 --- a/weave/trace_server/refs_internal.py +++ b/weave/trace_server/refs_internal.py @@ -8,6 +8,8 @@ import urllib from typing import Any, Union +from . import validation + WEAVE_INTERNAL_SCHEME = "weave-trace-internal" WEAVE_SCHEME = "weave" WEAVE_PRIVATE_SCHEME = "weave-private" @@ -29,23 +31,9 @@ class InvalidInternalRef(ValueError): pass -def quote_select(s: str, quote_chars: tuple[str, ...] = ("/", ":", "%")) -> str: - """We don't need to quote every single character, rather - we just need to quote the characters that are used as - delimiters in the URI. Right now, we only use "/", ":", and "%". Moreover, - the only user-controlled fields are the object name and the extra fields. - - All other fields are generated by the system and are safe for parsing. - """ - # We have to be careful here since quoting creates new "%" characters. - # We don't want to double-quote characters. - if "%" in quote_chars and quote_chars[-1] != "%": - raise ValueError("Quoting '%' must be the last character in the list.") - - for c in quote_chars: - s = s.replace(c, urllib.parse.quote(c)) - - return s +def extra_value_quoter(s: str) -> str: + # Here, we encode all non alpha-numerics or `_.-~`. + return urllib.parse.quote(s, safe="") def validate_extra(extra: list[str]) -> None: @@ -118,20 +106,21 @@ def __post_init__(self) -> None: validate_no_slashes(self.version, "version") validate_no_colons(self.version, "version") validate_extra(self.extra) + validation.object_id_validator(self.name) def uri(self) -> str: - u = f"{WEAVE_INTERNAL_SCHEME}:///{self.project_id}/object/{quote_select(self.name)}:{self.version}" + u = f"{WEAVE_INTERNAL_SCHEME}:///{self.project_id}/object/{self.name}:{self.version}" if self.extra: - u += "/" + "/".join(quote_select(e) for e in self.extra) + u += "/" + "/".join(extra_value_quoter(e) for e in self.extra) return u @dataclasses.dataclass(frozen=True) class InternalOpRef(InternalObjectRef): def uri(self) -> str: - u = f"{WEAVE_INTERNAL_SCHEME}:///{self.project_id}/op/{quote_select(self.name)}:{self.version}" + u = f"{WEAVE_INTERNAL_SCHEME}:///{self.project_id}/op/{self.name}:{self.version}" if self.extra: - u += "/" + "/".join(quote_select(e) for e in self.extra) + u += "/" + "/".join(extra_value_quoter(e) for e in self.extra) return u @@ -151,7 +140,7 @@ def __post_init__(self) -> None: def uri(self) -> str: u = f"{WEAVE_INTERNAL_SCHEME}:///{self.project_id}/call/{self.id}" if self.extra: - u += "/" + "/".join(quote_select(e) for e in self.extra) + u += "/" + "/".join(extra_value_quoter(e) for e in self.extra) return u @@ -202,13 +191,12 @@ def _parse_remaining(remaining: list[str]) -> tuple[str, str, list[str]]: It is expected to be pre-split by slashes into parts. The return is a tuple of name, version, and extra parts, properly unquoted. """ - name_encoded, version = remaining[0].split(":") - name = urllib.parse.unquote_plus(name_encoded) + name, version = remaining[0].split(":") extra = remaining[1:] if len(extra) == 1 and extra[0] == "": extra = [] else: - extra = [urllib.parse.unquote_plus(r) for r in extra] + extra = [urllib.parse.unquote(r) for r in extra] return name, version, extra diff --git a/weave/trace_server/sqlite_trace_server.py b/weave/trace_server/sqlite_trace_server.py index e9910c4c2401..54a636873fd1 100644 --- a/weave/trace_server/sqlite_trace_server.py +++ b/weave/trace_server/sqlite_trace_server.py @@ -34,6 +34,7 @@ from weave.trace_server.trace_server_interface_util import ( WILDCARD_ARTIFACT_VERSION_AND_PATH, ) +from weave.trace_server.validation import object_id_validator from . import trace_server_interface as tsi from .ids import generate_id @@ -585,6 +586,9 @@ def obj_create(self, req: tsi.ObjCreateReq) -> tsi.ObjCreateRes: json_val = json.dumps(req.obj.val) digest = str_digest(json_val) + # Validate + object_id_validator(req.obj.object_id) + req_obj = req.obj # TODO: version index isn't right here, what if we delete stuff? with self.lock: diff --git a/weave/trace_server/tests/test_refs.py b/weave/trace_server/tests/test_refs.py index 41ce450a1716..42b90235d349 100644 --- a/weave/trace_server/tests/test_refs.py +++ b/weave/trace_server/tests/test_refs.py @@ -1,8 +1,84 @@ -from weave.trace.refs import ObjectRef +import random + +import pytest + +from weave.trace import refs +from weave.trace.weave_client import sanitize_object_name +from weave.trace_server import refs_internal + +quote = refs_internal.extra_value_quoter def test_isdescended_from(): - a = ObjectRef(entity="e", project="p", name="n", digest="v", extra=["x1"]) - b = ObjectRef(entity="e", project="p", name="n", digest="v", extra=["x1", "x2"]) + a = refs.ObjectRef( + entity="e", project="p", name="n", digest="v", extra=["attr", "x2"] + ) + b = refs.ObjectRef( + entity="e", + project="p", + name="n", + digest="v", + extra=["attr", "x2", "attr", "x4"], + ) assert a.is_descended_from(b) == False assert b.is_descended_from(a) == True + + +def string_with_every_char(disallowed_chars=[]): + char_codes = list(range(256)) + random.shuffle(char_codes) + return "".join(chr(i) for i in char_codes if chr(i) not in disallowed_chars) + + +def test_ref_parsing_external_invalid(): + with pytest.raises(Exception): + ref_start = refs.ObjectRef( + entity="entity", + project="project", + name=string_with_every_char(), + digest="1234567890", + extra=("key", string_with_every_char()), + ) + + +def test_ref_parsing_external_sanitized(): + ref_start = refs.ObjectRef( + entity="entity", + project="project", + name=sanitize_object_name(string_with_every_char()), + digest="1234567890", + extra=("key", string_with_every_char()), + ) + + ref_str = ref_start.uri() + exp_ref = f"{refs_internal.WEAVE_SCHEME}:///{ref_start.entity}/{ref_start.project}/object/{ref_start.name}:{ref_start.digest}/{ref_start.extra[0]}/{quote(ref_start.extra[1])}" + assert ref_str == exp_ref + + parsed = refs.parse_uri(ref_str) + assert parsed == ref_start + + +def test_ref_parsing_internal_invalid(): + with pytest.raises(Exception): + ref_start = refs_internal.InternalObjectRef( + project_id="project", + name=string_with_every_char(), + version="1234567890", + extra=("key", string_with_every_char()), + ) + + +def test_ref_parsing_internal_sanitized(): + ref_start = refs_internal.InternalObjectRef( + project_id="project", + name=sanitize_object_name(string_with_every_char()), + version="1234567890", + extra=["key", string_with_every_char()], + ) + + ref_str = ref_start.uri() + exp_ref = f"{refs_internal.WEAVE_INTERNAL_SCHEME}:///{ref_start.project_id}/object/{ref_start.name}:{ref_start.version}/{ref_start.extra[0]}/{quote(ref_start.extra[1])}" + assert ref_str == exp_ref + + parsed = refs_internal.parse_internal_uri(ref_str) + assert parsed == ref_start diff --git a/weave/trace_server/validation.py b/weave/trace_server/validation.py index 226e946f691a..97f3bc4b1f13 100644 --- a/weave/trace_server/validation.py +++ b/weave/trace_server/validation.py @@ -1,3 +1,4 @@ +import re import typing from weave.trace_server import refs_internal @@ -57,7 +58,20 @@ def wb_run_id_validator(s: typing.Optional[str]) -> typing.Optional[str]: return s +def _validate_object_name_charset(name: str) -> None: + # Object names must be alphanumeric with dashes + invalid_chars = re.findall(r"[^\w._-]", name) + if invalid_chars: + raise ValueError( + f"Invalid object name: {name}. Contains invalid characters: {invalid_chars}" + ) + + if not name: + raise ValueError("Object name cannot be empty") + + def object_id_validator(s: str) -> str: + _validate_object_name_charset(s) return validation_util.require_max_str_len(s, 128) diff --git a/weave/trace_server/validation_util.py b/weave/trace_server/validation_util.py index 2879c5b2ddae..2496a597926e 100644 --- a/weave/trace_server/validation_util.py +++ b/weave/trace_server/validation_util.py @@ -58,6 +58,6 @@ def require_internal_ref_uri( def require_max_str_len(s: str, length: int) -> str: - if len(s) >= length: + if len(s) > length: raise CHValidationError(f"String too long: {s}. Max length is {length}") return s