diff --git a/partiql-ast/src/main/java/org/partiql/ast/DataType.java b/partiql-ast/src/main/java/org/partiql/ast/DataType.java index 179b0c8a2..76278705c 100644 --- a/partiql-ast/src/main/java/org/partiql/ast/DataType.java +++ b/partiql-ast/src/main/java/org/partiql/ast/DataType.java @@ -34,7 +34,7 @@ public static class StructField extends AstNode { public final boolean isOptional; - @Nullable + @NotNull public final List constraints; @Nullable public final String comment; @@ -43,7 +43,7 @@ public StructField( @NotNull Identifier name, @NotNull DataType type, boolean isOptional, - @Nullable List constraints, + @NotNull List constraints, @Nullable String comment) { this.name = name; this.type = type; diff --git a/partiql-plan/src/main/java/org/partiql/plan/Action.java b/partiql-plan/src/main/java/org/partiql/plan/Action.java index 4c6cdd2e1..d486975fc 100644 --- a/partiql-plan/src/main/java/org/partiql/plan/Action.java +++ b/partiql-plan/src/main/java/org/partiql/plan/Action.java @@ -2,6 +2,7 @@ import org.jetbrains.annotations.NotNull; import org.partiql.plan.rex.Rex; +import org.partiql.spi.catalog.Table; /** * A PartiQL statement action within a plan. @@ -19,4 +20,10 @@ public interface Query extends Action { @NotNull public Rex getRex(); } + + // A better way to segment is to have an object interface in SPI + public interface CreateTable extends Action { + @NotNull + public Table getTable(); + } } diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/DdlField.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/DdlField.kt new file mode 100644 index 000000000..4b66f8d21 --- /dev/null +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/DdlField.kt @@ -0,0 +1,51 @@ +package org.partiql.planner.internal + +import org.partiql.planner.internal.ir.Statement +import org.partiql.planner.internal.ir.statementDDLAttribute +import org.partiql.planner.internal.typer.CompilerType +import org.partiql.spi.catalog.Identifier +import org.partiql.types.Field +import org.partiql.types.PType +import org.partiql.types.shape.PShape + +/** + * An implementation for [Field] that is used by DDL + * to hold additional information in the struct field. + * It is identical to [Statement.DDL.Attribute] + */ +internal data class DdlField( + val name: Identifier, + val type: PShape, + val isNullable: Boolean, + val isOptional: Boolean, + val constraints: List, + val isPrimaryKey: Boolean, + val isUnique: Boolean, + val comment: String? +) : Field { + + override fun getName(): String { + return name.getIdentifier().getText() + } + + override fun getType(): PType { + return type + } + + fun toAttr() = + statementDDLAttribute( + this.name, + this.type, + this.isNullable, + this.isOptional, + this.isPrimaryKey, + this.isUnique, + this.constraints, + this.comment + ) + + companion object { + fun fromAttr(attr: Statement.DDL.Attribute): DdlField = + DdlField(attr.name, attr.type, attr.isNullable, attr.isOptional, attr.constraints, attr.isPrimaryKey, attr.isUnique, attr.comment) + } +} diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/ir/Nodes.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/ir/Nodes.kt index ad5f69a85..6b2b32bcb 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/ir/Nodes.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/ir/Nodes.kt @@ -62,6 +62,12 @@ import org.partiql.planner.internal.ir.builder.RexOpTupleUnionBuilder import org.partiql.planner.internal.ir.builder.RexOpVarGlobalBuilder import org.partiql.planner.internal.ir.builder.RexOpVarLocalBuilder import org.partiql.planner.internal.ir.builder.RexOpVarUnresolvedBuilder +import org.partiql.planner.internal.ir.builder.StatementDdlAttributeBuilder +import org.partiql.planner.internal.ir.builder.StatementDdlBuilder +import org.partiql.planner.internal.ir.builder.StatementDdlCommandCreateTableBuilder +import org.partiql.planner.internal.ir.builder.StatementDdlConstraintCheckBuilder +import org.partiql.planner.internal.ir.builder.StatementDdlPartitionByAttrListBuilder +import org.partiql.planner.internal.ir.builder.StatementDdlTablePropertyBuilder import org.partiql.planner.internal.ir.builder.StatementQueryBuilder import org.partiql.planner.internal.ir.visitor.PlanVisitor import org.partiql.planner.internal.typer.CompilerType @@ -73,6 +79,7 @@ import org.partiql.spi.function.Function import org.partiql.value.PartiQLValue import org.partiql.value.PartiQLValueExperimental import kotlin.random.Random +import org.partiql.types.shape.PShape internal abstract class PlanNode { @JvmField @@ -177,6 +184,7 @@ internal sealed class Ref : PlanNode() { internal sealed class Statement : PlanNode() { public override fun accept(visitor: PlanVisitor, ctx: C): R = when (this) { is Query -> visitor.visitStatementQuery(this, ctx) + is DDL -> visitor.visitStatementDDL(this, ctx) } internal data class Query( @@ -195,6 +203,169 @@ internal sealed class Statement : PlanNode() { internal fun builder(): StatementQueryBuilder = StatementQueryBuilder() } } + + internal data class DDL( + @JvmField + internal val command: Command, + ) : Statement() { + override val children: List by lazy { + val kids = mutableListOf() + kids.add(command) + kids.filterNotNull() + } + + + override fun accept(visitor: PlanVisitor, ctx: C): R = + visitor.visitStatementDDL(this, ctx) + + sealed class Command : PlanNode() { + override fun accept(visitor: PlanVisitor, ctx: C): R = when (this) { + is CreateTable -> visitor.visitStatementDDLCommandCreateTable(this, ctx) + } + + internal data class CreateTable( + @JvmField + internal val name: Identifier, + @JvmField + internal val attributes: List, + @JvmField + internal val tblConstraints: List, + @JvmField + internal val partitionBy: PartitionBy?, + @JvmField + internal val tableProperties: List, + @JvmField + internal val primaryKey: List, + @JvmField + internal val unique: List, + ) : Command() { + override val children: List by lazy { + val kids = mutableListOf() + kids.addAll(attributes) + kids.addAll(tblConstraints) + partitionBy?.let { kids.add(it) } + kids.addAll(tableProperties) + kids.filterNotNull() + } + + + override fun accept(visitor: PlanVisitor, ctx: C): R = + visitor.visitStatementDDLCommandCreateTable(this, ctx) + + internal companion object { + @JvmStatic + internal fun builder(): StatementDdlCommandCreateTableBuilder = + StatementDdlCommandCreateTableBuilder() + } + } + } + + internal data class Attribute( + @JvmField + internal val name: Identifier, + @JvmField + internal val type: PShape, + @JvmField + internal val isNullable: Boolean, + @JvmField + internal val isOptional: Boolean, + @JvmField + internal val isPrimaryKey: Boolean, + @JvmField + internal val isUnique: Boolean, + @JvmField + internal val constraints: List, + @JvmField + val comment: String? + ) : PlanNode() { + override val children: List by lazy { + val kids = mutableListOf() + kids.addAll(constraints) + kids.filterNotNull() + } + + override fun accept(visitor: PlanVisitor, ctx: C): R = + visitor.visitStatementDDLAttribute(this, ctx) + + internal companion object { + @JvmStatic + internal fun builder(): StatementDdlAttributeBuilder = StatementDdlAttributeBuilder() + } + } + + internal sealed class Constraint : PlanNode() { + override fun accept(visitor: PlanVisitor, ctx: C): R = when (this) { + is Check -> visitor.visitStatementDDLConstraintCheck(this, ctx) + } + + internal data class Check( + @JvmField + internal val expression: Rex, + @JvmField + val sql: String + ) : Constraint() { + override val children: List by lazy { + val kids = mutableListOf() + kids.add(expression) + kids.filterNotNull() + } + + + override fun accept(visitor: PlanVisitor, ctx: C): R = + visitor.visitStatementDDLConstraintCheck(this, ctx) + + internal companion object { + @JvmStatic + internal fun builder(): StatementDdlConstraintCheckBuilder = + StatementDdlConstraintCheckBuilder() + } + } + } + + internal sealed class PartitionBy : PlanNode() { + override fun accept(visitor: PlanVisitor, ctx: C): R = when (this) { + is AttrList -> visitor.visitStatementDDLPartitionByAttrList(this, ctx) + } + + internal data class AttrList( + @JvmField + internal val attrs: List, + ) : PartitionBy() { + override val children: List = emptyList() + + override fun accept(visitor: PlanVisitor, ctx: C): R = + visitor.visitStatementDDLPartitionByAttrList(this, ctx) + + internal companion object { + @JvmStatic + internal fun builder(): StatementDdlPartitionByAttrListBuilder = + StatementDdlPartitionByAttrListBuilder() + } + } + } + + internal data class TableProperty( + @JvmField + internal val name: String, + @JvmField + internal val `value`: String, + ) : PlanNode() { + override val children: List = emptyList() + + override fun accept(visitor: PlanVisitor, ctx: C): R = + visitor.visitStatementDDLTableProperty(this, ctx) + + internal companion object { + @JvmStatic + internal fun builder(): StatementDdlTablePropertyBuilder = StatementDdlTablePropertyBuilder() + } + } + + internal companion object { + @JvmStatic + internal fun builder(): StatementDdlBuilder = StatementDdlBuilder() + } + } } internal data class Rex( diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/normalize/NormalizeFromSource.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/normalize/NormalizeFromSource.kt index 8d9168833..225ba6864 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/normalize/NormalizeFromSource.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/normalize/NormalizeFromSource.kt @@ -23,6 +23,7 @@ import org.partiql.ast.FromExpr import org.partiql.ast.FromJoin import org.partiql.ast.FromTableRef import org.partiql.ast.FromType +import org.partiql.ast.Query import org.partiql.ast.QueryBody import org.partiql.ast.Statement import org.partiql.ast.expr.Expr @@ -33,7 +34,10 @@ import org.partiql.planner.internal.helpers.toBinder */ internal object NormalizeFromSource : AstPass { - override fun apply(statement: Statement): Statement = statement.accept(Visitor, 0) as Statement + override fun apply(statement: Statement): Statement = when (statement) { + is Query -> statement.accept(Visitor, 0) as Statement + else -> statement + } private object Visitor : AstRewriter() { diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/normalize/NormalizeGroupBy.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/normalize/NormalizeGroupBy.kt index bd23b6a4e..4228eb604 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/normalize/NormalizeGroupBy.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/normalize/NormalizeGroupBy.kt @@ -19,6 +19,7 @@ import org.partiql.ast.Ast.groupByKey import org.partiql.ast.AstNode import org.partiql.ast.AstRewriter import org.partiql.ast.GroupBy +import org.partiql.ast.Query import org.partiql.ast.Statement import org.partiql.ast.expr.Expr import org.partiql.planner.internal.helpers.toBinder @@ -28,7 +29,10 @@ import org.partiql.planner.internal.helpers.toBinder */ internal object NormalizeGroupBy : AstPass { - override fun apply(statement: Statement) = Visitor.visitStatement(statement, 0) as Statement + override fun apply(statement: Statement) = when (statement) { + is Query -> Visitor.visitStatement(statement, 0) as Statement + else -> statement + } private object Visitor : AstRewriter() { diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/AstToPlan.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/AstToPlan.kt index 986fc19c7..8797a56c5 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/AstToPlan.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/AstToPlan.kt @@ -18,11 +18,17 @@ package org.partiql.planner.internal.transforms import org.partiql.ast.AstNode import org.partiql.ast.AstVisitor +import org.partiql.ast.DataType import org.partiql.ast.Query +import org.partiql.ast.ddl.CreateTable import org.partiql.ast.expr.ExprQuerySet +import org.partiql.errors.TypeCheckException import org.partiql.planner.internal.Env import org.partiql.planner.internal.ir.statementQuery +import org.partiql.planner.internal.typer.CompilerType +import org.partiql.planner.internal.typer.PlanTyper.Companion.toCType import org.partiql.spi.catalog.Identifier +import org.partiql.types.PType import org.partiql.ast.Identifier as AstIdentifier import org.partiql.ast.IdentifierChain as AstIdentifierChain import org.partiql.ast.Statement as AstStatement @@ -42,6 +48,10 @@ internal object AstToPlan { override fun defaultReturn(node: AstNode, env: Env) = throw IllegalArgumentException("Unsupported statement") + override fun visitCreateTable(node: CreateTable, env: Env): PlanStatement { + return DdlConverter.apply(node, env) + } + override fun visitQuery(node: Query, env: Env): PlanStatement { val rex = when (val expr = node.expr) { is ExprQuerySet -> RelConverter.apply(expr, env) @@ -72,4 +82,118 @@ internal object AstToPlan { true -> Identifier.Part.delimited(identifier.symbol) false -> Identifier.Part.regular(identifier.symbol) } + + fun visitType(type: DataType): CompilerType { + return when (type.code()) { + // + // TODO CHAR_VARYING, CHARACTER_LARGE_OBJECT, CHAR_LARGE_OBJECT + DataType.CHARACTER, DataType.CHAR -> { + val length = type.length ?: 1 + assertGtZeroAndCreate(PType.CHAR, "length", length, PType::character) + } + DataType.CHARACTER_VARYING, DataType.VARCHAR -> { + val length = type.length ?: 1 + assertGtZeroAndCreate(PType.VARCHAR, "length", length, PType::varchar) + } + DataType.CLOB -> assertGtZeroAndCreate(PType.CLOB, "length", type.length ?: Int.MAX_VALUE, PType::clob) + DataType.STRING -> PType.string() + // + // TODO BINARY_LARGE_OBJECT + DataType.BLOB -> assertGtZeroAndCreate(PType.BLOB, "length", type.length ?: Int.MAX_VALUE, PType::blob) + // + DataType.BIT -> error("BIT is not supported yet.") + DataType.BIT_VARYING -> error("BIT VARYING is not supported yet.") + // - + DataType.NUMERIC -> { + val p = type.precision + val s = type.scale + when { + p == null && s == null -> PType.decimal(38, 0) + p != null && s != null -> { + assertParamCompToZero(PType.NUMERIC, "precision", p, false) + assertParamCompToZero(PType.NUMERIC, "scale", s, true) + if (s > p) { + throw TypeCheckException("Numeric scale cannot be greater than precision.") + } + PType.decimal(type.precision!!, type.scale!!) + } + p != null && s == null -> { + assertParamCompToZero(PType.NUMERIC, "precision", p, false) + PType.decimal(p, 0) + } + else -> error("Precision can never be null while scale is specified.") + } + } + DataType.DEC, DataType.DECIMAL -> { + val p = type.precision + val s = type.scale + when { + p == null && s == null -> PType.decimal(38, 0) + p != null && s != null -> { + assertParamCompToZero(PType.DECIMAL, "precision", p, false) + assertParamCompToZero(PType.DECIMAL, "scale", s, true) + if (s > p) { + throw TypeCheckException("Decimal scale cannot be greater than precision.") + } + PType.decimal(p, s) + } + p != null && s == null -> { + assertParamCompToZero(PType.DECIMAL, "precision", p, false) + PType.decimal(p, 0) + } + else -> error("Precision can never be null while scale is specified.") + } + } + DataType.BIGINT, DataType.INT8, DataType.INTEGER8 -> PType.bigint() + DataType.INT4, DataType.INTEGER4, DataType.INTEGER, DataType.INT -> PType.integer() + DataType.INT2, DataType.SMALLINT -> PType.smallint() + DataType.TINYINT -> PType.tinyint() // TODO define in parser + // - + DataType.FLOAT -> PType.real() + DataType.REAL -> PType.real() + DataType.DOUBLE_PRECISION -> PType.doublePrecision() + // + DataType.BOOL -> PType.bool() + // + DataType.DATE -> PType.date() + DataType.TIME -> assertGtEqZeroAndCreate(PType.TIME, "precision", type.precision ?: 0, PType::time) + DataType.TIME_WITH_TIME_ZONE -> assertGtEqZeroAndCreate(PType.TIMEZ, "precision", type.precision ?: 0, PType::timez) + DataType.TIMESTAMP -> assertGtEqZeroAndCreate(PType.TIMESTAMP, "precision", type.precision ?: 6, PType::timestamp) + DataType.TIMESTAMP_WITH_TIME_ZONE -> assertGtEqZeroAndCreate(PType.TIMESTAMPZ, "precision", type.precision ?: 6, PType::timestampz) + // + DataType.INTERVAL -> error("INTERVAL is not supported yet.") + // + DataType.STRUCT -> PType.struct() + DataType.TUPLE -> PType.struct() + // + DataType.LIST -> PType.array() + DataType.BAG -> PType.bag() + // + DataType.USER_DEFINED -> TODO("Custom type not supported ") + else -> error("Unsupported DataType type: $type") + }.toCType() + } + + private fun assertGtZeroAndCreate(type: Int, param: String, value: Int, create: (Int) -> PType): PType { + assertParamCompToZero(type, param, value, false) + return create.invoke(value) + } + + private fun assertGtEqZeroAndCreate(type: Int, param: String, value: Int, create: (Int) -> PType): PType { + assertParamCompToZero(type, param, value, true) + return create.invoke(value) + } + + /** + * @param allowZero when FALSE, this asserts that [value] > 0. If TRUE, this asserts that [value] >= 0. + */ + private fun assertParamCompToZero(type: Int, param: String, value: Int, allowZero: Boolean) { + val (result, compString) = when (allowZero) { + true -> (value >= 0) to "greater than" + false -> (value > 0) to "greater than or equal to" + } + if (!result) { + throw TypeCheckException("$type $param must be an integer value $compString 0.") + } + } } diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/DdlConverter.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/DdlConverter.kt new file mode 100644 index 000000000..e5c9dd435 --- /dev/null +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/DdlConverter.kt @@ -0,0 +1,204 @@ + +package org.partiql.planner.internal.transforms + +import org.partiql.ast.AstNode +import org.partiql.ast.AstVisitor +import org.partiql.ast.DataType +import org.partiql.ast.ddl.AttributeConstraint.Check +import org.partiql.ast.ddl.AttributeConstraint.Null +import org.partiql.ast.ddl.AttributeConstraint.Unique +import org.partiql.ast.ddl.ColumnDefinition +import org.partiql.ast.ddl.CreateTable +import org.partiql.ast.ddl.Ddl +import org.partiql.ast.ddl.KeyValue +import org.partiql.ast.ddl.PartitionBy +import org.partiql.ast.ddl.TableConstraint +import org.partiql.ast.sql.sql +import org.partiql.planner.internal.DdlField +import org.partiql.planner.internal.Env +import org.partiql.planner.internal.ir.PlanNode +import org.partiql.planner.internal.ir.Statement +import org.partiql.planner.internal.ir.statementDDL +import org.partiql.planner.internal.ir.statementDDLAttribute +import org.partiql.planner.internal.ir.statementDDLCommandCreateTable +import org.partiql.planner.internal.ir.statementDDLConstraintCheck +import org.partiql.planner.internal.ir.statementDDLPartitionByAttrList +import org.partiql.planner.internal.ir.statementDDLTableProperty +import org.partiql.planner.internal.transforms.AstToPlan.convert +import org.partiql.planner.internal.transforms.AstToPlan.visitType +import org.partiql.planner.internal.typer.PlanTyper.Companion.toCType +import org.partiql.spi.catalog.Identifier +import org.partiql.types.PType +import org.partiql.types.shape.PShape + +internal object DdlConverter { + internal fun apply(statement: Ddl, env: Env): Statement.DDL = ToDdl.visitDdl(statement, env) + + /** + * Consider this as the first step to lowering create table statement to [PShape] + * + * Post this processing: + * We made sure that: + * 1. At column level + * - No multiple declaration of primary key constraint associated with one attribute declaration + * - i.e., `FOO INT2 PRIMARY KEY PRIMARY KEY` will be rejected + * - No optional attribute is declared as priamry key at column level + * - i.e., FOO OPTIONAL INT2 PRIMARY KEY will be rejected + * - Nullability and optionality is deducted from column level constraint + * - Comment is attached to the PType via PTrait + * 2. At table level + * - No multiple declaration of primary key constraint assciated with table + * ``` + * CREATE TABLE ... ( + * ... + * PRIMARY KEY (foo) + * PRIMARY KEY (bar) + * ) + * ``` + * will be rejected. + * - All unique constraints declared at table level will be concatenated to a single list + */ + private object ToDdl : AstVisitor() { + + override fun defaultReturn(node: AstNode?, ctx: Env): PlanNode { + throw IllegalArgumentException("unsupported DDL node: $node") + } + + override fun visitDdl(node: Ddl, ctx: Env): Statement.DDL { + return when (node) { + is CreateTable -> statementDDL(visitCreateTable(node, ctx)) + else -> throw IllegalArgumentException("Unsupported DDL Command: $node") + } + } + + override fun visitCreateTable(node: CreateTable, ctx: Env): Statement.DDL.Command.CreateTable { + val tableName = convert(node.name) + val attributes = node.columns.map { visitColumnDefinition(it, ctx) } + // Table Level PK + val pk = node.constraints.filterIsInstance() + .filter { it.isPrimaryKey } + .let { + when (it.size) { + 0 -> emptyList() + 1 -> it.first().columns.map { convert(it) } + else -> throw IllegalArgumentException("multiple PK") + } + } + + val unique = node.constraints.filterIsInstance() + .filter { !it.isPrimaryKey } + .fold(emptyList()) { acc, constr -> + acc + constr.columns.map { convert(it) } + } + + val partitionBy = node.partitionBy?.let { visitPartitionBy(it, ctx) } + val tableProperty = node.tableProperties.map { visitKeyValue(it, ctx) } + + return statementDDLCommandCreateTable( + tableName, + attributes, + emptyList(), + partitionBy, + tableProperty, + pk, + unique + ) + } + + // !!! The planning stage ignores the constraint name for now + override fun visitColumnDefinition(node: ColumnDefinition, ctx: Env): Statement.DDL.Attribute { + val name = convert(node.name) + val type = visitType(node.dataType, ctx) + + // Validation and reducing for nullable constraint + // If there are one or more nullable constraint, last one wins + // otherwise, nullable by default + val nullableConstraint = node.constraints + .filterIsInstance() + .reduceOrNull { acc, next -> next } + ?.isNullable ?: true + + // Validation and reducing for PK constraints + // Rule: No multiple PK constraint + val isPk = node.constraints + .filterIsInstance() + .filter { it.isPrimaryKey } + .let { + if (it.size > 1) { + throw IllegalArgumentException("Multiple primary key constraint declarations are not allowed.") + } else it + }.any() + + // validation -- No optional attribute declared as primary key + if (isPk && node.isOptional) throw IllegalArgumentException("Optional attribute as primary key is not supported.") + + // final nullability decision: + // if nullableConstraint was concluded to be not null, + // then not null + // if nullableConstraint was concluded to be nullable, and there is a valid PK constraint + // then not null + // if nullableConstraint was concluded to be nullable, and there is no valid PK constraint + // then nullable + val nullable = nullableConstraint && !isPk + + // Uniqueness decision + // if associated with unique constraint or Primary key constraint, true + // else false + val isUnique = node + .constraints + .filterIsInstance() + .any() || isPk + + val additionalConstrs = node.constraints + .filter { it !is Null && it !is Unique } + .map { it.accept(this, ctx) as Statement.DDL.Constraint } + + return statementDDLAttribute(name, type, nullable, node.isOptional, isPk, isUnique, additionalConstrs, node.comment) + } + + override fun visitCheck(node: Check, ctx: Env): Statement.DDL.Constraint = + statementDDLConstraintCheck( + RexConverter.apply(node.searchCondition, ctx), + node.searchCondition.sql() + ) + + override fun visitPartitionBy(node: PartitionBy, ctx: Env): Statement.DDL.PartitionBy { + return statementDDLPartitionByAttrList(node.columns.map { convert(it) }) + } + + override fun visitKeyValue(node: KeyValue, ctx: Env): Statement.DDL.TableProperty { + return statementDDLTableProperty(node.key, node.value) + } + + private fun visitType(node: DataType, ctx: Env): PShape { + // Struct requires special process in DDL + return if (node.code() == DataType.STRUCT) { + val fields = node.fields.map { field -> + val name = convert(field.name) + val type = visitType(field.type, ctx) + // No support for nested PK or UNIQUE + val hasUnique = field.constraints + .filterIsInstance() + .any() + if (hasUnique) { + throw IllegalArgumentException("Associating Primary Key Constraint or Unique Constraint on Struct Field is not supported") + } + + val isNullable = field.constraints + .filterIsInstance() + .reduceOrNull { acc, next -> next } + ?.isNullable ?: true + + val additionalConsts = field.constraints + .filterNot { it is Null } + .map { it.accept(this, ctx) as Statement.DDL.Constraint } + + DdlField(name, type, isNullable, field.isOptional, additionalConsts, false, false, field.comment) + } + PShape(PType.row(fields)) + } else { + PShape(visitType(node).getDelegate()) + } + } + } +} diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/InlineCheckConstraintExtractor.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/InlineCheckConstraintExtractor.kt new file mode 100644 index 000000000..b845641cf --- /dev/null +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/InlineCheckConstraintExtractor.kt @@ -0,0 +1,72 @@ +package org.partiql.planner.internal.transforms +import org.partiql.planner.internal.ir.PlanNode +import org.partiql.planner.internal.ir.Rex +import org.partiql.planner.internal.ir.Statement +import org.partiql.planner.internal.ir.visitor.PlanBaseVisitor +import org.partiql.types.shape.PShape +import org.partiql.types.shape.trait.ConstraintTrait +import org.partiql.types.shape.trait.RangeTrait +import org.partiql.value.NumericValue +import org.partiql.value.PartiQLValueExperimental + +/** + * Lowers an inline check constraint to PShape when applicable. + * + * For example, + * + * `Foo INT2 CHECK(Foo >= 0 AND FOO <=10)` will be lowered to a PShape with trait range(0,10). + * + * We can extend the lowering capability in the future if we want to, but for now, + * it only attempts to lower if the expression is + * 1. Simple expression in which the operator is either GTE or LTE. + * 2. Multiple simple expressions chained by and operator. + */ +internal object InlineCheckConstraintExtractor : PlanBaseVisitor() { + // Unable to lower the check constraint to PShape + override fun defaultReturn(node: PlanNode, ctx: PShape): PShape = ctx + + override fun visitStatementDDLConstraintCheck(node: Statement.DDL.Constraint.Check, ctx: PShape): PShape { + val lowered = visitRex(node.expression, ctx) + // No lowering happened, then wrap the PShape with a generic Constraint trait + return if (lowered == ctx) { + ConstraintTrait(ctx, node.sql) + } else lowered + } + + override fun visitRex(node: Rex, ctx: PShape): PShape { + return node.op.accept(this, ctx) + } + + override fun visitRexOpCallStatic(node: Rex.Op.Call.Static, ctx: PShape): PShape { + return when (node.fn.name) { + "gte" -> getRhsAsNumericOrNull(node.args[1])?.let { + RangeTrait(ctx, it, null) + } ?: ctx + "lte" -> getRhsAsNumericOrNull(node.args[1])?.let { + RangeTrait(ctx, null, it) + } ?: ctx + "and" -> handleAnd(node, ctx) + else -> super.visitRexOpCall(node, ctx) + } + } + + @OptIn(PartiQLValueExperimental::class) + private fun getRhsAsNumericOrNull(rex: Rex) = + when (val op = rex.op) { + is Rex.Op.Lit -> { + when (val v = op.value) { + is NumericValue<*> -> v.value + else -> null + } + } + else -> null + } + + private fun handleAnd(node: Rex.Op.Call.Static, ctx: PShape): PShape { + val lhs = node.args.first().accept(this, ctx) + // No lowering happened for lhs, do not attempt to lower the right-hand side + if (lhs == ctx) { return ctx } + val rhs = node.args[1].accept(this, lhs) + return rhs + } +} diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/PlanTransform.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/PlanTransform.kt index 8d1537e38..ca8ed834b 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/PlanTransform.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/PlanTransform.kt @@ -13,21 +13,30 @@ import org.partiql.plan.rex.RexCase import org.partiql.plan.rex.RexStruct import org.partiql.plan.rex.RexType import org.partiql.plan.rex.RexVar +import org.partiql.planner.internal.DdlField import org.partiql.planner.internal.PlannerFlag +import org.partiql.planner.internal.ir.PlanNode import org.partiql.planner.internal.ir.Rel import org.partiql.planner.internal.ir.SetQuantifier +import org.partiql.planner.internal.ir.Statement import org.partiql.planner.internal.ir.visitor.PlanBaseVisitor +import org.partiql.spi.catalog.Table import org.partiql.spi.errors.PErrorListener import org.partiql.spi.value.Datum import org.partiql.types.Field import org.partiql.types.PType +import org.partiql.types.shape.PShape +import org.partiql.types.shape.trait.MetadataTrait +import org.partiql.types.shape.trait.NotNullTrait +import org.partiql.types.shape.trait.PrimaryKeyTrait +import org.partiql.types.shape.trait.RequiredTrait +import org.partiql.types.shape.trait.UniqueTrait import org.partiql.value.DecimalValue import org.partiql.value.PartiQLValueExperimental import org.partiql.planner.internal.ir.PartiQLPlan as IPlan import org.partiql.planner.internal.ir.PlanNode as INode import org.partiql.planner.internal.ir.Rel as IRel import org.partiql.planner.internal.ir.Rex as IRex -import org.partiql.planner.internal.ir.Statement as IStatement /** * This produces a V1 plan from the internal plan IR. @@ -45,12 +54,114 @@ internal class PlanTransform(private val flags: Set) { */ fun transform(internal: IPlan, listener: PErrorListener): Plan { val signal = flags.contains(PlannerFlag.SIGNAL_MODE) - val query = (internal.statement as IStatement.Query) - val visitor = Visitor(listener, signal) - val root = visitor.visitRex(query.root, query.root.type) - val action = Action.Query { root } - // TODO replace with standard implementations (or just remove plan transform altogether when possible). - return Plan { action } + when (internal.statement) { + is Statement.DDL -> { + val query = internal.statement + val visitor = DDLVisitor(listener, signal) + val action = visitor.visitStatementDDL(query, Unit) + // TODO replace with standard implementations (or just remove plan transform altogether when possible). + return Plan { action } + } + is Statement.Query -> { + val query = internal.statement + val visitor = Visitor(listener, signal) + val root = visitor.visitRex(query.root, query.root.type) + val action = Action.Query { root } + // TODO replace with standard implementations (or just remove plan transform altogether when possible). + return Plan { action } + } + } + } + + // For now: Break down to two separate visitors + private class DDLVisitor( + private val listener: PErrorListener, + private val signal: Boolean, + ) : PlanBaseVisitor() { + override fun defaultReturn(node: PlanNode, ctx: Unit): Any? { + TODO("Translation not supported for ${node::class.simpleName}") + } + + // DDL + override fun visitStatementDDL(node: Statement.DDL, ctx: Unit): Action { + return when (val command = node.command) { + is Statement.DDL.Command.CreateTable -> { + visitStatementDDLCommandCreateTable(command, ctx) + } + } + } + + override fun visitStatementDDLCommandCreateTable(node: Statement.DDL.Command.CreateTable, ctx: Unit): Action.CreateTable { + val fields = node.attributes.map { + val shape = visitStatementDDLAttribute(it, ctx) + Field.of(it.name.getIdentifier().getText(), shape) + } + val row = PShape(PType.row(fields)) + val schema = PShape(PType.bag(row)).let { + if (node.primaryKey.isNotEmpty()) + PrimaryKeyTrait(it, node.primaryKey.map { it.getIdentifier().getText() }) + else it + }.let { + if (node.unique.isNotEmpty()) + UniqueTrait(it, node.unique.map { it.getIdentifier().getText() }) + else it + }.let { + var shape = it + if (node.tableProperties.isNotEmpty()) { + node.tableProperties.forEach { prop -> + shape = MetadataTrait(shape, prop.name, prop.value) + } + } + shape + }.let { it -> + when (val partition = node.partitionBy) { + is Statement.DDL.PartitionBy.AttrList -> { + val names = buildString { + append("[") + append(partition.attrs.joinToString(",") { it.getIdentifier().getText() }) + append("]") + } + MetadataTrait(it, "partition", names) + } + null -> it + } + } + return Action.CreateTable { + Table.builder() + .name(node.name.getIdentifier().getText()) + .schema(schema) + .build() + } + } + + override fun visitStatementDDLAttribute(node: Statement.DDL.Attribute, ctx: Unit): PShape { + val ddlField = DdlField.fromAttr(node) + return visitDdlField(ddlField, ctx) + } + + private fun visitDdlField(ddlField: DdlField, ctx: Unit): PShape { + val baseShape = ddlField.type + val fieldReduced = when (baseShape.code()) { + PType.ROW -> { + val fields = baseShape.fields.map { + val shape = visitDdlField(it as DdlField, ctx) + Field.of(it.name.getIdentifier().getText(), shape) + } + PType.row(fields) + } + else -> baseShape + }.let { PShape(it) } + val constraintReduced = ddlField + .constraints + .fold(fieldReduced) { acc, constr -> + InlineCheckConstraintExtractor.visitStatementDDLConstraint(constr, acc) + }.let { if (!ddlField.isNullable) NotNullTrait(it) else it } + .let { if (!ddlField.isOptional) RequiredTrait(it) else it } + val metadataReduced = ddlField.comment?.let { + MetadataTrait(constraintReduced, "comment", it) + } ?: constraintReduced + return metadataReduced + } } private class Visitor( diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RexConverter.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RexConverter.kt index 04a87d147..7dadb2b21 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RexConverter.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RexConverter.kt @@ -61,7 +61,6 @@ import org.partiql.ast.expr.PathStep import org.partiql.ast.expr.Scope import org.partiql.ast.expr.TrimSpec import org.partiql.ast.expr.TruthValue -import org.partiql.errors.TypeCheckException import org.partiql.planner.internal.Env import org.partiql.planner.internal.ir.Rel import org.partiql.planner.internal.ir.Rex @@ -89,8 +88,8 @@ import org.partiql.planner.internal.ir.rexOpSubquery import org.partiql.planner.internal.ir.rexOpTupleUnion import org.partiql.planner.internal.ir.rexOpVarLocal import org.partiql.planner.internal.ir.rexOpVarUnresolved +import org.partiql.planner.internal.transforms.AstToPlan.visitType import org.partiql.planner.internal.typer.CompilerType -import org.partiql.planner.internal.typer.PlanTyper.Companion.toCType import org.partiql.planner.internal.utils.DateTimeUtils import org.partiql.spi.catalog.Identifier import org.partiql.types.PType @@ -1073,120 +1072,6 @@ internal object RexConverter { return rex(ANY, rexOpCastUnresolved(type, arg)) } - private fun visitType(type: DataType): CompilerType { - return when (type.code()) { - // - // TODO CHAR_VARYING, CHARACTER_LARGE_OBJECT, CHAR_LARGE_OBJECT - DataType.CHARACTER, DataType.CHAR -> { - val length = type.length ?: 1 - assertGtZeroAndCreate(PType.CHAR, "length", length, PType::character) - } - DataType.CHARACTER_VARYING, DataType.VARCHAR -> { - val length = type.length ?: 1 - assertGtZeroAndCreate(PType.VARCHAR, "length", length, PType::varchar) - } - DataType.CLOB -> assertGtZeroAndCreate(PType.CLOB, "length", type.length ?: Int.MAX_VALUE, PType::clob) - DataType.STRING -> PType.string() - // - // TODO BINARY_LARGE_OBJECT - DataType.BLOB -> assertGtZeroAndCreate(PType.BLOB, "length", type.length ?: Int.MAX_VALUE, PType::blob) - // - DataType.BIT -> error("BIT is not supported yet.") - DataType.BIT_VARYING -> error("BIT VARYING is not supported yet.") - // - - DataType.NUMERIC -> { - val p = type.precision - val s = type.scale - when { - p == null && s == null -> PType.decimal(38, 0) - p != null && s != null -> { - assertParamCompToZero(PType.NUMERIC, "precision", p, false) - assertParamCompToZero(PType.NUMERIC, "scale", s, true) - if (s > p) { - throw TypeCheckException("Numeric scale cannot be greater than precision.") - } - PType.decimal(type.precision!!, type.scale!!) - } - p != null && s == null -> { - assertParamCompToZero(PType.NUMERIC, "precision", p, false) - PType.decimal(p, 0) - } - else -> error("Precision can never be null while scale is specified.") - } - } - DataType.DEC, DataType.DECIMAL -> { - val p = type.precision - val s = type.scale - when { - p == null && s == null -> PType.decimal(38, 0) - p != null && s != null -> { - assertParamCompToZero(PType.DECIMAL, "precision", p, false) - assertParamCompToZero(PType.DECIMAL, "scale", s, true) - if (s > p) { - throw TypeCheckException("Decimal scale cannot be greater than precision.") - } - PType.decimal(p, s) - } - p != null && s == null -> { - assertParamCompToZero(PType.DECIMAL, "precision", p, false) - PType.decimal(p, 0) - } - else -> error("Precision can never be null while scale is specified.") - } - } - DataType.BIGINT, DataType.INT8, DataType.INTEGER8 -> PType.bigint() - DataType.INT4, DataType.INTEGER4, DataType.INTEGER, DataType.INT -> PType.integer() - DataType.INT2, DataType.SMALLINT -> PType.smallint() - DataType.TINYINT -> PType.tinyint() // TODO define in parser - // - - DataType.FLOAT -> PType.real() - DataType.REAL -> PType.real() - DataType.DOUBLE_PRECISION -> PType.doublePrecision() - // - DataType.BOOL -> PType.bool() - // - DataType.DATE -> PType.date() - DataType.TIME -> assertGtEqZeroAndCreate(PType.TIME, "precision", type.precision ?: 0, PType::time) - DataType.TIME_WITH_TIME_ZONE -> assertGtEqZeroAndCreate(PType.TIMEZ, "precision", type.precision ?: 0, PType::timez) - DataType.TIMESTAMP -> assertGtEqZeroAndCreate(PType.TIMESTAMP, "precision", type.precision ?: 6, PType::timestamp) - DataType.TIMESTAMP_WITH_TIME_ZONE -> assertGtEqZeroAndCreate(PType.TIMESTAMPZ, "precision", type.precision ?: 6, PType::timestampz) - // - DataType.INTERVAL -> error("INTERVAL is not supported yet.") - // - DataType.STRUCT -> PType.struct() - DataType.TUPLE -> PType.struct() - // - DataType.LIST -> PType.array() - DataType.BAG -> PType.bag() - // - DataType.USER_DEFINED -> TODO("Custom type not supported ") - else -> error("Unsupported DataType type: $type") - }.toCType() - } - - private fun assertGtZeroAndCreate(type: Int, param: String, value: Int, create: (Int) -> PType): PType { - assertParamCompToZero(type, param, value, false) - return create.invoke(value) - } - - private fun assertGtEqZeroAndCreate(type: Int, param: String, value: Int, create: (Int) -> PType): PType { - assertParamCompToZero(type, param, value, true) - return create.invoke(value) - } - - /** - * @param allowZero when FALSE, this asserts that [value] > 0. If TRUE, this asserts that [value] >= 0. - */ - private fun assertParamCompToZero(type: Int, param: String, value: Int, allowZero: Boolean) { - val (result, compString) = when (allowZero) { - true -> (value >= 0) to "greater than" - false -> (value > 0) to "greater than or equal to" - } - if (!result) { - throw TypeCheckException("$type $param must be an integer value $compString 0.") - } - } - override fun visitExprSessionAttribute(node: ExprSessionAttribute, ctx: Env): Rex { val type = ANY val fn = node.sessionAttribute.name().lowercase() diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/PlanTyper.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/PlanTyper.kt index 260876cd6..95789c7b8 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/PlanTyper.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/PlanTyper.kt @@ -16,6 +16,7 @@ package org.partiql.planner.internal.typer +import org.partiql.planner.internal.DdlField import org.partiql.planner.internal.Env import org.partiql.planner.internal.PErrors import org.partiql.planner.internal.exclude.ExcludeRepr @@ -49,6 +50,9 @@ import org.partiql.planner.internal.ir.rexOpPivot import org.partiql.planner.internal.ir.rexOpStruct import org.partiql.planner.internal.ir.rexOpStructField import org.partiql.planner.internal.ir.rexOpSubquery +import org.partiql.planner.internal.ir.statementDDL +import org.partiql.planner.internal.ir.statementDDLCommandCreateTable +import org.partiql.planner.internal.ir.statementDDLPartitionByAttrList import org.partiql.planner.internal.ir.statementQuery import org.partiql.planner.internal.ir.util.PlanRewriter import org.partiql.spi.Context @@ -57,6 +61,7 @@ import org.partiql.spi.errors.PError import org.partiql.spi.errors.PErrorListener import org.partiql.types.Field import org.partiql.types.PType +import org.partiql.types.shape.PShape import org.partiql.value.BoolValue import org.partiql.value.MissingValue import org.partiql.value.PartiQLValueExperimental @@ -78,12 +83,13 @@ internal class PlanTyper(private val env: Env, config: Context) { * Rewrite the statement with inferred types and resolved variables */ fun resolve(statement: Statement): Statement { - if (statement !is Statement.Query) { - throw IllegalArgumentException("PartiQLPlanner only supports Query statements") + return when (statement) { + is Statement.DDL -> statement.accept(DdlTyper(), emptyList()) as Statement + is Statement.Query -> { + val root = statement.root.type(emptyList(), emptyList(), Strategy.GLOBAL) + return statementQuery(root) + } } - // root TypeEnv has no bindings - val root = statement.root.type(emptyList(), emptyList(), Strategy.GLOBAL) - return statementQuery(root) } internal companion object { fun PType.static(): CompilerType = CompilerType(this) @@ -1233,6 +1239,158 @@ internal class PlanTyper(private val env: Env, config: Context) { } } + /** + * Consider this as the secondary pass to lower to PShape. + * + * Post this pass: + * 1. Constraints associated with collection type are moved from attribute level to table level. + * 2. Side effect of PRIMARY KEY leads to unique and not null of attribute. + * + * We also verfied that: + * 1. primary key: + * - Only one PRIMARY KEY constraint declaration across the table declaration. + * - PRIMARY KEY cannot be asscoiated with optional attribute + * - PRIMARY KEY can only be associated with attribute which has scalar type + * - Duplicated attribute not allowed in Primary key declaration. + */ + internal inner class DdlTyper : PlanRewriter>() { + override fun visitStatementDDL(node: Statement.DDL, ctx: List): Statement.DDL { + when (val command = node.command) { + is Statement.DDL.Command.CreateTable -> { + val createTable = visitStatementDDLCommandCreateTable(command, ctx) + return statementDDL(createTable) + } + } + } + + override fun visitStatementDDLCommandCreateTable( + node: Statement.DDL.Command.CreateTable, + ctx: List + ): Statement.DDL.Command.CreateTable { + val attrs = node.attributes.map { attr -> + visitStatementDDLAttribute(attr, ctx) + } + + // Make sure that no duplicated PK across Attribute + // i.e., + // FOO INT2 PRIMARY KEY + // BAR INT2 PRIMARY KEY + val (attributePK, attributeUnique) = node.attributes.fold(mutableListOf() to mutableListOf()) { acc, attr -> + val pk = acc.first + val unique = acc.second + if (attr.isPrimaryKey) { + if (pk.isNotEmpty()) { + throw IllegalArgumentException("Multiple primary key constraint declarations are not allowed") + } else pk.add(attr.name.getIdentifier().getText()) + } + if (attr.isUnique) { + unique.add(attr.name.getIdentifier().getText()) + } + pk to unique + } + + // Make sure no duplicated PK across attribute and table level + if (node.primaryKey.isNotEmpty() && attributePK.isNotEmpty()) { + throw IllegalArgumentException("Multiple primary key constraint declarations are not allowed") + } + + // We make sure that the table level unique or primary key constraint does not refers to a non-existing attribute + // Only top level attribute can be associated with a primary key or unique constraint + val tableUniqueContr = node.unique.map { uniqueAttr -> + isDeclaredAttribute(uniqueAttr, attrs.map { DdlField.fromAttr(it) }) ?: throw IllegalArgumentException("Unresolved ref") + } + + val tablePrimaryContr = node.primaryKey.map { pkAtrr -> + isDeclaredAttribute(pkAtrr, attrs.map { DdlField.fromAttr(it) }) ?: throw IllegalArgumentException("Unresolved ref") + } + + // ALSO: For PK + // Thing like PRIMARY KEY (FOO, FOO) will be rejected + if (tablePrimaryContr.toSet().size != tablePrimaryContr.size) { + throw IllegalArgumentException("Attribute appears multiple times in primary key constraint") + } + val finalPK = (tablePrimaryContr + attributePK) + + val finalUnique = (tableUniqueContr + attributeUnique + finalPK).toSet().toList() + + val nonScalarTypeCode = listOf(PType.ROW, PType.STRUCT, PType.ARRAY, PType.DYNAMIC, PType.UNKNOWN, PType.VARIANT) + + val finalAttrs = attrs.map { + val name = it.name.getIdentifier().getText() + if (finalPK.contains(name)) { + when { + it.isOptional -> throw IllegalArgumentException("Optional Attribute $name can not be declared as Primary Key.") + it.type.code() in nonScalarTypeCode -> throw IllegalArgumentException("Primary Key can only be associated with scalar typed attribute") + else -> it.copy(isNullable = false) + } + } + else it + } + + val partitionBy = when (val partitionAttr = node.partitionBy) { + is Statement.DDL.PartitionBy.AttrList -> { + val attrListResolved = partitionAttr.attrs.map { attr -> + isDeclaredAttribute(attr, attrs.map { DdlField.fromAttr(it) }) ?: throw IllegalArgumentException("Unresolved ref") + } + statementDDLPartitionByAttrList(attrListResolved.map { Identifier.of(Identifier.Part.delimited(it)) }) + } + null -> null + } + return statementDDLCommandCreateTable( + node.name, + finalAttrs, + node.tblConstraints, + partitionBy, + node.tableProperties, + finalPK.map { Identifier.of(Identifier.Part.delimited(it)) }, + finalUnique.map { Identifier.of(Identifier.Part.delimited(it)) } + ) + } + + override fun visitStatementDDLAttribute(node: Statement.DDL.Attribute, ctx: List): Statement.DDL.Attribute { + val ddlField = DdlField.fromAttr(node) + return visitDdlField(ddlField, ctx).toAttr() + } + + private fun visitDdlField(node: DdlField, ctx: List): DdlField { + // Make sure those check constraints do not refer to out of scope variable + // and the check constraints resolves to boolean + val resolvedConstraint = node.constraints.map { + it.accept(this, listOf(node)) as Statement.DDL.Constraint + } + val type = if (node.type.code() == PType.ROW) { + val nested = node.type.fields.map { field -> + visitDdlField(field as DdlField, ctx) + } + PType.row(nested) + } else { + node.type + }.let { PShape(it) } + return node.copy(type = type, constraints = resolvedConstraint) + } + + // TODO: Identifier case sensitivity + private fun isDeclaredAttribute(identifier: Identifier, declareAttrs: List): String? { + declareAttrs.forEach { declared -> + if (identifier.matches(declared.name, identifier.getIdentifier().isRegular())) + // Storing as Declared to work around identifier case sensitivity at the moment + return declared.name + } + return null + } + + override fun visitStatementDDLConstraintCheck( + node: Statement.DDL.Constraint.Check, + ctx: List + ): Statement.DDL.Constraint.Check { + val bindings = ctx.map { Rel.Binding(it.name.getIdentifier().getText(), it.type.toCType()) } + val typed = node.expression.type(bindings, emptyList()) + if (typed.type.code() != PType.BOOL) { + throw IllegalArgumentException("Check Constraint - Search condition inferred as a non-boolean type") + } + return node.copy(typed, node.sql) + } + } // HELPERS private fun Rel.type(stack: List, strategy: Strategy = Strategy.LOCAL): Rel = diff --git a/partiql-planner/src/main/resources/partiql_plan_internal.ion b/partiql-planner/src/main/resources/partiql_plan_internal.ion index efc9b4560..ea82d9259 100644 --- a/partiql-planner/src/main/resources/partiql_plan_internal.ion +++ b/partiql-planner/src/main/resources/partiql_plan_internal.ion @@ -9,6 +9,7 @@ imports::{ fn_instance::'org.partiql.spi.function.Function.Instance', agg_signature::'org.partiql.spi.function.Aggregation', table::'org.partiql.spi.catalog.Table', + shape::'org.partiql.types.shape.PShape', ], } @@ -59,6 +60,50 @@ statement::[ query::{ root: rex, }, + d_d_l::{ + command:[ + create_table::{ + // Table Name, unresolved + name: identifier, + attributes: list::[attribute], + tbl_constraints: list::[constraint], + partition_by: optional::partition_by, + table_properties: list::[table_property], + primary_key: list::[identifier], + unique: list::[identifier] + }, + ], + _:[ + // Here we flattern non-check constraint + // into field associated with attribute + // !! Note: Constraint name will be ignored in the plan data model + attribute ::{ + name: identifier, + type: shape, + isNullable: bool, + isOptional: bool, + isPrimaryKey: bool, + isUnique: bool, + constraints: list::[constraint], + comment: optional::string + }, + constraint::[ + check::{ + expression: rex, + sql: string + } + ], + partition_by::[ + attr_list::{ + attrs: list::[identifier] + } + ], + table_property::{ + name: string, + value: string, + }, + ] + } ] // [ ALL | DISTINCT ] diff --git a/partiql-planner/src/test/kotlin/org/partiql/planner/ddl/DDLTest.kt b/partiql-planner/src/test/kotlin/org/partiql/planner/ddl/DDLTest.kt new file mode 100644 index 000000000..af47ea876 --- /dev/null +++ b/partiql-planner/src/test/kotlin/org/partiql/planner/ddl/DDLTest.kt @@ -0,0 +1,626 @@ +package org.partiql.planner.ddl + +import org.junit.jupiter.api.Assertions +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.assertThrows +import org.partiql.parser.PartiQLParser +import org.partiql.plan.Action +import org.partiql.plan.Plan +import org.partiql.planner.PartiQLPlanner +import org.partiql.planner.internal.TestCatalog +import org.partiql.spi.catalog.Session +import org.partiql.spi.catalog.Table +import org.partiql.spi.errors.PErrorException +import org.partiql.types.Field +import org.partiql.types.PType +import org.partiql.types.shape.PShape +import java.math.BigDecimal + +class DDLTest { + + val parser = PartiQLParser.builder().build() + val planner = PartiQLPlanner.standard() + val session = Session.builder() + .catalog("default") + .catalogs( + TestCatalog.builder() + .name("default") + .build() + ) + .namespace("SCHEMA") + .build() + fun plan(ddl: String): Plan { + val statement = parser.parse(ddl).statements.first() + return planner.plan(statement, session).plan + } + + private fun tableAssertionAndReturnFields( + table: Table, + tableName: String, + primaryKey: List = emptyList(), + unique: List = emptyList(), + metadata: Map = emptyMap(), + ): List { + Assertions.assertEquals(tableName, table.getName().getName()) + val schema = table.getSchema() as? PShape + ?: throw AssertionError("Expect Schema to be a PShape") + schema as? PShape ?: throw AssertionError("Expect Schema to be a PShape") + Assertions.assertEquals(PType.BAG, schema.code()) { + "Expect Schema to be a Bag Type" + } + if (primaryKey.isNotEmpty()) { + Assertions.assertTrue(primaryKey.containsAll(schema.primaryKey())) + Assertions.assertTrue(schema.primaryKey().containsAll(primaryKey)) + } + if (unique.isNotEmpty()) { + Assertions.assertTrue(unique.containsAll(schema.unique())) + Assertions.assertTrue(schema.unique().containsAll(unique)) + } + val struct = schema.typeParameter + Assertions.assertEquals(PType.ROW, struct.code()) + return struct.fields.toList() + } + + private fun assertField( + field: Field, + fieldName: String, + code: Int, + isNullable: Boolean, + isOptional: Boolean, + maxValue: Number? = null, + minValue: Number? = null, + precision: Int? = null, + scale: Int? = null, + length: Int? = null, + meta: Map = emptyMap(), + ) { + Assertions.assertEquals(fieldName, field.name) + val shape = field.type as PShape + Assertions.assertEquals(code, shape.code()) + Assertions.assertEquals(isOptional, shape.isOptional) + Assertions.assertEquals(isNullable, shape.isNullable) + if (maxValue != null) { + Assertions.assertTrue(BigDecimal(maxValue.toString()).compareTo(BigDecimal(shape.maxValue().toString())) == 0) { + """ + Expected maxValue to be $maxValue but was: ${shape.maxValue()} + """.trimIndent() + } + } + if (minValue != null) { + Assertions.assertTrue(BigDecimal(minValue.toString()).compareTo(BigDecimal(shape.minValue().toString())) == 0) { + """ + Expected minValue to be $minValue but was: ${shape.minValue()} + """.trimIndent() + } + } + if (precision != null) { + Assertions.assertEquals(precision, shape.precision) + } + if (scale != null) { + Assertions.assertEquals(scale, shape.scale) + } + if (length != null) { + Assertions.assertEquals(length, shape.length) + } + Assertions.assertEquals(meta, shape.meta()) + } + + private fun assertInt2NullableOptional(field: Field, name: String, nullable: Boolean, optional: Boolean) { + assertField( + field, + name, + PType.SMALLINT, + nullable, optional, + Short.MAX_VALUE, Short.MIN_VALUE + ) + } + + @Test + fun createTableBasicTest() { + val ddl = """ + CREATE TABLE foo ( + int2_attr INT2, + int4_attr INT4, + int8_attr INT8, + decimal_attr DECIMAL(10,5), + float_attr REAL, + char_attr CHAR(1), + varchar_attr VARCHAR(255), + timestamp_attr TIMESTAMP(2), + date_attr DATE + ); + """.trimIndent() + val plan = plan(ddl) + val createTable = plan.action as Action.CreateTable + val fields = tableAssertionAndReturnFields( + createTable.table, + "foo" + ) + val int2_attr = fields[0] + assertField( + int2_attr, + "int2_attr", + PType.SMALLINT, + true, false, + Short.MAX_VALUE, Short.MIN_VALUE + ) + val int4_attr = fields[1] + assertField( + int4_attr, + "int4_attr", + PType.INTEGER, + true, false, + Int.MAX_VALUE, Int.MIN_VALUE + ) + val int8_attr = fields[2] + assertField( + int8_attr, + "int8_attr", + PType.BIGINT, + true, false, + Long.MAX_VALUE, Long.MIN_VALUE + ) + val decimal_attr = fields[3] + assertField( + decimal_attr, + "decimal_attr", + PType.DECIMAL, + true, false, + precision = 10, scale = 5, + ) + val float_attr = fields[4] + assertField( + float_attr, + "float_attr", + PType.REAL, + true, false, + ) + val char_attr = fields[5] + assertField( + char_attr, + "char_attr", + PType.CHAR, + true, false, + length = 1 + ) + val varchar_attr = fields[6] + assertField( + varchar_attr, + "varchar_attr", + PType.VARCHAR, + true, false, + length = 255 + ) + val timestamp_attr = fields[7] + assertField( + timestamp_attr, + "timestamp_attr", + PType.TIMESTAMP, + true, false, + precision = 2 + ) + val date_attr = fields[8] + assertField( + date_attr, + "date_attr", + PType.DATE, + true, false, + ) + } + + @Test + fun createTableStructTest() { + val ddl = """ + CREATE TABLE foo ( + struct_attr STRUCT< + int2_attr: INT2, + int4_attr: INT4, + int8_attr: INT8, + decimal_attr: DECIMAL(10,5), + float_attr: REAL, + char_attr: CHAR(1), + varchar_attr: VARCHAR(255), + timestamp_attr: TIMESTAMP(2), + date_attr: DATE + > + ); + """.trimIndent() + val plan = plan(ddl) + val createTable = plan.action as Action.CreateTable + val fields = tableAssertionAndReturnFields( + createTable.table, + "foo" + ) + val struct_attr = fields[0] + assertField(struct_attr, "struct_attr", PType.ROW, true, false ) + val structFields = struct_attr.type.fields.toList() + + val int2_attr = structFields[0] + assertField( + int2_attr, + "int2_attr", + PType.SMALLINT, + true, false, + Short.MAX_VALUE, Short.MIN_VALUE + ) + val int4_attr = structFields[1] + assertField( + int4_attr, + "int4_attr", + PType.INTEGER, + true, false, + Int.MAX_VALUE, Int.MIN_VALUE + ) + val int8_attr = structFields[2] + assertField( + int8_attr, + "int8_attr", + PType.BIGINT, + true, false, + Long.MAX_VALUE, Long.MIN_VALUE + ) + val decimal_attr = structFields[3] + assertField( + decimal_attr, + "decimal_attr", + PType.DECIMAL, + true, false, + precision = 10, scale = 5, + ) + val float_attr = structFields[4] + assertField( + float_attr, + "float_attr", + PType.REAL, + true, false, + ) + val char_attr = structFields[5] + assertField( + char_attr, + "char_attr", + PType.CHAR, + true, false, + length = 1 + ) + val varchar_attr = structFields[6] + assertField( + varchar_attr, + "varchar_attr", + PType.VARCHAR, + true, false, + length = 255 + ) + val timestamp_attr = structFields[7] + assertField( + timestamp_attr, + "timestamp_attr", + PType.TIMESTAMP, + true, false, + precision = 2 + ) + val date_attr = structFields[8] + assertField( + date_attr, + "date_attr", + PType.DATE, + true, false, + ) + } + + @Test + fun createTableNotNullTest() { + val ddl = """ + CREATE TABLE foo ( + attr1 INT2 NOT NULL, + attr2 INT2 NOT NULL NULL, + attr3 INT2 NOT NULL NULL NOT NULL, + attr4 INT2 NULL NOT NULL + ); + """.trimIndent() + + val plan = plan(ddl) + val createTable = plan.action as Action.CreateTable + val fields = tableAssertionAndReturnFields( + createTable.table, + "foo" + ) + + val attr1 = fields[0] + assertInt2NullableOptional(attr1, "attr1", false, false) + val attr2 = fields[1] + assertInt2NullableOptional(attr2, "attr2", true, false) + val attr3 = fields[2] + assertInt2NullableOptional(attr3, "attr3", false, false) + val attr4 = fields[3] + assertInt2NullableOptional(attr4, "attr4", false, false) + } + + @Test + fun createTableOptionalTest() { + val ddl = """ + CREATE TABLE foo ( + attr1 OPTIONAL INT2, + attr2 OPTIONAL INT2 NOT NULL + ); + """.trimIndent() + val plan = plan(ddl) + val createTable = plan.action as Action.CreateTable + val fields = tableAssertionAndReturnFields( + createTable.table, + "foo" + ) + + val attr1 = fields[0] + assertInt2NullableOptional(attr1, "attr1", true, true) + val attr2 = fields[1] + assertInt2NullableOptional(attr2, "attr2", false, true) + } + + @Test + fun createTableCommentTest() { + val ddl = """ + CREATE TABLE foo ( + attr1 INT2 COMMENT 'attr1' + ); + """.trimIndent() + val plan = plan(ddl) + val createTable = plan.action as Action.CreateTable + val fields = tableAssertionAndReturnFields( + createTable.table, + "foo" + ) + + val attr1 = fields[0] + assertField( + attr1, + "attr1", + PType.SMALLINT, + true, false, + Short.MAX_VALUE, Short.MIN_VALUE, + meta = mapOf("comment" to "attr1") + ) + } + + @Test + fun createTableCheckConstraintLoweredTest() { + val ddl = """ + CREATE TABLE foo ( + attr1 INT2 CHECK(attr1 >= 0), + attr2 INT2 CHECK(attr2 >=0 and attr2 <= 10), + attr3 INT2 CHECK(attr3 >= 0) CHECK (attr3 <= 10), + -- Leading to a empty value set + attr4 INT2 CHECK(attr4 >= 1000000) + ); + """.trimIndent() + val plan = plan(ddl) + val createTable = plan.action as Action.CreateTable + val fields = tableAssertionAndReturnFields( + createTable.table, + "foo" + ) + + val attr1 = fields[0] + assertField( + attr1, + "attr1", + PType.SMALLINT, + true, false, + Short.MAX_VALUE, 0, + ) + + val attr2 = fields[1] + assertField( + attr2, + "attr2", + PType.SMALLINT, + true, false, + 10, 0, + ) + + val attr3 = fields[2] + assertField( + attr3, + "attr3", + PType.SMALLINT, + true, false, + 10, 0, + ) + + val attr4 = fields[3] + assertField( + attr4, + "attr4", + PType.SMALLINT, + true, false, + Short.MAX_VALUE, 1000000, + ) + } + + @Test + fun createTablePrimaryKeyInlineTest() { + val ddl = """ + CREATE TABLE foo ( + attr1 INT2 PRIMARY KEY + ); + """.trimIndent() + val plan = plan(ddl) + val createTable = plan.action as Action.CreateTable + val fields = tableAssertionAndReturnFields( + createTable.table, + "foo", + listOf("attr1") + ) + + // Side effect: Nullable is false + val attr1 = fields[0] + assertField( + attr1, + "attr1", + PType.SMALLINT, + false, false, + Short.MAX_VALUE, Short.MIN_VALUE, + ) + } + + @Test + fun createTablePrimaryKeyTableTest() { + val ddl = """ + CREATE TABLE foo ( + attr1 INT2, + -- case insensitive + PRIMARY KEY (ATTR1) + ); + """.trimIndent() + val plan = plan(ddl) + val createTable = plan.action as Action.CreateTable + val fields = tableAssertionAndReturnFields( + createTable.table, + "foo", + listOf("attr1") + ) + + // Side effect: Nullable is false + val attr1 = fields[0] + assertField( + attr1, + "attr1", + PType.SMALLINT, + false, false, + Short.MAX_VALUE, Short.MIN_VALUE, + ) + } + + @Test + fun createTableUniqueKey() { + val ddl = """ + CREATE TABLE foo ( + attr1 INT2 NOT NULL, + attr2 INT2 UNIQUE, + attr3 INT2 PRIMARY KEY, + attr4 INT2, + -- Duplicated declaration + UNIQUE (ATTR2), + UNIQUE (attr1), + UNIQUE (attr1, attr4) + ); + """.trimIndent() + val plan = plan(ddl) + val createTable = plan.action as Action.CreateTable + tableAssertionAndReturnFields( + createTable.table, + "foo", + listOf("attr3"), + // Side effect: attr3 is primary key, therefore it is unique + listOf("attr1", "attr2", "attr3", "attr4") + ) + } + + @Test + fun createTableMetadata() { + val ddl = """ + CREATE TABLE foo ( + attr1 INT2 + ) + TBLPROPERTIES('key' = 'value') + PARTITION BY (attr1); + """.trimIndent() + val plan = plan(ddl) + val createTable = plan.action as Action.CreateTable + tableAssertionAndReturnFields( + createTable.table, + "foo", + metadata = mapOf("key" to "value", "partition" to "[attr1]") + ) + } + + @Test + fun negative_createTable_primaryKey_1() { + val ddl = """ + CREATE TABLE foo ( + attr1 INT2 PRIMARY KEY, + attr2 INT2 PRIMARY KEY, + ); + """.trimIndent() + assertThrows { + plan(ddl) + } + } + + @Test + fun negative_createTable_primaryKey_2() { + val ddl = """ + CREATE TABLE foo ( + attr1 INT2 PRIMARY KEY, + attr2 INT2, + PRIMARY KEY(attr2) + ); + """.trimIndent() + assertThrows { + plan(ddl) + } + } + + @Test + fun negative_createTable_primaryKey_3() { + val ddl = """ + CREATE TABLE foo ( + attr1 INT2, + attr2 INT2, + PRIMARY KEY(attr3) + ); + """.trimIndent() + assertThrows { + plan(ddl) + } + } + + @Test + fun negative_createTable_primaryKey_4() { + val ddl = """ + CREATE TABLE foo ( + "attr1" INT2, + attr2 INT2, + PRIMARY KEY("ATTR1") + ); + """.trimIndent() + assertThrows { + plan(ddl) + } + } + + @Test + fun negative_createTable_primaryKey_5() { + val ddl = """ + CREATE TABLE foo ( + attr1 INT2, + PRIMARY KEY(attr1, attr1) + ); + """.trimIndent() + assertThrows { + plan(ddl) + } + } + + @Test + fun negative_createTable_primaryKey_6() { + val ddl = """ + CREATE TABLE foo ( + attr1 OPTIONAL INT2, + PRIMARY KEY(attr1) + ); + """.trimIndent() + assertThrows { + plan(ddl) + } + } + + @Test + fun negative_createTable_primaryKey_7() { + val ddl = """ + CREATE TABLE foo ( + attr1 ARRAY, + PRIMARY KEY(attr1) + ); + """.trimIndent() + } +} diff --git a/partiql-types/src/main/java/org/partiql/types/shape/PShape.java b/partiql-types/src/main/java/org/partiql/types/shape/PShape.java new file mode 100644 index 000000000..55c38945c --- /dev/null +++ b/partiql-types/src/main/java/org/partiql/types/shape/PShape.java @@ -0,0 +1,136 @@ +package org.partiql.types.shape; + + +import org.jetbrains.annotations.NotNull; +import org.partiql.types.Field; +import org.partiql.types.PType; +import org.partiql.types.shape.trait.UniqueTrait; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +/** + * TODO: Improve API Economic. + */ +public class PShape extends PType { + private final PType type; + + public PShape(PType type) { + super(type.code()); + this.type = type; + } + + public Number maxValue() { + Number number = null; + switch (this.code()) { + case TINYINT: + number = Byte.MAX_VALUE; + break; + case SMALLINT: + number = Short.MAX_VALUE; + break; + case INTEGER: + number = Integer.MAX_VALUE; + break; + case BIGINT: + number = Long.MAX_VALUE; + break; + default: + throw new UnsupportedOperationException("Retrieving max value not supported for type: " + this.name()); + } + return number; + } + + public Number minValue() { + Number number = null; + switch (this.code()) { + case TINYINT: + number = Byte.MIN_VALUE; + break; + case SMALLINT: + number = Short.MIN_VALUE; + break; + case INTEGER: + number = Integer.MIN_VALUE; + break; + case BIGINT: + number = Long.MIN_VALUE; + break; + default: + throw new UnsupportedOperationException("Retrieving max value not supported for type: " + this.name()); + } + return number; + } + + public boolean isNullable() { + return true; + } + + public boolean isOptional() { + return true; + } + + public Map meta() { + return new HashMap<>(); + } + + public Collection primaryKey() { + return new ArrayList<>(); + } + + public Collection unique() { + return new ArrayList<>(); + } + + @Override + public @NotNull Collection getFields() throws UnsupportedOperationException { + return type.getFields(); + } + + @Override + public int getPrecision() throws UnsupportedOperationException { + return type.getPrecision(); + } + + @Override + public int getLength() throws UnsupportedOperationException { + return type.getLength(); + } + + @Override + public int getScale() throws UnsupportedOperationException { + return type.getScale(); + } + + @Override + public @NotNull PType getTypeParameter() throws UnsupportedOperationException { + return type.getTypeParameter(); + } + + @Override + public @NotNull String name() { + return super.name(); + } + + @Override + public @NotNull String toString() { + return type.toString(); + } + + @Override + // TODO: Revisit Equals and hasCode function + public boolean equals(Object obj) { + if (obj == null || getClass() != obj.getClass()) return false; + PShape that = (PShape) obj; + return Objects.equals(type, that.type); + } + + // TODO: Revisit Equals and hasCode function + @Override + public int hashCode() { + return type.hashCode(); + } +} diff --git a/partiql-types/src/main/java/org/partiql/types/shape/trait/ConstraintTrait.java b/partiql-types/src/main/java/org/partiql/types/shape/trait/ConstraintTrait.java new file mode 100644 index 000000000..559675569 --- /dev/null +++ b/partiql-types/src/main/java/org/partiql/types/shape/trait/ConstraintTrait.java @@ -0,0 +1,34 @@ +package org.partiql.types.shape.trait; + +import org.partiql.types.shape.PShape; + +import java.util.Objects; + +/** + * TODO: Improve API Economic. + *

+ * TODO: Equals and HashCode. + */ +public class ConstraintTrait extends PTrait { + private final String expression; + + public ConstraintTrait(PShape shape, String expression) { + super(shape); + this.expression = expression; + } + + @Override + // TODO: Revisit Equals and hasCode function + public boolean equals(Object obj) { + if (obj == null || getClass() != obj.getClass()) return false; + ConstraintTrait that = (ConstraintTrait) obj; + if (!Objects.equals(expression, that.expression)) return false; + return Objects.equals(shape, that.shape); + } + + // TODO: Revisit Equals and hasCode function + @Override + public int hashCode() { + return shape.hashCode(); + } +} diff --git a/partiql-types/src/main/java/org/partiql/types/shape/trait/MetadataTrait.java b/partiql-types/src/main/java/org/partiql/types/shape/trait/MetadataTrait.java new file mode 100644 index 000000000..c44a01a1e --- /dev/null +++ b/partiql-types/src/main/java/org/partiql/types/shape/trait/MetadataTrait.java @@ -0,0 +1,49 @@ +package org.partiql.types.shape.trait; + +import org.jetbrains.annotations.NotNull; +import org.partiql.types.shape.PShape; + +import java.util.Map; +import java.util.Objects; + +/** + * TODO: Improve API Economic. + *

+ * TODO: Equals and HashCode. + */ +public class MetadataTrait extends PTrait { + @NotNull + private final String name; + @NotNull + private final String value; + + public MetadataTrait(PShape shape, @NotNull String name, @NotNull String value) { + super(shape); + this.name = name; + this.value = value; + } + + @Override + public Map meta() { + Map map = super.meta(); + map.put(name, value); + return map; + } + + + @Override + // TODO: Revisit Equals and hasCode function + public boolean equals(Object obj) { + if (obj == null || getClass() != obj.getClass()) return false; + MetadataTrait that = (MetadataTrait) obj; + if (!Objects.equals(name, that.name)) return false; + if (!Objects.equals(value, that.value)) return false; + return Objects.equals(shape, that.shape); + } + + // TODO: Revisit Equals and hasCode function + @Override + public int hashCode() { + return shape.hashCode(); + } +} diff --git a/partiql-types/src/main/java/org/partiql/types/shape/trait/NotNullTrait.java b/partiql-types/src/main/java/org/partiql/types/shape/trait/NotNullTrait.java new file mode 100644 index 000000000..95553a23d --- /dev/null +++ b/partiql-types/src/main/java/org/partiql/types/shape/trait/NotNullTrait.java @@ -0,0 +1,35 @@ +package org.partiql.types.shape.trait; + +import org.partiql.types.shape.PShape; + +import java.util.Objects; + +/** + * TODO: Improve API Economic. + *

+ * TODO: Equals and HashCode. + */ +public class NotNullTrait extends PTrait { + public NotNullTrait(PShape shape) { + super(shape); + } + + @Override + public boolean isNullable() { + return false; + } + + @Override + // TODO: Revisit Equals and hasCode function + public boolean equals(Object obj) { + if (obj == null || getClass() != obj.getClass()) return false; + NotNullTrait that = (NotNullTrait) obj; + return Objects.equals(shape, that.shape); + } + + // TODO: Revisit Equals and hasCode function + @Override + public int hashCode() { + return shape.hashCode(); + } +} diff --git a/partiql-types/src/main/java/org/partiql/types/shape/trait/PTrait.java b/partiql-types/src/main/java/org/partiql/types/shape/trait/PTrait.java new file mode 100644 index 000000000..298a3bca5 --- /dev/null +++ b/partiql-types/src/main/java/org/partiql/types/shape/trait/PTrait.java @@ -0,0 +1,108 @@ +package org.partiql.types.shape.trait; + +import org.jetbrains.annotations.NotNull; +import org.partiql.types.Field; +import org.partiql.types.PType; +import org.partiql.types.shape.PShape; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Map; +import java.util.Objects; + +/** + * TODO: Improve API Economic. + *

+ * TODO: Equals and HashCode. + */ +public abstract class PTrait extends PShape { + PShape shape; + + protected PTrait(PShape shape) { + super(shape); + this.shape = shape; + } + + @Override + public Number maxValue() { + return shape.maxValue(); + } + + @Override + public Number minValue() { + return shape.minValue(); + } + + @Override + public boolean isNullable() { + return shape.isNullable(); + } + + @Override + public boolean isOptional() { + return shape.isOptional(); + } + + @Override + public Map meta() { + return shape.meta(); + } + + @Override + public Collection primaryKey() { + return shape.primaryKey(); + } + + public Collection unique() { + return shape.unique(); + } + + @Override + public @NotNull Collection getFields() throws UnsupportedOperationException { + return shape.getFields(); + } + + @Override + public int getPrecision() throws UnsupportedOperationException { + return shape.getPrecision(); + } + + @Override + public int getLength() throws UnsupportedOperationException { + return shape.getLength(); + } + + @Override + public int getScale() throws UnsupportedOperationException { + return shape.getScale(); + } + + @Override + public @NotNull PType getTypeParameter() throws UnsupportedOperationException { + return shape.getTypeParameter(); + } + + @Override + public @NotNull String name() { + return shape.name(); + } + + @Override + public @NotNull String toString() { + return shape.toString(); + } + + @Override + // TODO: Revisit Equals and hasCode function + public boolean equals(Object obj) { + if (obj == null || getClass() != obj.getClass()) return false; + PTrait that = (PTrait) obj; + return Objects.equals(shape, that.shape); + } + + // TODO: Revisit Equals and hasCode function + @Override + public int hashCode() { + return shape.hashCode(); + } +} diff --git a/partiql-types/src/main/java/org/partiql/types/shape/trait/PrimaryKeyTrait.java b/partiql-types/src/main/java/org/partiql/types/shape/trait/PrimaryKeyTrait.java new file mode 100644 index 000000000..6f3cd1eb9 --- /dev/null +++ b/partiql-types/src/main/java/org/partiql/types/shape/trait/PrimaryKeyTrait.java @@ -0,0 +1,42 @@ +package org.partiql.types.shape.trait; + +import org.partiql.types.shape.PShape; + +import java.util.Collection; +import java.util.List; +import java.util.Objects; + +/** + * TODO: Improve API Economic. + *

+ * TODO: Equals and HashCode. + */ +public class PrimaryKeyTrait extends PTrait { + private final List identifier; + + public PrimaryKeyTrait(PShape shape, List identifier) { + super(shape); + this.identifier = identifier; + } + + @Override + public Collection primaryKey() { + return identifier; + } + + @Override + // TODO: Revisit Equals and hasCode function + public boolean equals(Object obj) { + if (obj == null || getClass() != obj.getClass()) return false; + PrimaryKeyTrait that = (PrimaryKeyTrait) obj; + if (!identifier.containsAll(that.identifier)) return false; + if (!that.identifier.containsAll(identifier)) return false; + return Objects.equals(shape, that.shape); + } + + // TODO: Revisit Equals and hasCode function + @Override + public int hashCode() { + return shape.hashCode(); + } +} diff --git a/partiql-types/src/main/java/org/partiql/types/shape/trait/RangeTrait.java b/partiql-types/src/main/java/org/partiql/types/shape/trait/RangeTrait.java new file mode 100644 index 000000000..b3aa5bf41 --- /dev/null +++ b/partiql-types/src/main/java/org/partiql/types/shape/trait/RangeTrait.java @@ -0,0 +1,54 @@ +package org.partiql.types.shape.trait; + +import org.partiql.types.shape.PShape; + +import java.util.Objects; + +/** + * TODO: Improve API Economic. + *

+ * TODO: Equals and HashCode. + */ +public class RangeTrait extends PTrait { + private final Number minValue; + private final Number maxValue; + + public RangeTrait(PShape shape, Number minValue, Number maxValue) { + super(shape); + this.minValue = minValue; + this.maxValue = maxValue; + } + + @Override + public Number minValue() { + if (minValue == null) { + return super.minValue(); + } + + return Math.max(shape.minValue().doubleValue(), minValue.doubleValue()); + } + + @Override + public Number maxValue() { + if (maxValue == null) { + return super.maxValue(); + } + return Math.min(shape.maxValue().doubleValue(), maxValue.doubleValue()); + } + + @Override + // TODO: Revisit Equals and hasCode function + public boolean equals(Object obj) { + if (obj == null || getClass() != obj.getClass()) return false; + RangeTrait that = (RangeTrait) obj; + if (!Objects.equals(maxValue, that.maxValue)) return false; + if (!Objects.equals(minValue, that.minValue)) return false; + return Objects.equals(shape, that.shape); + } + + // TODO: Revisit Equals and hasCode function + @Override + public int hashCode() { + return shape.hashCode(); + } +} diff --git a/partiql-types/src/main/java/org/partiql/types/shape/trait/RequiredTrait.java b/partiql-types/src/main/java/org/partiql/types/shape/trait/RequiredTrait.java new file mode 100644 index 000000000..121210f69 --- /dev/null +++ b/partiql-types/src/main/java/org/partiql/types/shape/trait/RequiredTrait.java @@ -0,0 +1,36 @@ +package org.partiql.types.shape.trait; + +import org.partiql.types.shape.PShape; + +import java.util.Objects; + +/** + * TODO: Improve API Economic. + *

+ * TODO: Equals and HashCode. + */ +public class RequiredTrait extends PTrait { + + public RequiredTrait(PShape shape) { + super(shape); + } + + @Override + public boolean isOptional() { + return false; + } + + @Override + // TODO: Revisit Equals and hasCode function + public boolean equals(Object obj) { + if (obj == null || getClass() != obj.getClass()) return false; + RequiredTrait that = (RequiredTrait) obj; + return Objects.equals(shape, that.shape); + } + + // TODO: Revisit Equals and hasCode function + @Override + public int hashCode() { + return shape.hashCode(); + } +} diff --git a/partiql-types/src/main/java/org/partiql/types/shape/trait/UniqueTrait.java b/partiql-types/src/main/java/org/partiql/types/shape/trait/UniqueTrait.java new file mode 100644 index 000000000..0e6d39a29 --- /dev/null +++ b/partiql-types/src/main/java/org/partiql/types/shape/trait/UniqueTrait.java @@ -0,0 +1,44 @@ +package org.partiql.types.shape.trait; + +import org.partiql.types.shape.PShape; + +import java.util.Collection; +import java.util.List; +import java.util.Objects; + +/** + * TODO: Improve API Economic. + *

+ * TODO: Equals and HashCode. + */ +public class UniqueTrait extends PTrait { + private final List identifier; + + public UniqueTrait(PShape shape, List identifier) { + super(shape); + this.identifier = identifier; + } + + @Override + public Collection unique() { + Collection unique = shape.unique(); + unique.addAll(identifier); + return unique; + } + + @Override + // TODO: Revisit Equals and hasCode function + public boolean equals(Object obj) { + if (obj == null || getClass() != obj.getClass()) return false; + UniqueTrait that = (UniqueTrait) obj; + if (!identifier.containsAll(that.identifier)) return false; + if (!that.identifier.containsAll(identifier)) return false; + return Objects.equals(shape, that.shape); + } + + // TODO: Revisit Equals and hasCode function + @Override + public int hashCode() { + return shape.hashCode(); + } +}