Skip to content

Commit

Permalink
Merge pull request #174 from PrefectHQ/sum-tests
Browse files Browse the repository at this point in the history
Fix ORM tests and add `sum` aggregate
  • Loading branch information
cicdw authored Jan 15, 2021
2 parents c257647 + 6018262 commit 3f804d2
Show file tree
Hide file tree
Showing 3 changed files with 177 additions and 56 deletions.
20 changes: 20 additions & 0 deletions changes/pr174.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# An example changelog entry
#
# 1. Choose one (or more if a PR encompasses multiple changes) of the following headers:
# - feature
# - enhancement
# - fix
# - deprecation
# - breaking (for breaking changes)
# - migration (for database migrations)
#
# 2. Fill in one (or more) bullet points under the heading, describing the change.
# Markdown syntax may be used.
#
# 3. If you would like to be credited as helping with this release, add a
# contributor section with your name and github username.
#
# Here's an example of a PR that adds an enhancement

enhancement:
- "Add `sum` aggregate to ORM - [#174](https://github.com/PrefectHQ/server/pull/174)"
26 changes: 24 additions & 2 deletions src/prefect_server/database/orm.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,7 +568,29 @@ async def count(
result = await prefect.plugins.hasura.client.execute(query, as_box=False)
return result["data"]["count"]["aggregate"]["count"]

async def max(self, columns) -> dict:
async def sum(self, columns: List[str]) -> dict:
"""
Returns the sum of the requested columns
Args:
Returns:
- dict: the requested columns and corresponding minmums
"""
agg_type = self.model.__root_fields__.get(
"select_aggregate", f"{self.model.__hasura_type__}_aggregate"
)
query = {
"query": {
with_args(f"sum_query: {agg_type}", {"where": self.where}): {
"aggregate": {"sum": set(columns)}
}
}
}
result = await prefect.plugins.hasura.client.execute(query, as_box=False)
return result["data"]["sum_query"]["aggregate"]["sum"]

async def max(self, columns: List[str]) -> dict:
"""
Returns the maximum value of the requested columns
Expand All @@ -590,7 +612,7 @@ async def max(self, columns) -> dict:
result = await prefect.plugins.hasura.client.execute(query, as_box=False)
return result["data"]["max_query"]["aggregate"]["max"]

async def min(self, columns) -> dict:
async def min(self, columns: List[str]) -> dict:
"""
Returns the minimum value of the requested columns
Expand Down
187 changes: 133 additions & 54 deletions tests/database/test_orm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Licensed under the Prefect Community License, available at
# https://www.prefect.io/legal/prefect-community-license

from asynctest import CoroutineMock
import datetime
import json
import uuid
Expand All @@ -10,10 +9,12 @@
import pendulum
import pydantic
import pytest
from asynctest import CoroutineMock
from box import Box

from prefect import models
from prefect.engine.state import Running, Scheduled
from prefect.utilities.graphql import EnumValue

from prefect_server.database import orm


Expand Down Expand Up @@ -277,97 +278,97 @@ async def test_where_with_where(self, project_id):
assert q.where == {"id": {"_in": [project_id]}}

async def test_insert_many(self, project_id):
f1 = models.Project(name="f1")
f2 = models.Project(name="f2")
f3 = models.Project(name="f3")
ids = await models.Project.insert_many([f1, f2, f3])
p1 = models.Project(name="p1")
p2 = models.Project(name="p2")
p3 = models.Project(name="p3")
ids = await models.Project.insert_many([p1, p2, p3])
assert len(ids) == 3
assert all([await models.Project.where(id=i).first() for i in ids])

async def test_insert_dict(self, project_id):
f1 = dict(name="f1")
f2 = dict(name="f2")
f3 = dict(name="f3")
ids = await models.Project.insert_many([f1, f2, f3])
p1 = dict(name="p1")
p2 = dict(name="p2")
p3 = dict(name="p3")
ids = await models.Project.insert_many([p1, p2, p3])
assert all([await models.Project.where(id=i).first() for i in ids])

async def test_insert_dict_with_apply_schema(self, project_id):
f1 = dict(name="f1")
f2 = dict(name="f2")
f3 = dict(name="f3")
ids = await models.Project.insert_many([f1, f2, f3])
p1 = dict(name="p1")
p2 = dict(name="p2")
p3 = dict(name="p3")
ids = await models.Project.insert_many([p1, p2, p3])
assert all([await models.Project.where(id=i).first() for i in ids])

async def test_get_more_than_100_objects(self, project_id):
await models.Project.where().delete()
await models.Project.insert_many(
[{"name": str(uuid.uuid4())} for i in range(108)]
)
flows = await models.Project.where().get()
assert len(flows) == 108
projects = await models.Project.where().get()
assert len(projects) == 108


class TestModelQuery:
@pytest.fixture
async def flow_ids(self):
# delete default flows
async def project_ids(self):
# delete all projects
await models.Project.where({}).delete()

f1 = dict(name="f1")
f2 = dict(name="f2")
f3 = dict(name="f3")
return await models.Project.insert_many([f1, f2, f3])
p1 = dict(name="p1")
p2 = dict(name="p2")
p3 = dict(name="p3")
return await models.Project.insert_many([p1, p2, p3])

async def test_get(self, flow_ids):
flows = await orm.ModelQuery(model=models.Project).get()
assert len(flows) == 3
assert all([isinstance(p, models.Project) for p in flows])
assert set(p.id for p in flows) == set(flow_ids)
async def test_get(self, project_ids):
projects = await orm.ModelQuery(model=models.Project).get()
assert len(projects) == 3
assert all([isinstance(p, models.Project) for p in projects])
assert set(p.id for p in projects) == set(project_ids)

async def test_get_selection_set(
self,
flow_ids,
project_ids,
):

flows = await orm.ModelQuery(model=models.Project).get(selection_set="name")
assert set(p.name for p in flows) == {"f1", "f2", "f3"}
projects = await orm.ModelQuery(model=models.Project).get(selection_set="name")
assert set(p.name for p in projects) == {"p1", "p2", "p3"}

async def test_get_limit(self, flow_ids):
flows = await orm.ModelQuery(model=models.Project).get(limit=2)
assert len(flows) == 2
async def test_get_limit(self, project_ids):
projects = await orm.ModelQuery(model=models.Project).get(limit=2)
assert len(projects) == 2

async def test_first(self, flow_ids):
flow = await orm.ModelQuery(model=models.Project).first()
assert isinstance(flow, models.Project)
async def test_first(self, project_ids):
project = await orm.ModelQuery(model=models.Project).first()
assert isinstance(project, models.Project)

async def test_count(
self,
flow_ids,
project_ids,
):
assert await orm.ModelQuery(model=models.Project, where={}).count() == 3

async def test_count_where(
self,
flow_ids,
project_ids,
):
assert (
await models.Project.where(
{
"name": {"_neq": "f2"},
"name": {"_neq": "p2"},
}
).count()
== 2
)

async def test_update_set(
self,
flow_ids,
project_ids,
):
await models.Project.where({"id": {"_eq": flow_ids[0]}}).update(
await models.Project.where({"id": {"_eq": project_ids[0]}}).update(
set=dict(name="test")
)
names = set(p.name for p in await models.Project.where({}).get("name"))
assert names == {"test", "f2", "f3"}
assert names == {"test", "p2", "p3"}

async def test_update_increment(
self,
Expand Down Expand Up @@ -429,11 +430,78 @@ async def test_update_delete_key_obj(

async def test_delete(
self,
flow_ids,
project_ids,
):
await models.Project.where({"id": {"_eq": flow_ids[0]}}).delete()
await models.Project.where({"id": {"_eq": project_ids[0]}}).delete()
names = set(p.name for p in await models.Project.where({}).get("name"))
assert names == {"f2", "f3"}
assert names == {"p2", "p3"}


class TestAggregates:
@pytest.fixture(autouse=True)
async def flow_ids(self, tenant_id, project_id, project_id_2, flow_group_id):
await models.Flow.where().delete()
await models.Flow.insert_many(
[
dict(
tenant_id=tenant_id,
project_id=project_id,
name="a",
version=1,
flow_group_id=flow_group_id,
),
dict(
tenant_id=tenant_id,
project_id=project_id,
name="b",
version=2,
flow_group_id=flow_group_id,
),
dict(
tenant_id=tenant_id,
project_id=project_id_2,
name="a",
version=3,
flow_group_id=flow_group_id,
),
]
)

async def test_count(self):
result = await models.Flow.where().count()
assert result == 3

async def test_count_where(self, project_id):
result = await models.Flow.where({"project_id": {"_eq": project_id}}).count()
assert result == 2

async def test_count_distinct(self, project_id):
result = await models.Flow.where().count(distinct_on=[EnumValue("name")])
assert result == 2

async def test_sum(self):
result = await models.Flow.where().sum(["version"])
assert result["version"] == 6

async def test_sum_where(self):
result = await models.Flow.where({"name": {"_eq": "a"}}).sum(["version"])
assert result["version"] == 4

async def test_max(self):
result = await models.Flow.where().max(["version"])
assert result["version"] == 3

async def test_max_where(self):
result = await models.Flow.where({"name": {"_eq": "b"}}).max(["version"])
assert result["version"] == 2

async def test_min(self):
result = await models.Flow.where().min(["version"])
assert result["version"] == 1

async def test_min_where(self):
result = await models.Flow.where({"name": {"_eq": "b"}}).min(["version"])
assert result["version"] == 2


class TestRunModels:
Expand Down Expand Up @@ -495,14 +563,17 @@ async def test_get_select_root_field_graphql(self, monkeypatch):
mock = CoroutineMock()
monkeypatch.setattr("prefect_server.database.hasura.HasuraClient.execute", mock)
graphql = await self.TestModel().where().get()
assert mock.awaited_once_with(query={"query": {"select: abc(where: {})": "id"}})
mock.assert_called_once_with(
query={"query": {"select: abc(where: {})": "id"}}, as_box=False
)

async def test_get_select_aggregate_root_field_graphql(self, monkeypatch):
mock = CoroutineMock()
monkeypatch.setattr("prefect_server.database.hasura.HasuraClient.execute", mock)
graphql = await self.TestModel().where().get()
assert mock.awaited_once_with(
query={"count": {"abc_aggregate(where: {})": "id"}}
graphql = await self.TestModel().where().count()
mock.assert_called_once_with(
{"query": {"count: abc_aggregate(where: {})": {"aggregate": "count"}}},
as_box=False,
)

async def test_get_insert_root_field_graphql(self):
Expand All @@ -525,16 +596,24 @@ async def test_get_custom_select_root_field_graphql(self, monkeypatch):
mock = CoroutineMock()
monkeypatch.setattr("prefect_server.database.hasura.HasuraClient.execute", mock)
graphql = await self.TestCustomModel().where().get()
assert mock.awaited_once_with(
query={"query": {"custom_select_xyz(where: {})": "id"}}
mock.assert_called_once_with(
query={"query": {"select: custom_select_xyz(where: {})": "id"}},
as_box=False,
)

async def test_get_custom_select_aggregate_root_field_graphql(self, monkeypatch):
mock = CoroutineMock()
monkeypatch.setattr("prefect_server.database.hasura.HasuraClient.execute", mock)
graphql = await self.TestCustomModel().where().get()
assert mock.awaited_once_with(
query={"count": {"custom_select_aggregate_xyz(where: {})": "id"}}
graphql = await self.TestCustomModel().where().count()
mock.assert_called_once_with(
{
"query": {
"count: custom_select_aggregate_xyz(where: {})": {
"aggregate": "count"
}
}
},
as_box=False,
)

async def test_get_custom_insert_root_field_graphql(self):
Expand Down

0 comments on commit 3f804d2

Please sign in to comment.