Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support numbers with exponential syntax outside of bigints/bigdecimals #342

Merged
merged 2 commits into from
Oct 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 19 additions & 20 deletions modules/core/src/main/scala/playground/QueryCompilerVisitor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -58,18 +58,22 @@ object QueryCompilerVisitor {

object QueryCompilerVisitorInternal extends SchemaVisitor[QueryCompiler] {

private def checkRange[A, B](
pc: QueryCompiler[A]
private def checkRange[B](
pc: QueryCompiler[BigDecimal]
)(
tag: String
)(
matchToRange: A => Option[B]
matchToRange: PartialFunction[BigDecimal, B]
) = (pc, QueryCompiler.pos).tupled.emap { case (i, range) =>
matchToRange(i)
.toRightIor(
Either
.catchOnly[ArithmeticException](matchToRange.lift(i).toRight(()))
.leftWiden[Any]
.flatten
.leftMap(_ =>
CompilationError
.error(NumberOutOfRange(i.toString, tag), range)
)
.toIor
.toIorNec
}

Expand All @@ -85,12 +89,12 @@ object QueryCompilerVisitorInternal extends SchemaVisitor[QueryCompiler] {
.typeCheck(NodeKind.Bool) { case b @ BooleanLiteral(_) => b }
.map(_.value.value)
case PUnit => struct(shapeId, hints, Vector.empty, _ => ())
case PLong => checkRange(integer)("int")(_.toLongOption)
case PInt => checkRange(integer)("int")(_.toIntOption)
case PShort => checkRange(integer)("short")(_.toShortOption)
case PByte => checkRange(integer)("byte")(_.toByteOption)
case PFloat => checkRange(integer)("float")(_.toFloatOption)
case PDouble => checkRange(integer)("double")(_.toDoubleOption)
case PLong => checkRange(number)("int")(_.toLongExact)
case PInt => checkRange(number)("int")(_.toIntExact)
case PShort => checkRange(number)("short")(_.toShortExact)
case PByte => checkRange(number)("byte")(_.toByteExact)
case PFloat => checkRange(number)("float") { case i if i.isDecimalFloat => i.toFloat }
case PDouble => checkRange(number)("double") { case i if i.isDecimalDouble => i.toDouble }
case PDocument => document
case PBlob =>
(string, QueryCompiler.pos).tupled.emap { case (s, range) =>
Expand All @@ -101,14 +105,8 @@ object QueryCompilerVisitorInternal extends SchemaVisitor[QueryCompiler] {
.toIor
.toIorNec
}
case PBigDecimal =>
checkRange(integer)("bigdecimal") { s =>
Either.catchNonFatal(BigDecimal(s)).toOption
}
case PBigInt =>
checkRange(integer)("bigint") { s =>
Either.catchNonFatal(BigInt(s)).toOption
}
case PBigDecimal => number
case PBigInt => checkRange(number)("bigint")(_.toBigIntExact.get)
case PUUID =>
stringLiteral.emap { s =>
Either
Expand Down Expand Up @@ -137,9 +135,10 @@ object QueryCompilerVisitorInternal extends SchemaVisitor[QueryCompiler] {
}
}

private val integer: QueryCompiler[String] = QueryCompiler
private val number: QueryCompiler[BigDecimal] = QueryCompiler
.typeCheck(NodeKind.IntLiteral) { case i @ IntLiteral(_) => i }
.map(_.value.value)
.map(BigDecimal(_))

def collection[C[_], A](
shapeId: ShapeId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,23 @@ object CompilationTests extends SimpleIOSuite with Checkers {
)
}

pureTest("int with exponential syntax - in range") {
assertNoDiff(
compile {
WithSource.liftId(IntLiteral("1e2").mapK(WithSource.liftId))
}(Schema.int),
Ior.right(100),
)
}

pureTest("int with exponential syntax - in range, but not an integer") {
assert(
compile {
WithSource.liftId(IntLiteral("10.1e0").mapK(WithSource.liftId))
}(Schema.int).isLeft
)
}

pureTest("short") {
assertNoDiff(
compile {
Expand All @@ -282,6 +299,15 @@ object CompilationTests extends SimpleIOSuite with Checkers {
)
}

pureTest("short with exponential syntax - in range") {
assertNoDiff(
compile {
WithSource.liftId(IntLiteral("1e2").mapK(WithSource.liftId))
}(Schema.short),
Ior.right(100.toShort),
)
}

pureTest("byte") {
assertNoDiff(
compile {
Expand All @@ -299,8 +325,17 @@ object CompilationTests extends SimpleIOSuite with Checkers {
)
}

pureTest("float") {
pureTest("byte with exponential syntax - in range") {
assertNoDiff(
compile {
WithSource.liftId(IntLiteral("1e2").mapK(WithSource.liftId))
}(Schema.byte),
Ior.right(100.toByte),
)
}

pureTest("float") {
assert.same(
compile {
WithSource.liftId(Float.MaxValue.mapK(WithSource.liftId))
}(Schema.float),
Expand All @@ -316,6 +351,15 @@ object CompilationTests extends SimpleIOSuite with Checkers {
)
}

pureTest("float - exponential syntax") {
assertNoDiff(
compile {
WithSource.liftId(IntLiteral("0.1e0").mapK(WithSource.liftId))
}(Schema.float),
Ior.right(0.1f),
)
}

pureTest("double") {
assertNoDiff(
compile {
Expand All @@ -334,6 +378,15 @@ object CompilationTests extends SimpleIOSuite with Checkers {
)
}

pureTest("double - exponential syntax") {
assertNoDiff(
compile {
WithSource.liftId(IntLiteral("0.1e0").mapK(WithSource.liftId))
}(Schema.double),
Ior.right(0.1),
)
}

test("bigint - OK") {
forall { (bi: BigInt) =>
assertNoDiff(
Expand All @@ -353,6 +406,15 @@ object CompilationTests extends SimpleIOSuite with Checkers {
)
}

pureTest("bigint - exponential syntax") {
assertNoDiff(
compile {
WithSource.liftId(IntLiteral("1e2").mapK(WithSource.liftId))
}(Schema.bigint),
Ior.right(BigInt(100)),
)
}

test("bigdecimal - OK") {
forall { (bd: BigDecimal) =>
assertNoDiff(
Expand All @@ -372,6 +434,15 @@ object CompilationTests extends SimpleIOSuite with Checkers {
)
}

pureTest("bigdecimal - exponential syntax") {
assertNoDiff(
compile {
WithSource.liftId(IntLiteral("1e2").mapK(WithSource.liftId))
}(Schema.bigdecimal),
Ior.right(BigDecimal(100)),
)
}

pureTest("boolean") {
assertNoDiff(
compile {
Expand Down