Skip to content

Commit

Permalink
chore: add $in to orm query builder (#2180)
Browse files Browse the repository at this point in the history
  • Loading branch information
gtarpenning authored Aug 21, 2024
1 parent dd00e51 commit b75182d
Show file tree
Hide file tree
Showing 6 changed files with 77 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ type GtOperation = {
$gt: [Operand, Operand];
};

type InOperation = {
$in: [Operand, Operand[]];
};

type GteOperation = {
$gte: [Operand, Operand];
};
Expand All @@ -61,6 +65,7 @@ type Operation =
| EqOperation
| GtOperation
| GteOperation
| InOperation
| ContainsOperation;

type Operand =
Expand Down
50 changes: 50 additions & 0 deletions weave/tests/trace/test_client_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -2297,6 +2297,56 @@ def test_obj(val):
assert inner_res.count == count


def test_in_operation(client):
@weave.op()
def test_op(label, val):
return val

test_op(1, [1, 2, 3])
test_op(2, [1, 2, 3])
test_op(3, [5, 6, 7])
test_op(4, [8, 2, 3])

call_ids = [call.id for call in test_op.calls()]
assert len(call_ids) == 4

query = {
"$in": [
{"$getField": "id"},
[{"$literal": call_id} for call_id in call_ids[:2]],
]
}

res = get_client_trace_server(client).calls_query_stats(
tsi.CallsQueryStatsReq.model_validate(
dict(
project_id=get_client_project_id(client),
query={"$expr": query},
)
)
)
assert res.count == 2

query = {
"$in": [
{"$getField": "id"},
[{"$literal": call_id} for call_id in call_ids],
]
}
res = get_client_trace_server(client).calls_query_stream(
tsi.CallsQueryReq.model_validate(
dict(
project_id=get_client_project_id(client),
query={"$expr": query},
)
)
)
res = list(res)
assert len(res) == 4
for i in range(4):
assert res[i].id == call_ids[i]


def test_call_has_client_version(client):
@weave.op
def test():
Expand Down
5 changes: 5 additions & 0 deletions weave/trace_server/calls_query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,10 @@ def process_operation(operation: tsi_query.Operation) -> str:
lhs_part = process_operand(operation.gte_[0])
rhs_part = process_operand(operation.gte_[1])
cond = f"({lhs_part} >= {rhs_part})"
elif isinstance(operation, tsi_query.InOperation):
lhs_part = process_operand(operation.in_[0])
rhs_part = ",".join(process_operand(op) for op in operation.in_[1])
cond = f"({lhs_part} IN ({rhs_part}))"
elif isinstance(operation, tsi_query.ContainsOperation):
lhs_part = process_operand(operation.contains_.input)
rhs_part = process_operand(operation.contains_.substr)
Expand Down Expand Up @@ -644,6 +648,7 @@ def process_operand(operand: "tsi_query.Operand") -> str:
tsi_query.EqOperation,
tsi_query.GtOperation,
tsi_query.GteOperation,
tsi_query.InOperation,
tsi_query.ContainsOperation,
),
):
Expand Down
7 changes: 7 additions & 0 deletions weave/trace_server/interface/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,11 @@ class GteOperation(BaseModel):
gte_: typing.Tuple["Operand", "Operand"] = Field(alias="$gte")


# https://www.mongodb.com/docs/manual/reference/operator/aggregation/in/

This comment has been minimized.

Copy link
@tssweeney

tssweeney Aug 25, 2024

Collaborator

How would someone specify "x in field y"?

class InOperation(BaseModel):
in_: typing.Tuple["Operand", list["Operand"]] = Field(alias="$in")


# This is not technically in the Mongo spec. Mongo has:
# https://www.mongodb.com/docs/manual/reference/operator/aggregation/regexMatch/,
# however, rather than support a full regex match right now, we will
Expand All @@ -143,6 +148,7 @@ class ContainsSpec(BaseModel):
EqOperation,
GtOperation,
GteOperation,
InOperation,
ContainsOperation,
]
Operand = typing.Union[
Expand All @@ -159,4 +165,5 @@ class ContainsSpec(BaseModel):
EqOperation.model_rebuild()
GtOperation.model_rebuild()
GteOperation.model_rebuild()
InOperation.model_rebuild()
ContainsOperation.model_rebuild()
5 changes: 5 additions & 0 deletions weave/trace_server/orm.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,6 +626,10 @@ def process_operation(operation: tsi_query.Operation) -> str:
lhs_part = process_operand(operation.gte_[0])
rhs_part = process_operand(operation.gte_[1])
cond = f"({lhs_part} >= {rhs_part})"
elif isinstance(operation, tsi_query.InOperation):
lhs_part = process_operand(operation.in_[0])
rhs_part = ",".join(process_operand(op) for op in operation.in_[1])
cond = f"({lhs_part} IN ({rhs_part}))"
elif isinstance(operation, tsi_query.ContainsOperation):
lhs_part = process_operand(operation.contains_.input)
rhs_part = process_operand(operation.contains_.substr)
Expand Down Expand Up @@ -665,6 +669,7 @@ def process_operand(operand: tsi_query.Operand) -> str:
tsi_query.EqOperation,
tsi_query.GtOperation,
tsi_query.GteOperation,
tsi_query.InOperation,
tsi_query.ContainsOperation,
),
):
Expand Down
5 changes: 5 additions & 0 deletions weave/trace_server/sqlite_trace_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,10 @@ def process_operation(operation: tsi_query.Operation) -> str:
lhs_part = process_operand(operation.gte_[0])
rhs_part = process_operand(operation.gte_[1])
cond = f"({lhs_part} >= {rhs_part})"
elif isinstance(operation, tsi_query.InOperation):
lhs_part = process_operand(operation.in_[0])
rhs_part = ",".join(process_operand(op) for op in operation.in_[1])
cond = f"({lhs_part} IN ({rhs_part}))"
elif isinstance(operation, tsi_query.ContainsOperation):
lhs_part = process_operand(operation.contains_.input)
rhs_part = process_operand(operation.contains_.substr)
Expand Down Expand Up @@ -368,6 +372,7 @@ def process_operand(operand: tsi_query.Operand) -> str:
tsi_query.EqOperation,
tsi_query.GtOperation,
tsi_query.GteOperation,
tsi_query.InOperation,
tsi_query.ContainsOperation,
),
):
Expand Down

0 comments on commit b75182d

Please sign in to comment.