Skip to content

Commit

Permalink
Fix count aggregation on optional fields (#92)
Browse files Browse the repository at this point in the history
* Fix count aggregation on optional fields
  • Loading branch information
suresh-prakash authored Jul 8, 2022
1 parent 111185c commit 04a5518
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,20 @@ public void testAggregateSimple() throws IOException {
assertSizeEqual(query, "mongo/count_response.json");
}

@Test
public void testOptionalFieldCount() throws IOException {
Query query =
Query.builder()
.addSelection(
AggregateExpression.of(COUNT, IdentifierExpression.of("props.seller.name")),
"count")
.build();

Iterator<Document> resultDocs = collection.aggregate(query);
assertDocsEqual(resultDocs, "mongo/optional_field_count_response.json");
assertSizeEqual(query, "mongo/optional_field_count_response.json");
}

@Test
public void testAggregateWithDuplicateSelections() throws IOException {
Query query =
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[
{
"count": 4
}
]
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import static java.util.Collections.unmodifiableMap;
import static org.hypertrace.core.documentstore.expression.operators.AggregationOperator.AVG;
import static org.hypertrace.core.documentstore.expression.operators.AggregationOperator.COUNT;
import static org.hypertrace.core.documentstore.expression.operators.AggregationOperator.DISTINCT;
import static org.hypertrace.core.documentstore.expression.operators.AggregationOperator.MAX;
import static org.hypertrace.core.documentstore.expression.operators.AggregationOperator.MIN;
Expand All @@ -26,6 +27,7 @@ final class MongoAggregateExpressionParser extends MongoSelectTypeExpressionPars
put(SUM, "$sum");
put(MIN, "$min");
put(MAX, "$max");
put(COUNT, "$push");
}
});

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package org.hypertrace.core.documentstore.mongo.query.transformer;

import static org.hypertrace.core.documentstore.expression.operators.AggregationOperator.COUNT;
import static org.hypertrace.core.documentstore.expression.operators.AggregationOperator.DISTINCT_COUNT;
import static org.hypertrace.core.documentstore.expression.operators.FunctionOperator.LENGTH;
import static org.hypertrace.core.documentstore.mongo.MongoUtils.encodeKey;
Expand Down Expand Up @@ -82,9 +83,10 @@ public Optional<SelectionSpec> visit(final AggregateExpression expression) {
final String encodedAlias = encodeKey(alias);
final SelectTypeExpression pairingExpression;

if (expression.getAggregator() == DISTINCT_COUNT) {
// Since MongoDB doesn't support $distinctCount in aggregations, we convert this to
// $addToSet function. So, we need to project $size(set) instead of just the alias
if (expression.getAggregator() == DISTINCT_COUNT || expression.getAggregator() == COUNT) {
// Since MongoDB doesn't support $distinctCount and $count(optional_field) in aggregations,
// we convert them to $addToSet and $push functions respectively.
// So, we need to project $size(set) or $size(list) instead of just the alias in these cases.
pairingExpression =
FunctionExpression.builder()
.operator(LENGTH)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
package org.hypertrace.core.documentstore.mongo.query.transformer;

import static java.util.Collections.unmodifiableMap;
import static org.hypertrace.core.documentstore.expression.operators.AggregationOperator.COUNT;
import static org.hypertrace.core.documentstore.expression.operators.AggregationOperator.DISTINCT;
import static org.hypertrace.core.documentstore.expression.operators.AggregationOperator.DISTINCT_COUNT;
import static org.hypertrace.core.documentstore.expression.operators.AggregationOperator.SUM;
import static org.hypertrace.core.documentstore.mongo.MongoCollection.ID_KEY;
import static org.hypertrace.core.documentstore.mongo.MongoUtils.FIELD_SEPARATOR;
import static org.hypertrace.core.documentstore.mongo.MongoUtils.encodeKey;

import java.util.EnumMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import org.hypertrace.core.documentstore.expression.impl.AggregateExpression;
import org.hypertrace.core.documentstore.expression.impl.ConstantExpression;
Expand Down Expand Up @@ -72,9 +72,6 @@
* supported
*/
final class MongoSelectionsUpdatingTransformation implements SelectTypeExpressionVisitor {
private static final Function<AggregateExpression, AggregateExpression> COUNT_HANDLER =
expression -> AggregateExpression.of(SUM, ConstantExpression.of(1));

private static final Function<AggregateExpression, AggregateExpression> DISTINCT_COUNT_HANDLER =
expression -> AggregateExpression.of(DISTINCT, expression.getExpression());

Expand All @@ -84,16 +81,15 @@ final class MongoSelectionsUpdatingTransformation implements SelectTypeExpressio
new EnumMap<>(AggregationOperator.class) {
{
put(DISTINCT_COUNT, DISTINCT_COUNT_HANDLER);
put(COUNT, COUNT_HANDLER);
}
});

private final List<GroupTypeExpression> groupTypeExpressions;
private final Set<GroupTypeExpression> groupTypeExpressions;
private final SelectionSpec source;

MongoSelectionsUpdatingTransformation(
List<GroupTypeExpression> groupTypeExpressions, SelectionSpec source) {
this.groupTypeExpressions = groupTypeExpressions;
this.groupTypeExpressions = new HashSet<>(groupTypeExpressions);
this.source = source;
}

Expand All @@ -118,16 +114,7 @@ public SelectionSpec visit(final FunctionExpression expression) {
@SuppressWarnings("unchecked")
@Override
public SelectionSpec visit(final IdentifierExpression expression) {
GroupTypeExpression matchingGroup = null;

for (final GroupTypeExpression group : groupTypeExpressions) {
if (expression.equals(group)) {
matchingGroup = group;
break;
}
}

if (matchingGroup == null) {
if (!groupTypeExpressions.contains(expression)) {
return source;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -311,11 +311,37 @@ public void testSimpleAggregate() {
+ " { "
+ " _id: null, "
+ " total: {"
+ " \"$sum\": 1"
+ " \"$push\": 1"
+ " }"
+ " }"
+ "}"),
BasicDBObject.parse("{" + "\"$project\": {" + " \"total\": \"$total\"" + "}" + "}"));
BasicDBObject.parse(
"{" + "\"$project\": {" + " \"total\": {\"$size\": \"$total\"}" + "}" + "}"));

testAggregation(query, pipeline);
}

@Test
public void testFieldCount() {
Query query =
Query.builder()
.addSelection(AggregateExpression.of(COUNT, IdentifierExpression.of("path")), "total")
.build();

List<BasicDBObject> pipeline =
List.of(
BasicDBObject.parse(
"{"
+ "\"$group\": "
+ " { "
+ " _id: null, "
+ " total: {"
+ " \"$push\": \"$path\""
+ " }"
+ " }"
+ "}"),
BasicDBObject.parse(
"{" + "\"$project\": {" + " \"total\": { \"$size\": \"$total\" }" + "}" + "}"));

testAggregation(query, pipeline);
}
Expand All @@ -339,7 +365,7 @@ public void testAggregateWithProjections() {
+ " { "
+ " _id: null, "
+ " total: {"
+ " \"$sum\": 1"
+ " \"$push\": 1"
+ " }"
+ " }"
+ "}"),
Expand All @@ -348,7 +374,7 @@ public void testAggregateWithProjections() {
+ "\"$project\": "
+ " {"
+ " name: 1,"
+ " total: \"$total\""
+ " total: {\"$size\": \"$total\"}"
+ " }"
+ "}"));

Expand Down

0 comments on commit 04a5518

Please sign in to comment.