Skip to content

Commit

Permalink
[NU-1836] Invoke extension methods statically (#7119)
Browse files Browse the repository at this point in the history
  • Loading branch information
lukasz-bigorajski authored Nov 13, 2024
1 parent 287f231 commit ad5b129
Show file tree
Hide file tree
Showing 15 changed files with 415 additions and 478 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,13 @@
import org.springframework.expression.spel.support.ReflectionHelper;
import org.springframework.expression.spel.support.ReflectiveMethodExecutor;
import org.springframework.util.ReflectionUtils;
import pl.touk.nussknacker.engine.extension.ExtensionsAwareMethodInvoker;

import java.lang.reflect.Method;
import java.lang.reflect.Modifier;

//this basically changed org.springframework.expression.spel.support.ReflectiveMethodExecutor
//we want to create TypeDescriptor using SpelEspReflectionHelper.convertArguments
//which should work faster
// As an additional feature we allow to invoke defined extension methods
public class NuReflectiveMethodExecutor extends ReflectiveMethodExecutor {

private final Method method;
Expand All @@ -29,10 +27,7 @@ public class NuReflectiveMethodExecutor extends ReflectiveMethodExecutor {

private boolean argumentConversionOccurred = false;

private final ExtensionsAwareMethodInvoker methodInvoker;

public NuReflectiveMethodExecutor(ReflectiveMethodExecutor original,
ExtensionsAwareMethodInvoker methodInvoker) {
public NuReflectiveMethodExecutor(ReflectiveMethodExecutor original) {
super(original.getMethod());
this.method = original.getMethod();
if (method.isVarArgs()) {
Expand All @@ -42,7 +37,6 @@ public NuReflectiveMethodExecutor(ReflectiveMethodExecutor original,
else {
this.varargsPosition = null;
}
this.methodInvoker = methodInvoker;
}

/**
Expand Down Expand Up @@ -98,8 +92,7 @@ public TypedValue execute(EvaluationContext context, Object target, Object... ar
arguments = ReflectionHelper.setupArgumentsForVarargsInvocation(this.method.getParameterTypes(), arguments);
}
ReflectionUtils.makeAccessible(this.method);
//Nussknacker: we use custom method invoker which is aware of extension methods
Object value = methodInvoker.invoke(this.method, target, arguments);
Object value = this.method.invoke(target, arguments);
return new TypedValue(value, new TypeDescriptor(new MethodParameter(this.method, -1)).narrow(value));
}
catch (Exception ex) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,37 +1,44 @@
package pl.touk.nussknacker.engine.extension

import pl.touk.nussknacker.engine.definition.clazz.{ClassDefinitionSet, MethodDefinition}
import pl.touk.nussknacker.engine.extension.ExtensionMethod.{NoArg, SingleArg}
import pl.touk.nussknacker.engine.spel.internal.ConversionHandler

import java.util
import java.util.{List => JList}

class ArrayExt(target: Any) extends util.AbstractList[Object] {
private val asList = ConversionHandler.convertArrayToList(target)

override def get(index: Int): AnyRef = asList.get(index)
override def size(): Int = asList.size()
override def lastIndexOf(o: Any): Int = super.lastIndexOf(o)
override def contains(o: Any): Boolean = super.contains(o)
override def indexOf(o: Any): Int = super.indexOf(o)
override def containsAll(c: util.Collection[_]): Boolean = super.containsAll(c)
override def isEmpty: Boolean = super.isEmpty
def empty: Boolean = super.isEmpty

class ArrayWrapper(target: Any) extends util.AbstractList[Object] {
private val asList = ConversionHandler.convertArrayToList(target)
override def get(index: Int): AnyRef = asList.get(index)
override def size(): Int = asList.size()
}

object ArrayExt extends ExtensionMethodsHandler {

override type ExtensionMethodInvocationTarget = ArrayExt
override val invocationTargetClass: Class[ArrayExt] = classOf[ArrayExt]

override def createConverter(
object ArrayExt extends ExtensionMethodsDefinition {

private val methodRegistry: Map[String, ExtensionMethod[_]] = Map(
"get" -> SingleArg((target, arg: Int) => new ArrayWrapper(target).get(arg)),
"size" -> NoArg(target => new ArrayWrapper(target).size()),
"lastIndexOf" -> SingleArg((target, arg: Any) => new ArrayWrapper(target).lastIndexOf(arg)),
"contains" -> SingleArg((target, arg: Any) => new ArrayWrapper(target).contains(arg)),
"indexOf" -> SingleArg((target, arg: Any) => new ArrayWrapper(target).indexOf(arg)),
"containsAll" -> SingleArg((target, arg: util.Collection[_]) => new ArrayWrapper(target).containsAll(arg)),
"isEmpty" -> NoArg(target => new ArrayWrapper(target).isEmpty),
"empty" -> NoArg(target => new ArrayWrapper(target).isEmpty),
)

override def findMethod(
clazz: Class[_],
methodName: String,
argsSize: Int,
set: ClassDefinitionSet
): ToExtensionMethodInvocationTargetConverter[ArrayExt] =
(target: Any) => new ArrayExt(target)
): Option[ExtensionMethod[_]] =
if (appliesToClassInRuntime(clazz))
methodRegistry.findMethod(methodName, argsSize)
else
None

override def extractDefinitions(clazz: Class[_], set: ClassDefinitionSet): Map[String, List[MethodDefinition]] =
if (clazz.isArray) {
if (appliesToClassInRuntime(clazz)) {
set
.get(classOf[JList[_]])
.map(_.methods)
Expand All @@ -40,5 +47,6 @@ object ArrayExt extends ExtensionMethodsHandler {
Map.empty
}

override def appliesToClassInRuntime(clazz: Class[_]): Boolean = clazz.isArray
private def appliesToClassInRuntime(clazz: Class[_]): Boolean = clazz.isArray

}
Original file line number Diff line number Diff line change
Expand Up @@ -7,36 +7,46 @@ import pl.touk.nussknacker.engine.api.typed.typing
import pl.touk.nussknacker.engine.api.typed.typing.{Typed, TypedObjectWithValue, TypingResult, Unknown}
import pl.touk.nussknacker.engine.definition.clazz.{ClassDefinitionSet, FunctionalMethodDefinition, MethodDefinition}
import pl.touk.nussknacker.engine.extension.CastOrConversionExt.getConversion
import pl.touk.nussknacker.engine.extension.ExtensionMethod.SingleArg
import pl.touk.nussknacker.engine.util.Implicits.RichScalaMap
import pl.touk.nussknacker.engine.util.classes.Extensions.{ClassExtensions, ClassesExtensions}

import java.lang.{Boolean => JBoolean}
import scala.util.Try

// todo: lbg - add casting methods to UTIL
class CastOrConversionExt(target: Any, classesBySimpleName: Map[String, Class[_]]) {
class CastOrConversionExt(classesBySimpleName: Map[String, Class[_]]) {
private val castException = new ClassCastException(s"Cannot cast value to given class")

def is(className: String): Boolean =
private val methodRegistry: Map[String, ExtensionMethod[_]] = Map(
"is" -> SingleArg(is),
"to" -> SingleArg(to),
"toOrNull" -> SingleArg(toOrNull),
)

private def is(target: Any, className: String): Boolean =
getClass(className).exists(clazz => clazz.isAssignableFrom(target.getClass)) ||
getConversion(className).exists(_.canConvert(target))

def to(className: String): Any =
orElse(tryCast(className), tryConvert(className))
.getOrElse(throw new IllegalStateException(s"Cannot cast or convert value: $target to: '$className'"))
private def to(target: Any, className: String): Any =
orElse(tryCast(target, className), tryConvert(target, className)) match {
case Right(value) => value
case Left(ex) => throw new IllegalStateException(s"Cannot cast or convert value: $target to: '$className'", ex)
}

def toOrNull(className: String): Any =
orElse(tryCast(className), tryConvert(className))
private def toOrNull(target: Any, className: String): Any =
orElse(tryCast(target, className), tryConvert(target, className))
.getOrElse(null)

private def tryCast(className: String): Either[Throwable, Any] = getClass(className) match {
case Some(clazz) => Try(clazz.cast(target)).toEither
case None => Left(new ClassCastException(s"Cannot cast: [$target] to: [$className]."))
private def tryCast(target: Any, className: String): Either[Throwable, Any] = getClass(className) match {
case Some(clazz) if clazz.isInstance(target) => Try(clazz.cast(target)).toEither
case _ => Left(castException)
}

private def getClass(className: String): Option[Class[_]] =
classesBySimpleName.get(className.toLowerCase())

private def tryConvert(className: String): Either[Throwable, Any] =
private def tryConvert(target: Any, className: String): Either[Throwable, Any] =
getConversion(className)
.flatMap(_.convertEither(target))

Expand All @@ -49,46 +59,49 @@ class CastOrConversionExt(target: Any, classesBySimpleName: Map[String, Class[_]

}

object CastOrConversionExt extends ExtensionMethodsHandler {
private val isMethodName = "is"
private val toMethodName = "to"
private val toOrNullMethodName = "toOrNull"
private val castOrConversionMethods = Set(isMethodName, toMethodName, toOrNullMethodName)
private val stringClass = classOf[String]

private val conversionsRegistry: List[Conversion] = List(
ToLongConversionExt,
ToDoubleConversionExt,
ToBigDecimalConversionExt,
ToBooleanConversionExt,
object CastOrConversionExt extends ExtensionMethodsDefinition {
private[extension] val isMethodName = "is"
private[extension] val toMethodName = "to"
private[extension] val toOrNullMethodName = "toOrNull"
private val castOrConversionMethods = Set(isMethodName, toMethodName, toOrNullMethodName)
private val stringClass = classOf[String]

private val conversionsRegistry: List[Conversion[_ >: Null <: AnyRef]] = List(
ToLongConversion,
ToDoubleConversion,
ToBigDecimalConversion,
ToBooleanConversion,
ToStringConversion,
ToMapConversionExt,
ToListConversionExt,
ToMapConversion,
ToListConversion,
ToByteConversion,
ToShortConversion,
ToIntegerConversion,
ToFloatConversion,
ToBigIntegerConversion,
)

private val conversionsByType: Map[String, Conversion] = conversionsRegistry
private val conversionsByType: Map[String, Conversion[_ >: Null <: AnyRef]] = conversionsRegistry
.flatMap(c => c.resultTypeClass.classByNameAndSimpleNameLowerCase().map(n => n._1 -> c))
.toMap

override type ExtensionMethodInvocationTarget = CastOrConversionExt
override val invocationTargetClass: Class[CastOrConversionExt] = classOf[CastOrConversionExt]

def isCastOrConversionMethod(methodName: String): Boolean =
castOrConversionMethods.contains(methodName)

def allowedConversions(clazz: Class[_]): List[Conversion] = conversionsRegistry.filter(_.appliesToConversion(clazz))
def allowedConversions(clazz: Class[_]): List[Conversion[_]] =
conversionsRegistry.filter(_.appliesToConversion(clazz))

override def createConverter(
// Convert methods should visible in runtime for every class because we allow invoke convert methods on an unknown
// object in Typer, but in the runtime the same type could be known and that's why should add convert method to an
// every class.
override def findMethod(
clazz: Class[_],
methodName: String,
argsSize: Int,
set: ClassDefinitionSet
): ToExtensionMethodInvocationTargetConverter[CastOrConversionExt] = {
val classesBySimpleName = set.classDefinitionsMap.keySet.classesByNamesAndSimpleNamesLowerCase()
(target: Any) => new CastOrConversionExt(target, classesBySimpleName)
}
): Option[ExtensionMethod[_]] =
new CastOrConversionExt(set.classDefinitionsMap.keySet.classesByNamesAndSimpleNamesLowerCase()).methodRegistry
.findMethod(methodName, argsSize)

override def extractDefinitions(clazz: Class[_], set: ClassDefinitionSet): Map[String, List[MethodDefinition]] = {
val castAllowedClasses = clazz.findAllowedClassesForCastParameter(set).mapValuesNow(_.clazzName)
Expand All @@ -100,12 +113,7 @@ object CastOrConversionExt extends ExtensionMethodsHandler {
}
}

// Convert methods should visible in runtime for every class because we allow invoke convert methods on an unknown
// object in Typer, but in the runtime the same type could be known and that's why should add convert method for an
// every class.
override def appliesToClassInRuntime(clazz: Class[_]): Boolean = true

private def getConversion(className: String): Either[Throwable, Conversion] =
private def getConversion(className: String): Either[Throwable, Conversion[_]] =
conversionsByType.get(className.toLowerCase) match {
case Some(conversion) => Right(conversion)
case None => Left(new IllegalArgumentException(s"Conversion for class $className not found"))
Expand Down
Loading

0 comments on commit ad5b129

Please sign in to comment.