From e66d17e7cce2eca95404aa63f3cf6ea6ac1e7043 Mon Sep 17 00:00:00 2001 From: Raman Sarvaria Date: Fri, 29 Apr 2022 11:15:04 -0400 Subject: [PATCH] Initial attempt to update JSON parsing to allow for optional numeric boolean values. --- .../scala/scalapb/json4s/JsonFormat.scala | 21 ++++++++++++++--- .../scala/scalapb/json4s/JsonFormatSpec.scala | 23 +++++++++++++++++++ 2 files changed, 41 insertions(+), 3 deletions(-) diff --git a/src/main/scala/scalapb/json4s/JsonFormat.scala b/src/main/scala/scalapb/json4s/JsonFormat.scala index 80771bd7..371548ec 100644 --- a/src/main/scala/scalapb/json4s/JsonFormat.scala +++ b/src/main/scala/scalapb/json4s/JsonFormat.scala @@ -9,7 +9,7 @@ import com.google.protobuf.timestamp.Timestamp import scalapb.json4s.JsonFormat.GenericCompanion import scalapb._ import org.json4s.JsonAST._ -import org.json4s.{Reader, Writer, MappingException} +import org.json4s.{MappingException, Reader, Writer} import scala.collection.mutable import scala.collection.concurrent.TrieMap @@ -464,6 +464,7 @@ object Parser { private final case class ParserConfig( isIgnoringUnknownFields: Boolean, isIgnoringOverlappingOneofFields: Boolean, + areNumericBooleanValuesAllowed: Boolean, mapEntriesAsKeyValuePairs: Boolean, formatRegistry: FormatRegistry, typeRegistry: TypeRegistry @@ -476,6 +477,7 @@ class Parser private (config: Parser.ParserConfig) { Parser.ParserConfig( isIgnoringUnknownFields = false, isIgnoringOverlappingOneofFields = false, + areNumericBooleanValuesAllowed = false, mapEntriesAsKeyValuePairs = false, JsonFormat.DefaultRegistry, TypeRegistry.empty @@ -495,6 +497,7 @@ class Parser private (config: Parser.ParserConfig) { Parser.ParserConfig( isIgnoringUnknownFields = false, isIgnoringOverlappingOneofFields = false, + areNumericBooleanValuesAllowed = false, mapEntriesAsKeyValuePairs = false, formatRegistry, typeRegistry @@ -507,6 +510,9 @@ class Parser private (config: Parser.ParserConfig) { def ignoringOverlappingOneofFields: Parser = new Parser(config.copy(isIgnoringOverlappingOneofFields = true)) + def allowNumericBooleanValues: Parser = + new Parser(config.copy(areNumericBooleanValuesAllowed = true)) + def mapEntriesAsKeyValuePairs: Parser = new Parser(config.copy(mapEntriesAsKeyValuePairs = true)) @@ -714,7 +720,8 @@ class Parser private (config: Parser.ParserConfig) { value, throw new JsonFormatException( s"Unexpected value ($value) for field ${fd.name} of ${fd.containingMessage.name}" - ) + ), + allowNumericBooleans = config.areNumericBooleanValuesAllowed ) } } @@ -902,7 +909,8 @@ object JsonFormat { def parsePrimitive( protoType: FieldDescriptorProto.Type, value: JValue, - onError: => PValue + onError: => PValue, + allowNumericBooleans: Boolean = false ): PValue = (protoType, value) match { case (Type.TYPE_UINT32 | Type.TYPE_FIXED32, JInt(x)) => @@ -966,6 +974,13 @@ object JsonFormat { case (Type.TYPE_BOOL, JBool(b)) => PBoolean(b) case (Type.TYPE_BOOL, JString("true")) => PBoolean(true) case (Type.TYPE_BOOL, JString("false")) => PBoolean(false) + case (Type.TYPE_BOOL, JInt(i)) => { + (allowNumericBooleans, i.toInt) match { + case (true, 0) => PBoolean(false) + case (true, 1) => PBoolean(true) + case (_, _) => onError + } + } case (Type.TYPE_STRING, JString(s)) => PString(s) case (Type.TYPE_BYTES, JString(s)) => PByteString( diff --git a/src/test/scala/scalapb/json4s/JsonFormatSpec.scala b/src/test/scala/scalapb/json4s/JsonFormatSpec.scala index e030fb17..75357800 100644 --- a/src/test/scala/scalapb/json4s/JsonFormatSpec.scala +++ b/src/test/scala/scalapb/json4s/JsonFormatSpec.scala @@ -603,6 +603,29 @@ class JsonFormatSpec assertRejects("optionalDouble", (minDouble.multiply(moreThanOne).toString)) } + "parser" should "parse numeric boolean values when enabled" in { + val parser = new Parser().allowNumericBooleanValues + def validateRejects(json: String): Assertion = { + a[JsonFormatException] mustBe thrownBy { + parser.fromJsonString[MyTest](json) + } + } + + parser.fromJsonString[MyTest]("""{"optBool":1}""") must be( + MyTest(optBool = Some(true)) + ) + parser.fromJsonString[MyTest]("""{"optBool":0}""") must be( + MyTest(optBool = Some(false)) + ) + + // Only 0 and 1 as integers should be parsed + validateRejects("""{"optBool":2}""") + validateRejects("""{"optBool":-1}""") + validateRejects("""{"optBool":"0"}""") + validateRejects("""{"optBool":"1"}""") + } + + val anyEnabledJavaTypeRegistry = JavaTypeRegistry .newBuilder() .add(TestProto.companion.javaDescriptor)