Skip to content

Commit

Permalink
Merge pull request #225 from disneystreaming/dfrancoeur/fix-enum-conf…
Browse files Browse the repository at this point in the history
…licts

Fix the prevent enum conflict algorithm
  • Loading branch information
daddykotex authored Dec 20, 2023
2 parents c9f3d62 + 07cf64c commit 2fd0853
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 63 deletions.
103 changes: 45 additions & 58 deletions modules/proto/src/smithyproto/proto3/ModelPreProcessor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -113,75 +113,62 @@ object ModelPreProcessor {
def getName(): String = "prevent-enum-conflicts"
def transform(x: TransformContext): Model = {
val currentModel = x.getModel
val enumsShapes: List[EnumShape] =
currentModel.getEnumShapes.asScala.toList
.filter(s => !Prelude.isPreludeShape(s))
val enumsShapes: List[EnumShape] = currentModel
.getEnumShapes()
.asScala
.filterNot(Prelude.isPreludeShape)
.toList

val intEnums: List[IntEnumShape] =
currentModel.getIntEnumShapes.asScala.toList.filter(s =>
!Prelude.isPreludeShape(s)
)
val intEnums: List[IntEnumShape] = currentModel
.getIntEnumShapes()
.asScala
.filterNot(Prelude.isPreludeShape)
.toList

val allEnums: List[Shape] = enumsShapes ++ intEnums

val enumLabelsPerNamespace = allEnums
.groupMapReduce(_.getId.getNamespace)(es =>
es.getMemberNames.asScala.toList
)(_ ++ _)

val enumHasConflictMap: Map[String, Boolean] = {
allEnums.flatMap { es =>
val id = es.getId
val enums = es.getMemberNames.asScala.toList

val allEnumValues =
enumLabelsPerNamespace.getOrElse(id.getNamespace, List.empty)
val allCombos = for {
e <- allEnums
memberName <- e.getMemberNames().asScala.toList
} yield (e.getId().getNamespace(), memberName)

enums.map { enumName =>
enumName -> (allEnumValues.count(_ == enumName) > 1)
val allRepeatedCombos =
allCombos
.groupBy(identity)
.view
.mapValues(_.size)
.collect {
case (k, v) if v > 1 => k
}
}.toMap
.toSet

}
def enumHasConflict(enumValue: String): Boolean = {
enumHasConflictMap.getOrElse(enumValue, false)
}
def hasConflict(member: MemberShape): Boolean = allRepeatedCombos(
(member.getId().getNamespace(), member.getMemberName())
)

val newEnumShapes: List[Shape] = enumsShapes
.map { enumShape =>
{
val b = enumShape.toBuilder
b.clearMembers()
enumShape.getAllMembers.asScala.foreach {
case (memberName, member) =>
val newMember = if (enumHasConflict(memberName)) {
renameMember(member)
} else {
member
}
b.addMember(newMember)
}
b.build()
}
val newEnumShapes: List[Shape] = enumsShapes.map { enumShape =>
val b = enumShape.toBuilder
b.clearMembers()
enumShape.members.asScala.foreach {
case member if hasConflict(member) =>
b.addMember(renameMember(member))
case member =>
b.addMember(member)
}
b.build()
}

val newIntEnumShapes = intEnums
.map { enumShape =>
{
val b = enumShape.toBuilder
b.clearMembers()
enumShape.getAllMembers.asScala.foreach {
case (memberName, member) =>
val newMember = if (enumHasConflict(memberName)) {
renameMember(member)
} else {
member
}
b.addMember(newMember)
}
b.build()
}
val newIntEnumShapes = intEnums.map { intEnumShape =>
val b = intEnumShape.toBuilder
b.clearMembers()
intEnumShape.members.asScala.foreach {
case member if hasConflict(member) =>
b.addMember(renameMember(member))
case member =>
b.addMember(member)
}
b.build()
}

val allShapes = newEnumShapes ++ newIntEnumShapes

Expand Down
82 changes: 77 additions & 5 deletions modules/proto/tests/src/ModelPrePocessorSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,75 @@ class ModelPrePocessorSpec extends FunSuite {
}
}

test("apply PreventEnumConflicts - across namespace") {
val smithyTest =
"""|$version: "2"
|namespace test
|
|enum Enum1 {
| VCONFLICT
|}
|
|enum Enum2 {
| VCONFLICT
|}
|""".stripMargin

val other =
"""|$version: "2"
|namespace a.ns
|
|enum OtherEnum {
| VCONFLICT
|}
|""".stripMargin
val original = buildModel(smithyTest, other)
val transformed =
process(original, ModelPreProcessor.transformers.PreventEnumConflicts)
def getEnumNames(m: Model, shapeId: ShapeId): List[String] = {
m.getShape(shapeId)
.toScala
.toList
.collect {
case shape: EnumShape =>
shape.getMemberNames.asScala.toList
case shape: IntEnumShape =>
shape.getMemberNames.asScala.toList
}
.flatten
}

assertEquals(
getEnumNames(original, ShapeId.from("test#Enum1")),
List("VCONFLICT")
)

assertEquals(
getEnumNames(transformed, ShapeId.from("test#Enum1")),
List("ENUM1_VCONFLICT")
)

assertEquals(
getEnumNames(original, ShapeId.from("test#Enum2")),
List("VCONFLICT")
)

assertEquals(
getEnumNames(transformed, ShapeId.from("test#Enum2")),
List("ENUM2_VCONFLICT")
)

assertEquals(
getEnumNames(original, ShapeId.from(s"a.ns#OtherEnum")),
List("VCONFLICT")
)

assertEquals(
getEnumNames(transformed, ShapeId.from(s"a.ns#OtherEnum")),
List("VCONFLICT")
)
}

private def checkTransformer(src: String, t: ProjectionTransformer)(
check: (Model, Model) => Unit
): Unit = {
Expand All @@ -328,13 +397,16 @@ class ModelPrePocessorSpec extends FunSuite {
check(original, transformed)
}

private def buildModel(src: String): Model = {
Model
private def buildModel(srcs: String*): Model = {
val assembler = Model
.assembler()
.discoverModels()
.addUnparsedModel("inlined-in-test.smithy", src)
.assemble()
.unwrap()

srcs.zipWithIndex.foreach { case (s, i) =>
assembler.addUnparsedModel(s"inlined-in-test.$i.smithy", s)
}

assembler.assemble().unwrap()
}

private def process(m: Model, t: ProjectionTransformer): Model = {
Expand Down

0 comments on commit 2fd0853

Please sign in to comment.