diff --git a/script/testing/junit/traces/groupby.test b/script/testing/junit/traces/groupby.test new file mode 100644 index 0000000000..6e0d7e8d94 --- /dev/null +++ b/script/testing/junit/traces/groupby.test @@ -0,0 +1,30 @@ +statement ok +CREATE TABLE foo (a INT, b INT); + +statement ok +INSERT INTO foo VALUES (1, 2); + +statement ok +INSERT INTO foo VALUES (1, 3); + +statement ok +INSERT INTO foo VALUES (2, 2); + +query I rowsort +SELECT a, sum(b) FROM foo GROUP BY a; +---- +1 +5 +2 +2 + +query I rowsort +SELECT a as x, sum(b) FROM foo GROUP BY x; +---- +1 +5 +2 +2 + +statement ok +DROP TABLE foo; diff --git a/src/binder/bind_node_visitor.cpp b/src/binder/bind_node_visitor.cpp index 0121e42a95..e26eea6bf1 100644 --- a/src/binder/bind_node_visitor.cpp +++ b/src/binder/bind_node_visitor.cpp @@ -439,9 +439,6 @@ void BindNodeVisitor::Visit(common::ManagedPointer node if (node->GetSelectLimit() != nullptr) node->GetSelectLimit()->Accept(common::ManagedPointer(this).CastManagedPointerTo()); - if (node->GetSelectGroupBy() != nullptr) - node->GetSelectGroupBy()->Accept(common::ManagedPointer(this).CastManagedPointerTo()); - std::vector> new_select_list; BINDER_LOG_TRACE("Gathering select columns..."); @@ -481,7 +478,13 @@ void BindNodeVisitor::Visit(common::ManagedPointer node node->SetSelectColumns(new_select_list); node->SetDepth(context_->GetDepth()); + if (node->GetSelectGroupBy() != nullptr) { + UnaliasGroupBy(node->GetSelectGroupBy(), node->GetSelectColumns()); + node->GetSelectGroupBy()->Accept(common::ManagedPointer(this).CastManagedPointerTo()); + } + if (node->GetSelectOrderBy() != nullptr) { + UnaliasOrderBy(node->GetSelectOrderBy(), node->GetSelectColumns()); UnifyOrderByExpression(node->GetSelectOrderBy(), node->GetSelectColumns()); node->GetSelectOrderBy()->Accept(common::ManagedPointer(this).CastManagedPointerTo()); } @@ -793,12 +796,12 @@ void BindNodeVisitor::UnifyOrderByExpression( common::ManagedPointer order_by_description, const std::vector> &select_items) { auto &exprs = order_by_description->GetOrderByExpressions(); - auto size = order_by_description->GetOrderByExpressionsSize(); - for (size_t idx = 0; idx < size; idx++) { - if (exprs[idx].Get()->GetExpressionType() == noisepage::parser::ExpressionType::VALUE_CONSTANT) { - auto constant_value_expression = exprs[idx].CastManagedPointerTo(); + for (auto &expr : exprs) { + // Rewrite integer constant expressions to use the corresponding SELECT column expression instead. + if (expr->GetExpressionType() == noisepage::parser::ExpressionType::VALUE_CONSTANT) { + auto constant_value_expression = expr.CastManagedPointerTo(); type::TypeId type = constant_value_expression->GetReturnValueType(); - int64_t column_id = 0; + int64_t column_id; switch (type) { case type::TypeId::TINYINT: case type::TypeId::SMALLINT: @@ -816,21 +819,98 @@ void BindNodeVisitor::UnifyOrderByExpression( throw BINDER_EXCEPTION(fmt::format("ORDER BY position \"{}\" is not in select list", std::to_string(column_id)), common::ErrorCode::ERRCODE_UNDEFINED_COLUMN); } - exprs[idx] = select_items[column_id - 1]; - } else if (exprs[idx].Get()->GetExpressionType() == noisepage::parser::ExpressionType::COLUMN_VALUE) { - auto column_value_expression = exprs[idx].CastManagedPointerTo(); - std::string column_name = column_value_expression->GetColumnName(); - if (!column_name.empty()) { - for (auto select_expression : select_items) { - auto abstract_select_expression = select_expression.CastManagedPointerTo(); - if (abstract_select_expression->GetExpressionName() == column_name) { - exprs[idx] = select_expression; - break; + expr = select_items[column_id - 1]; + } + } +} + +common::ManagedPointer BindNodeVisitor::UnaliasExpression( + common::ManagedPointer expr, + const std::vector> &select_items) { + // get_aliased will check if the given column value expression is an alias for any of the select columns provided. + // If an alias exists, the underlying select column is returned. + // If no alias exists, nullptr is returned. + auto get_aliased = [](common::ManagedPointer cve, + const std::vector> &select_items) { + const std::string &table_name = cve->GetTableName(); + const std::string &column_name = cve->GetColumnName(); + if (!column_name.empty()) { + common::ManagedPointer select_expr = nullptr; + for (auto &select_expression : select_items) { + // Check if the expression specifies a table. If so, we will also check the select expression. + bool same_table = true; + if (!table_name.empty() && select_expression->GetExpressionType() == parser::ExpressionType::COLUMN_VALUE) { + const auto &tbl = select_expression.CastManagedPointerTo()->GetTableName(); + same_table = table_name == tbl; + } + // The column name must match either the select name or the alias name. + bool same_as_name = column_name == select_expression->GetExpressionName(); + bool same_as_alias = column_name == select_expression->GetAlias(); + if (same_table && (same_as_name || same_as_alias)) { + // If an expression was already found, then something is ambiguous in the SQL. + if (select_expr != nullptr) { + throw BINDER_EXCEPTION(fmt::format("Ambiguous alias \"{}\"", column_name), + common::ErrorCode::ERRCODE_AMBIGUOUS_COLUMN); } + // Otherwise, set the expression. + select_expr = select_expression; } } + // If an expression was found, return it. + if (select_expr != nullptr) { + return select_expr; + } + } + return common::ManagedPointer(nullptr); + }; + + // Check if the current expression itself is a mere alias for some SELECT column, if so, return that column. + if (expr->GetExpressionType() == parser::ExpressionType::COLUMN_VALUE) { + auto cve = expr.CastManagedPointerTo(); + auto select_expr = get_aliased(cve, select_items); + if (select_expr != nullptr) { + return select_expr; } } + + // Otherwise, the current expression is not just an alias, but may still contain aliases. + // For each child, if the child is a ColumnValueExpression, the child may be an alias referencing some column + // from the SELECT. In this case, replace the ColumnValueExpression child with the SELECT column. + for (size_t i = 0; i < expr->GetChildrenSize(); ++i) { + auto child = expr->GetChild(i); + if (child->GetExpressionType() == noisepage::parser::ExpressionType::COLUMN_VALUE) { + auto cve = child.CastManagedPointerTo(); + auto select_expr = get_aliased(cve, select_items); + if (select_expr != nullptr) { + expr->SetChild(i, select_expr); + } + } + } + + // Repeat this unaliasing for all of the expression's children. + for (size_t i = 0; i < expr->GetChildrenSize(); ++i) { + expr->SetChild(i, UnaliasExpression(expr->GetChild(i), select_items)); + } + + return expr; +} + +void BindNodeVisitor::UnaliasOrderBy( + common::ManagedPointer order_by_description, + const std::vector> &select_items) { + auto &exprs = order_by_description->GetOrderByExpressions(); + for (auto &expr : exprs) { + expr = UnaliasExpression(expr, select_items); + } +} + +void BindNodeVisitor::UnaliasGroupBy( + common::ManagedPointer group_by_description, + const std::vector> &select_items) { + auto &exprs = group_by_description->GetColumns(); + for (auto &expr : exprs) { + expr = UnaliasExpression(expr, select_items); + } } void BindNodeVisitor::InitTableRef(const common::ManagedPointer node) { diff --git a/src/catalog/postgres/pg_proc_impl.cpp b/src/catalog/postgres/pg_proc_impl.cpp index f7a50cd2a5..27e1600362 100644 --- a/src/catalog/postgres/pg_proc_impl.cpp +++ b/src/catalog/postgres/pg_proc_impl.cpp @@ -364,12 +364,13 @@ proc_oid_t PgProcImpl::GetProcOid(const common::ManagedPointer txn, const common::ManagedPointer dbc) { - const auto INT = dbc->GetTypeOidForType(type::TypeId::INTEGER); // NOLINT - const auto STR = dbc->GetTypeOidForType(type::TypeId::VARCHAR); // NOLINT - const auto REAL = dbc->GetTypeOidForType(type::TypeId::REAL); // NOLINT - const auto DATE = dbc->GetTypeOidForType(type::TypeId::DATE); // NOLINT - const auto BOOL = dbc->GetTypeOidForType(type::TypeId::BOOLEAN); // NOLINT - const auto VAR = dbc->GetTypeOidForType(type::TypeId::VARIADIC); // NOLINT + const auto INT = dbc->GetTypeOidForType(type::TypeId::INTEGER); // NOLINT + const auto STR = dbc->GetTypeOidForType(type::TypeId::VARCHAR); // NOLINT + const auto REAL = dbc->GetTypeOidForType(type::TypeId::REAL); // NOLINT + const auto DATE = dbc->GetTypeOidForType(type::TypeId::DATE); // NOLINT + const auto TIMESTAMP = dbc->GetTypeOidForType(type::TypeId::TIMESTAMP); // NOLINT + const auto BOOL = dbc->GetTypeOidForType(type::TypeId::BOOLEAN); // NOLINT + const auto VAR = dbc->GetTypeOidForType(type::TypeId::VARIADIC); // NOLINT auto create_fn = [&](const std::string &procname, const std::vector &args, const std::vector &arg_types, const std::vector &all_arg_types, @@ -441,7 +442,8 @@ void PgProcImpl::BootstrapProcs(const common::ManagedPointer txn, const common::ManagedPointer dbc) { - constexpr auto REAL = type::TypeId::REAL; // NOLINT - constexpr auto INT = type::TypeId::INTEGER; // NOLINT - constexpr auto VAR = type::TypeId::VARCHAR; // NOLINT + constexpr auto REAL = type::TypeId::REAL; // NOLINT + constexpr auto INT = type::TypeId::INTEGER; // NOLINT + constexpr auto VAR = type::TypeId::VARCHAR; // NOLINT + constexpr auto DATE = type::TypeId::DATE; // NOLINT + constexpr auto TIMESTAMP = type::TypeId::TIMESTAMP; // NOLINT auto create_fn = [&](std::string &&func_name, type::TypeId func_ret_type, std::vector &&arg_types, execution::ast::Builtin builtin, bool is_exec_ctx_required) { @@ -553,7 +557,8 @@ void PgProcImpl::BootstrapProcContexts(const common::ManagedPointerArguments()[0]->GetType()->IsSpecificBuiltin(date_kind)) { - ReportIncorrectCallArg(call, 0, GetBuiltinType(date_kind)); - return; - } + auto str_kind = ast::BuiltinType::StringVal; + auto timestamp_kind = ast::BuiltinType::Timestamp; switch (builtin) { - case ast::Builtin::DatePart: + case ast::Builtin::DatePart: { + if (!call->Arguments()[0]->GetType()->IsSpecificBuiltin(date_kind)) { + ReportIncorrectCallArg(call, 0, GetBuiltinType(date_kind)); + return; + } if (!call->Arguments()[1]->GetType()->IsSpecificBuiltin(integer_kind)) { ReportIncorrectCallArg(call, 1, GetBuiltinType(integer_kind)); return; } call->SetType(GetBuiltinType(ast::BuiltinType::Integer)); return; + } + case ast::Builtin::DatePartPostgres: { + if (!call->Arguments()[0]->GetType()->IsSpecificBuiltin(str_kind)) { + ReportIncorrectCallArg(call, 0, GetBuiltinType(str_kind)); + return; + } + if (!call->Arguments()[1]->GetType()->IsSpecificBuiltin(timestamp_kind)) { + ReportIncorrectCallArg(call, 1, GetBuiltinType(timestamp_kind)); + return; + } + auto text = call->Arguments()[0]->As()->Arguments()[0]->As()->StringVal().GetView(); + if (text == "year") { + call->SetType(GetBuiltinType(ast::BuiltinType::Integer)); + } else { + UNREACHABLE("Case not handled."); + } + return; + } default: // TODO(Amadou): Support other date function. UNREACHABLE("Impossible date function"); @@ -3299,7 +3318,8 @@ void Sema::CheckBuiltinCall(ast::CallExpr *call) { CheckBuiltinStringLikeCall(call); break; } - case ast::Builtin::DatePart: { + case ast::Builtin::DatePart: + case ast::Builtin::DatePartPostgres: { CheckBuiltinDateFunctionCall(call, builtin); break; } diff --git a/src/execution/vm/bytecode_generator.cpp b/src/execution/vm/bytecode_generator.cpp index 74c7f2bb5c..ea022b22ce 100644 --- a/src/execution/vm/bytecode_generator.cpp +++ b/src/execution/vm/bytecode_generator.cpp @@ -636,19 +636,38 @@ void BytecodeGenerator::VisitSqlStringLikeCall(ast::CallExpr *call) { } void BytecodeGenerator::VisitBuiltinDateFunctionCall(ast::CallExpr *call, ast::Builtin builtin) { - auto dest = GetExecutionResult()->GetOrCreateDestination(call->GetType()); - auto input = VisitExpressionForSQLValue(call->Arguments()[0]); - auto date_type = - sql::DatePartType(call->Arguments()[1]->As()->Arguments()[0]->As()->Int64Val()); + switch (builtin) { + case ast::Builtin::DatePart: { + auto dest = GetExecutionResult()->GetOrCreateDestination(call->GetType()); + auto input = VisitExpressionForSQLValue(call->Arguments()[0]); + auto date_type = + sql::DatePartType(call->Arguments()[1]->As()->Arguments()[0]->As()->Int64Val()); + + switch (date_type) { + case sql::DatePartType::YEAR: + GetEmitter()->Emit(Bytecode::ExtractYearFromDate, dest, input); + break; + default: + UNREACHABLE("Unimplemented DatePartType"); + } + GetExecutionResult()->SetDestination(dest); + break; + } + case ast::Builtin::DatePartPostgres: { + auto dest = GetExecutionResult()->GetOrCreateDestination(call->GetType()); + auto text = call->Arguments()[0]->As()->Arguments()[0]->As()->StringVal().GetView(); + auto timestamp = VisitExpressionForSQLValue(call->Arguments()[1]); - switch (date_type) { - case sql::DatePartType::YEAR: - GetEmitter()->Emit(Bytecode::ExtractYearFromDate, dest, input); + if (text == "year") { + GetEmitter()->Emit(Bytecode::ExtractYearFromTimestamp, dest, timestamp); + } else { + UNREACHABLE("Unimplemented DatePartPostgres case."); + } break; + } default: - UNREACHABLE("Unimplemented DatePartType"); + UNREACHABLE("Unimplemented Date builtin"); } - GetExecutionResult()->SetDestination(dest); } void BytecodeGenerator::VisitBuiltinTableIterCall(ast::CallExpr *call, ast::Builtin builtin) { @@ -2682,7 +2701,8 @@ void BytecodeGenerator::VisitBuiltinCallExpr(ast::CallExpr *call) { VisitSqlStringLikeCall(call); break; } - case ast::Builtin::DatePart: { + case ast::Builtin::DatePart: + case ast::Builtin::DatePartPostgres: { VisitBuiltinDateFunctionCall(call, builtin); break; } diff --git a/src/execution/vm/vm.cpp b/src/execution/vm/vm.cpp index ff745e7b02..8dd2d8f82d 100644 --- a/src/execution/vm/vm.cpp +++ b/src/execution/vm/vm.cpp @@ -2590,6 +2590,13 @@ void VM::Interpret(const uint8_t *ip, Frame *frame) { // NOLINT DISPATCH_NEXT(); } + OP(ExtractYearFromTimestamp) : { + auto *result = frame->LocalAt(READ_LOCAL_ID()); + auto *input = frame->LocalAt(READ_LOCAL_ID()); + OpExtractYearFromTimestamp(result, input); + DISPATCH_NEXT(); + } + // ------------------------------------------------------- // Replication functions // ------------------------------------------------------- diff --git a/src/include/binder/bind_node_visitor.h b/src/include/binder/bind_node_visitor.h index 25d6815dd0..60bf888882 100644 --- a/src/include/binder/bind_node_visitor.h +++ b/src/include/binder/bind_node_visitor.h @@ -107,13 +107,31 @@ class BindNodeVisitor final : public SqlNodeVisitor { static void InitTableRef(common::ManagedPointer node); /** - * Change the type of exprs_ of order_by_description from ConstantValueExpression to ColumnValueExpression. - * @param order_by_description OrderByDescription - * @param select_items select columns + * Replace integer constants in an ORDER BY with the corresponding expressions from the SELECT columns. + * @param order_by_description The ORDER BY that may contain integer constants. + * @param select_items The SELECT columns that the integer constants refer to. */ void UnifyOrderByExpression(common::ManagedPointer order_by_description, const std::vector> &select_items); + /** + * Rewrite aliases in an expression with the corresponding expressions from the SELECT columns. + * + * @param expression The expression that may need rewriting. This expression is mutated! + * @param select_items The SELECT columns that may have aliases. + */ + common::ManagedPointer UnaliasExpression( + common::ManagedPointer expression, + const std::vector> &select_items); + + /** Unalias expressions in the GROUP BY. */ + void UnaliasGroupBy(common::ManagedPointer group_by_description, + const std::vector> &select_items); + + /** Unalias expressions in the ORDER BY. */ + void UnaliasOrderBy(common::ManagedPointer order_by_description, + const std::vector> &select_items); + void ValidateDatabaseName(const std::string &db_name); /** diff --git a/src/include/execution/ast/builtins.h b/src/include/execution/ast/builtins.h index 86406acc03..7c212b337d 100644 --- a/src/include/execution/ast/builtins.h +++ b/src/include/execution/ast/builtins.h @@ -34,6 +34,7 @@ namespace noisepage::execution::ast { /* SQL Functions */ \ F(Like, like) \ F(DatePart, datePart) \ + F(DatePartPostgres, datePartPostgres) \ \ /* Thread State Container */ \ F(ExecutionContextAddRowsAffected, execCtxAddRowsAffected) \ diff --git a/src/include/execution/vm/bytecode_handlers.h b/src/include/execution/vm/bytecode_handlers.h index 3f149abdf9..2dbd40eb6a 100644 --- a/src/include/execution/vm/bytecode_handlers.h +++ b/src/include/execution/vm/bytecode_handlers.h @@ -2113,6 +2113,16 @@ VM_OP_WARM void OpExtractYearFromDate(noisepage::execution::sql::Integer *result } } +VM_OP_WARM void OpExtractYearFromTimestamp(noisepage::execution::sql::Integer *result, + noisepage::execution::sql::TimestampVal *input) { + if (input->is_null_) { + result->is_null_ = true; + } else { + result->is_null_ = false; + result->val_ = input->val_.ExtractYear(); + } +} + VM_OP_WARM void OpAbortTxn(noisepage::execution::exec::ExecutionContext *exec_ctx) { exec_ctx->GetTxn()->SetMustAbort(); throw noisepage::ABORT_EXCEPTION("transaction aborted"); diff --git a/src/include/execution/vm/bytecodes.h b/src/include/execution/vm/bytecodes.h index c16bf03f21..efe483f8b5 100644 --- a/src/include/execution/vm/bytecodes.h +++ b/src/include/execution/vm/bytecodes.h @@ -735,8 +735,9 @@ namespace noisepage::execution::vm { F(Position, OperandType::Local, OperandType::Local, OperandType::Local, OperandType::Local) \ F(InitCap, OperandType::Local, OperandType::Local, OperandType::Local) \ \ - /* Date Functions */ \ + /* Date and timestamp functions. */ \ F(ExtractYearFromDate, OperandType::Local, OperandType::Local) \ + F(ExtractYearFromTimestamp, OperandType::Local, OperandType::Local) \ \ F(AbortTxn, OperandType::Local) \ \ diff --git a/src/include/parser/select_statement.h b/src/include/parser/select_statement.h index 28a9902522..d99cc7eb01 100644 --- a/src/include/parser/select_statement.h +++ b/src/include/parser/select_statement.h @@ -61,19 +61,13 @@ class OrderByDescription { */ void Accept(common::ManagedPointer v) { v->Visit(common::ManagedPointer(this)); } - /** - * @return order by types - */ + /** @return The types of the ORDER BY terms. */ std::vector GetOrderByTypes() { return types_; } - /** - * @return number of order by expressions - */ + /** @return The number of ORDER BY terms. */ size_t GetOrderByExpressionsSize() const { return exprs_.size(); } - /** - * @return order by expression - */ + /** @return Mutable reference to the expressions representing the ORDER BY terms. */ std::vector> &GetOrderByExpressions() { return exprs_; } /** @@ -255,8 +249,8 @@ class GroupByDescription { */ void Accept(common::ManagedPointer v) { v->Visit(common::ManagedPointer(this)); } - /** @return group by columns */ - const std::vector> &GetColumns() { return columns_; } + /** @return Mutable reference to the expressions representing the GROUP BY terms. */ + std::vector> &GetColumns() { return columns_; } /** @return having clause */ common::ManagedPointer GetHaving() { return common::ManagedPointer(having_); }