diff --git a/modules/core/src/main/scala/playground/QueryCompilerVisitor.scala b/modules/core/src/main/scala/playground/QueryCompilerVisitor.scala index 851080b5..97756614 100644 --- a/modules/core/src/main/scala/playground/QueryCompilerVisitor.scala +++ b/modules/core/src/main/scala/playground/QueryCompilerVisitor.scala @@ -58,18 +58,22 @@ object QueryCompilerVisitor { object QueryCompilerVisitorInternal extends SchemaVisitor[QueryCompiler] { - private def checkRange[A, B]( - pc: QueryCompiler[A] + private def checkRange[B]( + pc: QueryCompiler[BigDecimal] )( tag: String )( - matchToRange: A => Option[B] + matchToRange: PartialFunction[BigDecimal, B] ) = (pc, QueryCompiler.pos).tupled.emap { case (i, range) => - matchToRange(i) - .toRightIor( + Either + .catchOnly[ArithmeticException](matchToRange.lift(i).toRight(())) + .leftWiden[Any] + .flatten + .leftMap(_ => CompilationError .error(NumberOutOfRange(i.toString, tag), range) ) + .toIor .toIorNec } @@ -85,12 +89,12 @@ object QueryCompilerVisitorInternal extends SchemaVisitor[QueryCompiler] { .typeCheck(NodeKind.Bool) { case b @ BooleanLiteral(_) => b } .map(_.value.value) case PUnit => struct(shapeId, hints, Vector.empty, _ => ()) - case PLong => checkRange(integer)("int")(_.toLongOption) - case PInt => checkRange(integer)("int")(_.toIntOption) - case PShort => checkRange(integer)("short")(_.toShortOption) - case PByte => checkRange(integer)("byte")(_.toByteOption) - case PFloat => checkRange(integer)("float")(_.toFloatOption) - case PDouble => checkRange(integer)("double")(_.toDoubleOption) + case PLong => checkRange(number)("int")(_.toLongExact) + case PInt => checkRange(number)("int")(_.toIntExact) + case PShort => checkRange(number)("short")(_.toShortExact) + case PByte => checkRange(number)("byte")(_.toByteExact) + case PFloat => checkRange(number)("float") { case i if i.isDecimalFloat => i.toFloat } + case PDouble => checkRange(number)("double") { case i if i.isDecimalDouble => i.toDouble } case PDocument => document case PBlob => (string, QueryCompiler.pos).tupled.emap { case (s, range) => @@ -101,14 +105,8 @@ object QueryCompilerVisitorInternal extends SchemaVisitor[QueryCompiler] { .toIor .toIorNec } - case PBigDecimal => - checkRange(integer)("bigdecimal") { s => - Either.catchNonFatal(BigDecimal(s)).toOption - } - case PBigInt => - checkRange(integer)("bigint") { s => - Either.catchNonFatal(BigInt(s)).toOption - } + case PBigDecimal => number + case PBigInt => checkRange(number)("bigint")(_.toBigIntExact.get) case PUUID => stringLiteral.emap { s => Either @@ -137,9 +135,10 @@ object QueryCompilerVisitorInternal extends SchemaVisitor[QueryCompiler] { } } - private val integer: QueryCompiler[String] = QueryCompiler + private val number: QueryCompiler[BigDecimal] = QueryCompiler .typeCheck(NodeKind.IntLiteral) { case i @ IntLiteral(_) => i } .map(_.value.value) + .map(BigDecimal(_)) def collection[C[_], A]( shapeId: ShapeId, diff --git a/modules/core/src/test/scala/playground/smithyql/CompilationTests.scala b/modules/core/src/test/scala/playground/smithyql/CompilationTests.scala index e061bab0..b0dc813e 100644 --- a/modules/core/src/test/scala/playground/smithyql/CompilationTests.scala +++ b/modules/core/src/test/scala/playground/smithyql/CompilationTests.scala @@ -265,6 +265,23 @@ object CompilationTests extends SimpleIOSuite with Checkers { ) } + pureTest("int with exponential syntax - in range") { + assertNoDiff( + compile { + WithSource.liftId(IntLiteral("1e2").mapK(WithSource.liftId)) + }(Schema.int), + Ior.right(100), + ) + } + + pureTest("int with exponential syntax - in range, but not an integer") { + assert( + compile { + WithSource.liftId(IntLiteral("10.1e0").mapK(WithSource.liftId)) + }(Schema.int).isLeft + ) + } + pureTest("short") { assertNoDiff( compile { @@ -282,6 +299,15 @@ object CompilationTests extends SimpleIOSuite with Checkers { ) } + pureTest("short with exponential syntax - in range") { + assertNoDiff( + compile { + WithSource.liftId(IntLiteral("1e2").mapK(WithSource.liftId)) + }(Schema.short), + Ior.right(100.toShort), + ) + } + pureTest("byte") { assertNoDiff( compile { @@ -299,8 +325,17 @@ object CompilationTests extends SimpleIOSuite with Checkers { ) } - pureTest("float") { + pureTest("byte with exponential syntax - in range") { assertNoDiff( + compile { + WithSource.liftId(IntLiteral("1e2").mapK(WithSource.liftId)) + }(Schema.byte), + Ior.right(100.toByte), + ) + } + + pureTest("float") { + assert.same( compile { WithSource.liftId(Float.MaxValue.mapK(WithSource.liftId)) }(Schema.float), @@ -316,6 +351,15 @@ object CompilationTests extends SimpleIOSuite with Checkers { ) } + pureTest("float - exponential syntax") { + assertNoDiff( + compile { + WithSource.liftId(IntLiteral("0.1e0").mapK(WithSource.liftId)) + }(Schema.float), + Ior.right(0.1f), + ) + } + pureTest("double") { assertNoDiff( compile { @@ -334,6 +378,15 @@ object CompilationTests extends SimpleIOSuite with Checkers { ) } + pureTest("double - exponential syntax") { + assertNoDiff( + compile { + WithSource.liftId(IntLiteral("0.1e0").mapK(WithSource.liftId)) + }(Schema.double), + Ior.right(0.1), + ) + } + test("bigint - OK") { forall { (bi: BigInt) => assertNoDiff( @@ -353,6 +406,15 @@ object CompilationTests extends SimpleIOSuite with Checkers { ) } + pureTest("bigint - exponential syntax") { + assertNoDiff( + compile { + WithSource.liftId(IntLiteral("1e2").mapK(WithSource.liftId)) + }(Schema.bigint), + Ior.right(BigInt(100)), + ) + } + test("bigdecimal - OK") { forall { (bd: BigDecimal) => assertNoDiff( @@ -372,6 +434,15 @@ object CompilationTests extends SimpleIOSuite with Checkers { ) } + pureTest("bigdecimal - exponential syntax") { + assertNoDiff( + compile { + WithSource.liftId(IntLiteral("1e2").mapK(WithSource.liftId)) + }(Schema.bigdecimal), + Ior.right(BigDecimal(100)), + ) + } + pureTest("boolean") { assertNoDiff( compile {