Skip to content

Commit

Permalink
use postprocess to handle unit responses
Browse files Browse the repository at this point in the history
  • Loading branch information
lewisjkl committed Apr 11, 2024
1 parent 5428355 commit 523bfce
Show file tree
Hide file tree
Showing 7 changed files with 233 additions and 93 deletions.
1 change: 1 addition & 0 deletions build.sc
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ trait OpenApiModule
object tests extends this.ScalaTests with BaseMunitTests {
def ivyDeps = super.ivyDeps() ++ Agg(
buildDeps.smithy.build,
buildDeps.smithy.diff,
buildDeps.scalaJavaCompat
)
}
Expand Down
1 change: 1 addition & 0 deletions buildDeps.sc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ object smithy {
val smithyVersion = "1.41.1"
val model = ivy"software.amazon.smithy:smithy-model:$smithyVersion"
val build = ivy"software.amazon.smithy:smithy-build:$smithyVersion"
val diff = ivy"software.amazon.smithy:smithy-diff:$smithyVersion"
}
object cats {
val mtl = ivy"org.typelevel::cats-mtl:1.4.0"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ private[compiler] object IModelPostProcessor {
RequirementShiftTransformer,
ContentTypeShiftTransformer,
ReorientDefaultValueTransformer,
DropRequiredWhenDefaultValue
DropRequiredWhenDefaultValue,
EmptyStructureToUnitTransformer
)

private[this] def transform(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/* Copyright 2022 Disney Streaming
*
* Licensed under the Tomorrow Open Source Technology License, Version 1.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://disneystreaming.github.io/TOST-1.0.txt
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package smithytranslate.compiler.internals
package postprocess

import org.typelevel.ci._

// Removes empty structures and changes shapes which target
// them to instead target Unit
private[compiler] object EmptyStructureToUnitTransformer
extends IModelPostProcessor {

def apply(model: IModel): IModel = {
val defs = model.definitions.map(d => d.id -> d).toMap

val structuresToRemove = model.definitions.flatMap {
case s: Structure if isEmptyStructure(defs, s) =>
Some(s.id)
case _ => None
}.toSet
val amendedDefs = model.definitions.flatMap {
case s: Structure if structuresToRemove(s.id) => None
case op: OperationDef =>
val changeInput: OperationDef => OperationDef = o =>
if (o.input.exists(structuresToRemove)) o.copy(input = Some(unit))
else o
val changeOutput: OperationDef => OperationDef = o =>
if (o.output.exists(structuresToRemove)) o.copy(output = Some(unit))
else o
Some(changeInput.andThen(changeOutput)(op))
case other => Some(other)
}
IModel(amendedDefs, model.suppressions)
}

private val unit =
DefId(
Namespace(List("smithy", "api")),
Name(Segment.StandardLib(ci"Unit"))
)

// consider empty if has no fields OR if has one field with Body hint (httpPayload)
private def isEmptyStructure(
defs: Map[DefId, Definition],
d: Definition
): Boolean =
d match {
case s: Structure =>
(s.localFields.isEmpty && s.parents.isEmpty && s.hints.isEmpty) || {
def isHttpPayload =
s.localFields.length == 1 && s.localFields.head.hints
.contains(Hint.Body)
def isHttpPayloadEmpty = defs
.get(s.localFields.head.tpe)
.exists(isEmptyStructure(defs, _))
isHttpPayload && isHttpPayloadEmpty
}
case _ => false
}
}
12 changes: 1 addition & 11 deletions modules/openapi/src/internals/OpenApiToIModel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -252,15 +252,6 @@ private[openapi] class OpenApiToIModel[F[_]: Parallel: TellShape: TellError](
val code = output.code
code >= 200 && code < 300
}.toList match {
case output :: Nil if output.code == 204 =>
F.pure(
Some(
(204 -> DefId(
Namespace(List("smithy", "api")),
Name.stdLib("Unit")
))
)
)
case output :: Nil =>
val code = output.code
recordRefOrMessage(output.refOrMessage, None)
Expand Down Expand Up @@ -340,8 +331,7 @@ private[openapi] class OpenApiToIModel[F[_]: Parallel: TellShape: TellError](
}
}
.map { fields =>
if (fields.isEmpty) None
else Structure(defId, fields, Vector.empty, message.hints).some
Structure(defId, fields, Vector.empty, message.hints).some
}
.flatMap(_.traverse(recordDef).map(_.as(defId)))
}
Expand Down
155 changes: 155 additions & 0 deletions modules/openapi/tests/src/ModelWrapper.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
/* Copyright 2022 Disney Streaming
*
* Licensed under the Tomorrow Open Source Technology License, Version 1.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://disneystreaming.github.io/TOST-1.0.txt
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package smithytranslate.compiler.openapi

import software.amazon.smithy.model._
import software.amazon.smithy.model.node._
import scala.jdk.CollectionConverters._
import software.amazon.smithy.build.transforms.FilterSuppressions
import software.amazon.smithy.build.TransformContext
import software.amazon.smithy.model.shapes.SmithyIdlModelSerializer
import software.amazon.smithy.model.transform.ModelTransformer
import software.amazon.smithy.model.traits.Trait
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.traits.BoxTrait
import software.amazon.smithy.diff.ModelDiff
import java.util.stream.Collectors

// In order to have nice comparisons from test reports.
class ModelWrapper(val model: Model) {

override def equals(obj: Any): Boolean = obj match {
case wrapper: ModelWrapper =>
val one = reorderMetadata(reorderFields(model))
val two = reorderMetadata(reorderFields(wrapper.model))
val diff = ModelDiff
.builder()
.oldModel(one)
.newModel(two)
.compare()
.getDifferences()
val added = diff.addedShapes().toList
val hasChanges =
diff
.changedShapes()
.toList
.asScala
.exists { changed =>
val addedTraits =
changed.addedTraits().toList.asScala
val removedTraits = changed
.removedTraits()
.toList
.asScala
val changedTraits = changed
.changedTraits()
.toList
.asScala
.filterNot { pair =>
// compare shapeId and node values to avoid issues with differing java classes
pair.getLeft.toShapeId == pair.getRight.toShapeId && pair.getLeft.toNode == pair.getRight.toNode
}
.filterNot { pair =>
// don't consider synthetic traits
pair.getLeft().toShapeId().getNamespace() == "smithy.synthetic"
}
addedTraits.nonEmpty || removedTraits.nonEmpty || changedTraits.nonEmpty
}
val removed =
diff.removedShapes().toList.asScala
added.isEmpty && !hasChanges && removed.isEmpty
case _ => false
}

private def reorderMetadata(model: Model): Model = {
implicit val nodeOrd: Ordering[Node] = (x: Node, y: Node) =>
x.hashCode() - y.hashCode()

implicit val nodeStringOrd: Ordering[StringNode] = {
val ord = Ordering[String]
(x: StringNode, y: StringNode) => ord.compare(x.getValue(), y.getValue())
}
def goNode(n: Node): Node = n match {
case array: ArrayNode =>
val elements = array.getElements().asScala.toList.sorted
Node.arrayNode(elements: _*)
case obj: ObjectNode =>
Node.objectNode(
obj.getMembers().asScala.toSeq.sortBy(_._1).toMap.asJava
)
case other => other
}
def go(metadata: Map[String, Node]): Map[String, Node] = {
val keys = metadata.keySet.toVector.sorted
keys.map { k =>
k -> goNode(metadata(k))
}.toMap
}

val builder = model.toBuilder()
val newMeta = go(model.getMetadata().asScala.toMap)
builder.clearMetadata()
builder.metadata(newMeta.asJava)
builder.build()
}

private val reorderFields: Model => Model = m => {
val structures = m.getStructureShapes().asScala.map { structShape =>
val sortedMembers =
structShape.members().asScala.toList.sortBy(_.getMemberName())
structShape.toBuilder().members(sortedMembers.asJava).build()
}
m.toBuilder().addShapes(structures.asJava).build()
}

private def update(model: Model): Model = {
val filterSuppressions: Model => Model = m =>
new FilterSuppressions().transform(
TransformContext
.builder()
.model(m)
.settings(
ObjectNode.builder().withMember("removeUnused", true).build()
)
.build()
)
(filterSuppressions andThen reorderFields)(model)
}

override def toString() =
SmithyIdlModelSerializer
.builder()
.build()
.serialize(update(model))
.asScala
.map(in => s"${in._1.toString.toUpperCase}:\n\n${in._2}")
.mkString("\n")
}

object ModelWrapper {
def apply(model: Model): ModelWrapper = {
// Remove all box traits because they are applied inconsistently depending on if you
// load from Java model or from unparsed string model
@annotation.nowarn("msg=class BoxTrait in package traits is deprecated")
val noBoxModel = ModelTransformer
.create()
.filterTraits(
model,
((_: Shape, trt: Trait) => trt.toShapeId() != BoxTrait.ID)
)
new ModelWrapper(noBoxModel)
}
}
81 changes: 0 additions & 81 deletions modules/openapi/tests/src/TestUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -191,85 +191,4 @@ object TestUtils {
)
)
}

// In order to have nice comparisons from munit reports.
class ModelWrapper(val model: Model) {
override def equals(obj: Any): Boolean = obj match {
case wrapper: ModelWrapper =>
reorderMetadata(model) == reorderMetadata(wrapper.model)
case _ => false
}

private def reorderMetadata(model: Model): Model = {
implicit val nodeOrd: Ordering[Node] = (x: Node, y: Node) =>
x.hashCode() - y.hashCode()

implicit val nodeStringOrd: Ordering[StringNode] = {
val ord = Ordering[String]
(x: StringNode, y: StringNode) =>
ord.compare(x.getValue(), y.getValue())
}
def goNode(n: Node): Node = n match {
case array: ArrayNode =>
val elements = array.getElements().asScala.toList.sorted
Node.arrayNode(elements: _*)
case obj: ObjectNode =>
Node.objectNode(
obj.getMembers().asScala.toSeq.sortBy(_._1).toMap.asJava
)
case other => other
}
def go(metadata: Map[String, Node]): Map[String, Node] = {
val keys = metadata.keySet.toVector.sorted
keys.map { k =>
k -> goNode(metadata(k))
}.toMap
}

val builder = model.toBuilder()
val newMeta = go(model.getMetadata().asScala.toMap)
builder.clearMetadata()
builder.metadata(newMeta.asJava)
builder.build()
}

private def filter(model: Model): Model = {
val filterSuppressions: Model => Model = m =>
new FilterSuppressions().transform(
TransformContext
.builder()
.model(m)
.settings(
ObjectNode.builder().withMember("removeUnused", true).build()
)
.build()
)
(filterSuppressions)(model)
}

override def toString() = {
SmithyIdlModelSerializer
.builder()
.build()
.serialize(filter(model))
.asScala
.map(in => s"${in._1.toString.toUpperCase}:\n\n${in._2}")
.mkString("\n")
}
}

object ModelWrapper {
def apply(model: Model): ModelWrapper = {
// Remove all box traits because they are applied inconsistently depending on if you
// load from Java model or from unparsed string model
@annotation.nowarn("msg=class BoxTrait in package traits is deprecated")
val noBoxModel = ModelTransformer
.create()
.filterTraits(
model,
((_: Shape, trt: Trait) => trt.toShapeId() != BoxTrait.ID).asJava
)
new ModelWrapper(noBoxModel)
}
}
}

0 comments on commit 523bfce

Please sign in to comment.