Skip to content
This repository has been archived by the owner on Feb 20, 2023. It is now read-only.

Add support for aliases in GROUP BY. Add Postgres-style date_part. #1607

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions script/testing/junit/traces/groupby.test
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
statement ok
CREATE TABLE foo (a INT, b INT);

statement ok
INSERT INTO foo VALUES (1, 2);

statement ok
INSERT INTO foo VALUES (1, 3);

statement ok
INSERT INTO foo VALUES (2, 2);

query I rowsort
SELECT a, sum(b) FROM foo GROUP BY a;
----
1
5
2
2

query I rowsort
SELECT a as x, sum(b) FROM foo GROUP BY x;
----
1
5
2
2

statement ok
DROP TABLE foo;
116 changes: 98 additions & 18 deletions src/binder/bind_node_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -439,9 +439,6 @@ void BindNodeVisitor::Visit(common::ManagedPointer<parser::SelectStatement> node
if (node->GetSelectLimit() != nullptr)
node->GetSelectLimit()->Accept(common::ManagedPointer(this).CastManagedPointerTo<SqlNodeVisitor>());

if (node->GetSelectGroupBy() != nullptr)
node->GetSelectGroupBy()->Accept(common::ManagedPointer(this).CastManagedPointerTo<SqlNodeVisitor>());

std::vector<common::ManagedPointer<parser::AbstractExpression>> new_select_list;

BINDER_LOG_TRACE("Gathering select columns...");
Expand Down Expand Up @@ -481,7 +478,13 @@ void BindNodeVisitor::Visit(common::ManagedPointer<parser::SelectStatement> node
node->SetSelectColumns(new_select_list);
node->SetDepth(context_->GetDepth());

if (node->GetSelectGroupBy() != nullptr) {
UnaliasGroupBy(node->GetSelectGroupBy(), node->GetSelectColumns());
node->GetSelectGroupBy()->Accept(common::ManagedPointer(this).CastManagedPointerTo<SqlNodeVisitor>());
}

if (node->GetSelectOrderBy() != nullptr) {
UnaliasOrderBy(node->GetSelectOrderBy(), node->GetSelectColumns());
UnifyOrderByExpression(node->GetSelectOrderBy(), node->GetSelectColumns());
node->GetSelectOrderBy()->Accept(common::ManagedPointer(this).CastManagedPointerTo<SqlNodeVisitor>());
}
Expand Down Expand Up @@ -793,12 +796,12 @@ void BindNodeVisitor::UnifyOrderByExpression(
common::ManagedPointer<parser::OrderByDescription> order_by_description,
const std::vector<common::ManagedPointer<parser::AbstractExpression>> &select_items) {
auto &exprs = order_by_description->GetOrderByExpressions();
auto size = order_by_description->GetOrderByExpressionsSize();
for (size_t idx = 0; idx < size; idx++) {
if (exprs[idx].Get()->GetExpressionType() == noisepage::parser::ExpressionType::VALUE_CONSTANT) {
auto constant_value_expression = exprs[idx].CastManagedPointerTo<parser::ConstantValueExpression>();
for (auto &expr : exprs) {
// Rewrite integer constant expressions to use the corresponding SELECT column expression instead.
if (expr->GetExpressionType() == noisepage::parser::ExpressionType::VALUE_CONSTANT) {
auto constant_value_expression = expr.CastManagedPointerTo<parser::ConstantValueExpression>();
type::TypeId type = constant_value_expression->GetReturnValueType();
int64_t column_id = 0;
int64_t column_id;
switch (type) {
case type::TypeId::TINYINT:
case type::TypeId::SMALLINT:
Expand All @@ -816,21 +819,98 @@ void BindNodeVisitor::UnifyOrderByExpression(
throw BINDER_EXCEPTION(fmt::format("ORDER BY position \"{}\" is not in select list", std::to_string(column_id)),
common::ErrorCode::ERRCODE_UNDEFINED_COLUMN);
}
exprs[idx] = select_items[column_id - 1];
} else if (exprs[idx].Get()->GetExpressionType() == noisepage::parser::ExpressionType::COLUMN_VALUE) {
auto column_value_expression = exprs[idx].CastManagedPointerTo<parser::ColumnValueExpression>();
std::string column_name = column_value_expression->GetColumnName();
if (!column_name.empty()) {
for (auto select_expression : select_items) {
auto abstract_select_expression = select_expression.CastManagedPointerTo<parser::AbstractExpression>();
if (abstract_select_expression->GetExpressionName() == column_name) {
exprs[idx] = select_expression;
break;
expr = select_items[column_id - 1];
}
}
}

common::ManagedPointer<parser::AbstractExpression> BindNodeVisitor::UnaliasExpression(
common::ManagedPointer<parser::AbstractExpression> expr,
const std::vector<common::ManagedPointer<parser::AbstractExpression>> &select_items) {
// get_aliased will check if the given column value expression is an alias for any of the select columns provided.
// If an alias exists, the underlying select column is returned.
// If no alias exists, nullptr is returned.
auto get_aliased = [](common::ManagedPointer<parser::ColumnValueExpression> cve,
const std::vector<common::ManagedPointer<parser::AbstractExpression>> &select_items) {
const std::string &table_name = cve->GetTableName();
const std::string &column_name = cve->GetColumnName();
if (!column_name.empty()) {
common::ManagedPointer<parser::AbstractExpression> select_expr = nullptr;
for (auto &select_expression : select_items) {
// Check if the expression specifies a table. If so, we will also check the select expression.
bool same_table = true;
if (!table_name.empty() && select_expression->GetExpressionType() == parser::ExpressionType::COLUMN_VALUE) {
const auto &tbl = select_expression.CastManagedPointerTo<parser::ColumnValueExpression>()->GetTableName();
same_table = table_name == tbl;
}
// The column name must match either the select name or the alias name.
bool same_as_name = column_name == select_expression->GetExpressionName();
bool same_as_alias = column_name == select_expression->GetAlias();
if (same_table && (same_as_name || same_as_alias)) {
// If an expression was already found, then something is ambiguous in the SQL.
if (select_expr != nullptr) {
throw BINDER_EXCEPTION(fmt::format("Ambiguous alias \"{}\"", column_name),
common::ErrorCode::ERRCODE_AMBIGUOUS_COLUMN);
}
// Otherwise, set the expression.
select_expr = select_expression;
}
}
// If an expression was found, return it.
if (select_expr != nullptr) {
return select_expr;
}
}
return common::ManagedPointer<parser::AbstractExpression>(nullptr);
};

// Check if the current expression itself is a mere alias for some SELECT column, if so, return that column.
if (expr->GetExpressionType() == parser::ExpressionType::COLUMN_VALUE) {
auto cve = expr.CastManagedPointerTo<parser::ColumnValueExpression>();
auto select_expr = get_aliased(cve, select_items);
if (select_expr != nullptr) {
return select_expr;
}
}

// Otherwise, the current expression is not just an alias, but may still contain aliases.
// For each child, if the child is a ColumnValueExpression, the child may be an alias referencing some column
// from the SELECT. In this case, replace the ColumnValueExpression child with the SELECT column.
for (size_t i = 0; i < expr->GetChildrenSize(); ++i) {
auto child = expr->GetChild(i);
if (child->GetExpressionType() == noisepage::parser::ExpressionType::COLUMN_VALUE) {
auto cve = child.CastManagedPointerTo<parser::ColumnValueExpression>();
auto select_expr = get_aliased(cve, select_items);
if (select_expr != nullptr) {
expr->SetChild(i, select_expr);
}
}
}

// Repeat this unaliasing for all of the expression's children.
for (size_t i = 0; i < expr->GetChildrenSize(); ++i) {
expr->SetChild(i, UnaliasExpression(expr->GetChild(i), select_items));
}

return expr;
}

void BindNodeVisitor::UnaliasOrderBy(
common::ManagedPointer<parser::OrderByDescription> order_by_description,
const std::vector<common::ManagedPointer<parser::AbstractExpression>> &select_items) {
auto &exprs = order_by_description->GetOrderByExpressions();
for (auto &expr : exprs) {
expr = UnaliasExpression(expr, select_items);
}
}

void BindNodeVisitor::UnaliasGroupBy(
common::ManagedPointer<parser::GroupByDescription> group_by_description,
const std::vector<common::ManagedPointer<parser::AbstractExpression>> &select_items) {
auto &exprs = group_by_description->GetColumns();
for (auto &expr : exprs) {
expr = UnaliasExpression(expr, select_items);
}
}

void BindNodeVisitor::InitTableRef(const common::ManagedPointer<parser::TableRef> node) {
Expand Down
27 changes: 16 additions & 11 deletions src/catalog/postgres/pg_proc_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -364,12 +364,13 @@ proc_oid_t PgProcImpl::GetProcOid(const common::ManagedPointer<transaction::Tran

void PgProcImpl::BootstrapProcs(const common::ManagedPointer<transaction::TransactionContext> txn,
const common::ManagedPointer<DatabaseCatalog> dbc) {
const auto INT = dbc->GetTypeOidForType(type::TypeId::INTEGER); // NOLINT
const auto STR = dbc->GetTypeOidForType(type::TypeId::VARCHAR); // NOLINT
const auto REAL = dbc->GetTypeOidForType(type::TypeId::REAL); // NOLINT
const auto DATE = dbc->GetTypeOidForType(type::TypeId::DATE); // NOLINT
const auto BOOL = dbc->GetTypeOidForType(type::TypeId::BOOLEAN); // NOLINT
const auto VAR = dbc->GetTypeOidForType(type::TypeId::VARIADIC); // NOLINT
const auto INT = dbc->GetTypeOidForType(type::TypeId::INTEGER); // NOLINT
const auto STR = dbc->GetTypeOidForType(type::TypeId::VARCHAR); // NOLINT
const auto REAL = dbc->GetTypeOidForType(type::TypeId::REAL); // NOLINT
const auto DATE = dbc->GetTypeOidForType(type::TypeId::DATE); // NOLINT
const auto TIMESTAMP = dbc->GetTypeOidForType(type::TypeId::TIMESTAMP); // NOLINT
const auto BOOL = dbc->GetTypeOidForType(type::TypeId::BOOLEAN); // NOLINT
const auto VAR = dbc->GetTypeOidForType(type::TypeId::VARIADIC); // NOLINT

auto create_fn = [&](const std::string &procname, const std::vector<std::string> &args,
const std::vector<type_oid_t> &arg_types, const std::vector<type_oid_t> &all_arg_types,
Expand Down Expand Up @@ -441,7 +442,8 @@ void PgProcImpl::BootstrapProcs(const common::ManagedPointer<transaction::Transa
create_fn("replication_get_last_txn_id", {}, {}, {}, INT, false);

// Other functions.
create_fn("date_part", {"date, date_part_type"}, {DATE, INT}, {DATE, INT}, INT, false);
create_fn("date_part", {"date", "date_part_type"}, {DATE, INT}, {DATE, INT}, INT, false);
create_fn("date_part", {"text", "timestamp"}, {STR, TIMESTAMP}, {STR, TIMESTAMP}, INT, false);
create_fn("version", {}, {}, {}, STR, false);

CreateProcedure(
Expand Down Expand Up @@ -481,9 +483,11 @@ void PgProcImpl::BootstrapProcContext(const common::ManagedPointer<transaction::

void PgProcImpl::BootstrapProcContexts(const common::ManagedPointer<transaction::TransactionContext> txn,
const common::ManagedPointer<DatabaseCatalog> dbc) {
constexpr auto REAL = type::TypeId::REAL; // NOLINT
constexpr auto INT = type::TypeId::INTEGER; // NOLINT
constexpr auto VAR = type::TypeId::VARCHAR; // NOLINT
constexpr auto REAL = type::TypeId::REAL; // NOLINT
constexpr auto INT = type::TypeId::INTEGER; // NOLINT
constexpr auto VAR = type::TypeId::VARCHAR; // NOLINT
constexpr auto DATE = type::TypeId::DATE; // NOLINT
constexpr auto TIMESTAMP = type::TypeId::TIMESTAMP; // NOLINT

auto create_fn = [&](std::string &&func_name, type::TypeId func_ret_type, std::vector<type::TypeId> &&arg_types,
execution::ast::Builtin builtin, bool is_exec_ctx_required) {
Expand Down Expand Up @@ -553,7 +557,8 @@ void PgProcImpl::BootstrapProcContexts(const common::ManagedPointer<transaction:
create_fn("replication_get_last_txn_id", INT, {}, execution::ast::Builtin::ReplicationGetLastTransactionId, true);

// Other functions.
create_fn("date_part", INT, {type::TypeId::DATE, INT}, execution::ast::Builtin::DatePart, false);
create_fn("date_part", INT, {DATE, INT}, execution::ast::Builtin::DatePart, false);
create_fn("date_part", INT, {VAR, TIMESTAMP}, execution::ast::Builtin::DatePartPostgres, false);
create_fn("version", VAR, {}, execution::ast::Builtin::Version, true);

create_fn("nprunnersemitint", INT, {INT, INT, INT, INT}, execution::ast::Builtin::NpRunnersEmitInt, true);
Expand Down
36 changes: 28 additions & 8 deletions src/execution/sema/sema_builtin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -241,25 +241,44 @@ void Sema::CheckBuiltinStringLikeCall(ast::CallExpr *call) {
}

void Sema::CheckBuiltinDateFunctionCall(ast::CallExpr *call, ast::Builtin builtin) {
if (!CheckArgCountAtLeast(call, 1)) {
if (!CheckArgCount(call, 2)) {
return;
}
// First arg must be a date.
auto date_kind = ast::BuiltinType::Date;
auto integer_kind = ast::BuiltinType::Integer;
if (!call->Arguments()[0]->GetType()->IsSpecificBuiltin(date_kind)) {
ReportIncorrectCallArg(call, 0, GetBuiltinType(date_kind));
return;
}
auto str_kind = ast::BuiltinType::StringVal;
auto timestamp_kind = ast::BuiltinType::Timestamp;

switch (builtin) {
case ast::Builtin::DatePart:
case ast::Builtin::DatePart: {
if (!call->Arguments()[0]->GetType()->IsSpecificBuiltin(date_kind)) {
ReportIncorrectCallArg(call, 0, GetBuiltinType(date_kind));
return;
}
if (!call->Arguments()[1]->GetType()->IsSpecificBuiltin(integer_kind)) {
ReportIncorrectCallArg(call, 1, GetBuiltinType(integer_kind));
return;
}
call->SetType(GetBuiltinType(ast::BuiltinType::Integer));
return;
}
case ast::Builtin::DatePartPostgres: {
if (!call->Arguments()[0]->GetType()->IsSpecificBuiltin(str_kind)) {
ReportIncorrectCallArg(call, 0, GetBuiltinType(str_kind));
return;
}
if (!call->Arguments()[1]->GetType()->IsSpecificBuiltin(timestamp_kind)) {
ReportIncorrectCallArg(call, 1, GetBuiltinType(timestamp_kind));
return;
}
auto text = call->Arguments()[0]->As<ast::CallExpr>()->Arguments()[0]->As<ast::LitExpr>()->StringVal().GetView();
if (text == "year") {
call->SetType(GetBuiltinType(ast::BuiltinType::Integer));
} else {
UNREACHABLE("Case not handled.");
}
return;
}
default:
// TODO(Amadou): Support other date function.
UNREACHABLE("Impossible date function");
Expand Down Expand Up @@ -3299,7 +3318,8 @@ void Sema::CheckBuiltinCall(ast::CallExpr *call) {
CheckBuiltinStringLikeCall(call);
break;
}
case ast::Builtin::DatePart: {
case ast::Builtin::DatePart:
case ast::Builtin::DatePartPostgres: {
CheckBuiltinDateFunctionCall(call, builtin);
break;
}
Expand Down
40 changes: 30 additions & 10 deletions src/execution/vm/bytecode_generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -636,19 +636,38 @@ void BytecodeGenerator::VisitSqlStringLikeCall(ast::CallExpr *call) {
}

void BytecodeGenerator::VisitBuiltinDateFunctionCall(ast::CallExpr *call, ast::Builtin builtin) {
auto dest = GetExecutionResult()->GetOrCreateDestination(call->GetType());
auto input = VisitExpressionForSQLValue(call->Arguments()[0]);
auto date_type =
sql::DatePartType(call->Arguments()[1]->As<ast::CallExpr>()->Arguments()[0]->As<ast::LitExpr>()->Int64Val());
switch (builtin) {
case ast::Builtin::DatePart: {
auto dest = GetExecutionResult()->GetOrCreateDestination(call->GetType());
auto input = VisitExpressionForSQLValue(call->Arguments()[0]);
auto date_type =
sql::DatePartType(call->Arguments()[1]->As<ast::CallExpr>()->Arguments()[0]->As<ast::LitExpr>()->Int64Val());

switch (date_type) {
case sql::DatePartType::YEAR:
GetEmitter()->Emit(Bytecode::ExtractYearFromDate, dest, input);
break;
default:
UNREACHABLE("Unimplemented DatePartType");
}
GetExecutionResult()->SetDestination(dest);
break;
}
case ast::Builtin::DatePartPostgres: {
auto dest = GetExecutionResult()->GetOrCreateDestination(call->GetType());
auto text = call->Arguments()[0]->As<ast::CallExpr>()->Arguments()[0]->As<ast::LitExpr>()->StringVal().GetView();
auto timestamp = VisitExpressionForSQLValue(call->Arguments()[1]);

switch (date_type) {
case sql::DatePartType::YEAR:
GetEmitter()->Emit(Bytecode::ExtractYearFromDate, dest, input);
if (text == "year") {
GetEmitter()->Emit(Bytecode::ExtractYearFromTimestamp, dest, timestamp);
} else {
UNREACHABLE("Unimplemented DatePartPostgres case.");
}
break;
}
default:
UNREACHABLE("Unimplemented DatePartType");
UNREACHABLE("Unimplemented Date builtin");
}
GetExecutionResult()->SetDestination(dest);
}

void BytecodeGenerator::VisitBuiltinTableIterCall(ast::CallExpr *call, ast::Builtin builtin) {
Expand Down Expand Up @@ -2682,7 +2701,8 @@ void BytecodeGenerator::VisitBuiltinCallExpr(ast::CallExpr *call) {
VisitSqlStringLikeCall(call);
break;
}
case ast::Builtin::DatePart: {
case ast::Builtin::DatePart:
case ast::Builtin::DatePartPostgres: {
VisitBuiltinDateFunctionCall(call, builtin);
break;
}
Expand Down
7 changes: 7 additions & 0 deletions src/execution/vm/vm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2590,6 +2590,13 @@ void VM::Interpret(const uint8_t *ip, Frame *frame) { // NOLINT
DISPATCH_NEXT();
}

OP(ExtractYearFromTimestamp) : {
auto *result = frame->LocalAt<sql::Integer *>(READ_LOCAL_ID());
auto *input = frame->LocalAt<sql::TimestampVal *>(READ_LOCAL_ID());
OpExtractYearFromTimestamp(result, input);
DISPATCH_NEXT();
}

// -------------------------------------------------------
// Replication functions
// -------------------------------------------------------
Expand Down
Loading