diff --git a/modules/proto/src/smithyproto/proto3/Compiler.scala b/modules/proto/src/smithyproto/proto3/Compiler.scala index 70f23d5..329c0c3 100644 --- a/modules/proto/src/smithyproto/proto3/Compiler.scala +++ b/modules/proto/src/smithyproto/proto3/Compiler.scala @@ -39,7 +39,7 @@ class Compiler(model: Model, allShapes: Boolean) { import ProtoIR._ - private val allRelevantShapes: Set[Shape] = { + private lazy val allRelevantShapes: Set[Shape] = { if (allShapes) { model .shapes() @@ -60,7 +60,7 @@ class Compiler(model: Model, allShapes: Boolean) { } } - private val conflictingEnumValues: Set[MemberShape] = { + private lazy val conflictingEnumValues: Set[MemberShape] = { val enumMembers = allRelevantShapes.collect { case m: MemberShape => m }.filter { m => val container = model.expectShape(m.getContainer()) @@ -280,8 +280,23 @@ class Compiler(model: Model, allShapes: Boolean) { } } + private def shouldWrapCollection(shape: Shape): Boolean = { + val hasWrapped = hasProtoWrapped(shape) + val membersTargetingThis = + model.getMemberShapes().asScala.filter(_.getTarget() == shape.getId()) + val isTargetedByWrappedMember = + membersTargetingThis.exists(hasProtoWrapped(_)) + // oneofs cannot have lists / maps fields + val isTargetedByUnionMember = + membersTargetingThis.exists(member => + model.expectShape(member.getContainer()).isUnionShape + ) + + hasWrapped || isTargetedByWrappedMember || isTargetedByUnionMember + } + override def listShape(shape: ListShape): TopLevelDefs = { - if (hasProtoWrapped(shape)) { + if (shouldWrapCollection(shape)) { shape.getMember().accept(typeVisitor()).toList.flatMap { tpe => topLevelMessage(shape, Type.ListType(tpe)) } @@ -289,7 +304,7 @@ class Compiler(model: Model, allShapes: Boolean) { } override def mapShape(shape: MapShape): TopLevelDefs = { - if (hasProtoWrapped(shape)) { + if (shouldWrapCollection(shape)) { for { keyType <- shape.getKey().accept(typeVisitor()).toList valueType <- shape.getValue().accept(typeVisitor()).toList @@ -420,7 +435,14 @@ class Compiler(model: Model, allShapes: Boolean) { // We assume the model is well-formed so the result should be non-null val targetShape = model.expectShape(m.getTarget) val numType = extractNumType(m) - val isWrapped = hasProtoWrapped(m) || hasProtoWrapped(targetShape) + val isWrapped = { + val memberHasWrapped = hasProtoWrapped(m) + val targetHasWrapped = hasProtoWrapped(targetShape) + // repeated / map fields cannot be in oneofs + val isList = targetShape.isListShape() + val isMap = targetShape.isMapShape() + memberHasWrapped || targetHasWrapped || isList || isMap + } val fieldType = targetShape .accept(typeVisitor(isWrapped = isWrapped, numType)) @@ -545,14 +567,17 @@ class Compiler(model: Model, allShapes: Boolean) { } def listShape(shape: ListShape): Option[Type] = { - shape.getMember().accept(typeVisitor()).map(Type.ListType(_)) + if (isWrapped) Some(Type.RefType(shape)) + else shape.getMember().accept(typeVisitor()).map(Type.ListType(_)) } def mapShape(shape: MapShape): Option[Type] = { - for { - key <- shape.getKey().accept(typeVisitor()) - value <- shape.getValue().accept(typeVisitor()) - } yield Type.MapType(key, value) + if (isWrapped) Some(Type.RefType(shape)) + else + for { + key <- shape.getKey().accept(typeVisitor()) + value <- shape.getValue().accept(typeVisitor()) + } yield Type.MapType(key, value) } def memberShape(shape: MemberShape): Option[Type] = { diff --git a/modules/proto/tests/src/CompilerRendererSuite.scala b/modules/proto/tests/src/CompilerRendererSuite.scala index 1df0457..8e8de67 100644 --- a/modules/proto/tests/src/CompilerRendererSuite.scala +++ b/modules/proto/tests/src/CompilerRendererSuite.scala @@ -27,9 +27,20 @@ class CompilerRendererSuite extends FunSuite { | |use alloy.proto#protoWrapped | + |list StringList { + | member: String + |} + | + |map StringMap { + | key: String, + | value: String + |} + | |union MyUnion { | name: String | id: Integer + | stringList: StringList + | stringMap: StringMap |} |""".stripMargin @@ -37,10 +48,20 @@ class CompilerRendererSuite extends FunSuite { | |package com.example; | + |message StringList { + | repeated string value = 1; + |} + | + |message StringMap { + | map value = 1; + |} + | |message MyUnion { | oneof definition { | string name = 1; | int32 id = 2; + | com.example.StringList stringList = 3; + | com.example.StringMap stringMap = 4; | } |} |""".stripMargin @@ -632,14 +653,21 @@ class CompilerRendererSuite extends FunSuite { | |structure Foo { | strings: StringMap + | @alloy.proto#protoWrapped + | wrappedStrings: StringMap |} |""".stripMargin val expected = """|syntax = "proto3"; | |package com.example; | + |message StringMap { + | map value = 1; + |} + | |message Foo { | map strings = 1; + | com.example.StringMap wrappedStrings = 2; |} |""".stripMargin convertCheck(source, Map("com/example/example.proto" -> expected)) @@ -715,6 +743,8 @@ class CompilerRendererSuite extends FunSuite { | |structure Foo { | strings: List + | @alloy.proto#protoWrapped + | wrappedStrings: List |} |""".stripMargin @@ -726,8 +756,13 @@ class CompilerRendererSuite extends FunSuite { | string name = 1; |} | + |message List { + | repeated com.example.ListItem value = 1; + |} + | |message Foo { | repeated com.example.ListItem strings = 1; + | com.example.List wrappedStrings = 2; |} |""".stripMargin convertCheck(source, Map("com/example/example.proto" -> expected))