Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the prevent enum conflict algorithm #225

Merged
merged 2 commits into from
Dec 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 45 additions & 58 deletions modules/proto/src/smithyproto/proto3/ModelPreProcessor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -113,75 +113,62 @@ object ModelPreProcessor {
def getName(): String = "prevent-enum-conflicts"
def transform(x: TransformContext): Model = {
val currentModel = x.getModel
val enumsShapes: List[EnumShape] =
currentModel.getEnumShapes.asScala.toList
.filter(s => !Prelude.isPreludeShape(s))
val enumsShapes: List[EnumShape] = currentModel
.getEnumShapes()
.asScala
.filterNot(Prelude.isPreludeShape)
.toList

val intEnums: List[IntEnumShape] =
currentModel.getIntEnumShapes.asScala.toList.filter(s =>
!Prelude.isPreludeShape(s)
)
val intEnums: List[IntEnumShape] = currentModel
.getIntEnumShapes()
.asScala
.filterNot(Prelude.isPreludeShape)
.toList

val allEnums: List[Shape] = enumsShapes ++ intEnums

val enumLabelsPerNamespace = allEnums
.groupMapReduce(_.getId.getNamespace)(es =>
es.getMemberNames.asScala.toList
)(_ ++ _)

val enumHasConflictMap: Map[String, Boolean] = {
allEnums.flatMap { es =>
val id = es.getId
val enums = es.getMemberNames.asScala.toList

val allEnumValues =
enumLabelsPerNamespace.getOrElse(id.getNamespace, List.empty)
val allCombos = for {
e <- allEnums
memberName <- e.getMemberNames().asScala.toList
} yield (e.getId().getNamespace(), memberName)

enums.map { enumName =>
enumName -> (allEnumValues.count(_ == enumName) > 1)
val allRepeatedCombos =
allCombos
.groupBy(identity)
.view
.mapValues(_.size)
.collect {
case (k, v) if v > 1 => k
}
}.toMap
.toSet

}
def enumHasConflict(enumValue: String): Boolean = {
enumHasConflictMap.getOrElse(enumValue, false)
}
def hasConflict(member: MemberShape): Boolean = allRepeatedCombos(
(member.getId().getNamespace(), member.getMemberName())
)

val newEnumShapes: List[Shape] = enumsShapes
.map { enumShape =>
{
val b = enumShape.toBuilder
b.clearMembers()
enumShape.getAllMembers.asScala.foreach {
case (memberName, member) =>
val newMember = if (enumHasConflict(memberName)) {
renameMember(member)
} else {
member
}
b.addMember(newMember)
}
b.build()
}
val newEnumShapes: List[Shape] = enumsShapes.map { enumShape =>
val b = enumShape.toBuilder
b.clearMembers()
enumShape.members.asScala.foreach {
case member if hasConflict(member) =>
b.addMember(renameMember(member))
case member =>
b.addMember(member)
}
b.build()
}

val newIntEnumShapes = intEnums
.map { enumShape =>
{
val b = enumShape.toBuilder
b.clearMembers()
enumShape.getAllMembers.asScala.foreach {
case (memberName, member) =>
val newMember = if (enumHasConflict(memberName)) {
renameMember(member)
} else {
member
}
b.addMember(newMember)
}
b.build()
}
val newIntEnumShapes = intEnums.map { intEnumShape =>
val b = intEnumShape.toBuilder
b.clearMembers()
intEnumShape.members.asScala.foreach {
case member if hasConflict(member) =>
b.addMember(renameMember(member))
case member =>
b.addMember(member)
}
b.build()
}

val allShapes = newEnumShapes ++ newIntEnumShapes

Expand Down
82 changes: 77 additions & 5 deletions modules/proto/tests/src/ModelPrePocessorSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,75 @@ class ModelPrePocessorSpec extends FunSuite {
}
}

test("apply PreventEnumConflicts - across namespace") {
val smithyTest =
"""|$version: "2"
|namespace test
|
|enum Enum1 {
| VCONFLICT
|}
|
|enum Enum2 {
| VCONFLICT
|}
|""".stripMargin

val other =
"""|$version: "2"
|namespace a.ns
|
|enum OtherEnum {
| VCONFLICT
|}
|""".stripMargin
val original = buildModel(smithyTest, other)
val transformed =
process(original, ModelPreProcessor.transformers.PreventEnumConflicts)
def getEnumNames(m: Model, shapeId: ShapeId): List[String] = {
m.getShape(shapeId)
.toScala
.toList
.collect {
case shape: EnumShape =>
shape.getMemberNames.asScala.toList
case shape: IntEnumShape =>
shape.getMemberNames.asScala.toList
}
.flatten
}

assertEquals(
getEnumNames(original, ShapeId.from("test#Enum1")),
List("VCONFLICT")
)

assertEquals(
getEnumNames(transformed, ShapeId.from("test#Enum1")),
List("ENUM1_VCONFLICT")
)

assertEquals(
getEnumNames(original, ShapeId.from("test#Enum2")),
List("VCONFLICT")
)

assertEquals(
getEnumNames(transformed, ShapeId.from("test#Enum2")),
List("ENUM2_VCONFLICT")
)

assertEquals(
getEnumNames(original, ShapeId.from(s"a.ns#OtherEnum")),
List("VCONFLICT")
)

assertEquals(
getEnumNames(transformed, ShapeId.from(s"a.ns#OtherEnum")),
List("VCONFLICT")
)
}

private def checkTransformer(src: String, t: ProjectionTransformer)(
check: (Model, Model) => Unit
): Unit = {
Expand All @@ -328,13 +397,16 @@ class ModelPrePocessorSpec extends FunSuite {
check(original, transformed)
}

private def buildModel(src: String): Model = {
Model
private def buildModel(srcs: String*): Model = {
val assembler = Model
.assembler()
.discoverModels()
.addUnparsedModel("inlined-in-test.smithy", src)
.assemble()
.unwrap()

srcs.zipWithIndex.foreach { case (s, i) =>
assembler.addUnparsedModel(s"inlined-in-test.$i.smithy", s)
}

assembler.assemble().unwrap()
}

private def process(m: Model, t: ProjectionTransformer): Model = {
Expand Down
Loading