Skip to content

Commit

Permalink
Avoid collisions in generated symbols (#48)
Browse files Browse the repository at this point in the history
If no names collide, this does nothing.

Otherwise each generated symbol is prefixed with its index.

Co-authored-by: Jesse Wilson <[email protected]>
  • Loading branch information
swankjesse and squarejesse authored Oct 28, 2024
1 parent de1a713 commit e1c2909
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 51 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,50 @@ class BurstKotlinPluginTest {
)
}

@Test
fun burstValuesWithNameCollisions() {
val result = compile(
sourceFile = SourceFile.kotlin(
"CoffeeTest.kt",
"""
import app.cash.burst.Burst
import app.cash.burst.burstValues
import kotlin.test.Test
@Burst
class CoffeeTest {
@Test
fun test(
content: Any? = burstValues(
3, // No name is generated for the first value.
"1",
1,
1L,
"CASE_INSENSITIVE_ORDER",
String.CASE_INSENSITIVE_ORDER,
true,
"true"
)
) {
}
}
""",
),
)
assertEquals(KotlinCompilation.ExitCode.OK, result.exitCode, result.messages)

val baseClass = result.classLoader.loadClass("CoffeeTest")
assertThat(baseClass.testSuffixes).containsExactlyInAnyOrder(
"1_1",
"2_1",
"3_1",
"4_CASE_INSENSITIVE_ORDER",
"5_CASE_INSENSITIVE_ORDER",
"6_true",
"7_true",
)
}

private val Class<*>.testSuffixes: List<String>
get() = methods.mapNotNull {
when {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,6 @@ private class BurstValuesArgument(
override fun expression() = value.deepCopyWithSymbols(declarationParent)
}

/** Returns a name like `orderCoffee_Decaf_Oat` with each argument value inline. */
internal fun name(
prefix: String,
arguments: List<Argument>,
): String = arguments.joinToString(prefix = prefix, separator = "_", transform = Argument::name)

/**
* Returns all arguments for [parameter].
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,15 +88,8 @@ internal class ClassSpecializer(
val valueParameters = onlyConstructor.valueParameters
if (valueParameters.isEmpty()) return // Nothing to do.

val parameterArguments = valueParameters.map { parameter ->
pluginContext.allPossibleArguments(parameter, burstApis)
}

val cartesianProduct = parameterArguments.cartesianProduct()

val indexOfDefaultSpecialization = cartesianProduct.indexOfFirst { arguments ->
arguments.all { it.isDefault }
}
val specializations = specializations(pluginContext, burstApis, valueParameters)
val indexOfDefaultSpecialization = specializations.indexOfFirst { it.isDefault }

// Make sure the constructor we're using is accessible. Drop the default arguments to prevent
// JUnit from using it.
Expand All @@ -111,65 +104,65 @@ internal class ClassSpecializer(
// Add a no-args constructor that calls the only constructor as the default specialization.
createNoArgsConstructor(
superConstructor = onlyConstructor,
arguments = cartesianProduct[indexOfDefaultSpecialization],
specialization = specializations[indexOfDefaultSpecialization],
)
} else {
// There's no default specialization. Make the class abstract so JUnit skips it.
original.modality = Modality.ABSTRACT
}

// Add a subclass for each specialization.
cartesianProduct.mapIndexed { index, arguments ->
for ((index, specialization) in specializations.withIndex()) {
// Don't generate code for the default specialization; we only want to run it once.
if (index == indexOfDefaultSpecialization) return@mapIndexed
if (index == indexOfDefaultSpecialization) continue

createSpecialization(
createSubclass(
superConstructor = onlyConstructor,
arguments = arguments,
specialization = specialization,
)
}
}

private fun createSpecialization(
private fun createSubclass(
superConstructor: IrConstructor,
arguments: List<Argument>,
specialization: Specialization,
) {
val specialization = original.factory.buildClass {
val created = original.factory.buildClass {
initDefaults(original)
visibility = PUBLIC
name = Name.identifier(name("${original.name.identifier}_", arguments))
name = Name.identifier("${original.name.identifier}_${specialization.name}")
}.apply {
superTypes = listOf(original.defaultType)
createImplicitParameterDeclarationWithWrappedDescriptor()
}

specialization.addConstructor {
created.addConstructor {
initDefaults(original)
}.apply {
irConstructorBody(pluginContext) { statements ->
statements += irDelegatingConstructorCall(
context = pluginContext,
symbol = superConstructor.symbol,
valueArgumentsCount = arguments.size,
valueArgumentsCount = specialization.arguments.size,
) {
for ((index, argument) in arguments.withIndex()) {
for ((index, argument) in specialization.arguments.withIndex()) {
putValueArgument(index, argument.expression())
}
}
statements += irInstanceInitializerCall(
context = pluginContext,
classSymbol = specialization.symbol,
classSymbol = created.symbol,
)
}
}

originalParent.addDeclaration(specialization)
specialization.addFakeOverrides(irTypeSystemContext)
originalParent.addDeclaration(created)
created.addFakeOverrides(irTypeSystemContext)
}

private fun createNoArgsConstructor(
superConstructor: IrConstructor,
arguments: List<Argument>,
specialization: Specialization,
) {
original.addConstructor {
initDefaults(original)
Expand All @@ -179,9 +172,9 @@ internal class ClassSpecializer(
statements += irDelegatingConstructorCall(
context = pluginContext,
symbol = superConstructor.symbol,
valueArgumentsCount = arguments.size,
valueArgumentsCount = specialization.arguments.size,
) {
for ((index, argument) in arguments.withIndex()) {
for ((index, argument) in specialization.arguments.withIndex()) {
putValueArgument(index, argument.expression())
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,20 +70,13 @@ internal class FunctionSpecializer(
val originalDispatchReceiver = original.dispatchReceiverParameter
?: throw BurstCompilationException("Unexpected dispatch receiver", original)

val parameterArguments = valueParameters.map { parameter ->
pluginContext.allPossibleArguments(parameter, burstApis)
}

val cartesianProduct = parameterArguments.cartesianProduct()

val indexOfDefaultSpecialization = cartesianProduct.indexOfFirst { arguments ->
arguments.all { it.isDefault }
}
val specializations = specializations(pluginContext, burstApis, valueParameters)
val indexOfDefaultSpecialization = specializations.indexOfFirst { it.isDefault }

val specializations = cartesianProduct.mapIndexed { index, arguments ->
createSpecialization(
val functions = specializations.mapIndexed { index, specialization ->
createFunction(
originalDispatchReceiver = originalDispatchReceiver,
arguments = arguments,
specialization = specialization,
isDefaultSpecialization = index == indexOfDefaultSpecialization,
)
}
Expand All @@ -94,21 +87,21 @@ internal class FunctionSpecializer(
}

// Add new declarations.
for (specialization in specializations) {
originalParent.addDeclaration(specialization)
for (function in functions) {
originalParent.addDeclaration(function)
}
}

private fun createSpecialization(
private fun createFunction(
originalDispatchReceiver: IrValueParameter,
arguments: List<Argument>,
specialization: Specialization,
isDefaultSpecialization: Boolean,
): IrSimpleFunction {
val result = original.factory.buildFun {
initDefaults(original)
name = when {
isDefaultSpecialization -> original.name
else -> Name.identifier(name("${original.name.identifier}_", arguments))
else -> Name.identifier("${original.name.identifier}_${specialization.name}")
}
returnType = original.returnType
}.apply {
Expand Down Expand Up @@ -136,7 +129,7 @@ internal class FunctionSpecializer(
callee = original.symbol,
).apply {
this.dispatchReceiver = irGet(receiverLocal)
for ((index, argument) in arguments.withIndex()) {
for ((index, argument) in specialization.arguments.withIndex()) {
putValueArgument(index, argument.expression())
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/*
* Copyright (C) 2024 Cash App
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package app.cash.burst.kotlin

import org.jetbrains.kotlin.backend.common.extensions.IrPluginContext
import org.jetbrains.kotlin.ir.declarations.IrValueParameter
import org.jetbrains.kotlin.ir.symbols.UnsafeDuringIrConstructionAPI

internal class Specialization(
/** The argument values for this specialization. */
val arguments: List<Argument>,

/** A string like `Decaf_Oat` with each argument value named. */
val name: String,
) {
val isDefault: Boolean = arguments.all { it.isDefault }
}

@UnsafeDuringIrConstructionAPI
internal fun specializations(
pluginContext: IrPluginContext,
burstApis: BurstApis,
parameters: List<IrValueParameter>,
): List<Specialization> {
val parameterArguments = parameters.map { parameter ->
pluginContext.allPossibleArguments(parameter, burstApis)
}

val specializations = parameterArguments.cartesianProduct().map { arguments ->
Specialization(
arguments = arguments,
name = arguments.joinToString(separator = "_", transform = Argument::name),
)
}

// If all elements already have distinct names, we're done.
if (specializations.distinctBy { it.name }.size == specializations.size) {
return specializations
}

// Otherwise, prefix each with its index.
return specializations.mapIndexed { index, specialization ->
Specialization(specialization.arguments, "${index}_${specialization.name}")
}
}

0 comments on commit e1c2909

Please sign in to comment.