Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Fixes issues with special characters in object name #2156

Merged
merged 25 commits into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,7 @@ const buildBaseRef = (
const parts = path.toPath().slice(-startIndex);
parts.forEach(part => {
if (typeof part === 'string') {
baseRef += '/' + OBJECT_ATTR_EDGE_NAME + '/' + part;
baseRef += '/' + OBJECT_ATTR_EDGE_NAME + '/' + encodeURIComponent(part);
} else if (typeof part === 'number') {
baseRef += '/' + LIST_INDEX_EDGE_NAME + '/' + part.toString();
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,11 @@ function replaceTableRefsInFlattenedData(flattened: Record<string, any>) {
if (parentRef) {
const newVal: ExpandedRefWithValueAsTableRef =
makeRefExpandedPayload(
parentRef + '/' + OBJECT_ATTR_EDGE_NAME + '/' + attr,
parentRef +
'/' +
OBJECT_ATTR_EDGE_NAME +
'/' +
encodeURIComponent(attr),
val
);
return [key, newVal];
Expand Down
105 changes: 53 additions & 52 deletions weave-js/src/react.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ import {
useState,
} from 'react';

import {WEAVE_REF_PREFIX} from './components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/constants';
import {PanelCompContext} from './components/Panel2/PanelComp';
import {usePanelContext} from './components/Panel2/PanelContext';
import {toWeaveType} from './components/Panel2/toWeaveType';
Expand All @@ -64,7 +65,6 @@ import {
getChainRootVar,
isConstructor,
} from './core/mutate';
import {trimStartChar} from './core/util/string';
import {UseNodeValueServerExecutionError} from './errors';
import {useDeepMemo} from './hookUtils';
import {consoleLog} from './util';
Expand Down Expand Up @@ -505,6 +505,7 @@ export const isWeaveObjectRef = (ref: ObjectRef): ref is WeaveObjectRef => {
// Unfortunately many teams have been created that violate this.
const PATTERN_ENTITY = '([^/]+)';
const PATTERN_PROJECT = '([^\\#?%:]{1,128})'; // Project name
const PATTERN_REF_EXTRA = '([a-zA-Z0-9_.~/%-]*)'; // Optional ref extra (valid chars are result of python urllib.parse.quote and javascript encodeURIComponent)
const RE_WEAVE_OBJECT_REF_PATHNAME = new RegExp(
[
'^', // Start of the string
Expand All @@ -518,7 +519,7 @@ const RE_WEAVE_OBJECT_REF_PATHNAME = new RegExp(
':',
'([*]|[a-zA-Z0-9]+)', // Artifact version, allowing '*' for any version
'/?', // Ref extra portion is optional
'([a-zA-Z0-9_/]*)', // Optional ref extra
PATTERN_REF_EXTRA, // Optional ref extra
'$', // End of the string
].join('')
);
Expand All @@ -531,7 +532,7 @@ const RE_WEAVE_TABLE_REF_PATHNAME = new RegExp(
'/table/',
'([a-f0-9]+)', // Digest
'/?', // Ref extra portion is optional
'([a-zA-Z0-9_/]*)', // Optional ref extra
PATTERN_REF_EXTRA, // Optional ref extra
'$', // End of the string
].join('')
);
Expand All @@ -544,7 +545,7 @@ const RE_WEAVE_CALL_REF_PATHNAME = new RegExp(
'/call/',
'([0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12})', // Call UUID
'/?', // Ref extra portion is optional
'([a-zA-Z0-9_/]*)', // Optional ref extra
PATTERN_REF_EXTRA, // Optional ref extra
'$', // End of the string
].join('')
);
Expand All @@ -561,7 +562,7 @@ export const parseRef = (ref: string): ObjectRef => {
} else if (isLocalArtifact) {
splitLimit = 2;
} else if (isWeaveRef) {
splitLimit = 4;
return parseWeaveRef(ref);
} else {
throw new Error(`Unknown protocol: ${url.protocol}`);
}
Expand Down Expand Up @@ -598,58 +599,58 @@ export const parseRef = (ref: string): ObjectRef => {
artifactPath,
};
}
throw new Error(`Unknown protocol: ${url.protocol}`);
};

if (isWeaveRef) {
const trimmed = trimStartChar(decodedUri, '/');
const tableMatch = trimmed.match(RE_WEAVE_TABLE_REF_PATHNAME);
if (tableMatch !== null) {
const [entity, project, digest] = tableMatch.slice(1);
return {
scheme: 'weave',
entityName: entity,
projectName: project,
weaveKind: 'table' as WeaveKind,
artifactName: '',
artifactVersion: digest,
artifactRefExtra: '',
};
}
const callMatch = trimmed.match(RE_WEAVE_CALL_REF_PATHNAME);
if (callMatch !== null) {
const [entity, project, callId] = callMatch.slice(1);
return {
scheme: 'weave',
entityName: entity,
projectName: project,
weaveKind: 'call' as WeaveKind,
artifactName: callId,
artifactVersion: '',
artifactRefExtra: '',
};
}
const match = trimmed.match(RE_WEAVE_OBJECT_REF_PATHNAME);
if (match === null) {
throw new Error('Invalid weave ref uri: ' + ref);
}
const [
entityName,
projectName,
weaveKind,
artifactName,
artifactVersion,
artifactRefExtra,
] = match.slice(1);
const parseWeaveRef = (ref: string): WeaveObjectRef => {
const trimmed = ref.slice(WEAVE_REF_PREFIX.length);
const tableMatch = trimmed.match(RE_WEAVE_TABLE_REF_PATHNAME);
if (tableMatch !== null) {
const [entity, project, digest] = tableMatch.slice(1);
return {
scheme: 'weave',
entityName,
projectName,
weaveKind: weaveKind as WeaveKind,
artifactName,
artifactVersion,
artifactRefExtra: artifactRefExtra ?? '',
entityName: entity,
projectName: project,
weaveKind: 'table' as WeaveKind,
artifactName: '',
artifactVersion: digest,
artifactRefExtra: '',
};
}
throw new Error(`Unknown protocol: ${url.protocol}`);
const callMatch = trimmed.match(RE_WEAVE_CALL_REF_PATHNAME);
if (callMatch !== null) {
const [entity, project, callId] = callMatch.slice(1);
return {
scheme: 'weave',
entityName: entity,
projectName: project,
weaveKind: 'call' as WeaveKind,
artifactName: callId,
artifactVersion: '',
artifactRefExtra: '',
};
}
const match = trimmed.match(RE_WEAVE_OBJECT_REF_PATHNAME);
if (match === null) {
throw new Error('Invalid weave ref uri: ' + ref);
}
const [
entityName,
projectName,
weaveKind,
artifactName,
artifactVersion,
artifactRefExtra,
] = match.slice(1);
return {
scheme: 'weave',
entityName,
projectName,
weaveKind: weaveKind as WeaveKind,
artifactName,
artifactVersion,
artifactRefExtra: artifactRefExtra ?? '',
};
};

export const objectRefWithExtra = (
Expand Down
82 changes: 82 additions & 0 deletions weave/tests/test_client_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
from weave import Thread, ThreadPoolExecutor, weave_client
from weave.trace.vals import MissingSelfInstanceError
from weave.trace_server.ids import generate_id
from weave.trace_server.refs_internal import extra_value_quoter
from weave.trace_server.sqlite_trace_server import SqliteTraceServer
from weave.weave_client import sanitize_object_name

from ..trace_server import trace_server_interface as tsi
from ..trace_server.trace_server_interface_util import (
Expand Down Expand Up @@ -2421,3 +2423,83 @@ def test_model_save(client):
assert isinstance(expected_predict_op, str) and expected_predict_op.startswith(
"weave:///"
)


class Custom(weave.Object):
val: dict


def test_object_with_disallowed_keys(client):
name = "%"
obj = Custom(name=name, val={"1": 1})

weave.publish(obj)

# we sanitize the name
assert obj.ref.name == "_"

create_req = tsi.ObjCreateReq.model_validate(
dict(
obj=dict(
project_id=client._project_id(),
object_id=name,
val={"1": 1},
)
)
)
with pytest.raises(Exception):
client.server.obj_create(create_req)


chars = "+_(){}|\"'<>!@$^&*#:,.[]-=;~`"


def test_objects_and_keys_with_special_characters(client):
# make sure to include ":", "/" which are URI-related

name_with_special_characters = "name: /" + chars
dict_payload = {name_with_special_characters: "hello world"}

obj = Custom(name=name_with_special_characters, val=dict_payload)

weave.publish(obj)
assert obj.ref is not None

entity, project = client._project_id().split("/")
project_id = f"{entity}/{project}"
ref_base = f"weave:///{project_id}"
exp_name = sanitize_object_name(name_with_special_characters)
assert exp_name == "name"
exp_key = extra_value_quoter(name_with_special_characters)
assert (
exp_key
== "name%3A%20%2F%2B_%28%29%7B%7D%7C%22%27%3C%3E%21%40%24%5E%26%2A%23%3A%2C.%5B%5D-%3D%3B~%60"
)
exp_digest = "2bnzTXFjtlwrtXWNLhAyvYq0XbRFfr633kKL2IkBOlI"

exp_obj_ref = f"{ref_base}/object/{exp_name}:{exp_digest}"
assert obj.ref.uri() == exp_obj_ref

@weave.op
def test(obj: Custom):
return obj.val[name_with_special_characters]

test.name = name_with_special_characters

res = test(obj)

exp_res_ref = f"{exp_obj_ref}/attr/val/key/{exp_key}"
found_ref = res.ref.uri()
assert res == "hello world"
assert found_ref == exp_res_ref

gotten_res = weave.ref(found_ref).get()
assert gotten_res == "hello world"

exp_op_digest = "WZAc2HA5GyWPr7YzHiSBncbHKMywXN3hk8onqRy2KkA"
exp_op_ref = f"{ref_base}/op/{exp_name}:{exp_op_digest}"

found_ref = test.ref.uri()
assert found_ref == exp_op_ref
gotten_fn = weave.ref(found_ref).get()
assert gotten_fn(obj) == "hello world"
26 changes: 9 additions & 17 deletions weave/trace/refs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import dataclasses
import urllib
from typing import Any, Union

from ..trace_server import refs_internal
Expand Down Expand Up @@ -56,7 +57,7 @@ class ObjectRef(RefWithExtra):
def uri(self) -> str:
u = f"weave:///{self.entity}/{self.project}/object/{self.name}:{self.digest}"
if self.extra:
u += "/" + "/".join(self.extra)
u += "/" + "/".join(refs_internal.extra_value_quoter(e) for e in self.extra)
return u

def get(self) -> Any:
Expand Down Expand Up @@ -106,7 +107,7 @@ class OpRef(ObjectRef):
def uri(self) -> str:
u = f"weave:///{self.entity}/{self.project}/op/{self.name}:{self.digest}"
if self.extra:
u += "/" + "/".join(self.extra)
u += "/" + "/".join(refs_internal.extra_value_quoter(e) for e in self.extra)
return u


Expand All @@ -120,7 +121,7 @@ class CallRef(RefWithExtra):
def uri(self) -> str:
u = f"weave:///{self.entity}/{self.project}/call/{self.id}"
if self.extra:
u += "/" + "/".join(self.extra)
u += "/" + "/".join(refs_internal.extra_value_quoter(e) for e in self.extra)
return u


Expand All @@ -138,27 +139,18 @@ def parse_uri(uri: str) -> AnyRef:
remaining = tuple(parts[3:])
if kind == "table":
return TableRef(entity=entity, project=project, digest=remaining[0])
elif kind == "call":
return CallRef(
entity=entity, project=project, id=remaining[0], extra=remaining[1:]
)
extra = tuple(urllib.parse.unquote(r) for r in remaining[1:])
if kind == "call":
return CallRef(entity=entity, project=project, id=remaining[0], extra=extra)
elif kind == "object":
name, version = remaining[0].split(":")
return ObjectRef(
entity=entity,
project=project,
name=name,
digest=version,
extra=remaining[1:],
entity=entity, project=project, name=name, digest=version, extra=extra
)
elif kind == "op":
name, version = remaining[0].split(":")
return OpRef(
entity=entity,
project=project,
name=name,
digest=version,
extra=remaining[1:],
entity=entity, project=project, name=name, digest=version, extra=extra
)
else:
raise ValueError(f"Unknown ref kind: {kind}")
Loading
Loading