Skip to content

Commit

Permalink
Ensure lists/maps get wrapped when needed
Browse files Browse the repository at this point in the history
  • Loading branch information
Baccata committed Jan 31, 2024
1 parent 4dc4be9 commit 3c6279d
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 10 deletions.
45 changes: 35 additions & 10 deletions modules/proto/src/smithyproto/proto3/Compiler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class Compiler(model: Model, allShapes: Boolean) {

import ProtoIR._

private val allRelevantShapes: Set[Shape] = {
private lazy val allRelevantShapes: Set[Shape] = {
if (allShapes) {
model
.shapes()
Expand All @@ -60,7 +60,7 @@ class Compiler(model: Model, allShapes: Boolean) {
}
}

private val conflictingEnumValues: Set[MemberShape] = {
private lazy val conflictingEnumValues: Set[MemberShape] = {
val enumMembers =
allRelevantShapes.collect { case m: MemberShape => m }.filter { m =>
val container = model.expectShape(m.getContainer())
Expand Down Expand Up @@ -280,16 +280,31 @@ class Compiler(model: Model, allShapes: Boolean) {
}
}

private def shouldWrapCollection(shape: Shape): Boolean = {
val hasWrapped = hasProtoWrapped(shape)
val membersTargetingThis =
model.getMemberShapes().asScala.filter(_.getTarget() == shape.getId())
val isTargetedByWrappedMember =
membersTargetingThis.exists(hasProtoWrapped(_))
// oneofs cannot have lists / maps fields
val isTargetedByUnionMember =
membersTargetingThis.exists(member =>
model.expectShape(member.getContainer()).isUnionShape
)

hasWrapped || isTargetedByWrappedMember || isTargetedByUnionMember
}

override def listShape(shape: ListShape): TopLevelDefs = {
if (hasProtoWrapped(shape)) {
if (shouldWrapCollection(shape)) {
shape.getMember().accept(typeVisitor()).toList.flatMap { tpe =>
topLevelMessage(shape, Type.ListType(tpe))
}
} else Nil
}

override def mapShape(shape: MapShape): TopLevelDefs = {
if (hasProtoWrapped(shape)) {
if (shouldWrapCollection(shape)) {
for {
keyType <- shape.getKey().accept(typeVisitor()).toList
valueType <- shape.getValue().accept(typeVisitor()).toList
Expand Down Expand Up @@ -420,7 +435,14 @@ class Compiler(model: Model, allShapes: Boolean) {
// We assume the model is well-formed so the result should be non-null
val targetShape = model.expectShape(m.getTarget)
val numType = extractNumType(m)
val isWrapped = hasProtoWrapped(m) || hasProtoWrapped(targetShape)
val isWrapped = {
val memberHasWrapped = hasProtoWrapped(m)
val targetHasWrapped = hasProtoWrapped(targetShape)
// repeated / map fields cannot be in oneofs
val isList = targetShape.isListShape()
val isMap = targetShape.isMapShape()
memberHasWrapped || targetHasWrapped || isList || isMap
}
val fieldType =
targetShape
.accept(typeVisitor(isWrapped = isWrapped, numType))
Expand Down Expand Up @@ -545,14 +567,17 @@ class Compiler(model: Model, allShapes: Boolean) {
}

def listShape(shape: ListShape): Option[Type] = {
shape.getMember().accept(typeVisitor()).map(Type.ListType(_))
if (isWrapped) Some(Type.RefType(shape))
else shape.getMember().accept(typeVisitor()).map(Type.ListType(_))
}

def mapShape(shape: MapShape): Option[Type] = {
for {
key <- shape.getKey().accept(typeVisitor())
value <- shape.getValue().accept(typeVisitor())
} yield Type.MapType(key, value)
if (isWrapped) Some(Type.RefType(shape))
else
for {
key <- shape.getKey().accept(typeVisitor())
value <- shape.getValue().accept(typeVisitor())
} yield Type.MapType(key, value)
}

def memberShape(shape: MemberShape): Option[Type] = {
Expand Down
35 changes: 35 additions & 0 deletions modules/proto/tests/src/CompilerRendererSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,41 @@ class CompilerRendererSuite extends FunSuite {
|
|use alloy.proto#protoWrapped
|
|list StringList {
| member: String
|}
|
|map StringMap {
| key: String,
| value: String
|}
|
|union MyUnion {
| name: String
| id: Integer
| stringList: StringList
| stringMap: StringMap
|}
|""".stripMargin

val expected = """|syntax = "proto3";
|
|package com.example;
|
|message StringList {
| repeated string value = 1;
|}
|
|message StringMap {
| map<string, string> value = 1;
|}
|
|message MyUnion {
| oneof definition {
| string name = 1;
| int32 id = 2;
| com.example.StringList stringList = 3;
| com.example.StringMap stringMap = 4;
| }
|}
|""".stripMargin
Expand Down Expand Up @@ -632,14 +653,21 @@ class CompilerRendererSuite extends FunSuite {
|
|structure Foo {
| strings: StringMap
| @alloy.proto#protoWrapped
| wrappedStrings: StringMap
|}
|""".stripMargin
val expected = """|syntax = "proto3";
|
|package com.example;
|
|message StringMap {
| map<string, int32> value = 1;
|}
|
|message Foo {
| map<string, int32> strings = 1;
| com.example.StringMap wrappedStrings = 2;
|}
|""".stripMargin
convertCheck(source, Map("com/example/example.proto" -> expected))
Expand Down Expand Up @@ -715,6 +743,8 @@ class CompilerRendererSuite extends FunSuite {
|
|structure Foo {
| strings: List
| @alloy.proto#protoWrapped
| wrappedStrings: List
|}
|""".stripMargin

Expand All @@ -726,8 +756,13 @@ class CompilerRendererSuite extends FunSuite {
| string name = 1;
|}
|
|message List {
| repeated com.example.ListItem value = 1;
|}
|
|message Foo {
| repeated com.example.ListItem strings = 1;
| com.example.List wrappedStrings = 2;
|}
|""".stripMargin
convertCheck(source, Map("com/example/example.proto" -> expected))
Expand Down

0 comments on commit 3c6279d

Please sign in to comment.