From b36fd1c01bb3a2489f43e7276a6e18ef728085db Mon Sep 17 00:00:00 2001 From: Pavel Marek Date: Tue, 15 Oct 2024 21:26:08 +0200 Subject: [PATCH] Migrate some passes to Mini passes (#11191) Gets ready for avoiding IR traversal by introducing _mini passes_ as proposed by #10981: - creates [MiniPassFactory](https://github.com/enso-org/enso/pull/11191/commits/762045a35745ccb08421ab08135882bac3479de3) (that extends common `IRProcessingPass`) to transform an `IR` element to another `IR` element - modifies `PassManager` to recognize such _mini passes_ and treat them in a special way - by using `MiniIRPass.compile` - `MiniIRPass.compile` is using `IR.mapExpressions` to traverse the `IR` - alternative approach [withNewChildren](https://github.com/enso-org/enso/pull/11191/commits/1abc70d33c47ca1700157de3649041dbac01067d) rejected for now, see _future work_ for details - unlike _mega passes_ `IRMiniPass.compile` **does not recursively** traverse, but with 0964711ba944fccee20e7f2f0d5a3e8c07127355 it invokes each _mini pass_ at constant stack depth - way better for profiling - `MiniIRPass.prepare` _works on edges_ since ffd27dfe9b6fbbd59cccb35633cb802be097815c - there is `IRMiniPass prepare(parent, child)` to collect information while pre-order traversing from a particular `IR` parent to a particular `IR` child - `PassManager` rewritten to group _subsequent mini passes_ together by `MiniIRPass.combine` and really and traverse the `IR` just once - done in 2736a76 - converted to _mini pass_: `LambdaShorthandToLambda`, `OperatorToFunction`, `SectionsToBinOp` and `TailCall` - tested for 1:1 compatibility by [converting original code to test code](https://github.com/enso-org/enso/pull/11191/commits/f54ba6d1623b7742048360baa12b7d5e776022cc) and _comparing `IR` produced by old and new_ implementations --- .../org/enso/compiler/dump/IRDumperPass.java | 21 +- .../enso/compiler/pass/ChainedMiniPass.java | 51 ++ .../enso/compiler/pass/IRProcessingPass.java | 16 + .../org/enso/compiler/pass/MiniIRPass.java | 136 ++++ .../enso/compiler/pass/MiniPassFactory.java | 31 + .../enso/compiler/pass/MiniPassTraverser.java | 89 +++ .../pass/analyse/PassPersistance.java | 4 +- .../analyse/PrivateConstructorAnalysis.java | 22 +- .../pass/analyse/PrivateModuleAnalysis.java | 22 +- .../pass/analyse/PrivateSymbolsAnalysis.java | 19 +- .../enso/compiler/pass/analyse/TailCall.java | 393 +++++++++++ .../pass/analyse/types/TypeInference.java | 21 +- .../pass/desugar/OperatorToFunctionMini.java | 44 ++ .../pass/desugar/SectionsToBinOp.java | 253 +++++++ .../main/scala/org/enso/compiler/Passes.scala | 10 +- .../scala/org/enso/compiler/pass/IRPass.scala | 9 +- .../org/enso/compiler/pass/PassManager.scala | 106 ++- .../compiler/pass/analyse/AliasAnalysis.scala | 7 +- .../analyse/CachePreferenceAnalysis.scala | 7 +- .../pass/analyse/DataflowAnalysis.scala | 7 +- .../compiler/pass/desugar/ComplexType.scala | 11 +- .../pass/desugar/FunctionBinding.scala | 9 +- .../pass/desugar/GenerateMethodBodies.scala | 7 +- .../desugar/LambdaShorthandToLambda.scala | 410 +---------- .../desugar/LambdaShorthandToLambdaMini.scala | 352 ++++++++++ .../pass/desugar/NestedPatternMatch.scala | 7 +- .../pass/desugar/OperatorToFunction.scala | 79 +-- .../pass/lint/ShadowedPatternFields.scala | 7 +- .../compiler/pass/lint/UnusedBindings.scala | 7 +- .../pass/optimise/LambdaConsolidate.scala | 9 +- .../optimise/UnreachableMatchBranches.scala | 7 +- .../pass/resolve/IgnoredBindings.scala | 7 +- .../pass/resolve/SuspendedArguments.scala | 7 +- .../compiler/pass/resolve/TypeFunctions.scala | 9 +- .../pass/resolve/TypeSignatures.scala | 7 +- .../org/enso/compiler/test/CompilerTests.java | 21 +- .../test/mini/passes/MiniPassTester.java | 63 ++ .../org/enso/compiler/test/MiniPassTest.scala | 134 ++++ .../org/enso/compiler/test/PassesTest.scala | 2 +- .../compiler/test/pass/PassManagerTest.scala | 8 +- .../test/pass/analyse/TailCallMegaPass.scala} | 61 +- .../test/pass/analyse/TailCallTest.scala | 653 +++++++++++------- .../LambdaShorthandToLambdaMegaPass.scala | 447 ++++++++++++ .../desugar/LambdaShorthandToLambdaTest.scala | 264 ++++--- .../pass/desugar/OperatorToFunctionTest.scala | 190 ++++- .../desugar/SectionsToBinOpMegaPass.scala} | 13 +- .../pass/desugar/SectionsToBinOpTest.scala | 133 ++-- .../interpreter/runtime/IrToTruffle.scala | 2 +- 48 files changed, 3088 insertions(+), 1106 deletions(-) create mode 100644 engine/runtime-compiler/src/main/java/org/enso/compiler/pass/ChainedMiniPass.java create mode 100644 engine/runtime-compiler/src/main/java/org/enso/compiler/pass/IRProcessingPass.java create mode 100644 engine/runtime-compiler/src/main/java/org/enso/compiler/pass/MiniIRPass.java create mode 100644 engine/runtime-compiler/src/main/java/org/enso/compiler/pass/MiniPassFactory.java create mode 100644 engine/runtime-compiler/src/main/java/org/enso/compiler/pass/MiniPassTraverser.java create mode 100644 engine/runtime-compiler/src/main/java/org/enso/compiler/pass/analyse/TailCall.java create mode 100644 engine/runtime-compiler/src/main/java/org/enso/compiler/pass/desugar/OperatorToFunctionMini.java create mode 100644 engine/runtime-compiler/src/main/java/org/enso/compiler/pass/desugar/SectionsToBinOp.java create mode 100644 engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/desugar/LambdaShorthandToLambdaMini.scala create mode 100644 engine/runtime-integration-tests/src/test/java/org/enso/compiler/test/mini/passes/MiniPassTester.java create mode 100644 engine/runtime-integration-tests/src/test/scala/org/enso/compiler/test/MiniPassTest.scala rename engine/{runtime-compiler/src/main/scala/org/enso/compiler/pass/analyse/TailCall.scala => runtime-integration-tests/src/test/scala/org/enso/compiler/test/pass/analyse/TailCallMegaPass.scala} (90%) create mode 100644 engine/runtime-integration-tests/src/test/scala/org/enso/compiler/test/pass/desugar/LambdaShorthandToLambdaMegaPass.scala rename engine/{runtime-compiler/src/main/scala/org/enso/compiler/pass/desugar/SectionsToBinOp.scala => runtime-integration-tests/src/test/scala/org/enso/compiler/test/pass/desugar/SectionsToBinOpMegaPass.scala} (96%) diff --git a/engine/runtime-compiler/src/main/java/org/enso/compiler/dump/IRDumperPass.java b/engine/runtime-compiler/src/main/java/org/enso/compiler/dump/IRDumperPass.java index f0e7772b2629..acd86fccbe80 100644 --- a/engine/runtime-compiler/src/main/java/org/enso/compiler/dump/IRDumperPass.java +++ b/engine/runtime-compiler/src/main/java/org/enso/compiler/dump/IRDumperPass.java @@ -3,39 +3,28 @@ import java.io.IOException; import java.nio.file.Files; import java.nio.file.Path; -import java.util.UUID; import org.enso.compiler.context.InlineContext; import org.enso.compiler.context.ModuleContext; import org.enso.compiler.core.IR; import org.enso.compiler.core.ir.Expression; import org.enso.compiler.core.ir.Module; import org.enso.compiler.pass.IRPass; +import org.enso.compiler.pass.IRProcessingPass; import scala.collection.immutable.Seq; /** A pass that just dumps IR to the local {@code ir-dumps} directory. See {@link IRDumper}. */ public class IRDumperPass implements IRPass { public static final IRDumperPass INSTANCE = new IRDumperPass(); - private UUID uuid; private IRDumperPass() {} @Override - public UUID key() { - return uuid; - } - - @Override - public void org$enso$compiler$pass$IRPass$_setter_$key_$eq(UUID v) { - this.uuid = v; - } - - @Override - public Seq precursorPasses() { + public Seq precursorPasses() { return nil(); } @Override - public Seq invalidatedPasses() { + public Seq invalidatedPasses() { return nil(); } @@ -68,8 +57,8 @@ public T updateMetadataInDuplicate(T sourceIr, T copyOfIr) { } @SuppressWarnings("unchecked") - private static scala.collection.immutable.List nil() { + private static scala.collection.immutable.List nil() { Object obj = scala.collection.immutable.Nil$.MODULE$; - return (scala.collection.immutable.List) obj; + return (scala.collection.immutable.List) obj; } } diff --git a/engine/runtime-compiler/src/main/java/org/enso/compiler/pass/ChainedMiniPass.java b/engine/runtime-compiler/src/main/java/org/enso/compiler/pass/ChainedMiniPass.java new file mode 100644 index 000000000000..1fceeecbcdd2 --- /dev/null +++ b/engine/runtime-compiler/src/main/java/org/enso/compiler/pass/ChainedMiniPass.java @@ -0,0 +1,51 @@ +package org.enso.compiler.pass; + +import java.util.Objects; +import org.enso.compiler.core.IR; +import org.enso.compiler.core.ir.Expression; + +/** Utility class for chaining mini passes together. */ +final class ChainedMiniPass extends MiniIRPass { + private final MiniIRPass firstPass; + private final MiniIRPass secondPass; + + private ChainedMiniPass(MiniIRPass firstPass, MiniIRPass secondPass) { + this.firstPass = firstPass; + this.secondPass = secondPass; + } + + static MiniIRPass chain(MiniIRPass firstPass, MiniIRPass secondPass) { + if (firstPass == null) { + return secondPass; + } + return new ChainedMiniPass(firstPass, secondPass); + } + + @Override + public MiniIRPass prepare(IR parent, Expression current) { + var first = firstPass.prepare(parent, current); + var second = secondPass.prepare(parent, current); + if (first == firstPass && second == secondPass) { + return this; + } else { + return new ChainedMiniPass(first, second); + } + } + + @Override + public Expression transformExpression(Expression ir) { + var fstIr = firstPass.transformExpression(ir); + var sndIr = secondPass.transformExpression(fstIr); + return sndIr; + } + + @Override + public boolean checkPostCondition(IR ir) { + return firstPass.checkPostCondition(ir) && secondPass.checkPostCondition(ir); + } + + @Override + public String toString() { + return Objects.toString(firstPass) + ":" + Objects.toString(secondPass); + } +} diff --git a/engine/runtime-compiler/src/main/java/org/enso/compiler/pass/IRProcessingPass.java b/engine/runtime-compiler/src/main/java/org/enso/compiler/pass/IRProcessingPass.java new file mode 100644 index 000000000000..2d38422d488f --- /dev/null +++ b/engine/runtime-compiler/src/main/java/org/enso/compiler/pass/IRProcessingPass.java @@ -0,0 +1,16 @@ +package org.enso.compiler.pass; + +import org.enso.compiler.core.ir.ProcessingPass; +import scala.collection.immutable.Seq; + +/** + * A generic {@link IR} processing pass. Currently with two subclasses: classical {@link IRPass mega + * IR processing pass} and {@link MiniPassFactory}. + */ +public interface IRProcessingPass extends ProcessingPass { + /** The passes that this pass depends _directly_ on to run. */ + public Seq precursorPasses(); + + /** The passes that are invalidated by running this pass. */ + public Seq invalidatedPasses(); +} diff --git a/engine/runtime-compiler/src/main/java/org/enso/compiler/pass/MiniIRPass.java b/engine/runtime-compiler/src/main/java/org/enso/compiler/pass/MiniIRPass.java new file mode 100644 index 000000000000..af921af41cca --- /dev/null +++ b/engine/runtime-compiler/src/main/java/org/enso/compiler/pass/MiniIRPass.java @@ -0,0 +1,136 @@ +package org.enso.compiler.pass; + +import java.util.function.Function; +import org.enso.compiler.core.IR; +import org.enso.compiler.core.ir.Expression; +import org.enso.compiler.core.ir.Module; + +/** + * Mini IR pass operates on a single IR element at a time. The {@link org.enso.compiler.Compiler} + * traverses the whole IR tree in DFS. It works in two phases. + * + *

Note that the current implementation is limited to traverse only {@link + * org.enso.compiler.core.ir.Expression} elements as provided by {@link + * IR#mapExpressions(Function)}. Hence, the additional method {@link #transformModule(Module)}. + * + *

In the first, prepare phase, the compiler traverses from the root to the leaves and + * calls the {@link #prepare(Expression)} method on the mini pass. During this phase, the mini pass + * can gather information about the current IR element, but not modify it. + * + *

In the second, transform phase, the compiler returns from the leaves to the root and + * calls the {@link #transformExpression(Expression)} method on the mini pass. During this phase, + * the mini pass is free to transform the current IR element. The children of the current IR element + * are already transformed. + * + *

For each IR element: + * + *

    + *
  1. The {@link #prepare(Expression)} method is called to prepare the pass for the current IR + * element. This method is called when the {@link org.enso.compiler.Compiler} traverses the IR + * tree from top to bottom. This is useful for mini passes that need to build some information + * about the current IR element before transforming it. The mini pass must not modify the IR + * element neither attach any metadata to it in this method. By returning {@code null} from + * this method, the mini pass signals to the compiler that it wishes to not process the + * subtree of the current IR element. + *
  2. The {@link #transformExpression(Expression)} method is called to transform the current IR + * element. This method is called when the {@link org.enso.compiler.Compiler} traverses the + * element from bottom to top. All the children of the current IR element are already + * transformed when this method is called. + *
+ * + *

Inspired by: Miniphases: compilation + * using modular and efficient tree transformations. PDF available at infoscience.epfl.ch + */ +public abstract class MiniIRPass { + /** + * Prepare the pass for the provided IR element. This method is called when the {@link + * org.enso.compiler.Compiler} traverses the IR element from top to bottom. + * + *

The mini pass is free to gather any information about the elements it encounters (via this + * method) and use it in the {@link #transformExpression(Expression)} method. Note however, that + * it is not wise to store the references to the IR or their children for later comparison in the + * {@link #transformExpression(Expression) transform phase}, as the IR tree will most likely be + * transformed during the compilation process. + * + *

TL;DR; Do no store references to the IR elements or their children in this method. + * + * @param parent the the parent of the edge + * @param child the child expression element to be be processed. + * @return an instance of the pass to process the child's element subtree + */ + public MiniIRPass prepare(IR parent, Expression child) { + return this; + } + + /** + * Transform the provided IR element. Children of the IR element are already transformed when this + * method is called. This method is called when the {@link org.enso.compiler.Compiler} traverses + * the IR element from bottom to top. + * + *

The pass should not do any traversal in this method. + * + * @param expr Expression IR element to be transformed by this pass. + * @return The transformed Expression IR, or the same IR if no transformation is needed. Must not + * return null. + */ + public abstract Expression transformExpression(Expression expr); + + /** + * Transforms the module IR. This is the last method that is called. + * + * @see #transformExpression(Expression) + */ + public Module transformModule(Module moduleIr) { + return moduleIr; + } + + public boolean checkPostCondition(IR ir) { + return true; + } + + /** + * Name of the mini pass. + * + * @return by default it returns name of the implementing class + */ + @Override + public String toString() { + return getClass().getName(); + } + + /** + * Combines two mini IR passes into one that delegates to both of them. + * + * @param first first mini pass (can be {@code null}) + * @param second second mini pass + * @return a combined pass that calls both non-{@code null} of the provided passes + */ + public static MiniIRPass combine(MiniIRPass first, MiniIRPass second) { + return ChainedMiniPass.chain(first, second); + } + + /** + * Takes an IR element of given type {@code irType} and transforms it by provided {@link + * MiniIRPass}. When assertions are on, the resulting IR is checked with {@link + * #checkPostCondition} method of provided {@code miniPass}. + * + * @param the in and out type of IR + * @param irType class of the requested IR type + * @param ir the IR element (not {@code null}) + * @param miniPass the pass to apply + * @return the transformed IR + */ + public static T compile(Class irType, T ir, MiniIRPass miniPass) { + var newIr = MiniPassTraverser.compileDeep(ir, miniPass); + assert irType.isInstance(newIr) + : "Expected " + + irType.getName() + + " but got " + + newIr.getClass().getName() + + " by " + + miniPass; + assert miniPass.checkPostCondition(newIr) : "Post condition failed for " + miniPass; + return irType.cast(newIr); + } +} diff --git a/engine/runtime-compiler/src/main/java/org/enso/compiler/pass/MiniPassFactory.java b/engine/runtime-compiler/src/main/java/org/enso/compiler/pass/MiniPassFactory.java new file mode 100644 index 000000000000..8c98d1d87bf4 --- /dev/null +++ b/engine/runtime-compiler/src/main/java/org/enso/compiler/pass/MiniPassFactory.java @@ -0,0 +1,31 @@ +package org.enso.compiler.pass; + +import org.enso.compiler.context.InlineContext; +import org.enso.compiler.context.ModuleContext; + +/** + * Mini IR pass operates on a single IR element at a time. The {@link org.enso.compiler.Compiler} + * traverses the whole IR tree in DFS. The actual work is done by {@link MiniIRPass} implementation. + * This factory only contains a collection of factory methods to create such {@link MiniIRPass IR + * mini passes}. If a mini pass supports only inline compilation, its {@link + * #createForModuleCompilation(ModuleContext)} method should return null. + */ +public interface MiniPassFactory extends IRProcessingPass { + /** + * Creates an instance of mini pass that is capable of transforming IR elements in the context of + * a module. + * + * @param moduleContext A mini pass can optionally save reference to this module context. + * @return May return {@code null} if module compilation is not supported. + */ + MiniIRPass createForModuleCompilation(ModuleContext moduleContext); + + /** + * Creates an instance of mini pass that is capable of transforming IR elements in the context of + * an inline compilation. + * + * @param inlineContext A mini pass can optionally save reference to this inline context. + * @return Must not return {@code null}. Inline compilation should always be supported. + */ + MiniIRPass createForInlineCompilation(InlineContext inlineContext); +} diff --git a/engine/runtime-compiler/src/main/java/org/enso/compiler/pass/MiniPassTraverser.java b/engine/runtime-compiler/src/main/java/org/enso/compiler/pass/MiniPassTraverser.java new file mode 100644 index 000000000000..4243b0fce642 --- /dev/null +++ b/engine/runtime-compiler/src/main/java/org/enso/compiler/pass/MiniPassTraverser.java @@ -0,0 +1,89 @@ +package org.enso.compiler.pass; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.LinkedList; +import java.util.List; +import org.enso.compiler.core.IR; +import org.enso.compiler.core.ir.Expression; +import org.enso.compiler.core.ir.Module; + +/** Implementation of {@link MiniIRPass#compile}. */ +final class MiniPassTraverser { + private final MiniIRPass miniPass; + private List in; + private final List out; + private int outIndex; + + private MiniPassTraverser(MiniIRPass miniPass, List out, int outIndex) { + this.miniPass = miniPass; + this.out = out; + this.outIndex = outIndex; + } + + private boolean enqueue(Collection queue) { + if (in == null) { + var ir = out.get(outIndex); + in = enqueueSubExpressions(queue, ir, miniPass); + return !in.isEmpty(); + } else { + return false; + } + } + + private void convertExpression() { + if (outIndex != -1) { + var oldIr = out.get(outIndex); + var index = new int[1]; + var newIr = oldIr.mapExpressions((old) -> (Expression) in.get(index[0]++)); + var transformedIr = + switch (newIr) { + case Module m -> miniPass.transformModule(m); + case Expression e -> miniPass.transformExpression(e); + default -> throw new IllegalArgumentException("" + oldIr); + }; + if (oldIr != transformedIr) { + out.set(outIndex, transformedIr); + } + outIndex = -1; + } + } + + static IR compileDeep(IR root, MiniIRPass miniPass) { + var result = new IR[] {root}; + var rootTask = new MiniPassTraverser(miniPass, Arrays.asList(result), 0); + var stackOfPendingIrs = new LinkedList(); + stackOfPendingIrs.add(rootTask); + while (!stackOfPendingIrs.isEmpty()) { + if (stackOfPendingIrs.peekLast().enqueue(stackOfPendingIrs)) { + // continue descent + continue; + } + var deepestIr = stackOfPendingIrs.removeLast(); + deepestIr.convertExpression(); + } + assert result[0] != null; + return result[0]; + } + + /** + * @param queue queue to put objects in + * @param ir IR to process + * @param miniPass process with this mini pass + * @return {@code true} if the has been modified with new tries to process first + */ + private static List enqueueSubExpressions( + Collection queue, IR ir, MiniIRPass miniPass) { + var childExpressions = new ArrayList(); + var i = new int[1]; + ir.mapExpressions( + (ch) -> { + var preparedMiniPass = miniPass.prepare(ir, ch); + childExpressions.add(ch); + queue.add(new MiniPassTraverser(preparedMiniPass, childExpressions, i[0]++)); + return ch; + }); + return childExpressions; + } +} diff --git a/engine/runtime-compiler/src/main/java/org/enso/compiler/pass/analyse/PassPersistance.java b/engine/runtime-compiler/src/main/java/org/enso/compiler/pass/analyse/PassPersistance.java index 42792f80032d..1e9cf07989a8 100644 --- a/engine/runtime-compiler/src/main/java/org/enso/compiler/pass/analyse/PassPersistance.java +++ b/engine/runtime-compiler/src/main/java/org/enso/compiler/pass/analyse/PassPersistance.java @@ -47,7 +47,7 @@ @Persistable(clazz = GlobalNames$.class, id = 1205) @Persistable(clazz = IgnoredBindings$.class, id = 1206) @Persistable(clazz = Patterns$.class, id = 1207) -@Persistable(clazz = TailCall$.class, id = 1208) +@Persistable(clazz = TailCall.class, id = 1208) @Persistable(clazz = TypeNames$.class, id = 1209) @Persistable(clazz = TypeSignatures$.class, id = 1210) @Persistable(clazz = DocumentationComments$.class, id = 1211) @@ -66,7 +66,7 @@ @Persistable(clazz = Graph.Link.class, id = 1266, allowInlining = false) @Persistable(clazz = TypeInference.class, id = 1280) @Persistable(clazz = FramePointerAnalysis$.class, id = 1281) -@Persistable(clazz = TailCall$TailPosition$Tail$.class, id = 1282) +@Persistable(clazz = TailCall.TailPosition.class, id = 1282) public final class PassPersistance { private PassPersistance() {} diff --git a/engine/runtime-compiler/src/main/java/org/enso/compiler/pass/analyse/PrivateConstructorAnalysis.java b/engine/runtime-compiler/src/main/java/org/enso/compiler/pass/analyse/PrivateConstructorAnalysis.java index e44aca531aa4..443f65d59492 100644 --- a/engine/runtime-compiler/src/main/java/org/enso/compiler/pass/analyse/PrivateConstructorAnalysis.java +++ b/engine/runtime-compiler/src/main/java/org/enso/compiler/pass/analyse/PrivateConstructorAnalysis.java @@ -1,7 +1,6 @@ package org.enso.compiler.pass.analyse; import java.util.List; -import java.util.UUID; import org.enso.compiler.context.InlineContext; import org.enso.compiler.context.ModuleContext; import org.enso.compiler.core.IR; @@ -11,6 +10,7 @@ import org.enso.compiler.core.ir.expression.errors.Syntax.InconsistentConstructorVisibility$; import org.enso.compiler.core.ir.module.scope.Definition; import org.enso.compiler.pass.IRPass; +import org.enso.compiler.pass.IRProcessingPass; import scala.collection.immutable.Seq; import scala.jdk.javaapi.CollectionConverters; @@ -21,31 +21,19 @@ public final class PrivateConstructorAnalysis implements IRPass { public static final PrivateConstructorAnalysis INSTANCE = new PrivateConstructorAnalysis(); - private UUID uuid; - private PrivateConstructorAnalysis() {} @Override - public void org$enso$compiler$pass$IRPass$_setter_$key_$eq(UUID v) { - this.uuid = v; - } - - @Override - public UUID key() { - return uuid; - } - - @Override - public Seq precursorPasses() { - List passes = List.of(PrivateModuleAnalysis.INSTANCE); + public Seq precursorPasses() { + List passes = List.of(PrivateModuleAnalysis.INSTANCE); return CollectionConverters.asScala(passes).toList(); } @Override @SuppressWarnings("unchecked") - public Seq invalidatedPasses() { + public Seq invalidatedPasses() { Object obj = scala.collection.immutable.Nil$.MODULE$; - return (scala.collection.immutable.List) obj; + return (scala.collection.immutable.List) obj; } @Override diff --git a/engine/runtime-compiler/src/main/java/org/enso/compiler/pass/analyse/PrivateModuleAnalysis.java b/engine/runtime-compiler/src/main/java/org/enso/compiler/pass/analyse/PrivateModuleAnalysis.java index a553d253dcba..8b772a3a0f7a 100644 --- a/engine/runtime-compiler/src/main/java/org/enso/compiler/pass/analyse/PrivateModuleAnalysis.java +++ b/engine/runtime-compiler/src/main/java/org/enso/compiler/pass/analyse/PrivateModuleAnalysis.java @@ -2,7 +2,6 @@ import java.util.ArrayList; import java.util.List; -import java.util.UUID; import org.enso.compiler.context.InlineContext; import org.enso.compiler.context.ModuleContext; import org.enso.compiler.core.IR; @@ -13,6 +12,7 @@ import org.enso.compiler.core.ir.module.scope.Import; import org.enso.compiler.data.BindingsMap; import org.enso.compiler.pass.IRPass; +import org.enso.compiler.pass.IRProcessingPass; import org.enso.pkg.QualifiedName; import scala.Option; import scala.collection.immutable.Seq; @@ -31,31 +31,21 @@ */ public final class PrivateModuleAnalysis implements IRPass { public static final PrivateModuleAnalysis INSTANCE = new PrivateModuleAnalysis(); - private UUID uuid; private PrivateModuleAnalysis() {} @Override - public void org$enso$compiler$pass$IRPass$_setter_$key_$eq(UUID v) { - this.uuid = v; - } - - @Override - public UUID key() { - return uuid; - } - - @Override - public Seq precursorPasses() { - List passes = List.of(BindingAnalysis$.MODULE$, ImportSymbolAnalysis$.MODULE$); + public Seq precursorPasses() { + List passes = + List.of(BindingAnalysis$.MODULE$, ImportSymbolAnalysis$.MODULE$); return CollectionConverters.asScala(passes).toList(); } @Override @SuppressWarnings("unchecked") - public Seq invalidatedPasses() { + public Seq invalidatedPasses() { Object obj = scala.collection.immutable.Nil$.MODULE$; - return (scala.collection.immutable.List) obj; + return (scala.collection.immutable.List) obj; } @Override diff --git a/engine/runtime-compiler/src/main/java/org/enso/compiler/pass/analyse/PrivateSymbolsAnalysis.java b/engine/runtime-compiler/src/main/java/org/enso/compiler/pass/analyse/PrivateSymbolsAnalysis.java index c9f9d58070c4..e34e4c13029b 100644 --- a/engine/runtime-compiler/src/main/java/org/enso/compiler/pass/analyse/PrivateSymbolsAnalysis.java +++ b/engine/runtime-compiler/src/main/java/org/enso/compiler/pass/analyse/PrivateSymbolsAnalysis.java @@ -1,7 +1,6 @@ package org.enso.compiler.pass.analyse; import java.util.List; -import java.util.UUID; import org.enso.compiler.context.InlineContext; import org.enso.compiler.context.ModuleContext; import org.enso.compiler.core.IR; @@ -17,6 +16,7 @@ import org.enso.compiler.data.BindingsMap.ResolvedModule; import org.enso.compiler.data.BindingsMap.ResolvedName; import org.enso.compiler.pass.IRPass; +import org.enso.compiler.pass.IRProcessingPass; import org.enso.compiler.pass.resolve.Patterns$; import org.enso.pkg.QualifiedName; import scala.collection.immutable.Seq; @@ -30,30 +30,19 @@ */ public class PrivateSymbolsAnalysis implements IRPass { public static final PrivateSymbolsAnalysis INSTANCE = new PrivateSymbolsAnalysis(); - private UUID uuid; private PrivateSymbolsAnalysis() {} @Override - public void org$enso$compiler$pass$IRPass$_setter_$key_$eq(UUID v) { - this.uuid = v; - } - - @Override - public UUID key() { - return uuid; - } - - @Override - public Seq precursorPasses() { - List passes = + public Seq precursorPasses() { + List passes = List.of( PrivateModuleAnalysis.INSTANCE, PrivateConstructorAnalysis.INSTANCE, Patterns$.MODULE$); return CollectionConverters.asScala(passes).toList(); } @Override - public Seq invalidatedPasses() { + public Seq invalidatedPasses() { return nil(); } diff --git a/engine/runtime-compiler/src/main/java/org/enso/compiler/pass/analyse/TailCall.java b/engine/runtime-compiler/src/main/java/org/enso/compiler/pass/analyse/TailCall.java new file mode 100644 index 000000000000..b328f666abae --- /dev/null +++ b/engine/runtime-compiler/src/main/java/org/enso/compiler/pass/analyse/TailCall.java @@ -0,0 +1,393 @@ +package org.enso.compiler.pass.analyse; + +import java.util.List; +import org.enso.compiler.context.InlineContext; +import org.enso.compiler.context.ModuleContext; +import org.enso.compiler.core.CompilerError; +import org.enso.compiler.core.CompilerStub; +import org.enso.compiler.core.IR; +import org.enso.compiler.core.ir.CallArgument; +import org.enso.compiler.core.ir.Expression; +import org.enso.compiler.core.ir.Function; +import org.enso.compiler.core.ir.Literal; +import org.enso.compiler.core.ir.MetadataStorage.MetadataPair; +import org.enso.compiler.core.ir.Module; +import org.enso.compiler.core.ir.Name; +import org.enso.compiler.core.ir.Type; +import org.enso.compiler.core.ir.Warning; +import org.enso.compiler.core.ir.expression.Application; +import org.enso.compiler.core.ir.expression.Case; +import org.enso.compiler.core.ir.expression.Comment; +import org.enso.compiler.core.ir.module.scope.Definition; +import org.enso.compiler.core.ir.module.scope.definition.*; +import org.enso.compiler.pass.IRPass; +import org.enso.compiler.pass.IRProcessingPass; +import org.enso.compiler.pass.MiniIRPass; +import org.enso.compiler.pass.MiniPassFactory; +import org.enso.compiler.pass.desugar.*; +import org.enso.compiler.pass.resolve.ExpressionAnnotations; +import org.enso.compiler.pass.resolve.ExpressionAnnotations$; +import org.enso.compiler.pass.resolve.GlobalNames$; +import org.enso.compiler.pass.resolve.ModuleAnnotations.Annotations; +import scala.Option; +import scala.collection.immutable.Seq; +import scala.jdk.javaapi.CollectionConverters; + +/** + * This pass performs tail call analysis on the Enso IR. + * + *

It is responsible for marking every single expression with whether it is in tail position. + * This allows the code generator to correctly create the Truffle nodes. If the expression is in + * tail position, [[TailPosition.Tail]] metadata is attached to it, otherwise, nothing is attached. + * + *

This pass requires the context to provide: + * + *

- The tail position of its expression, where relevant. + */ +public final class TailCall implements MiniPassFactory { + public static final TailCall INSTANCE = new TailCall(); + private static final MetadataPair TAIL_META = + new MetadataPair<>(INSTANCE, TailPosition.Tail); + + private TailCall() {} + + @Override + public Seq precursorPasses() { + List passes = + List.of( + FunctionBinding$.MODULE$, + GenerateMethodBodies$.MODULE$, + SectionsToBinOp.INSTANCE, + OperatorToFunction$.MODULE$, + LambdaShorthandToLambda$.MODULE$, + GlobalNames$.MODULE$); + return CollectionConverters.asScala(passes).toList(); + } + + @Override + public Seq invalidatedPasses() { + return CollectionConverters.asScala(List.of()).toList(); + } + + @Override + public MiniIRPass createForInlineCompilation(InlineContext inlineContext) { + var opt = inlineContext.isInTailPosition(); + if (opt.isEmpty()) { + throw new CompilerError( + "Information about the tail position for an inline expression " + + "must be known by the point of tail call analysis."); + } + return mini(Boolean.TRUE.equals(opt.get())); + } + + @Override + public MiniIRPass createForModuleCompilation(ModuleContext moduleContext) { + return mini(false); + } + + private static final Mini IN_TAIL_POS = new Mini(true); + private static final Mini NOT_IN_TAIL_POS = new Mini(false); + + static Mini mini(boolean isInTailPos) { + return isInTailPos ? IN_TAIL_POS : NOT_IN_TAIL_POS; + } + + /** Expresses the tail call state of an IR Node. */ + @SuppressWarnings("unchecked") + public static final class TailPosition implements IRPass.IRMetadata { + public static final TailPosition Tail = new TailPosition(); + + private TailPosition() {} + + /** A boolean representation of the expression's tail state. */ + public boolean isTail() { + return true; + } + + @Override + public String metadataName() { + return "TailCall.TailPosition.Tail"; + } + + @Override + public Option duplicate() { + return Option.apply(this); + } + + /** + * @inheritdoc + */ + @Override + public TailPosition prepareForSerialization(CompilerStub compiler) { + return this; + } + + /** + * @inheritdoc + */ + @Override + public Option restoreFromSerialization(CompilerStub compiler) { + return Option.apply(this); + } + } + + /** + * Checks if the provided `expression` is annotated with a tail call annotation. + * + * @param expression the expression to check + * @return `true` if `expression` is annotated with `@Tail_Call`, otherwise `false` + */ + public static final boolean isTailAnnotated(Expression expression) { + var meta = expression.passData().get(ExpressionAnnotations$.MODULE$); + if (meta.isEmpty()) { + return false; + } + var anns = (Annotations) meta.get(); + return anns.annotations().exists(a -> ExpressionAnnotations.tailCallName().equals(a.name())); + } + + private static final class Mini extends MiniIRPass { + private final boolean isInTailPos; + + Mini(boolean in) { + isInTailPos = in; + } + + @Override + public Module transformModule(Module m) { + m.bindings().map(this::updateModuleBinding); + return m; + } + + /** + * Performs tail call analysis on a top-level definition in a module. + * + * @param moduleDefinition the top-level definition to analyse + * @return `definition`, annotated with tail call information + */ + private Void updateModuleBinding(Definition moduleDefinition) { + switch (moduleDefinition) { + case Method.Conversion method -> markAsTail(method); + case Method.Explicit method -> markAsTail(method); + case Method.Binding b -> throw new CompilerError( + "Sugared method definitions should not occur during tail call " + "analysis."); + case Definition.Type t -> markAsTail(t); + case Definition.SugaredType st -> throw new CompilerError( + "Complex type definitions should not be present during " + "tail call analysis."); + case Comment.Documentation cd -> throw new CompilerError( + "Documentation should not exist as an entity during tail call analysis."); + case Type.Ascription ta -> throw new CompilerError( + "Type signatures should not exist at the top level during " + "tail call analysis."); + case Name.BuiltinAnnotation ba -> throw new CompilerError( + "Annotations should already be associated by the point of " + "tail call analysis."); + case Name.GenericAnnotation ann -> markAsTail(ann); + default -> {} + } + return null; + } + + private void markAsTailConditionally(IR ir) { + if (isInTailPos) { + markAsTail(ir); + } + } + + private Void markAsTail(IR ir) { + ir.passData().update(TAIL_META); + return null; + } + + @Override + public Expression transformExpression(Expression ir) { + switch (ir) { + case Literal l -> {} + case Application.Prefix p -> { + markAsTailConditionally(p); + // Note [Call Argument Tail Position] + p.arguments().foreach(a -> markAsTail(a)); + } + case Case.Expr e -> { + if (isInTailPos) { + markAsTail(ir); + // Note [Analysing Branches in Case Expressions] + e.branches().foreach(b -> markAsTail(b)); + } + } + default -> markAsTailConditionally(ir); + } + if (!isInTailPos && isTailAnnotated(ir)) { + ir.getDiagnostics().add(new Warning.WrongTco(ir.identifiedLocation())); + } + return ir; + } + + @Override + public Mini prepare(IR parent, Expression child) { + var isChildTailCandidate = + switch (parent) { + case Module m -> true; + case Expression e -> { + var tailCandidates = new java.util.IdentityHashMap(); + collectTailCandidatesExpression(e, tailCandidates); + yield tailCandidates.containsKey(child); + } + default -> false; + }; + return new Mini(isChildTailCandidate); + } + + /** + * Performs tail call analysis on an arbitrary expression. + * + * @param expression the expression to check + * @return `expression`, annotated with tail position metadata + */ + private void collectTailCandidatesExpression( + Expression expression, java.util.Map tailCandidates) { + switch (expression) { + case Function function -> collectTailCandicateFunction(function, tailCandidates); + case Case caseExpr -> collectTailCandidatesCase(caseExpr, tailCandidates); + case Application app -> collectTailCandidatesApplication(app, tailCandidates); + case Name name -> collectTailCandidatesName(name, tailCandidates); + case Comment c -> throw new CompilerError( + "Comments should not be present during tail call analysis."); + case Expression.Block b -> { + if (isInTailPos) { + tailCandidates.put(b.returnValue(), true); + } + } + default -> {} + } + } + + /** + * Performs tail call analysis on an occurrence of a name. + * + * @param name the name to check + * @return `name`, annotated with tail position metadata + */ + private void collectTailCandidatesName(Name name, java.util.Map tailCandidates) { + if (isInTailPos) { + tailCandidates.put(name, true); + } + } + + /** + * Performs tail call analysis on an application. + * + * @param application the application to check + * @return `application`, annotated with tail position metadata + */ + private void collectTailCandidatesApplication( + Application application, java.util.Map tailCandidates) { + switch (application) { + case Application.Prefix p -> p.arguments() + .foreach(a -> collectTailCandidatesCallArg(a, tailCandidates)); + case Application.Force f -> { + if (isInTailPos) { + tailCandidates.put(f.target(), true); + } + } + case Application.Sequence s -> {} + case Application.Typeset ts -> {} + default -> throw new CompilerError("Unexpected binary operator."); + } + } + + /** + * Performs tail call analysis on a call site argument. + * + * @param argument the argument to check + * @return `argument`, annotated with tail position metadata + */ + private Void collectTailCandidatesCallArg( + CallArgument argument, java.util.Map tailCandidates) { + switch (argument) { + case CallArgument.Specified ca -> + // Note [Call Argument Tail Position] + tailCandidates.put(ca.value(), true); + default -> {} + } + return null; + } + + /* Note [Call Argument Tail Position] + * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + * In order to efficiently deal with Enso's ability to suspend function + * arguments, we behave as if all arguments to a function are passed as + * thunks. This means that the _function_ becomes responsible for deciding + * when to evaluate its arguments. + * + * Conceptually, this results in a desugaring as follows: + * + * ``` + * foo a b c + * ``` + * + * Becomes: + * + * ``` + * foo ({} -> a) ({} -> b) ({} -> c) + * ``` + * + * Quite obviously, the arguments `a`, `b` and `c` are in tail position in + * these closures, and hence should be marked as tail. + */ + + /** + * Performs tail call analysis on a case expression. + * + * @param caseExpr the case expression to check + * @return `caseExpr`, annotated with tail position metadata + */ + private void collectTailCandidatesCase( + Case caseExpr, java.util.Map tailCandidates) { + switch (caseExpr) { + case Case.Expr expr -> { + if (isInTailPos) { + // Note [Analysing Branches in Case Expressions] + expr.branches() + .foreach( + b -> { + tailCandidates.put(b.expression(), true); + return null; + }); + } + } + default -> throw new CompilerError("Unexpected case branch."); + } + } + + /* Note [Analysing Branches in Case Expressions] + * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + * When performing tail call analysis on a case expression it is very + * important to recognise that the branches of a case expression should all + * have the same tail call state. The branches should only be marked as being + * in tail position when the case expression _itself_ is in tail position. + * + * As only one branch is ever executed, it is hence safe to mark _all_ + * branches as being in tail position if the case expression is. + */ + + /** + * Body of the function may be in tail position. + * + * @param function the function to check + * @return `function`, annotated with tail position metadata + */ + private void collectTailCandicateFunction( + Function function, java.util.Map tailCandidates) { + var canBeTCO = function.canBeTCO(); + var markAsTail = (!canBeTCO && isInTailPos) || canBeTCO; + switch (function) { + case Function.Lambda l -> { + if (markAsTail) { + tailCandidates.put(l.body(), true); + } + } + default -> throw new CompilerError( + "Function sugar should not be present during tail call analysis."); + } + } + } +} diff --git a/engine/runtime-compiler/src/main/java/org/enso/compiler/pass/analyse/types/TypeInference.java b/engine/runtime-compiler/src/main/java/org/enso/compiler/pass/analyse/types/TypeInference.java index 0dc640cbac8d..ded8086c8862 100644 --- a/engine/runtime-compiler/src/main/java/org/enso/compiler/pass/analyse/types/TypeInference.java +++ b/engine/runtime-compiler/src/main/java/org/enso/compiler/pass/analyse/types/TypeInference.java @@ -2,7 +2,6 @@ import java.util.List; import java.util.Objects; -import java.util.UUID; import org.enso.compiler.context.InlineContext; import org.enso.compiler.context.ModuleContext; import org.enso.compiler.core.IR; @@ -12,6 +11,7 @@ import org.enso.compiler.core.ir.module.scope.Definition; import org.enso.compiler.core.ir.module.scope.definition.Method; import org.enso.compiler.pass.IRPass; +import org.enso.compiler.pass.IRProcessingPass; import org.enso.compiler.pass.analyse.BindingAnalysis$; import org.enso.compiler.pass.resolve.FullyQualifiedNames$; import org.enso.compiler.pass.resolve.GlobalNames$; @@ -92,21 +92,10 @@ protected void encounteredInvocationOfNonFunctionType( .add(new Warning.NotInvokable(relatedIr.identifiedLocation(), type.toString())); } }; - private UUID uuid; @Override - public void org$enso$compiler$pass$IRPass$_setter_$key_$eq(UUID v) { - this.uuid = v; - } - - @Override - public UUID key() { - return uuid; - } - - @Override - public Seq precursorPasses() { - List passes = + public Seq precursorPasses() { + List passes = List.of( BindingAnalysis$.MODULE$, GlobalNames$.MODULE$, @@ -119,8 +108,8 @@ public Seq precursorPasses() { @Override @SuppressWarnings("unchecked") - public Seq invalidatedPasses() { - return (Seq) Seq$.MODULE$.empty(); + public Seq invalidatedPasses() { + return (Seq) Seq$.MODULE$.empty(); } @Override diff --git a/engine/runtime-compiler/src/main/java/org/enso/compiler/pass/desugar/OperatorToFunctionMini.java b/engine/runtime-compiler/src/main/java/org/enso/compiler/pass/desugar/OperatorToFunctionMini.java new file mode 100644 index 000000000000..d33225c59c81 --- /dev/null +++ b/engine/runtime-compiler/src/main/java/org/enso/compiler/pass/desugar/OperatorToFunctionMini.java @@ -0,0 +1,44 @@ +package org.enso.compiler.pass.desugar; + +import org.enso.compiler.core.IR; +import org.enso.compiler.core.ir.CallArgument; +import org.enso.compiler.core.ir.Expression; +import org.enso.compiler.core.ir.expression.Application; +import org.enso.compiler.core.ir.expression.Operator; +import org.enso.compiler.pass.MiniIRPass; +import scala.collection.mutable.ListBuffer; + +public class OperatorToFunctionMini extends MiniIRPass { + OperatorToFunctionMini() {} + + @Override + public Expression transformExpression(Expression ir) { + if (ir instanceof Operator.Binary binOp) { + ListBuffer args = new ListBuffer<>(); + args.addOne(binOp.left()); + args.addOne(binOp.right()); + return new Application.Prefix( + binOp.operator(), + args.toList(), + false, + binOp.location().isDefined() ? binOp.location().get() : null, + binOp.passData(), + binOp.diagnostics()); + } + return ir; + } + + @Override + public boolean checkPostCondition(IR ir) { + boolean[] isChildOperator = {false}; + ir.children() + .foreach( + child -> { + if (child instanceof Operator.Binary) { + isChildOperator[0] = true; + } + return null; + }); + return !isChildOperator[0]; + } +} diff --git a/engine/runtime-compiler/src/main/java/org/enso/compiler/pass/desugar/SectionsToBinOp.java b/engine/runtime-compiler/src/main/java/org/enso/compiler/pass/desugar/SectionsToBinOp.java new file mode 100644 index 000000000000..36cd49c72b1e --- /dev/null +++ b/engine/runtime-compiler/src/main/java/org/enso/compiler/pass/desugar/SectionsToBinOp.java @@ -0,0 +1,253 @@ +package org.enso.compiler.pass.desugar; + +import java.util.List; +import org.enso.compiler.context.InlineContext; +import org.enso.compiler.context.ModuleContext; +import org.enso.compiler.core.ir.CallArgument; +import org.enso.compiler.core.ir.DefinitionArgument; +import org.enso.compiler.core.ir.Expression; +import org.enso.compiler.core.ir.Function; +import org.enso.compiler.core.ir.MetadataStorage; +import org.enso.compiler.core.ir.Name; +import org.enso.compiler.core.ir.expression.Application; +import org.enso.compiler.core.ir.expression.Section; +import org.enso.compiler.pass.IRProcessingPass; +import org.enso.compiler.pass.MiniIRPass; +import org.enso.compiler.pass.MiniPassFactory; +import org.enso.compiler.pass.analyse.AliasAnalysis$; +import org.enso.compiler.pass.analyse.CachePreferenceAnalysis$; +import org.enso.compiler.pass.analyse.DataflowAnalysis$; +import org.enso.compiler.pass.analyse.DemandAnalysis$; +import org.enso.compiler.pass.analyse.TailCall; +import org.enso.compiler.pass.lint.UnusedBindings$; +import scala.Option; +import scala.collection.immutable.Seq; +import scala.jdk.javaapi.CollectionConverters; + +public final class SectionsToBinOp implements MiniPassFactory { + + public static final SectionsToBinOp INSTANCE = new SectionsToBinOp(); + + private SectionsToBinOp() {} + + @Override + public Seq precursorPasses() { + List passes = List.of(GenerateMethodBodies$.MODULE$); + return CollectionConverters.asScala(passes).toList(); + } + + @Override + public Seq invalidatedPasses() { + List passes = + List.of( + AliasAnalysis$.MODULE$, + CachePreferenceAnalysis$.MODULE$, + DataflowAnalysis$.MODULE$, + DemandAnalysis$.MODULE$, + TailCall.INSTANCE, + UnusedBindings$.MODULE$); + return CollectionConverters.asScala(passes).toList(); + } + + @Override + public MiniIRPass createForModuleCompilation(ModuleContext moduleContext) { + var ctx = InlineContext.fromModuleContext(moduleContext); + return new Mini(ctx); + } + + @Override + public MiniIRPass createForInlineCompilation(InlineContext inlineContext) { + return new Mini(inlineContext); + } + + private static final class Mini extends MiniIRPass { + + private final InlineContext ctx; + + private Mini(InlineContext ctx) { + this.ctx = ctx; + } + + public Expression transformExpression(Expression ir) { + var freshNameSupply = ctx.freshNameSupply().get(); + return switch (ir) { + case Section.Left sectionLeft -> { + var arg = sectionLeft.arg(); + var op = sectionLeft.operator(); + var loc = sectionLeft.location().isDefined() ? sectionLeft.location().get() : null; + var passData = sectionLeft.passData(); + var rightArgName = freshNameSupply.newName(false, Option.empty()); + var rightCallArg = new CallArgument.Specified(Option.empty(), rightArgName, null, meta()); + var rightDefArg = + new DefinitionArgument.Specified( + rightArgName.duplicate(true, true, true, false), + Option.empty(), + Option.empty(), + false, + null, + meta()); + + if (arg.value() instanceof Name.Blank) { + var leftArgName = freshNameSupply.newName(false, Option.empty()); + var leftCallArg = new CallArgument.Specified(Option.empty(), leftArgName, null, meta()); + var leftDefArg = + new DefinitionArgument.Specified( + leftArgName.duplicate(true, true, true, false), + Option.empty(), + Option.empty(), + false, + null, + meta()); + var opCall = + new Application.Prefix( + op, + cons(leftCallArg, cons(rightCallArg, nil())), + false, + null, + passData, + sectionLeft.diagnostics()); + + var rightLam = + new Function.Lambda(cons(rightDefArg, nil()), opCall, null, true, meta()); + + yield new Function.Lambda(cons(leftDefArg, nil()), rightLam, loc, true, meta()); + } else { + yield new Application.Prefix( + op, cons(arg, nil()), false, loc, passData, sectionLeft.diagnostics()); + } + } + + case Section.Sides sectionSides -> { + var op = sectionSides.operator(); + var loc = sectionSides.location().isDefined() ? sectionSides.location().get() : null; + var passData = sectionSides.passData(); + var leftArgName = freshNameSupply.newName(false, Option.empty()); + var leftCallArg = new CallArgument.Specified(Option.empty(), leftArgName, null, meta()); + var leftDefArg = + new DefinitionArgument.Specified( + leftArgName.duplicate(true, true, true, false), + Option.empty(), + Option.empty(), + false, + null, + meta()); + + var rightArgName = freshNameSupply.newName(false, Option.empty()); + var rightCallArg = new CallArgument.Specified(Option.empty(), rightArgName, null, meta()); + var rightDefArg = + new DefinitionArgument.Specified( + rightArgName.duplicate(true, true, true, false), + Option.empty(), + Option.empty(), + false, + null, + meta()); + + var opCall = + new Application.Prefix( + op, + cons(leftCallArg, cons(rightCallArg, nil())), + false, + null, + passData, + sectionSides.diagnostics()); + + var rightLambda = + new Function.Lambda(cons(rightDefArg, nil()), opCall, null, true, meta()); + + yield new Function.Lambda(cons(leftDefArg, nil()), rightLambda, loc, true, meta()); + } + + /* Note [Blanks in Sections] + * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + * While the naiive compositional translation of `(- _)` first translates + * the section into a function applying `-` to two arguments, one of which + * is a blank, the compositional nature of the blanks translation actually + * works against us here. + * + * As the `LambdaShorthandToLambda` pass can only operate on the + * application with the blanks, it can't know to push the blank outside + * that application chain. To that end, we have to handle this case + * specially here instead. What we want it to translate to is as follows: + * + * `(- _)` == `x -> (- x)` == `x -> y -> y - x` + * + * We implement this special case here. + * + * The same is true of left sections. + */ + + case Section.Right sectionRight -> { + var arg = sectionRight.arg(); + var op = sectionRight.operator(); + var loc = sectionRight.location().isDefined() ? sectionRight.location().get() : null; + var passData = sectionRight.passData(); + var leftArgName = freshNameSupply.newName(false, Option.empty()); + var leftCallArg = new CallArgument.Specified(Option.empty(), leftArgName, null, meta()); + var leftDefArg = + new DefinitionArgument.Specified( + leftArgName.duplicate(true, true, true, false), + Option.empty(), + Option.empty(), + false, + null, + meta()); + + if (arg.value() instanceof Name.Blank) { + // Note [Blanks in Sections] + var rightArgName = freshNameSupply.newName(false, Option.empty()); + var rightCallArg = + new CallArgument.Specified(Option.empty(), rightArgName, null, meta()); + var rightDefArg = + new DefinitionArgument.Specified( + rightArgName.duplicate(true, true, true, false), + Option.empty(), + Option.empty(), + false, + null, + meta()); + + var opCall = + new Application.Prefix( + op, + cons(leftCallArg, cons(rightCallArg, nil())), + false, + null, + passData, + sectionRight.diagnostics()); + + var leftLam = new Function.Lambda(cons(leftDefArg, nil()), opCall, null, true, meta()); + + yield new Function.Lambda(cons(rightDefArg, nil()), leftLam, loc, true, meta()); + } else { + var opCall = + new Application.Prefix( + op, + cons(leftCallArg, cons(arg, nil())), + false, + null, + passData, + sectionRight.diagnostics()); + + yield new Function.Lambda(cons(leftDefArg, nil()), opCall, loc, true, meta()); + } + } + default -> ir; + }; + } + + private static MetadataStorage meta() { + return new MetadataStorage(); + } + + @SuppressWarnings("unchecked") + private static scala.collection.immutable.List nil() { + return (scala.collection.immutable.List) scala.collection.immutable.Nil$.MODULE$; + } + + private static scala.collection.immutable.List cons( + T head, scala.collection.immutable.List tail) { + return scala.collection.immutable.$colon$colon$.MODULE$.apply(head, tail); + } + } +} diff --git a/engine/runtime-compiler/src/main/scala/org/enso/compiler/Passes.scala b/engine/runtime-compiler/src/main/scala/org/enso/compiler/Passes.scala index 34ef9011b995..acdf7361e1e6 100644 --- a/engine/runtime-compiler/src/main/scala/org/enso/compiler/Passes.scala +++ b/engine/runtime-compiler/src/main/scala/org/enso/compiler/Passes.scala @@ -18,7 +18,7 @@ import org.enso.compiler.pass.optimise.{ } import org.enso.compiler.pass.resolve._ import org.enso.compiler.pass.{ - IRPass, + IRProcessingPass, PassConfiguration, PassGroup, PassManager @@ -41,7 +41,7 @@ class Passes(config: CompilerConfig) { val globalTypingPasses = new PassGroup( List( MethodDefinitions, - SectionsToBinOp, + SectionsToBinOp.INSTANCE, OperatorToFunction, LambdaShorthandToLambda, ImportSymbolAnalysis, @@ -86,7 +86,7 @@ class Passes(config: CompilerConfig) { AliasAnalysis, DemandAnalysis, AliasAnalysis, - TailCall, + TailCall.INSTANCE, Patterns ) ++ (if (config.privateCheckEnabled) { List(PrivateSymbolsAnalysis.INSTANCE) @@ -123,7 +123,7 @@ class Passes(config: CompilerConfig) { ) /** The ordered representation of all passes run by the compiler. */ - val allPassOrdering: List[IRPass] = passOrdering.flatMap(_.passes) + val allPassOrdering: List[IRProcessingPass] = passOrdering.flatMap(_.passes) /** Configuration for the passes. */ private val passConfig: PassConfiguration = PassConfiguration( @@ -142,7 +142,7 @@ class Passes(config: CompilerConfig) { * @param pass the pass to get the precursors for * @return the precursors to the first instance of `pass` */ - def getPrecursors(pass: IRPass): Option[PassGroup] = { + def getPrecursors(pass: IRProcessingPass): Option[PassGroup] = { val allPasses = passOrdering.flatMap(_.passes) val result = allPasses.takeWhile(_ != pass) if (result.length != allPasses.length) { diff --git a/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/IRPass.scala b/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/IRPass.scala index cb4a1c996d0b..4d898059d691 100644 --- a/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/IRPass.scala +++ b/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/IRPass.scala @@ -19,10 +19,7 @@ import scala.reflect.ClassTag * its header the requirements it has for pass configuration and for passes * that must run before it. */ -trait IRPass extends ProcessingPass { - - /** An identifier for the pass. Useful for keying it in maps. */ - val key: UUID @Identifier = IRPass.genId +trait IRPass extends IRProcessingPass with ProcessingPass { /** The type of the metadata object that the pass writes to the IR. */ type Metadata <: ProcessingPass.Metadata @@ -31,10 +28,10 @@ trait IRPass extends ProcessingPass { type Config <: IRPass.Configuration /** The passes that this pass depends _directly_ on to run. */ - val precursorPasses: Seq[IRPass] + val precursorPasses: Seq[IRProcessingPass] /** The passes that are invalidated by running this pass. */ - val invalidatedPasses: Seq[IRPass] + val invalidatedPasses: Seq[IRProcessingPass] /** Executes the pass on the provided `ir`, and returns a possibly transformed * or annotated version of `ir`. diff --git a/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/PassManager.scala b/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/PassManager.scala index d3779ae3267d..5bb0732e5734 100644 --- a/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/PassManager.scala +++ b/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/PassManager.scala @@ -1,5 +1,6 @@ package org.enso.compiler.pass +import org.slf4j.LoggerFactory import org.enso.compiler.context.{InlineContext, ModuleContext} import org.enso.compiler.core.ir.{Expression, Module} import org.enso.compiler.core.CompilerError @@ -20,7 +21,8 @@ class PassManager( passes: List[PassGroup], passConfiguration: PassConfiguration ) { - val allPasses = verifyPassOrdering(passes.flatMap(_.passes)) + private val logger = LoggerFactory.getLogger(classOf[PassManager]) + val allPasses = verifyPassOrdering(passes.flatMap(_.passes)) /** Computes a valid pass ordering for the compiler. * @@ -28,8 +30,10 @@ class PassManager( * @throws CompilerError if a valid pass ordering cannot be computed * @return a valid pass ordering for the compiler, based on `passes` */ - private def verifyPassOrdering(passes: List[IRPass]): List[IRPass] = { - var validPasses: Set[IRPass] = Set() + private def verifyPassOrdering( + passes: List[IRProcessingPass] + ): List[IRProcessingPass] = { + var validPasses: Set[IRProcessingPass] = Set() passes.foreach(pass => { val prereqsSatisfied = @@ -89,18 +93,65 @@ class PassManager( val passesWithIndex = passGroup.passes.zipWithIndex - passesWithIndex.foldLeft(ir) { + logger.debug( + "runPassesOnModule[{}@{}]", + moduleContext.getName(), + moduleContext.module.getCompilationStage() + ) + var pendingMiniPasses: List[MiniPassFactory] = List() + def flushMiniPass(in: Module): Module = { + if (pendingMiniPasses.nonEmpty) { + val miniPasses = pendingMiniPasses.map(factory => + factory.createForModuleCompilation(newContext) + ) + val combinedPass = miniPasses.fold(null)(MiniIRPass.combine) + logger.trace(" flushing pending mini pass: {}", combinedPass) + pendingMiniPasses = List() + MiniIRPass.compile(classOf[Module], in, combinedPass) + } else { + in + } + } + val res = passesWithIndex.foldLeft(ir) { case (intermediateIR, (pass, index)) => { - // TODO [AA, MK] This is a possible race condition. - passConfiguration - .get(pass) - .foreach(c => - c.shouldWriteToContext = isLastRunOf(index, pass, passGroup) - ) - - pass.runModule(intermediateIR, newContext) + pass match { + case miniFactory: MiniPassFactory => + logger.trace( + " mini collected: {}", + pass + ) + val combiningPreventedBy = pendingMiniPasses.find { p => + p.invalidatedPasses.contains(miniFactory) + } + val irForRemainingMiniPasses = if (combiningPreventedBy.isDefined) { + logger.trace( + " pass {} forces flush before {}", + combiningPreventedBy.orNull, + miniFactory + ) + flushMiniPass(intermediateIR) + } else { + intermediateIR + } + pendingMiniPasses = pendingMiniPasses.appended(miniFactory) + irForRemainingMiniPasses + case megaPass: IRPass => + // TODO [AA, MK] This is a possible race condition. + passConfiguration + .get(megaPass) + .foreach(c => + c.shouldWriteToContext = isLastRunOf(index, megaPass, passGroup) + ) + val flushedIR = flushMiniPass(intermediateIR) + logger.trace( + " mega running: {}", + megaPass + ) + megaPass.runModule(flushedIR, newContext) + } } } + flushMiniPass(res) } /** Executes all passes on the [[Expression]]. @@ -141,14 +192,21 @@ class PassManager( passesWithIndex.foldLeft(ir) { case (intermediateIR, (pass, index)) => { - // TODO [AA, MK] This is a possible race condition. - passConfiguration - .get(pass) - .foreach(c => - c.shouldWriteToContext = isLastRunOf(index, pass, passGroup) - ) - - pass.runExpression(intermediateIR, newContext) + + pass match { + case miniFactory: MiniPassFactory => + val miniPass = miniFactory.createForInlineCompilation(newContext) + MiniIRPass.compile(classOf[Expression], intermediateIR, miniPass) + case megaPass: IRPass => + // TODO [AA, MK] This is a possible race condition. + passConfiguration + .get(megaPass) + .foreach(c => + c.shouldWriteToContext = isLastRunOf(index, megaPass, passGroup) + ) + megaPass.runExpression(intermediateIR, newContext) + } + } } } @@ -189,8 +247,10 @@ class PassManager( * information from `sourceIr` */ def runMetadataUpdate(sourceIr: Module, copyOfIr: Module): Module = { - allPasses.foldLeft(copyOfIr) { (module, pass) => - pass.updateMetadataInDuplicate(sourceIr, module) + allPasses.foldLeft(copyOfIr) { + case (module, megaPass: IRPass) => + megaPass.updateMetadataInDuplicate(sourceIr, module) + case (module, _) => module } } } @@ -199,4 +259,4 @@ class PassManager( * * @param passes the passes in the group */ -class PassGroup(val passes: List[IRPass]) +class PassGroup(val passes: List[IRProcessingPass]) diff --git a/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/analyse/AliasAnalysis.scala b/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/analyse/AliasAnalysis.scala index a50e66726574..a5d465d5eeef 100644 --- a/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/analyse/AliasAnalysis.scala +++ b/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/analyse/AliasAnalysis.scala @@ -28,6 +28,7 @@ import org.enso.compiler.core.ir.{ } import org.enso.compiler.core.{CompilerError, IR} import org.enso.compiler.pass.IRPass +import org.enso.compiler.pass.IRProcessingPass import org.enso.compiler.pass.analyse.alias.graph.Graph import org.enso.compiler.pass.analyse.alias.graph.GraphOccurrence import org.enso.compiler.pass.analyse.alias.graph.Graph.Scope @@ -74,15 +75,15 @@ case object AliasAnalysis extends IRPass { override type Metadata = alias.AliasMetadata override type Config = Configuration - override lazy val precursorPasses: Seq[IRPass] = List( + override lazy val precursorPasses: Seq[IRProcessingPass] = List( FunctionBinding, GenerateMethodBodies, - SectionsToBinOp, + SectionsToBinOp.INSTANCE, OperatorToFunction, LambdaShorthandToLambda ) - override lazy val invalidatedPasses: Seq[IRPass] = + override lazy val invalidatedPasses: Seq[IRProcessingPass] = List(DataflowAnalysis, UnusedBindings) /** Performs alias analysis on a module. diff --git a/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/analyse/CachePreferenceAnalysis.scala b/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/analyse/CachePreferenceAnalysis.scala index 2f5cf2a77196..b822e824b514 100644 --- a/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/analyse/CachePreferenceAnalysis.scala +++ b/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/analyse/CachePreferenceAnalysis.scala @@ -17,6 +17,7 @@ import org.enso.compiler.core.ir.module.scope.definition import org.enso.compiler.core.ir.expression.{Application, Comment, Error} import org.enso.compiler.core.ir.MetadataStorage._ import org.enso.compiler.pass.IRPass +import org.enso.compiler.pass.IRProcessingPass import org.enso.compiler.pass.desugar._ import java.util @@ -38,16 +39,16 @@ case object CachePreferenceAnalysis extends IRPass { override type Metadata = WeightInfo /** Run desugaring passes first. */ - override lazy val precursorPasses: Seq[IRPass] = List( + override lazy val precursorPasses: Seq[IRProcessingPass] = List( ComplexType, FunctionBinding, GenerateMethodBodies, LambdaShorthandToLambda, OperatorToFunction, - SectionsToBinOp + SectionsToBinOp.INSTANCE ) - override lazy val invalidatedPasses: Seq[IRPass] = List() + override lazy val invalidatedPasses: Seq[IRProcessingPass] = List() /** Performs the cache preference analysis on a module. * diff --git a/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/analyse/DataflowAnalysis.scala b/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/analyse/DataflowAnalysis.scala index f58851df86f8..2147865f8c94 100644 --- a/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/analyse/DataflowAnalysis.scala +++ b/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/analyse/DataflowAnalysis.scala @@ -29,6 +29,7 @@ import org.enso.compiler.core.ir.{ } import org.enso.compiler.core.ir.MetadataStorage._ import org.enso.compiler.pass.IRPass +import org.enso.compiler.pass.IRProcessingPass import org.enso.compiler.pass.analyse.DataflowAnalysis.DependencyInfo.Type.asStatic import org.enso.compiler.pass.analyse.alias.graph.GraphOccurrence @@ -53,13 +54,13 @@ case object DataflowAnalysis extends IRPass { override type Metadata = DependencyInfo override type Config = IRPass.Configuration.Default - override lazy val precursorPasses: Seq[IRPass] = List( + override lazy val precursorPasses: Seq[IRProcessingPass] = List( AliasAnalysis, DemandAnalysis, - TailCall + TailCall.INSTANCE ) - override lazy val invalidatedPasses: Seq[IRPass] = List() + override lazy val invalidatedPasses: Seq[IRProcessingPass] = List() /** Executes the dataflow analysis process on an Enso module. * diff --git a/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/desugar/ComplexType.scala b/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/desugar/ComplexType.scala index b48321718c53..3869b49851b0 100644 --- a/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/desugar/ComplexType.scala +++ b/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/desugar/ComplexType.scala @@ -19,6 +19,7 @@ import org.enso.compiler.core.ir.module.scope.definition import org.enso.compiler.core.ir.expression.Error import org.enso.compiler.core.CompilerError import org.enso.compiler.pass.IRPass +import org.enso.compiler.pass.IRProcessingPass import org.enso.compiler.pass.analyse.{ AliasAnalysis, DataflowAnalysis, @@ -51,8 +52,10 @@ case object ComplexType extends IRPass { override type Metadata = IRPass.Metadata.Empty override type Config = IRPass.Configuration.Default - override lazy val precursorPasses: Seq[IRPass] = List(ModuleAnnotations) - override lazy val invalidatedPasses: Seq[IRPass] = + override lazy val precursorPasses: Seq[IRProcessingPass] = List( + ModuleAnnotations + ) + override lazy val invalidatedPasses: Seq[IRProcessingPass] = List( AliasAnalysis, DataflowAnalysis, @@ -64,8 +67,8 @@ case object ComplexType extends IRPass { LambdaShorthandToLambda, NestedPatternMatch, OperatorToFunction, - SectionsToBinOp, - TailCall, + SectionsToBinOp.INSTANCE, + TailCall.INSTANCE, UnusedBindings ) diff --git a/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/desugar/FunctionBinding.scala b/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/desugar/FunctionBinding.scala index 174963fa0ca7..406d12b10306 100644 --- a/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/desugar/FunctionBinding.scala +++ b/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/desugar/FunctionBinding.scala @@ -17,6 +17,7 @@ import org.enso.compiler.core.ir.{ import org.enso.compiler.core.ir.MetadataStorage.MetadataPair import org.enso.compiler.core.CompilerError import org.enso.compiler.pass.IRPass +import org.enso.compiler.pass.IRProcessingPass import org.enso.compiler.pass.analyse.{ AliasAnalysis, DataflowAnalysis, @@ -42,8 +43,8 @@ case object FunctionBinding extends IRPass { override type Metadata = IRPass.Metadata.Empty override type Config = IRPass.Configuration.Default - override lazy val precursorPasses: Seq[IRPass] = List(ComplexType) - override lazy val invalidatedPasses: Seq[IRPass] = List( + override lazy val precursorPasses: Seq[IRProcessingPass] = List(ComplexType) + override lazy val invalidatedPasses: Seq[IRProcessingPass] = List( AliasAnalysis, DataflowAnalysis, DemandAnalysis, @@ -53,8 +54,8 @@ case object FunctionBinding extends IRPass { LambdaShorthandToLambda, NestedPatternMatch, OperatorToFunction, - SectionsToBinOp, - TailCall + SectionsToBinOp.INSTANCE, + TailCall.INSTANCE ) /** The name of the conversion method, as a reserved name for methods. */ diff --git a/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/desugar/GenerateMethodBodies.scala b/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/desugar/GenerateMethodBodies.scala index c1f39e1ea116..b949bd457cdc 100644 --- a/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/desugar/GenerateMethodBodies.scala +++ b/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/desugar/GenerateMethodBodies.scala @@ -15,6 +15,7 @@ import org.enso.compiler.core.ir.expression.errors import org.enso.compiler.core.CompilerError import org.enso.compiler.core.ir.expression.Foreign import org.enso.compiler.pass.IRPass +import org.enso.compiler.pass.IRProcessingPass import org.enso.compiler.pass.analyse.{ AliasAnalysis, DataflowAnalysis, @@ -48,14 +49,14 @@ case object GenerateMethodBodies extends IRPass { override type Metadata = IRPass.Metadata.Empty override type Config = IRPass.Configuration.Default - override lazy val precursorPasses: Seq[IRPass] = + override lazy val precursorPasses: Seq[IRProcessingPass] = List(ComplexType, FunctionBinding) - override lazy val invalidatedPasses: Seq[IRPass] = List( + override lazy val invalidatedPasses: Seq[IRProcessingPass] = List( AliasAnalysis, DataflowAnalysis, LambdaConsolidate, NestedPatternMatch, - TailCall, + TailCall.INSTANCE, UnusedBindings ) diff --git a/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/desugar/LambdaShorthandToLambda.scala b/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/desugar/LambdaShorthandToLambda.scala index 8d4f3416e4f3..aec05bcd2c18 100644 --- a/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/desugar/LambdaShorthandToLambda.scala +++ b/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/desugar/LambdaShorthandToLambda.scala @@ -1,18 +1,7 @@ package org.enso.compiler.pass.desugar -import org.enso.compiler.context.{FreshNameSupply, InlineContext, ModuleContext} -import org.enso.compiler.core.ir.{ - CallArgument, - DefinitionArgument, - Expression, - Function, - IdentifiedLocation, - Module, - Name -} +import org.enso.compiler.context.{InlineContext, ModuleContext} import org.enso.compiler.core.CompilerError -import org.enso.compiler.core.ir.expression.{Application, Case, Operator} -import org.enso.compiler.pass.IRPass import org.enso.compiler.pass.analyse.{ AliasAnalysis, DataflowAnalysis, @@ -26,412 +15,51 @@ import org.enso.compiler.pass.resolve.{ IgnoredBindings, OverloadsResolution } +import org.enso.compiler.pass.{IRProcessingPass, MiniPassFactory} -/** This pass translates `_` arguments at application sites to lambda functions. - * - * This pass has no configuration. - * - * This pass requires the context to provide: - * - * - A [[FreshNameSupply]] +/** Implementation moved to `LambdaShorthandToLambdaMegaPass` test. */ -case object LambdaShorthandToLambda extends IRPass { - override type Metadata = IRPass.Metadata.Empty - override type Config = IRPass.Configuration.Default - - override lazy val precursorPasses: Seq[IRPass] = List( +case object LambdaShorthandToLambda extends MiniPassFactory { + override lazy val precursorPasses: Seq[IRProcessingPass] = List( ComplexType, DocumentationComments, FunctionBinding, GenerateMethodBodies, OperatorToFunction, - SectionsToBinOp + SectionsToBinOp.INSTANCE ) - override lazy val invalidatedPasses: Seq[IRPass] = List( + override lazy val invalidatedPasses: Seq[IRProcessingPass] = List( AliasAnalysis, DataflowAnalysis, DemandAnalysis, IgnoredBindings, LambdaConsolidate, OverloadsResolution, - TailCall, + TailCall.INSTANCE, UnusedBindings ) - /** Desugars underscore arguments to lambdas for a module. - * - * @param ir the Enso IR to process - * @param moduleContext a context object that contains the information needed - * to process a module - * @return `ir`, possibly having made transformations or annotations to that - * IR. - */ - override def runModule( - ir: Module, + override def createForModuleCompilation( moduleContext: ModuleContext - ): Module = { - val new_bindings = ir.bindings.map { case a => - a.mapExpressions( - runExpression( - _, - InlineContext( - moduleContext, - freshNameSupply = moduleContext.freshNameSupply, - compilerConfig = moduleContext.compilerConfig - ) - ) + ): LambdaShorthandToLambdaMini = { + val freshNameSupply = moduleContext.freshNameSupply.getOrElse( + throw new CompilerError( + "Desugaring underscore arguments to lambdas requires a fresh name " + + "supply." ) - } - ir.copy(bindings = new_bindings) + ) + new LambdaShorthandToLambdaMini(freshNameSupply) } - /** Desugars underscore arguments to lambdas for an arbitrary expression. - * - * @param ir the Enso IR to process - * @param inlineContext a context object that contains the information needed - * for inline evaluation - * @return `ir`, possibly having made transformations or annotations to that - * IR. - */ - override def runExpression( - ir: Expression, + override def createForInlineCompilation( inlineContext: InlineContext - ): Expression = { + ): LambdaShorthandToLambdaMini = { val freshNameSupply = inlineContext.freshNameSupply.getOrElse( throw new CompilerError( "Desugaring underscore arguments to lambdas requires a fresh name " + "supply." ) ) - - desugarExpression(ir, freshNameSupply) - } - - // === Pass Internals ======================================================= - - /** Performs lambda shorthand desugaring on an arbitrary expression. - * - * @param ir the expression to desugar - * @param freshNameSupply the compiler's fresh name supply - * @return `ir`, with any lambda shorthand arguments desugared - */ - def desugarExpression( - ir: Expression, - freshNameSupply: FreshNameSupply - ): Expression = { - ir.transformExpressions { - case app: Application => desugarApplication(app, freshNameSupply) - case caseExpr: Case.Expr => desugarCaseExpr(caseExpr, freshNameSupply) - case name: Name => desugarName(name, freshNameSupply) - } - } - - /** Desugars an arbitrary name occurrence, turning isolated occurrences of - * `_` into the `id` function. - * - * @param name the name to desugar - * @param supply the compiler's fresh name supply - * @return `name`, desugared where necessary - */ - private def desugarName(name: Name, supply: FreshNameSupply): Expression = { - name match { - case blank: Name.Blank => - val newName = supply.newName() - - new Function.Lambda( - List( - DefinitionArgument.Specified( - name = Name.Literal( - newName.name, - isMethod = false, - identifiedLocation = null - ), - ascribedType = None, - defaultValue = None, - suspended = false, - identifiedLocation = null - ) - ), - newName, - blank.identifiedLocation - ) - case _ => name - } - } - - /** Desugars lambda shorthand arguments to an arbitrary function application. - * - * @param application the function application to desugar - * @param freshNameSupply the compiler's supply of fresh names - * @return `application`, with any lambda shorthand arguments desugared - */ - private def desugarApplication( - application: Application, - freshNameSupply: FreshNameSupply - ): Expression = { - application match { - case p @ Application.Prefix(fn, args, _, _, _) => - // Determine which arguments are lambda shorthand - val argIsUnderscore = determineLambdaShorthand(args) - - // Generate a new name for the arg value for each shorthand arg - val updatedArgs = - args - .zip(argIsUnderscore) - .map(updateShorthandArg(_, freshNameSupply)) - .map { case s @ CallArgument.Specified(_, value, _, _) => - s.copy(value = desugarExpression(value, freshNameSupply)) - } - - // Generate a definition arg instance for each shorthand arg - val defArgs = updatedArgs.zip(argIsUnderscore).map { - case (arg, isShorthand) => generateDefinitionArg(arg, isShorthand) - } - val actualDefArgs = defArgs.collect { case Some(defArg) => - defArg - } - - // Determine whether or not the function itself is shorthand - val functionIsShorthand = fn.isInstanceOf[Name.Blank] - val (updatedFn, updatedName) = if (functionIsShorthand) { - val newFn = freshNameSupply - .newName() - .copy( - location = fn.location, - passData = fn.passData, - diagnostics = fn.diagnostics - ) - val newName = newFn.name - (newFn, Some(newName)) - } else { - val newFn = desugarExpression(fn, freshNameSupply) - (newFn, None) - } - - val processedApp = p.copy( - function = updatedFn, - arguments = updatedArgs - ) - - // Wrap the app in lambdas from right to left, 1 lambda per shorthand - // arg - val appResult = - actualDefArgs.foldRight(processedApp: Expression)((arg, body) => - new Function.Lambda(List(arg), body, identifiedLocation = null) - ) - - // If the function is shorthand, do the same - val resultExpr = if (functionIsShorthand) { - new Function.Lambda( - List( - DefinitionArgument.Specified( - Name.Literal( - updatedName.get, - isMethod = false, - fn.identifiedLocation() - ), - None, - None, - suspended = false, - null - ) - ), - appResult, - identifiedLocation = null - ) - } else appResult - - resultExpr match { - case lam: Function.Lambda => lam.copy(location = p.location) - case result => result - } - case f @ Application.Force(tgt, _, _) => - f.copy(target = desugarExpression(tgt, freshNameSupply)) - case vector @ Application.Sequence(items, _, _) => - var bindings: List[Name] = List() - val newItems = items.map { - case blank: Name.Blank => - val name = freshNameSupply - .newName() - .copy( - location = blank.location, - passData = blank.passData, - diagnostics = blank.diagnostics - ) - bindings ::= name - name - case it => desugarExpression(it, freshNameSupply) - } - val newVec = vector.copy(newItems) - val locWithoutId = - if (newVec.identifiedLocation eq null) null - else new IdentifiedLocation(newVec.identifiedLocation.location()) - bindings.foldLeft(newVec: Expression) { (body, bindingName) => - val defArg = DefinitionArgument.Specified( - bindingName, - ascribedType = None, - defaultValue = None, - suspended = false, - identifiedLocation = null - ) - new Function.Lambda(List(defArg), body, locWithoutId) - } - case tSet @ Application.Typeset(expr, _, _) => - tSet.copy(expression = expr.map(desugarExpression(_, freshNameSupply))) - case _: Operator => - throw new CompilerError( - "Operators should be desugared by the point of underscore " + - "to lambda conversion." - ) - } - } - - /** Determines, positionally, which of the application arguments are lambda - * shorthand arguments. - * - * @param args the application arguments - * @return a list containing `true` for a given position if the arg in that - * position is lambda shorthand, otherwise `false` - */ - private def determineLambdaShorthand( - args: List[CallArgument] - ): List[Boolean] = { - args.map { case CallArgument.Specified(_, value, _, _) => - value match { - case _: Name.Blank => true - case _ => false - } - } - } - - /** Generates a new name to replace a shorthand argument, as well as the - * corresponding definition argument. - * - * @param argAndIsShorthand the arguments, and whether or not the argument in - * the corresponding position is shorthand - * @return the above described pair for a given position if the argument in - * a given position is shorthand, otherwise [[None]]. - */ - private def updateShorthandArg( - argAndIsShorthand: (CallArgument, Boolean), - freshNameSupply: FreshNameSupply - ): CallArgument = { - val arg = argAndIsShorthand._1 - val isShorthand = argAndIsShorthand._2 - - arg match { - case s @ CallArgument.Specified(_, value, _, _) => - if (isShorthand) { - val newName = freshNameSupply - .newName() - .copy( - location = value.location, - passData = value.passData, - diagnostics = value.diagnostics - ) - - s.copy(value = newName) - } else s - } - } - - /** Generates a corresponding definition argument to a call argument that was - * previously lambda shorthand. - * - * @param arg the argument to generate a corresponding def argument to - * @param isShorthand whether or not `arg` was shorthand - * @return a corresponding definition argument if `arg` `isShorthand`, - * otherwise [[None]] - */ - private def generateDefinitionArg( - arg: CallArgument, - isShorthand: Boolean - ): Option[DefinitionArgument] = { - if (isShorthand) { - arg match { - case specified @ CallArgument.Specified(_, value, _, passData) => - // Note [Safe Casting to Name.Literal] - val defArgName = - Name.Literal( - value.asInstanceOf[Name.Literal].name, - isMethod = false, - identifiedLocation = null - ) - - Some( - new DefinitionArgument.Specified( - defArgName, - None, - None, - suspended = false, - null, - passData.duplicate, - specified.diagnosticsCopy - ) - ) - } - } else None - } - - /* Note [Safe Casting to Name.Literal] - * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - * This cast is entirely safe here as, by construction in - * `updateShorthandArg`, any arg for which `isShorthand` is true has its - * value as an `Name.Literal`. - */ - - /** Performs desugaring of lambda shorthand arguments in a case expression. - * - * In the case where a user writes `case _ of`, this gets converted into a - * lambda (`x -> case x of`). - * - * @param caseExpr the case expression to desugar - * @param freshNameSupply the compiler's supply of fresh names - * @return `caseExpr`, with any lambda shorthand desugared - */ - private def desugarCaseExpr( - caseExpr: Case.Expr, - freshNameSupply: FreshNameSupply - ): Expression = { - val newBranches = caseExpr.branches.map( - _.mapExpressions(expr => desugarExpression(expr, freshNameSupply)) - ) - - caseExpr.scrutinee match { - case nameBlank: Name.Blank => - val scrutineeName = - freshNameSupply - .newName() - .copy( - location = nameBlank.location, - passData = nameBlank.passData, - diagnostics = nameBlank.diagnostics - ) - - val lambdaArg = DefinitionArgument.Specified( - scrutineeName.copy(id = null), - None, - None, - suspended = false, - null - ) - - val newCaseExpr = caseExpr.copy( - scrutinee = scrutineeName, - branches = newBranches - ) - - new Function.Lambda( - caseExpr, - List(lambdaArg), - newCaseExpr, - caseExpr.identifiedLocation - ) - case x => - caseExpr.copy( - scrutinee = desugarExpression(x, freshNameSupply), - branches = newBranches - ) - } + new LambdaShorthandToLambdaMini(freshNameSupply) } } diff --git a/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/desugar/LambdaShorthandToLambdaMini.scala b/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/desugar/LambdaShorthandToLambdaMini.scala new file mode 100644 index 000000000000..9243a9988424 --- /dev/null +++ b/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/desugar/LambdaShorthandToLambdaMini.scala @@ -0,0 +1,352 @@ +package org.enso.compiler.pass.desugar + +import org.enso.compiler.context.FreshNameSupply +import org.enso.compiler.core.CompilerError +import org.enso.compiler.core.IR +import org.enso.compiler.core.ir.expression.{Application, Case, Operator} +import org.enso.compiler.core.ir.{ + CallArgument, + DefinitionArgument, + Expression, + Function, + IdentifiedLocation, + Name +} +import org.enso.compiler.pass.MiniIRPass + +class LambdaShorthandToLambdaMini( + protected val freshNameSupply: FreshNameSupply, + private val shouldSkipBlanks: Boolean = false +) extends MiniIRPass { + + override def prepare( + parent: IR, + current: Expression + ): LambdaShorthandToLambdaMini = { + if (shouldSkipBlanks(parent)) { + new LambdaShorthandToLambdaMini(freshNameSupply, true) + } else { + this + } + } + + private def shouldSkipBlanks(parent: IR): Boolean = { + parent match { + case Application.Prefix(fn, args, _, _, _) => + val hasBlankArg = args.exists { + case CallArgument.Specified(_, _: Name.Blank, _, _) => true + case _ => false + } + val hasBlankFn = fn.isInstanceOf[Name.Blank] + hasBlankArg || hasBlankFn + case Application.Sequence(items, _, _) => + val hasBlankItem = items.exists { + case _: Name.Blank => true + case _ => false + } + hasBlankItem + case Case.Expr(_: Name.Blank, _, _, _, _) => + true + case _ => false + } + } + + override def transformExpression(ir: Expression): Expression = { + val newIr = ir match { + case name: Name => desugarName(name) + case app: Application => desugarApplication(app) + case caseExpr: Case.Expr => desugarCaseExpr(caseExpr) + case _ => ir + } + newIr + } + + /** Desugars an arbitrary name occurrence, turning isolated occurrences of + * `_` into the `id` function. + * + * @param name the name to desugar + * @return `name`, desugared where necessary + */ + private def desugarName(name: Name): Expression = { + name match { + case blank: Name.Blank if !shouldSkipBlanks => + val newName = freshNameSupply.newName() + + new Function.Lambda( + List( + DefinitionArgument.Specified( + name = Name.Literal( + newName.name, + isMethod = false, + null + ), + ascribedType = None, + defaultValue = None, + suspended = false, + identifiedLocation = null + ) + ), + newName, + blank.location.orNull + ) + case _ => name + } + } + + /** Desugars lambda shorthand arguments to an arbitrary function application. + * + * @param application the function application to desugar + * @return `application`, with any lambda shorthand arguments desugared + */ + private def desugarApplication( + application: Application + ): Expression = { + application match { + case p @ Application.Prefix(fn, args, _, _, _) => + // Determine which arguments are lambda shorthand + val argIsUnderscore = determineLambdaShorthand(args) + + // Generate a new name for the arg value for each shorthand arg + val updatedArgs = + args + .zip(argIsUnderscore) + .map(updateShorthandArg) + + // Generate a definition arg instance for each shorthand arg + val defArgs = updatedArgs.zip(argIsUnderscore).map { + case (arg, isShorthand) => generateDefinitionArg(arg, isShorthand) + } + val actualDefArgs = defArgs.collect { case Some(defArg) => + defArg + } + + // Determine whether or not the function itself is shorthand + val functionIsShorthand = fn.isInstanceOf[Name.Blank] + val (updatedFn, updatedName) = if (functionIsShorthand) { + val newFn = freshNameSupply + .newName() + .copy( + location = fn.location, + passData = fn.passData, + diagnostics = fn.diagnostics + ) + val newName = newFn.name + (newFn, Some(newName)) + } else { + (fn, None) + } + + val processedApp = p.copy( + function = updatedFn, + arguments = updatedArgs + ) + + // Wrap the app in lambdas from right to left, 1 lambda per shorthand + // arg + val appResult = + actualDefArgs.foldRight(processedApp: Expression)((arg, body) => + new Function.Lambda(List(arg), body, null) + ) + + // If the function is shorthand, do the same + val resultExpr = if (functionIsShorthand) { + new Function.Lambda( + List( + DefinitionArgument.Specified( + Name + .Literal( + updatedName.get, + isMethod = false, + fn.location.orNull + ), + None, + None, + suspended = false, + null + ) + ), + appResult, + null + ) + } else appResult + + resultExpr match { + case lam: Function.Lambda => lam.copy(location = p.location) + case result => result + } + + case vector @ Application.Sequence(items, _, _) => + var bindings: List[Name] = List() + val newItems = items.map { + case blank: Name.Blank => + val name = freshNameSupply + .newName() + .copy( + location = blank.location, + passData = blank.passData, + diagnostics = blank.diagnostics + ) + bindings ::= name + name + case it => it + } + val newVec = vector.copy(newItems) + val locWithoutId = + newVec.location.map(l => new IdentifiedLocation(l.location())) + bindings.foldLeft(newVec: Expression) { (body, bindingName) => + val defArg = DefinitionArgument.Specified( + bindingName, + ascribedType = None, + defaultValue = None, + suspended = false, + identifiedLocation = null + ) + new Function.Lambda(List(defArg), body, locWithoutId.orNull) + } + + case _: Operator => + throw new CompilerError( + "Operators should be desugared by the point of underscore " + + "to lambda conversion." + ) + } + } + + /** Determines, positionally, which of the application arguments are lambda + * shorthand arguments. + * + * @param args the application arguments + * @return a list containing `true` for a given position if the arg in that + * position is lambda shorthand, otherwise `false` + */ + private def determineLambdaShorthand( + args: List[CallArgument] + ): List[Boolean] = { + args.map { case CallArgument.Specified(_, value, _, _) => + value match { + case _: Name.Blank => true + case _ => false + } + } + } + + /** Generates a new name to replace a shorthand argument, as well as the + * corresponding definition argument. + * + * @param argAndIsShorthand the arguments, and whether or not the argument in + * the corresponding position is shorthand + * @return the above described pair for a given position if the argument in + * a given position is shorthand, otherwise [[None]]. + */ + private def updateShorthandArg( + argAndIsShorthand: (CallArgument, Boolean) + ): CallArgument = { + val arg = argAndIsShorthand._1 + val isShorthand = argAndIsShorthand._2 + + arg match { + case s @ CallArgument.Specified(_, value, _, _) => + if (isShorthand) { + val newName = freshNameSupply + .newName() + .copy( + location = value.location, + passData = value.passData, + diagnostics = value.diagnostics + ) + + s.copy(value = newName) + } else s + } + } + + /** Generates a corresponding definition argument to a call argument that was + * previously lambda shorthand. + * + * @param arg the argument to generate a corresponding def argument to + * @param isShorthand whether or not `arg` was shorthand + * @return a corresponding definition argument if `arg` `isShorthand`, + * otherwise [[None]] + */ + private def generateDefinitionArg( + arg: CallArgument, + isShorthand: Boolean + ): Option[DefinitionArgument] = { + if (isShorthand) { + arg match { + case specified @ CallArgument.Specified(_, value, _, passData) => + // Note [Safe Casting to Name.Literal] + val defArgName = + Name.Literal( + value.asInstanceOf[Name.Literal].name, + isMethod = false, + null + ) + + Some( + new DefinitionArgument.Specified( + defArgName, + None, + None, + suspended = false, + null, + passData.duplicate, + specified.diagnosticsCopy + ) + ) + } + } else None + } + + /* Note [Safe Casting to Name.Literal] + * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + * This cast is entirely safe here as, by construction in + * `updateShorthandArg`, any arg for which `isShorthand` is true has its + * value as an `Name.Literal`. + */ + + /** Performs desugaring of lambda shorthand arguments in a case expression. + * + * In the case where a user writes `case _ of`, this gets converted into a + * lambda (`x -> case x of`). + * + * @param caseExpr the case expression to desugar + * @return `caseExpr`, with any lambda shorthand desugared + */ + private def desugarCaseExpr( + caseExpr: Case.Expr + ): Expression = { + caseExpr.scrutinee match { + case nameBlank: Name.Blank => + val scrutineeName = + freshNameSupply + .newName() + .copy( + location = nameBlank.location, + passData = nameBlank.passData, + diagnostics = nameBlank.diagnostics + ) + + val lambdaArg = DefinitionArgument.Specified( + scrutineeName.copy(id = null), + None, + None, + suspended = false, + null + ) + + val newCaseExpr = caseExpr.copy( + scrutinee = scrutineeName + ) + + new Function.Lambda( + caseExpr, + List(lambdaArg), + newCaseExpr, + caseExpr.location.orNull + ) + + case _ => caseExpr + } + } +} diff --git a/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/desugar/NestedPatternMatch.scala b/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/desugar/NestedPatternMatch.scala index ef83b112f67e..3ea59736b386 100644 --- a/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/desugar/NestedPatternMatch.scala +++ b/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/desugar/NestedPatternMatch.scala @@ -11,6 +11,7 @@ import org.enso.compiler.core.ir.{ import org.enso.compiler.core.ir.expression.{errors, Case} import org.enso.compiler.core.CompilerError import org.enso.compiler.pass.IRPass +import org.enso.compiler.pass.IRProcessingPass import org.enso.compiler.pass.analyse.{ AliasAnalysis, DataflowAnalysis, @@ -73,19 +74,19 @@ case object NestedPatternMatch extends IRPass { override type Metadata = IRPass.Metadata.Empty override type Config = IRPass.Configuration.Default - override lazy val precursorPasses: Seq[IRPass] = List( + override lazy val precursorPasses: Seq[IRProcessingPass] = List( ComplexType, DocumentationComments, FunctionBinding, GenerateMethodBodies, LambdaShorthandToLambda ) - override lazy val invalidatedPasses: Seq[IRPass] = List( + override lazy val invalidatedPasses: Seq[IRProcessingPass] = List( AliasAnalysis, DataflowAnalysis, DemandAnalysis, IgnoredBindings, - TailCall + TailCall.INSTANCE ) /** Desugars nested pattern matches in a module. diff --git a/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/desugar/OperatorToFunction.scala b/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/desugar/OperatorToFunction.scala index c7d868c565a3..91f078e14026 100644 --- a/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/desugar/OperatorToFunction.scala +++ b/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/desugar/OperatorToFunction.scala @@ -1,86 +1,35 @@ package org.enso.compiler.pass.desugar import org.enso.compiler.context.{InlineContext, ModuleContext} -import org.enso.compiler.core.ir.expression.{Application, Operator} -import org.enso.compiler.core.ir.{Expression, Module} -import org.enso.compiler.pass.IRPass +import org.enso.compiler.pass.{IRProcessingPass, MiniPassFactory} import org.enso.compiler.pass.analyse.{ AliasAnalysis, DataflowAnalysis, DemandAnalysis } -/** This pass converts usages of operators to calls to standard functions. - * - * This pass requires the context to provide: - * - * - Nothing +/** Implementation moved to `OperatorToFunctionTest` */ -case object OperatorToFunction extends IRPass { +case object OperatorToFunction extends MiniPassFactory { - /** A purely desugaring pass has no analysis output. */ - override type Metadata = IRPass.Metadata.Empty - override type Config = IRPass.Configuration.Default - - override lazy val precursorPasses: Seq[IRPass] = List( + override lazy val precursorPasses: Seq[IRProcessingPass] = List( GenerateMethodBodies, - SectionsToBinOp + SectionsToBinOp.INSTANCE ) - override lazy val invalidatedPasses: Seq[IRPass] = List( + override lazy val invalidatedPasses: Seq[IRProcessingPass] = List( AliasAnalysis, DataflowAnalysis, - DemandAnalysis + DemandAnalysis, + LambdaShorthandToLambda ) - /** Executes the conversion pass. - * - * @param ir the Enso IR to process - * @param moduleContext a context object that contains the information needed - * to process a module - * @return `ir`, possibly having made transformations or annotations to that - * IR. - */ - override def runModule( - ir: Module, + override def createForModuleCompilation( moduleContext: ModuleContext - ): Module = { - val new_bindings = ir.bindings.map { a => - a.mapExpressions( - runExpression( - _, - new InlineContext( - moduleContext, - compilerConfig = moduleContext.compilerConfig - ) - ) - ) - } - ir.copy(bindings = new_bindings) - } + ): OperatorToFunctionMini = + new OperatorToFunctionMini() - /** Executes the conversion pass in an inline context. - * - * @param ir the Enso IR to process - * @param inlineContext a context object that contains the information needed - * for inline evaluation - * @return `ir`, possibly having made transformations or annotations to that - * IR. - */ - override def runExpression( - ir: Expression, + override def createForInlineCompilation( inlineContext: InlineContext - ): Expression = - ir.transformExpressions { case operatorBinary: Operator.Binary => - new Application.Prefix( - operatorBinary.operator, - List( - operatorBinary.left.mapExpressions(runExpression(_, inlineContext)), - operatorBinary.right.mapExpressions(runExpression(_, inlineContext)) - ), - hasDefaultsSuspended = false, - operatorBinary.identifiedLocation, - operatorBinary.passData, - operatorBinary.diagnostics - ) - } + ): OperatorToFunctionMini = + new OperatorToFunctionMini() } diff --git a/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/lint/ShadowedPatternFields.scala b/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/lint/ShadowedPatternFields.scala index b6c9b659448b..f1c277bd431a 100644 --- a/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/lint/ShadowedPatternFields.scala +++ b/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/lint/ShadowedPatternFields.scala @@ -7,6 +7,7 @@ import org.enso.compiler.core.ir.{Expression, Module, Name, Pattern} import org.enso.compiler.core.ir.expression.{errors, warnings, Case} import org.enso.compiler.core.CompilerError import org.enso.compiler.pass.IRPass +import org.enso.compiler.pass.IRProcessingPass import org.enso.compiler.pass.analyse.{ AliasAnalysis, DataflowAnalysis, @@ -34,16 +35,16 @@ case object ShadowedPatternFields extends IRPass { override type Metadata = IRPass.Metadata.Empty override type Config = IRPass.Configuration.Default - override lazy val precursorPasses: Seq[IRPass] = List( + override lazy val precursorPasses: Seq[IRProcessingPass] = List( GenerateMethodBodies ) - override lazy val invalidatedPasses: Seq[IRPass] = List( + override lazy val invalidatedPasses: Seq[IRProcessingPass] = List( AliasAnalysis, DataflowAnalysis, DemandAnalysis, IgnoredBindings, NestedPatternMatch, - TailCall + TailCall.INSTANCE ) /** Lints for shadowed pattern fields. diff --git a/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/lint/UnusedBindings.scala b/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/lint/UnusedBindings.scala index f79c5b2e6241..21b9df88ba5b 100644 --- a/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/lint/UnusedBindings.scala +++ b/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/lint/UnusedBindings.scala @@ -15,6 +15,7 @@ import org.enso.compiler.core.ir.{ import org.enso.compiler.core.ir.expression.{errors, warnings, Case, Foreign} import org.enso.compiler.core.CompilerError import org.enso.compiler.pass.IRPass +import org.enso.compiler.pass.IRProcessingPass import org.enso.compiler.pass.analyse.AliasAnalysis import org.enso.compiler.pass.analyse.alias.{AliasMetadata => AliasInfo} import org.enso.compiler.pass.desugar._ @@ -32,7 +33,7 @@ case object UnusedBindings extends IRPass { override type Metadata = IRPass.Metadata.Empty override type Config = IRPass.Configuration.Default - override lazy val precursorPasses: Seq[IRPass] = List( + override lazy val precursorPasses: Seq[IRProcessingPass] = List( ComplexType, GenerateMethodBodies, IgnoredBindings, @@ -40,9 +41,9 @@ case object UnusedBindings extends IRPass { LambdaShorthandToLambda, NestedPatternMatch, OperatorToFunction, - SectionsToBinOp + SectionsToBinOp.INSTANCE ) - override lazy val invalidatedPasses: Seq[IRPass] = List() + override lazy val invalidatedPasses: Seq[IRProcessingPass] = List() /** Lints a module. * diff --git a/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/optimise/LambdaConsolidate.scala b/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/optimise/LambdaConsolidate.scala index 1ce2863b592f..6f2ab1e20910 100644 --- a/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/optimise/LambdaConsolidate.scala +++ b/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/optimise/LambdaConsolidate.scala @@ -15,6 +15,7 @@ import org.enso.compiler.core.ir.{ import org.enso.compiler.core.ir.expression.warnings import org.enso.compiler.core.ir.expression.errors import org.enso.compiler.pass.IRPass +import org.enso.compiler.pass.IRProcessingPass import org.enso.compiler.pass.analyse.alias.graph.{ GraphOccurrence, Graph => AliasGraph @@ -63,7 +64,7 @@ case object LambdaConsolidate extends IRPass { override type Metadata = IRPass.Metadata.Empty override type Config = IRPass.Configuration.Default - override lazy val precursorPasses: Seq[IRPass] = List( + override lazy val precursorPasses: Seq[IRProcessingPass] = List( AliasAnalysis, ComplexType, FunctionBinding, @@ -71,13 +72,13 @@ case object LambdaConsolidate extends IRPass { IgnoredBindings, LambdaShorthandToLambda, OperatorToFunction, - SectionsToBinOp + SectionsToBinOp.INSTANCE ) - override lazy val invalidatedPasses: Seq[IRPass] = List( + override lazy val invalidatedPasses: Seq[IRProcessingPass] = List( AliasAnalysis, DataflowAnalysis, DemandAnalysis, - TailCall + TailCall.INSTANCE ) /** Performs lambda consolidation on a module. diff --git a/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/optimise/UnreachableMatchBranches.scala b/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/optimise/UnreachableMatchBranches.scala index 92851f2a7180..00f9ab6734ea 100644 --- a/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/optimise/UnreachableMatchBranches.scala +++ b/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/optimise/UnreachableMatchBranches.scala @@ -11,6 +11,7 @@ import org.enso.compiler.core.ir.{ import org.enso.compiler.core.ir.expression.{errors, warnings, Case} import org.enso.compiler.core.CompilerError import org.enso.compiler.pass.IRPass +import org.enso.compiler.pass.IRProcessingPass import org.enso.compiler.pass.analyse.{ AliasAnalysis, DataflowAnalysis, @@ -46,20 +47,20 @@ case object UnreachableMatchBranches extends IRPass { override type Metadata = IRPass.Metadata.Empty override type Config = IRPass.Configuration.Default - override lazy val precursorPasses: Seq[IRPass] = List( + override lazy val precursorPasses: Seq[IRProcessingPass] = List( ComplexType, DocumentationComments, FunctionBinding, GenerateMethodBodies, LambdaShorthandToLambda ) - override lazy val invalidatedPasses: Seq[IRPass] = List( + override lazy val invalidatedPasses: Seq[IRProcessingPass] = List( AliasAnalysis, DataflowAnalysis, DemandAnalysis, IgnoredBindings, NestedPatternMatch, - TailCall + TailCall.INSTANCE ) /** Runs unreachable branch optimisation on a module. diff --git a/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/resolve/IgnoredBindings.scala b/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/resolve/IgnoredBindings.scala index 9509da9578d1..32a4277e794e 100644 --- a/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/resolve/IgnoredBindings.scala +++ b/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/resolve/IgnoredBindings.scala @@ -15,6 +15,7 @@ import org.enso.compiler.core.ir.MetadataStorage._ import org.enso.compiler.core.ir.expression.{errors, Case} import org.enso.compiler.core.CompilerError import org.enso.compiler.pass.IRPass +import org.enso.compiler.pass.IRProcessingPass import org.enso.compiler.pass.analyse.{ AliasAnalysis, DataflowAnalysis, @@ -42,17 +43,17 @@ case object IgnoredBindings extends IRPass { override type Metadata = State override type Config = IRPass.Configuration.Default - override lazy val precursorPasses: Seq[IRPass] = List( + override lazy val precursorPasses: Seq[IRProcessingPass] = List( ComplexType, GenerateMethodBodies, LambdaShorthandToLambda, NestedPatternMatch ) - override lazy val invalidatedPasses: Seq[IRPass] = List( + override lazy val invalidatedPasses: Seq[IRProcessingPass] = List( AliasAnalysis, DataflowAnalysis, DemandAnalysis, - TailCall + TailCall.INSTANCE ) /** Desugars ignored bindings for a module. diff --git a/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/resolve/SuspendedArguments.scala b/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/resolve/SuspendedArguments.scala index 449471851718..26629b1c9478 100644 --- a/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/resolve/SuspendedArguments.scala +++ b/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/resolve/SuspendedArguments.scala @@ -16,6 +16,7 @@ import org.enso.compiler.core.ir.module.scope.Definition import org.enso.compiler.core.ir.module.scope.definition import org.enso.compiler.core.CompilerError import org.enso.compiler.pass.IRPass +import org.enso.compiler.pass.IRProcessingPass import org.enso.compiler.pass.analyse._ import org.enso.compiler.pass.desugar.ComplexType import org.enso.compiler.pass.lint.UnusedBindings @@ -51,18 +52,18 @@ case object SuspendedArguments extends IRPass { override type Metadata = IRPass.Metadata.Empty override type Config = IRPass.Configuration.Default - override lazy val precursorPasses: Seq[IRPass] = List( + override lazy val precursorPasses: Seq[IRProcessingPass] = List( ComplexType, TypeSignatures, LambdaConsolidate ) - override lazy val invalidatedPasses: Seq[IRPass] = List( + override lazy val invalidatedPasses: Seq[IRProcessingPass] = List( AliasAnalysis, CachePreferenceAnalysis, DataflowAnalysis, DataflowAnalysis, DemandAnalysis, - TailCall, + TailCall.INSTANCE, UnusedBindings ) diff --git a/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/resolve/TypeFunctions.scala b/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/resolve/TypeFunctions.scala index a1ba77a68799..1d1ed8837636 100644 --- a/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/resolve/TypeFunctions.scala +++ b/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/resolve/TypeFunctions.scala @@ -17,6 +17,7 @@ import org.enso.compiler.core.ir.expression.Error import org.enso.compiler.core.CompilerError import org.enso.compiler.core.ir.expression.{Application, Operator} import org.enso.compiler.pass.IRPass +import org.enso.compiler.pass.IRProcessingPass import org.enso.compiler.pass.analyse._ import org.enso.compiler.pass.desugar.{ LambdaShorthandToLambda, @@ -38,19 +39,19 @@ case object TypeFunctions extends IRPass { override type Metadata = IRPass.Metadata.Empty override type Config = IRPass.Configuration.Default - override lazy val precursorPasses: Seq[IRPass] = List( + override lazy val precursorPasses: Seq[IRProcessingPass] = List( IgnoredBindings, LambdaShorthandToLambda, OperatorToFunction, - SectionsToBinOp + SectionsToBinOp.INSTANCE ) - override lazy val invalidatedPasses: Seq[IRPass] = List( + override lazy val invalidatedPasses: Seq[IRProcessingPass] = List( AliasAnalysis, CachePreferenceAnalysis, DataflowAnalysis, DemandAnalysis, - TailCall, + org.enso.compiler.pass.analyse.TailCall.INSTANCE, UnusedBindings ) diff --git a/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/resolve/TypeSignatures.scala b/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/resolve/TypeSignatures.scala index a3fe0293c34b..0d608c1a329c 100644 --- a/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/resolve/TypeSignatures.scala +++ b/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/resolve/TypeSignatures.scala @@ -16,6 +16,7 @@ import org.enso.compiler.core.ir.MetadataStorage._ import org.enso.compiler.core.ir.expression.{errors, Comment, Error} import org.enso.compiler.core.{CompilerError, IR} import org.enso.compiler.pass.IRPass +import org.enso.compiler.pass.IRProcessingPass import org.enso.compiler.pass.analyse._ import org.enso.compiler.pass.lint.UnusedBindings @@ -36,16 +37,16 @@ case object TypeSignatures extends IRPass { override type Metadata = Signature override type Config = IRPass.Configuration.Default - override lazy val precursorPasses: Seq[IRPass] = List( + override lazy val precursorPasses: Seq[IRProcessingPass] = List( TypeFunctions, ModuleAnnotations ) - override lazy val invalidatedPasses: Seq[IRPass] = List( + override lazy val invalidatedPasses: Seq[IRProcessingPass] = List( AliasAnalysis, CachePreferenceAnalysis, DataflowAnalysis, DemandAnalysis, - TailCall, + org.enso.compiler.pass.analyse.TailCall.INSTANCE, UnusedBindings ) diff --git a/engine/runtime-integration-tests/src/test/java/org/enso/compiler/test/CompilerTests.java b/engine/runtime-integration-tests/src/test/java/org/enso/compiler/test/CompilerTests.java index 3a3affc082ba..70e25be4190e 100644 --- a/engine/runtime-integration-tests/src/test/java/org/enso/compiler/test/CompilerTests.java +++ b/engine/runtime-integration-tests/src/test/java/org/enso/compiler/test/CompilerTests.java @@ -1,7 +1,7 @@ package org.enso.compiler.test; -import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.fail; import java.io.File; import java.io.IOException; @@ -19,15 +19,20 @@ protected static Module parse(CharSequence code) { return ir; } - public static void assertIR(String msg, Module old, Module now) throws IOException { + public static void assertIR(String msg, IR old, IR now) throws IOException { + assertEqualsIR(msg, null, old, now); + } + + public static void assertEqualsIR(String msg, String testName, IR old, IR now) + throws IOException { Function filter = f -> simplifyIR(f, true, true, false); String ir1 = filter.apply(old); String ir2 = filter.apply(now); if (!ir1.equals(ir2)) { - String name = findTestMethodName(); var home = new File(System.getProperty("java.io.tmpdir")).toPath(); - var file1 = home.resolve(name + ".1"); - var file2 = home.resolve(name + ".2"); + var fname = testName == null ? findTestMethodName() : testName; + var file1 = home.resolve(fname + ".1"); + var file2 = home.resolve(fname + ".2"); Files.writeString( file1, ir1, @@ -40,17 +45,19 @@ public static void assertIR(String msg, Module old, Module now) throws IOExcepti StandardOpenOption.TRUNCATE_EXISTING, StandardOpenOption.CREATE, StandardOpenOption.WRITE); - assertEquals(msg, file1, file2); + fail("IRs contained in files " + file1 + " and " + file2 + " should equal: " + msg); } } + private static int counter = 'A'; + private static String findTestMethodName() { for (var e : new Exception().getStackTrace()) { if (e.getMethodName().startsWith("test")) { return e.getMethodName(); } } - throw new IllegalStateException(); + return "test" + (char) counter++; } /** diff --git a/engine/runtime-integration-tests/src/test/java/org/enso/compiler/test/mini/passes/MiniPassTester.java b/engine/runtime-integration-tests/src/test/java/org/enso/compiler/test/mini/passes/MiniPassTester.java new file mode 100644 index 000000000000..85e774b634d6 --- /dev/null +++ b/engine/runtime-integration-tests/src/test/java/org/enso/compiler/test/mini/passes/MiniPassTester.java @@ -0,0 +1,63 @@ +package org.enso.compiler.test.mini.passes; + +import java.io.IOException; +import org.enso.compiler.PackageRepository; +import org.enso.compiler.context.FreshNameSupply; +import org.enso.compiler.context.ModuleContext; +import org.enso.compiler.core.EnsoParser; +import org.enso.compiler.core.ir.Module; +import org.enso.compiler.data.CompilerConfig; +import org.enso.compiler.pass.IRPass; +import org.enso.compiler.pass.MiniIRPass; +import org.enso.compiler.pass.MiniPassFactory; +import org.enso.compiler.pass.PassConfiguration; +import org.enso.compiler.test.CompilerTests; +import org.enso.pkg.QualifiedName; +import scala.Option; + +/** + * A tester class that asserts that a {@link MiniIRPass} has the same result as its corresponding + * {@link IRPass} in a module compilation. + */ +public abstract class MiniPassTester { + protected void compareModuleCompilation( + Module ir, ModuleContext moduleContext, IRPass megaPass, MiniPassFactory miniPassFactory) { + var miniPass = miniPassFactory.createForModuleCompilation(moduleContext); + if (miniPass == null) { + throw new IllegalArgumentException("Mini pass does not support module compilation"); + } + var megaPassResult = megaPass.runModule(ir, moduleContext); + var miniPassResult = MiniIRPass.compile(Module.class, ir, miniPass); + try { + CompilerTests.assertIR( + "Mini pass and mega pass results are equal", megaPassResult, miniPassResult); + } catch (IOException e) { + throw new AssertionError(e); + } + } + + protected Module parse(String code) { + var modIr = EnsoParser.compile(code); + return modIr; + } + + protected ModuleContext buildModuleContext(QualifiedName moduleName) { + var compilerConf = defaultCompilerConfig(); + Option freshNameSupply = Option.empty(); + Option passConfig = Option.empty(); + Option pkgRepo = Option.empty(); + var isGeneratingDocs = false; + var runtimeMod = org.enso.interpreter.runtime.Module.empty(moduleName, null); + return ModuleContext.apply( + runtimeMod.asCompilerModule(), + compilerConf, + freshNameSupply, + passConfig, + isGeneratingDocs, + pkgRepo); + } + + private static CompilerConfig defaultCompilerConfig() { + return CompilerConfig.apply(false, true, true, false, false, false, false, Option.empty()); + } +} diff --git a/engine/runtime-integration-tests/src/test/scala/org/enso/compiler/test/MiniPassTest.scala b/engine/runtime-integration-tests/src/test/scala/org/enso/compiler/test/MiniPassTest.scala new file mode 100644 index 000000000000..d26dcd80ca9a --- /dev/null +++ b/engine/runtime-integration-tests/src/test/scala/org/enso/compiler/test/MiniPassTest.scala @@ -0,0 +1,134 @@ +package org.enso.compiler.test + +import org.enso.compiler.context.{InlineContext, ModuleContext} +import org.enso.compiler.core.EnsoParser +import org.enso.compiler.core.ir.{Expression, Module} +import org.enso.compiler.pass.{IRPass, MiniIRPass, MiniPassFactory, PassManager} + +trait MiniPassTest extends CompilerTest { + def testName: String + + /** Configuration for mini pass + */ + def miniPassFactory: MiniPassFactory + + def megaPass: IRPass + + /** Configuration for mega pass + */ + def megaPassManager: PassManager + + /** Tests module compilation in both mega pass and mini pass. + * @param code Source code of the whole module to compile. + * @param createContext Function that creates module context. For both mega pass and minipass, + * there will be a new context created. + * @param testSpec Body of the test. Receives module compiled either by mega pass or by mini pass. + */ + def assertModuleCompilation( + code: String, + createContext: () => ModuleContext, + testSpec: Module => Unit, + compareIR: Boolean = false + ): Unit = { + val megaIr = withClue("Mega pass module compilation") { + val ctx = createContext() + processModuleWithMegaPass(code, ctx) + } + val miniIr = withClue("Mini pass module compilation") { + val ctx = createContext() + processModuleWithMiniPass(code, ctx) + } + if (compareIR) { + CompilerTests.assertIR("Should be the same", megaIr, miniIr) + } + withClue("Mega pass module spec execution") { + testSpec(megaIr) + } + withClue("Mini pass module spec execution") { + testSpec(miniIr) + } + } + + /** Tests inline compilation in both mega pass and mini pass. + * @param code Source code to compile. + * @param createContext Function that creates inline context. For both mega pass and minipass, + * there will be a new context created. + * @param testSpec Body of the test. Receives expression compiled either by mega pass or by mini pass. + */ + def assertInlineCompilation( + code: String, + createContext: () => InlineContext, + testSpec: Expression => Unit, + compareIR: Boolean = false + ): Unit = { + val megaIr = withClue("Mega pass inline compilation: ") { + val ctx = createContext() + preprocessExpressionWithMegaPass(code, ctx) + } + val miniIr = withClue("Mini pass inline compilation: ") { + val ctx = createContext() + preprocessExpressionWithMiniPass(code, ctx) + } + if (compareIR) { + CompilerTests.assertIR("Should be the same", megaIr, miniIr) + } + withClue("Mega pass inline spec execution") { + testSpec(megaIr) + } + withClue("Mini pass inline spec execution") { + testSpec(miniIr) + } + } + + private def processModuleWithMegaPass( + source: String, + moduleCtx: ModuleContext + ): Module = { + val module = parseModule(source) + val preprocessedModule = + megaPassManager.runPassesOnModule(module, moduleCtx) + megaPass.runModule(preprocessedModule, moduleCtx) + } + + private def processModuleWithMiniPass( + source: String, + moduleCtx: ModuleContext + ): Module = { + val module = parseModule(source) + val miniPass = miniPassFactory.createForModuleCompilation(moduleCtx) + val preprocessedModule = + megaPassManager.runPassesOnModule(module, moduleCtx) + MiniIRPass.compile(classOf[Module], preprocessedModule, miniPass) + } + + def preprocessExpressionWithMegaPass( + expression: String, + inlineCtx: InlineContext + ): Expression = { + val expr = parseExpression(expression) + val preprocessedExpr = + megaPassManager.runPassesInline(expr, inlineCtx) + megaPass.runExpression(preprocessedExpr, inlineCtx) + } + + def preprocessExpressionWithMiniPass( + expression: String, + inlineCtx: InlineContext + ): Expression = { + val expr = parseExpression(expression) + val miniPass = miniPassFactory.createForInlineCompilation(inlineCtx) + val preprocessedExpr = + megaPassManager.runPassesInline(expr, inlineCtx) + MiniIRPass.compile(classOf[Expression], preprocessedExpr, miniPass) + } + + private def parseModule(source: String): Module = { + EnsoParser.compile(source) + } + + private def parseExpression(source: String): Expression = { + val exprIrOpt = EnsoParser.compileInline(source) + exprIrOpt shouldBe defined + exprIrOpt.get + } +} diff --git a/engine/runtime-integration-tests/src/test/scala/org/enso/compiler/test/PassesTest.scala b/engine/runtime-integration-tests/src/test/scala/org/enso/compiler/test/PassesTest.scala index d25e2707a583..2ef5cf27e8e7 100644 --- a/engine/runtime-integration-tests/src/test/scala/org/enso/compiler/test/PassesTest.scala +++ b/engine/runtime-integration-tests/src/test/scala/org/enso/compiler/test/PassesTest.scala @@ -61,7 +61,7 @@ class PassesTest extends CompilerTest { BindingAnalysis, ModuleNameConflicts, MethodDefinitions, - SectionsToBinOp, + SectionsToBinOp.INSTANCE, OperatorToFunction, LambdaShorthandToLambda, ImportSymbolAnalysis, diff --git a/engine/runtime-integration-tests/src/test/scala/org/enso/compiler/test/pass/PassManagerTest.scala b/engine/runtime-integration-tests/src/test/scala/org/enso/compiler/test/pass/PassManagerTest.scala index e957029de4dd..ca0c2a02189c 100644 --- a/engine/runtime-integration-tests/src/test/scala/org/enso/compiler/test/pass/PassManagerTest.scala +++ b/engine/runtime-integration-tests/src/test/scala/org/enso/compiler/test/pass/PassManagerTest.scala @@ -3,7 +3,7 @@ package org.enso.compiler.test.pass import org.enso.compiler.Passes import org.enso.compiler.core.CompilerError import org.enso.compiler.pass.{ - IRPass, + IRProcessingPass, PassConfiguration, PassGroup, PassManager @@ -24,11 +24,11 @@ class PassManagerTest extends CompilerTest { // === Test Setup =========================================================== - val invalidOrdering: List[IRPass] = List( + val invalidOrdering: List[IRProcessingPass] = List( ComplexType, FunctionBinding, GenerateMethodBodies, - SectionsToBinOp, + SectionsToBinOp.INSTANCE, OperatorToFunction, LambdaShorthandToLambda, IgnoredBindings, @@ -36,7 +36,7 @@ class PassManagerTest extends CompilerTest { LambdaConsolidate, OverloadsResolution, DemandAnalysis, - TailCall, + TailCall.INSTANCE, AliasAnalysis, DataflowAnalysis, UnusedBindings diff --git a/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/analyse/TailCall.scala b/engine/runtime-integration-tests/src/test/scala/org/enso/compiler/test/pass/analyse/TailCallMegaPass.scala similarity index 90% rename from engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/analyse/TailCall.scala rename to engine/runtime-integration-tests/src/test/scala/org/enso/compiler/test/pass/analyse/TailCallMegaPass.scala index d69066b28388..8ba7958c00c2 100644 --- a/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/analyse/TailCall.scala +++ b/engine/runtime-integration-tests/src/test/scala/org/enso/compiler/test/pass/analyse/TailCallMegaPass.scala @@ -1,4 +1,4 @@ -package org.enso.compiler.pass.analyse +package org.enso.compiler.test.pass.analyse import org.enso.compiler.context.{InlineContext, ModuleContext} import org.enso.compiler.core.Implicits.{AsDiagnostics, AsMetadata} @@ -28,40 +28,45 @@ import org.enso.compiler.core.ir.{ } import org.enso.compiler.core.{CompilerError, IR} import org.enso.compiler.pass.IRPass +import org.enso.compiler.pass.IRProcessingPass +import org.enso.compiler.pass.analyse.TailCall +import org.enso.compiler.pass.analyse.TailCall.TailPosition import org.enso.compiler.pass.desugar._ import org.enso.compiler.pass.resolve.{ExpressionAnnotations, GlobalNames} -/** This pass performs tail call analysis on the Enso IR. +/** Original implementation of [[org.enso.compiler.pass.analyse.TailCall]]. + * Now server as the test verification of the new [[org.enso.compiler.pass.analyse.TailCallMini]]. + * + * This pass performs tail call analysis on the Enso IR. * * It is responsible for marking every single expression with whether it is in - * tail position. This allows the code generator to correctly create the + * tail position or not. This allows the code generator to correctly create the * Truffle nodes. - * If the expression is in tail position, [[TailPosition.Tail]] metadata is attached - * to it, otherwise, nothing is attached. * * This pass requires the context to provide: * * - The tail position of its expression, where relevant. */ -case object TailCall extends IRPass { +case object TailCallMegaPass extends IRPass { /** The annotation metadata type associated with IR nodes by this pass. */ override type Metadata = TailPosition override type Config = IRPass.Configuration.Default - override lazy val precursorPasses: Seq[IRPass] = List( + override lazy val precursorPasses: Seq[IRProcessingPass] = List( FunctionBinding, GenerateMethodBodies, - SectionsToBinOp, + SectionsToBinOp.INSTANCE, OperatorToFunction, LambdaShorthandToLambda, GlobalNames ) - override lazy val invalidatedPasses: Seq[IRPass] = List() + override lazy val invalidatedPasses: Seq[IRProcessingPass] = List() - private lazy val TAIL_META = new MetadataPair(this, TailPosition.Tail) + private lazy val TAIL_META = + new MetadataPair(TailCall.INSTANCE, TailPosition.Tail) /** Analyses tail call state for expressions in a module. * @@ -481,42 +486,6 @@ case object TailCall extends IRPass { } } - /** Expresses the tail call state of an IR Node. */ - sealed trait TailPosition extends IRPass.IRMetadata { - - /** A boolean representation of the expression's tail state. */ - def isTail: Boolean - } - object TailPosition { - - /** The expression is in a tail position and can be tail call optimised. - * If the expression is not in tail-call position, it has no metadata attached. - */ - final case object Tail extends TailPosition { - override val metadataName: String = "TailCall.TailPosition.Tail" - override def isTail: Boolean = true - - override def duplicate(): Option[IRPass.IRMetadata] = Some(Tail) - - /** @inheritdoc */ - override def prepareForSerialization(compiler: Compiler): Tail.type = this - - /** @inheritdoc */ - override def restoreFromSerialization( - compiler: Compiler - ): Option[Tail.type] = Some(this) - } - - /** Implicitly converts the tail position data into a boolean. - * - * @param tailPosition the tail position value - * @return the boolean value corresponding to `tailPosition` - */ - implicit def toBool(tailPosition: TailPosition): Boolean = { - tailPosition.isTail - } - } - /** Checks if the provided `expression` is annotated with a tail call * annotation. * diff --git a/engine/runtime-integration-tests/src/test/scala/org/enso/compiler/test/pass/analyse/TailCallTest.scala b/engine/runtime-integration-tests/src/test/scala/org/enso/compiler/test/pass/analyse/TailCallTest.scala index 3a8d2dc2c18a..fd4d86c4ced2 100644 --- a/engine/runtime-integration-tests/src/test/scala/org/enso/compiler/test/pass/analyse/TailCallTest.scala +++ b/engine/runtime-integration-tests/src/test/scala/org/enso/compiler/test/pass/analyse/TailCallTest.scala @@ -3,24 +3,34 @@ package org.enso.compiler.test.pass.analyse import org.enso.compiler.Passes import org.enso.compiler.context.{FreshNameSupply, InlineContext, ModuleContext} import org.enso.compiler.core.Implicits.AsMetadata -import org.enso.compiler.core.ir.{ - Expression, - Function, - Module, - Pattern, - Warning -} +import org.enso.compiler.core.ir.{Expression, Function, Pattern, Warning} import org.enso.compiler.core.ir.module.scope.definition import org.enso.compiler.core.ir.expression.Application import org.enso.compiler.core.ir.expression.Case import org.enso.compiler.pass.PassConfiguration._ import org.enso.compiler.pass.analyse.TailCall.TailPosition import org.enso.compiler.pass.analyse.{AliasAnalysis, TailCall} -import org.enso.compiler.pass.{PassConfiguration, PassGroup, PassManager} -import org.enso.compiler.test.CompilerTest +import org.enso.compiler.pass.{ + IRPass, + MiniPassFactory, + PassConfiguration, + PassGroup, + PassManager +} +import org.enso.compiler.test.MiniPassTest import org.enso.compiler.context.LocalScope -class TailCallTest extends CompilerTest { +class TailCallTest extends MiniPassTest { + override def testName: String = "Tail call" + + override def miniPassFactory: MiniPassFactory = { + TailCall.INSTANCE + } + + override def megaPass: IRPass = TailCallMegaPass + + override def megaPassManager: PassManager = + new PassManager(List(precursorPasses), passConfiguration) // === Test Setup =========================================================== @@ -45,53 +55,16 @@ class TailCallTest extends CompilerTest { val passes = new Passes(defaultConfig) - val precursorPasses: PassGroup = passes.getPrecursors(TailCall).get + val precursorPasses: PassGroup = passes.getPrecursors(TailCall.INSTANCE).get val passConfiguration: PassConfiguration = PassConfiguration( AliasAnalysis -->> AliasAnalysis.Configuration() ) - implicit val passManager: PassManager = - new PassManager(List(precursorPasses), passConfiguration) - - /** Adds an extension method to analyse an Enso module. - * - * @param ir the ir to analyse - */ - implicit class AnalyseModule(ir: Module) { - - /** Performs tail call analysis on [[ir]]. - * - * @param context the module context in which analysis takes place - * @return [[ir]], with tail call analysis metadata attached - */ - def analyse(implicit context: ModuleContext) = { - TailCall.runModule(ir, context) - } - } - - /** Adds an extension method to preprocess source code as an Enso expression. - * - * @param ir the ir to analyse - */ - implicit class AnalyseExpresion(ir: Expression) { - - /** Performs tail call analysis on [[ir]]. - * - * @param context the inline context in which analysis takes place - * @return [[ir]], with tail call analysis metadata attached - */ - def analyse(implicit context: InlineContext): Expression = { - TailCall.runExpression(ir, context) - } - } - // === The Tests ============================================================ "Tail call analysis on modules" should { - implicit val ctx: ModuleContext = mkModuleContext - - val ir = + val code = """ |Foo.bar = a -> b -> c -> | d = a + b @@ -103,18 +76,42 @@ class TailCallTest extends CompilerTest { |type MyAtom a b c | |Foo.from (that : Bar) = undefined - |""".stripMargin.preprocessModule.analyse + |""".stripMargin "mark methods as tail" in { - ir.bindings.head.getMetadata(TailCall) shouldEqual Some(TailPosition.Tail) + assertModuleCompilation( + code, + () => mkModuleContext, + ir => { + ir.bindings.head.getMetadata(TailCall.INSTANCE) shouldEqual Some( + TailPosition.Tail + ) + } + ) } "mark atoms as tail" in { - ir.bindings(1).getMetadata(TailCall) shouldEqual Some(TailPosition.Tail) + assertModuleCompilation( + code, + () => mkModuleContext, + ir => { + ir.bindings(1).getMetadata(TailCall.INSTANCE) shouldEqual Some( + TailPosition.Tail + ) + } + ) } "mark conversions as tail" in { - ir.bindings(2).getMetadata(TailCall) shouldEqual Some(TailPosition.Tail) + assertModuleCompilation( + code, + () => mkModuleContext, + ir => { + ir.bindings(2).getMetadata(TailCall.INSTANCE) shouldEqual Some( + TailPosition.Tail + ) + } + ) } } @@ -125,305 +122,435 @@ class TailCallTest extends CompilerTest { |""".stripMargin "mark the expression as tail if the context requires it" in { - implicit val ctx: InlineContext = mkTailContext - val ir = code.preprocessExpression.get.analyse - - ir.getMetadata(TailCall) shouldEqual Some(TailPosition.Tail) + assertInlineCompilation( + code, + () => mkTailContext, + ir => { + ir.getMetadata(TailCall.INSTANCE) shouldEqual Some(TailPosition.Tail) + }, + true + ) } "not mark the expression as tail if the context doesn't require it" in { - implicit val ctx: InlineContext = mkNoTailContext - val ir = code.preprocessExpression.get.analyse - - ir.getMetadata(TailCall) shouldEqual None + assertInlineCompilation( + code, + () => mkNoTailContext, + ir => { + ir.getMetadata(TailCall.INSTANCE) shouldEqual None + } + ) } "mark the value of a tail assignment as non-tail" in { - implicit val ctx: InlineContext = mkTailContext - val binding = + assertInlineCompilation( """ |foo = a b - |""".stripMargin.preprocessExpression.get.analyse - .asInstanceOf[Expression.Binding] - binding.getMetadata(TailCall) shouldEqual Some(TailPosition.Tail) - binding.expression.getMetadata(TailCall) shouldEqual None - + |""".stripMargin, + () => mkTailContext, + ir => { + val binding = ir.asInstanceOf[Expression.Binding] + binding.getMetadata(TailCall.INSTANCE) shouldEqual Some( + TailPosition.Tail + ) + binding.expression.getMetadata(TailCall.INSTANCE) shouldEqual None + }, + true + ) } } "Tail call analysis on functions" should { - implicit val ctx: InlineContext = mkTailContext - - val ir = + val code = """ |a -> b -> c -> | d = @Tail_Call (a + b) | e = a * c | @Tail_Call (d + e) - |""".stripMargin.preprocessExpression.get.analyse - .asInstanceOf[Function.Lambda] - - val fnBody = ir.body.asInstanceOf[Expression.Block] + |""".stripMargin "mark the last expression of the function as tail" in { - fnBody.returnValue.getMetadata(TailCall) shouldEqual Some( - TailPosition.Tail + assertInlineCompilation( + code, + () => mkTailContext, + ir => { + val lambda = ir.asInstanceOf[Function.Lambda] + val fnBody = lambda.body.asInstanceOf[Expression.Block] + fnBody.returnValue.getMetadata(TailCall.INSTANCE) shouldEqual Some( + TailPosition.Tail + ) + } ) } "mark the other expressions in the function as not tail" in { - fnBody.expressions.foreach(expr => - expr.getMetadata(TailCall) shouldEqual None + assertInlineCompilation( + code, + () => mkTailContext, + ir => { + val lambda = ir.asInstanceOf[Function.Lambda] + val fnBody = lambda.body.asInstanceOf[Expression.Block] + fnBody.expressions.foreach { expr => + expr.getMetadata(TailCall.INSTANCE) shouldEqual None + } + } ) } "warn about misplaced @TailCall annotations" in { - fnBody - .expressions(0) - .asInstanceOf[Expression.Binding] - .expression - .diagnosticsList - .count(_.isInstanceOf[Warning.WrongTco]) shouldEqual 1 - - fnBody.returnValue.diagnosticsList - .count(_.isInstanceOf[Warning.WrongTco]) shouldEqual 0 + assertInlineCompilation( + code, + () => mkTailContext, + ir => { + val lambda = ir.asInstanceOf[Function.Lambda] + val fnBody = lambda.body.asInstanceOf[Expression.Block] + fnBody + .expressions(0) + .asInstanceOf[Expression.Binding] + .expression + .diagnosticsList + .count(_.isInstanceOf[Warning.WrongTco]) shouldEqual 1 + fnBody.returnValue.diagnosticsList + .count(_.isInstanceOf[Warning.WrongTco]) shouldEqual 0 + }, + true + ) } } "Tail call analysis on local functions" should { - implicit val ctx: ModuleContext = mkModuleContext - - val ir = - """ - |adder_two = - | if 0 == 0 then 0 else - | @Tail_Call adder_two - |""".stripMargin.preprocessModule.analyse - - val fnBody = ir.bindings.head - .asInstanceOf[definition.Method] - .body - .asInstanceOf[Function.Lambda] - .body - "handle application involving local functions" in { - fnBody - .asInstanceOf[Expression.Block] - .returnValue - .asInstanceOf[Application.Prefix] - .arguments(2) - .value - .asInstanceOf[Expression.Block] - .returnValue - .asInstanceOf[Application.Prefix] - .function - .diagnosticsList - .count(_.isInstanceOf[Warning.WrongTco]) shouldEqual 0 + val code = """ + |adder_two = + | if 0 == 0 then 0 else + | @Tail_Call adder_two + |""".stripMargin + + assertModuleCompilation( + code, + () => mkModuleContext, + ir => { + val fnBody = ir.bindings.head + .asInstanceOf[definition.Method] + .body + .asInstanceOf[Function.Lambda] + .body + fnBody + .asInstanceOf[Expression.Block] + .returnValue + .asInstanceOf[Application.Prefix] + .arguments(2) + .value + .asInstanceOf[Expression.Block] + .returnValue + .asInstanceOf[Application.Prefix] + .function + .diagnosticsList + .count(_.isInstanceOf[Warning.WrongTco]) shouldEqual 0 + } + ) } - } "Tail call analysis on case expressions" should { "not mark any portion of the branch functions as tail by default" in { - implicit val ctx: ModuleContext = mkModuleContext - - val ir = + val code = """ |Foo.bar = a -> | x = case a of | Lambda fn arg -> fn arg | | x - |""".stripMargin.preprocessModule.analyse - - val caseExpr = ir.bindings.head - .asInstanceOf[definition.Method] - .body - .asInstanceOf[Function.Lambda] - .body - .asInstanceOf[Expression.Block] - .expressions - .head - .asInstanceOf[Expression.Binding] - .expression - .asInstanceOf[Expression.Block] - .returnValue - .asInstanceOf[Case.Expr] - - caseExpr.getMetadata(TailCall) shouldEqual None - caseExpr.branches.foreach(branch => { - val branchExpression = - branch.expression.asInstanceOf[Application.Prefix] - - branchExpression.getMetadata(TailCall) shouldEqual None - }) + |""".stripMargin + + assertModuleCompilation( + code, + () => mkModuleContext, + ir => { + val caseExpr = ir.bindings.head + .asInstanceOf[definition.Method] + .body + .asInstanceOf[Function.Lambda] + .body + .asInstanceOf[Expression.Block] + .expressions + .head + .asInstanceOf[Expression.Binding] + .expression + .asInstanceOf[Expression.Block] + .returnValue + .asInstanceOf[Case.Expr] + + caseExpr.getMetadata(TailCall.INSTANCE) shouldEqual None + caseExpr.branches.foreach(branch => { + val branchExpression = + branch.expression.asInstanceOf[Application.Prefix] + + branchExpression.getMetadata(TailCall.INSTANCE) shouldEqual None + }) + } + ) } "only mark the branches as tail if the expression is in tail position" in { - implicit val ctx: ModuleContext = mkModuleContext - - val ir = + val code = """ |Foo.bar = a -> | case a of | Lambda fn arg -> fn arg - |""".stripMargin.preprocessModule.analyse - - val caseExpr = ir.bindings.head - .asInstanceOf[definition.Method] - .body - .asInstanceOf[Function.Lambda] - .body - .asInstanceOf[Expression.Block] - .returnValue - .asInstanceOf[Expression.Block] - .returnValue - .asInstanceOf[Case.Expr] - - caseExpr.getMetadata(TailCall) shouldEqual Some( - TailPosition.Tail + |""".stripMargin + + assertModuleCompilation( + code, + () => mkModuleContext, + ir => { + val caseExpr = ir.bindings.head + .asInstanceOf[definition.Method] + .body + .asInstanceOf[Function.Lambda] + .body + .asInstanceOf[Expression.Block] + .returnValue + .asInstanceOf[Expression.Block] + .returnValue + .asInstanceOf[Case.Expr] + + caseExpr.getMetadata(TailCall.INSTANCE) shouldEqual Some( + TailPosition.Tail + ) + caseExpr.branches.foreach(branch => { + val branchExpression = + branch.expression.asInstanceOf[Application.Prefix] + + branchExpression.getMetadata(TailCall.INSTANCE) shouldEqual Some( + TailPosition.Tail + ) + }) + }, + true ) - caseExpr.branches.foreach(branch => { - val branchExpression = - branch.expression.asInstanceOf[Application.Prefix] - - branchExpression.getMetadata(TailCall) shouldEqual Some( - TailPosition.Tail - ) - }) } "mark patters and pattern elements as not tail" in { - implicit val ctx: InlineContext = mkTailContext - - val ir = + val code = """ |case x of | Cons a b -> a + b - |""".stripMargin.preprocessExpression.get.analyse - .asInstanceOf[Expression.Block] - .returnValue - .asInstanceOf[Case.Expr] - - val caseBranch = ir.branches.head - val pattern = caseBranch.pattern.asInstanceOf[Pattern.Constructor] - val patternConstructor = pattern.constructor - - pattern.getMetadata(TailCall) shouldEqual None - patternConstructor.getMetadata(TailCall) shouldEqual None - pattern.fields.foreach(f => { - f.getMetadata(TailCall) shouldEqual None - - f.asInstanceOf[Pattern.Name] - .name - .getMetadata(TailCall) shouldEqual None - }) + |""".stripMargin + + assertInlineCompilation( + code, + () => mkTailContext, + ir => { + val caseExpr = ir + .asInstanceOf[Expression.Block] + .returnValue + .asInstanceOf[Case.Expr] + val caseBranch = caseExpr.branches.head + val pattern = caseBranch.pattern.asInstanceOf[Pattern.Constructor] + val patternConstructor = pattern.constructor + + pattern.getMetadata(TailCall.INSTANCE) shouldEqual None + patternConstructor.getMetadata(TailCall.INSTANCE) shouldEqual None + pattern.fields.foreach(f => { + f.getMetadata(TailCall.INSTANCE) shouldEqual None + + f.asInstanceOf[Pattern.Name] + .name + .getMetadata(TailCall.INSTANCE) shouldEqual None + }) + } + ) } } "Tail call analysis on function calls" should { - implicit val ctx: ModuleContext = mkModuleContext - - val tailCall = - """ - |Foo.bar = - | IO.println "AAAAA" - |""".stripMargin.preprocessModule.analyse.bindings.head - .asInstanceOf[definition.Method] - val tailCallBody = tailCall.body - .asInstanceOf[Function.Lambda] - .body - .asInstanceOf[Expression.Block] - - val nonTailCall = - """ - |Foo.bar = - | a = b c d - | a - |""".stripMargin.preprocessModule.analyse.bindings.head - .asInstanceOf[definition.Method] - val nonTailCallBody = nonTailCall.body - .asInstanceOf[Function.Lambda] - .body - .asInstanceOf[Expression.Block] - - "mark the arguments as tail" in { - nonTailCallBody.expressions.head - .asInstanceOf[Expression.Binding] - .expression - .asInstanceOf[Application.Prefix] - .arguments - .foreach(arg => - arg.getMetadata(TailCall) shouldEqual Some( - TailPosition.Tail - ) - ) - - tailCallBody.returnValue - .asInstanceOf[Application.Prefix] - .arguments - .foreach(arg => - arg.getMetadata(TailCall) shouldEqual Some( - TailPosition.Tail - ) - ) + "work on function that has tail call return value" in { + assertModuleCompilation( + """ + |Foo.bar = + | IO.println "AAAAA" + |""".stripMargin, + () => mkModuleContext, + ir => { + val tailCall = ir.bindings.head.asInstanceOf[definition.Method] + val tailCallBody = tailCall.body + .asInstanceOf[Function.Lambda] + .body + .asInstanceOf[Expression.Block] + + withClue("Mark the arguments as tail") { + tailCallBody.returnValue + .asInstanceOf[Application.Prefix] + .arguments + .foreach(arg => + arg.getMetadata(TailCall.INSTANCE) shouldEqual Some( + TailPosition.Tail + ) + ) + } + + withClue( + "Mark the function call as tail if it is in a tail position" + ) { + tailCallBody.returnValue.getMetadata( + TailCall.INSTANCE + ) shouldEqual Some( + TailPosition.Tail + ) + } + } + ) } - "mark the function call as tail if it is in a tail position" in { - tailCallBody.returnValue.getMetadata(TailCall) shouldEqual Some( - TailPosition.Tail + "work on function that has not-tail call return value" in { + assertModuleCompilation( + """ + |Foo.bar = + | a = b c d + | a + |""".stripMargin, + () => mkModuleContext, + ir => { + val nonTailCall = ir.bindings.head.asInstanceOf[definition.Method] + val nonTailCallBody = nonTailCall.body + .asInstanceOf[Function.Lambda] + .body + .asInstanceOf[Expression.Block] + + withClue("Mark the arguments as tail") { + nonTailCallBody + .expressions(0) + .asInstanceOf[Expression.Binding] + .expression + .asInstanceOf[Application.Prefix] + .arguments + .foreach(arg => + arg.getMetadata(TailCall.INSTANCE) shouldEqual Some( + TailPosition.Tail + ) + ) + } + + withClue( + "Mark the function call as not tail if it is in a tail position" + ) { + nonTailCallBody.expressions.head + .asInstanceOf[Expression.Binding] + .expression + .getMetadata(TailCall.INSTANCE) shouldEqual None + } + } ) } "mark the function call as not tail if it is in a tail position" in { - nonTailCallBody.expressions.head - .asInstanceOf[Expression.Binding] - .expression - .getMetadata(TailCall) shouldEqual None + assertModuleCompilation( + """ + |Foo.bar = + | a = b c d + | a + |""".stripMargin, + () => mkModuleContext, + ir => { + val nonTailCall = ir.bindings.head.asInstanceOf[definition.Method] + val nonTailCallBody = nonTailCall.body + .asInstanceOf[Function.Lambda] + .body + .asInstanceOf[Expression.Block] + nonTailCallBody.expressions.head + .asInstanceOf[Expression.Binding] + .expression + .getMetadata(TailCall.INSTANCE) shouldEqual None + } + ) } } "Tail call analysis on blocks" should { - implicit val ctx: ModuleContext = mkModuleContext - - val ir = + val code = """ |Foo.bar = a -> b -> c -> | d = a + b | mul = a -> b -> a * b | mul c d - |""".stripMargin.preprocessModule.analyse.bindings.head - .asInstanceOf[definition.Method] - - val block = ir.body - .asInstanceOf[Function.Lambda] - .body - .asInstanceOf[Expression.Block] + |""".stripMargin "mark the bodies of bound functions as tail properly" in { - block - .expressions(1) - .asInstanceOf[Expression.Binding] - .expression - .asInstanceOf[Function.Lambda] - .body - .getMetadata(TailCall) shouldEqual Some(TailPosition.Tail) + assertModuleCompilation( + code, + () => mkModuleContext, + ir => { + val method = ir.bindings.head.asInstanceOf[definition.Method] + val block = method.body + .asInstanceOf[Function.Lambda] + .body + .asInstanceOf[Expression.Block] + + block + .expressions(1) + .asInstanceOf[Expression.Binding] + .expression + .asInstanceOf[Function.Lambda] + .body + .getMetadata(TailCall.INSTANCE) shouldEqual Some(TailPosition.Tail) + } + ) } "mark the block expressions as not tail" in { - block.expressions.foreach(expr => - expr.getMetadata(TailCall) shouldEqual None + assertModuleCompilation( + code, + () => mkModuleContext, + ir => { + val method = ir.bindings.head.asInstanceOf[definition.Method] + val block = method.body + .asInstanceOf[Function.Lambda] + .body + .asInstanceOf[Expression.Block] + + block.expressions.foreach(expr => + expr.getMetadata(TailCall.INSTANCE) shouldEqual None + ) + } ) } "mark the final expression of the block as tail" in { - block.returnValue.getMetadata(TailCall) shouldEqual Some( - TailPosition.Tail + assertModuleCompilation( + code, + () => mkModuleContext, + ir => { + val method = ir.bindings.head.asInstanceOf[definition.Method] + val block = method.body + .asInstanceOf[Function.Lambda] + .body + .asInstanceOf[Expression.Block] + + block.returnValue.getMetadata(TailCall.INSTANCE) shouldEqual Some( + TailPosition.Tail + ) + } ) } "mark the block as tail if it is in a tail position" in { - block.getMetadata(TailCall) shouldEqual Some(TailPosition.Tail) + assertModuleCompilation( + code, + () => mkModuleContext, + ir => { + val method = ir.bindings.head.asInstanceOf[definition.Method] + val block = method.body + .asInstanceOf[Function.Lambda] + .body + .asInstanceOf[Expression.Block] + + block.getMetadata(TailCall.INSTANCE) shouldEqual Some( + TailPosition.Tail + ) + }, + true + ) } } } diff --git a/engine/runtime-integration-tests/src/test/scala/org/enso/compiler/test/pass/desugar/LambdaShorthandToLambdaMegaPass.scala b/engine/runtime-integration-tests/src/test/scala/org/enso/compiler/test/pass/desugar/LambdaShorthandToLambdaMegaPass.scala new file mode 100644 index 000000000000..8d0d81347322 --- /dev/null +++ b/engine/runtime-integration-tests/src/test/scala/org/enso/compiler/test/pass/desugar/LambdaShorthandToLambdaMegaPass.scala @@ -0,0 +1,447 @@ +package org.enso.compiler.test.pass.desugar + +import org.enso.compiler.context.{FreshNameSupply, InlineContext, ModuleContext} +import org.enso.compiler.core.CompilerError +import org.enso.compiler.core.ir.{ + CallArgument, + DefinitionArgument, + Expression, + Function, + IdentifiedLocation, + Module, + Name +} +import org.enso.compiler.core.ir.expression.{Application, Case, Operator} +import org.enso.compiler.pass.{IRPass, IRProcessingPass} +import org.enso.compiler.pass.analyse.{ + AliasAnalysis, + DataflowAnalysis, + DemandAnalysis, + TailCall +} +import org.enso.compiler.pass.desugar.{ + ComplexType, + FunctionBinding, + GenerateMethodBodies, + OperatorToFunction, + SectionsToBinOp +} +import org.enso.compiler.pass.lint.UnusedBindings +import org.enso.compiler.pass.optimise.LambdaConsolidate +import org.enso.compiler.pass.resolve.{ + DocumentationComments, + IgnoredBindings, + OverloadsResolution +} + +/** Original implementation of [[org.enso.compiler.pass.desugar.LambdaShorthandToLambda]]. + * Now serves as the test verification of the new [[org.enso.compiler.pass.desugar.LambdaShorthandToLambdaMini]] mini pass version. + * + * This pass translates `_` arguments at application sites to lambda functions. + * + * This pass has no configuration. + * + * This pass requires the context to provide: + * + * - A [[FreshNameSupply]] + */ +case object LambdaShorthandToLambdaMegaPass extends IRPass { + override type Metadata = IRPass.Metadata.Empty + override type Config = IRPass.Configuration.Default + + override lazy val precursorPasses: Seq[IRProcessingPass] = List( + ComplexType, + DocumentationComments, + FunctionBinding, + GenerateMethodBodies, + OperatorToFunction, + SectionsToBinOp.INSTANCE + ) + override lazy val invalidatedPasses: Seq[IRProcessingPass] = List( + AliasAnalysis, + DataflowAnalysis, + DemandAnalysis, + IgnoredBindings, + LambdaConsolidate, + OverloadsResolution, + TailCall.INSTANCE, + UnusedBindings + ) + + /** Desugars underscore arguments to lambdas for a module. + * + * @param ir the Enso IR to process + * @param moduleContext a context object that contains the information needed + * to process a module + * @return `ir`, possibly having made transformations or annotations to that + * IR. + */ + override def runModule( + ir: Module, + moduleContext: ModuleContext + ): Module = { + val new_bindings = ir.bindings.map { case a => + a.mapExpressions( + runExpression( + _, + InlineContext( + moduleContext, + freshNameSupply = moduleContext.freshNameSupply, + compilerConfig = moduleContext.compilerConfig + ) + ) + ) + } + ir.copy(bindings = new_bindings) + } + + /** Desugars underscore arguments to lambdas for an arbitrary expression. + * + * @param ir the Enso IR to process + * @param inlineContext a context object that contains the information needed + * for inline evaluation + * @return `ir`, possibly having made transformations or annotations to that + * IR. + */ + override def runExpression( + ir: Expression, + inlineContext: InlineContext + ): Expression = { + val freshNameSupply = inlineContext.freshNameSupply.getOrElse( + throw new CompilerError( + "Desugaring underscore arguments to lambdas requires a fresh name " + + "supply." + ) + ) + + desugarExpression(ir, freshNameSupply) + } + + // === Pass Internals ======================================================= + + /** Performs lambda shorthand desugaring on an arbitrary expression. + * + * @param ir the expression to desugar + * @param freshNameSupply the compiler's fresh name supply + * @return `ir`, with any lambda shorthand arguments desugared + */ + def desugarExpression( + ir: Expression, + freshNameSupply: FreshNameSupply + ): Expression = { + ir.transformExpressions { + case app: Application => desugarApplication(app, freshNameSupply) + case caseExpr: Case.Expr => desugarCaseExpr(caseExpr, freshNameSupply) + case name: Name => desugarName(name, freshNameSupply) + } + } + + /** Desugars an arbitrary name occurrence, turning isolated occurrences of + * `_` into the `id` function. + * + * @param name the name to desugar + * @param supply the compiler's fresh name supply + * @return `name`, desugared where necessary + */ + private def desugarName(name: Name, supply: FreshNameSupply): Expression = { + name match { + case blank: Name.Blank => + val newName = supply.newName() + + new Function.Lambda( + List( + DefinitionArgument.Specified( + name = Name.Literal( + newName.name, + isMethod = false, + null + ), + ascribedType = None, + defaultValue = None, + suspended = false, + identifiedLocation = null + ) + ), + newName, + blank.location.orNull + ) + case _ => name + } + } + + /** Desugars lambda shorthand arguments to an arbitrary function application. + * + * @param application the function application to desugar + * @param freshNameSupply the compiler's supply of fresh names + * @return `application`, with any lambda shorthand arguments desugared + */ + private def desugarApplication( + application: Application, + freshNameSupply: FreshNameSupply + ): Expression = { + application match { + case p @ Application.Prefix(fn, args, _, _, _) => + // Determine which arguments are lambda shorthand + val argIsUnderscore = determineLambdaShorthand(args) + + // Generate a new name for the arg value for each shorthand arg + val updatedArgs = + args + .zip(argIsUnderscore) + .map(updateShorthandArg(_, freshNameSupply)) + .map { case s @ CallArgument.Specified(_, value, _, _) => + s.copy(value = desugarExpression(value, freshNameSupply)) + } + + // Generate a definition arg instance for each shorthand arg + val defArgs = updatedArgs.zip(argIsUnderscore).map { + case (arg, isShorthand) => generateDefinitionArg(arg, isShorthand) + } + val actualDefArgs = defArgs.collect { case Some(defArg) => + defArg + } + + // Determine whether or not the function itself is shorthand + val functionIsShorthand = fn.isInstanceOf[Name.Blank] + val (updatedFn, updatedName) = if (functionIsShorthand) { + val newFn = freshNameSupply + .newName() + .copy( + location = fn.location, + passData = fn.passData, + diagnostics = fn.diagnostics + ) + val newName = newFn.name + (newFn, Some(newName)) + } else { + val newFn = desugarExpression(fn, freshNameSupply) + (newFn, None) + } + + val processedApp = p.copy( + function = updatedFn, + arguments = updatedArgs + ) + + // Wrap the app in lambdas from right to left, 1 lambda per shorthand + // arg + val appResult = + actualDefArgs.foldRight(processedApp: Expression)((arg, body) => + new Function.Lambda(List(arg), body, null) + ) + + // If the function is shorthand, do the same + val resultExpr = if (functionIsShorthand) { + new Function.Lambda( + List( + DefinitionArgument.Specified( + Name + .Literal( + updatedName.get, + isMethod = false, + fn.location.orNull + ), + None, + None, + suspended = false, + null + ) + ), + appResult, + null + ) + } else appResult + + resultExpr match { + case lam: Function.Lambda => lam.copy(location = p.location) + case result => result + } + case f @ Application.Force(tgt, _, _) => + f.copy(target = desugarExpression(tgt, freshNameSupply)) + case vector @ Application.Sequence(items, _, _) => + var bindings: List[Name] = List() + val newItems = items.map { + case blank: Name.Blank => + val name = freshNameSupply + .newName() + .copy( + location = blank.location, + passData = blank.passData, + diagnostics = blank.diagnostics + ) + bindings ::= name + name + case it => desugarExpression(it, freshNameSupply) + } + val newVec = vector.copy(newItems) + val locWithoutId = + newVec.location.map(l => new IdentifiedLocation(l.location())) + bindings.foldLeft(newVec: Expression) { (body, bindingName) => + val defArg = DefinitionArgument.Specified( + bindingName, + ascribedType = None, + defaultValue = None, + suspended = false, + identifiedLocation = null + ) + new Function.Lambda(List(defArg), body, locWithoutId.orNull) + } + case tSet @ Application.Typeset(expr, _, _) => + tSet.copy(expression = expr.map(desugarExpression(_, freshNameSupply))) + case _: Operator => + throw new CompilerError( + "Operators should be desugared by the point of underscore " + + "to lambda conversion." + ) + } + } + + /** Determines, positionally, which of the application arguments are lambda + * shorthand arguments. + * + * @param args the application arguments + * @return a list containing `true` for a given position if the arg in that + * position is lambda shorthand, otherwise `false` + */ + private def determineLambdaShorthand( + args: List[CallArgument] + ): List[Boolean] = { + args.map { case CallArgument.Specified(_, value, _, _) => + value match { + case _: Name.Blank => true + case _ => false + } + } + } + + /** Generates a new name to replace a shorthand argument, as well as the + * corresponding definition argument. + * + * @param argAndIsShorthand the arguments, and whether or not the argument in + * the corresponding position is shorthand + * @return the above described pair for a given position if the argument in + * a given position is shorthand, otherwise [[None]]. + */ + private def updateShorthandArg( + argAndIsShorthand: (CallArgument, Boolean), + freshNameSupply: FreshNameSupply + ): CallArgument = { + val arg = argAndIsShorthand._1 + val isShorthand = argAndIsShorthand._2 + + arg match { + case s @ CallArgument.Specified(_, value, _, _) => + if (isShorthand) { + val newName = freshNameSupply + .newName() + .copy( + location = value.location, + passData = value.passData, + diagnostics = value.diagnostics + ) + + s.copy(value = newName) + } else s + } + } + + /** Generates a corresponding definition argument to a call argument that was + * previously lambda shorthand. + * + * @param arg the argument to generate a corresponding def argument to + * @param isShorthand whether or not `arg` was shorthand + * @return a corresponding definition argument if `arg` `isShorthand`, + * otherwise [[None]] + */ + private def generateDefinitionArg( + arg: CallArgument, + isShorthand: Boolean + ): Option[DefinitionArgument] = { + if (isShorthand) { + arg match { + case specified @ CallArgument.Specified(_, value, _, passData) => + // Note [Safe Casting to Name.Literal] + val defArgName = + Name.Literal( + value.asInstanceOf[Name.Literal].name, + isMethod = false, + null + ) + + Some( + new DefinitionArgument.Specified( + defArgName, + None, + None, + suspended = false, + null, + passData.duplicate, + specified.diagnosticsCopy + ) + ) + } + } else None + } + + /* Note [Safe Casting to Name.Literal] + * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + * This cast is entirely safe here as, by construction in + * `updateShorthandArg`, any arg for which `isShorthand` is true has its + * value as an `Name.Literal`. + */ + + /** Performs desugaring of lambda shorthand arguments in a case expression. + * + * In the case where a user writes `case _ of`, this gets converted into a + * lambda (`x -> case x of`). + * + * @param caseExpr the case expression to desugar + * @param freshNameSupply the compiler's supply of fresh names + * @return `caseExpr`, with any lambda shorthand desugared + */ + private def desugarCaseExpr( + caseExpr: Case.Expr, + freshNameSupply: FreshNameSupply + ): Expression = { + val newBranches = caseExpr.branches.map( + _.mapExpressions(expr => desugarExpression(expr, freshNameSupply)) + ) + + caseExpr.scrutinee match { + case nameBlank: Name.Blank => + val scrutineeName = + freshNameSupply + .newName() + .copy( + location = nameBlank.location, + passData = nameBlank.passData, + diagnostics = nameBlank.diagnostics + ) + + val lambdaArg = DefinitionArgument.Specified( + scrutineeName.copy(id = null), + None, + None, + suspended = false, + null + ) + + val newCaseExpr = caseExpr.copy( + scrutinee = scrutineeName, + branches = newBranches + ) + + new Function.Lambda( + caseExpr, + List(lambdaArg), + newCaseExpr, + caseExpr.location.orNull + ) + case x => + caseExpr.copy( + scrutinee = desugarExpression(x, freshNameSupply), + branches = newBranches + ) + } + } +} diff --git a/engine/runtime-integration-tests/src/test/scala/org/enso/compiler/test/pass/desugar/LambdaShorthandToLambdaTest.scala b/engine/runtime-integration-tests/src/test/scala/org/enso/compiler/test/pass/desugar/LambdaShorthandToLambdaTest.scala index 6d2c7708a293..8201be729048 100644 --- a/engine/runtime-integration-tests/src/test/scala/org/enso/compiler/test/pass/desugar/LambdaShorthandToLambdaTest.scala +++ b/engine/runtime-integration-tests/src/test/scala/org/enso/compiler/test/pass/desugar/LambdaShorthandToLambdaTest.scala @@ -2,6 +2,7 @@ package org.enso.compiler.test.pass.desugar import org.enso.compiler.Passes import org.enso.compiler.context.{FreshNameSupply, InlineContext} +import org.enso.compiler.core.IR import org.enso.compiler.core.ir.{ CallArgument, DefinitionArgument, @@ -12,8 +13,13 @@ import org.enso.compiler.core.ir.{ } import org.enso.compiler.core.ir.expression.{Application, Case} import org.enso.compiler.pass.desugar.LambdaShorthandToLambda -import org.enso.compiler.pass.{PassConfiguration, PassGroup, PassManager} -import org.enso.compiler.test.CompilerTest +import org.enso.compiler.pass.{ + MiniIRPass, + PassConfiguration, + PassGroup, + PassManager +} +import org.enso.compiler.test.{CompilerTest, CompilerTests} class LambdaShorthandToLambdaTest extends CompilerTest { @@ -42,7 +48,13 @@ class LambdaShorthandToLambdaTest extends CompilerTest { * @return [[ir]], with all lambda shorthand desugared */ def desugar(implicit inlineContext: InlineContext): Expression = { - LambdaShorthandToLambda.runExpression(ir, inlineContext) + LambdaShorthandToLambdaMegaPass.runExpression(ir, inlineContext) + } + + def desugarMini(implicit inlineContext: InlineContext): Expression = { + val miniPass = + LambdaShorthandToLambda.createForInlineCompilation(inlineContext) + MiniIRPass.compile(classOf[Expression], ir, miniPass) } } @@ -54,6 +66,20 @@ class LambdaShorthandToLambdaTest extends CompilerTest { buildInlineContext(freshNameSupply = Some(new FreshNameSupply)) } + private def desugarWithMegaPass( + code: String + ): IR = { + implicit val ctx: InlineContext = mkInlineContext + code.preprocessExpression.get.desugar + } + + private def desugarWithMiniPass( + code: String + ): IR = { + implicit val ctx: InlineContext = mkInlineContext + code.preprocessExpression.get.desugarMini + } + // === The Tests ============================================================ "Desugaring of underscore arguments" should { @@ -335,107 +361,123 @@ class LambdaShorthandToLambdaTest extends CompilerTest { "Nested underscore arguments" should { "work for applications" in { - implicit val ctx: InlineContext = mkInlineContext - - val ir = + val code = """ |a _ (fn _ c) - |""".stripMargin.preprocessExpression.get.desugar - - ir shouldBe an[Function.Lambda] - ir.asInstanceOf[Function.Lambda].body shouldBe an[Application.Prefix] - val irBody = ir - .asInstanceOf[Function.Lambda] - .body - .asInstanceOf[Application.Prefix] - - irBody - .arguments(1) - .asInstanceOf[CallArgument.Specified] - .value shouldBe an[Function.Lambda] - val lamArg = irBody - .arguments(1) - .asInstanceOf[CallArgument.Specified] - .value - .asInstanceOf[Function.Lambda] - val lamArgArgName = - lamArg.arguments.head.asInstanceOf[DefinitionArgument.Specified].name - - lamArg.body shouldBe an[Application.Prefix] - val lamArgBody = lamArg.body.asInstanceOf[Application.Prefix] - val lamArgBodyArg1Name = lamArgBody.arguments.head - .asInstanceOf[CallArgument.Specified] - .value - .asInstanceOf[Name.Literal] - - lamArgArgName.name shouldEqual lamArgBodyArg1Name.name + |""".stripMargin + val megaIr = desugarWithMegaPass(code) + val miniIr = desugarWithMiniPass(code) + + for ((ir, msg) <- List((miniIr, "Mini IR"), (megaIr, "Mega IR"))) { + withClue("Processed by " + msg) { + ir shouldBe an[Function.Lambda] + ir.asInstanceOf[Function.Lambda].body shouldBe an[Application.Prefix] + val irBody = ir + .asInstanceOf[Function.Lambda] + .body + .asInstanceOf[Application.Prefix] + + irBody + .arguments(1) + .asInstanceOf[CallArgument.Specified] + .value shouldBe an[Function.Lambda] + val lamArg = irBody + .arguments(1) + .asInstanceOf[CallArgument.Specified] + .value + .asInstanceOf[Function.Lambda] + val lamArgArgName = + lamArg.arguments.head + .asInstanceOf[DefinitionArgument.Specified] + .name + + lamArg.body shouldBe an[Application.Prefix] + val lamArgBody = lamArg.body.asInstanceOf[Application.Prefix] + val lamArgBodyArg1Name = lamArgBody.arguments.head + .asInstanceOf[CallArgument.Specified] + .value + .asInstanceOf[Name.Literal] + + lamArgArgName.name shouldEqual lamArgBodyArg1Name.name + } + } } "work in named applications" in { - implicit val ctx: InlineContext = mkInlineContext - - val ir = + val code = """ |a _ (fn (t = _) c) - |""".stripMargin.preprocessExpression.get.desugar - - ir shouldBe an[Function.Lambda] - ir.asInstanceOf[Function.Lambda].body shouldBe an[Application.Prefix] - val irBody = ir - .asInstanceOf[Function.Lambda] - .body - .asInstanceOf[Application.Prefix] - - irBody - .arguments(1) - .asInstanceOf[CallArgument.Specified] - .value shouldBe an[Function.Lambda] - val lamArg = irBody - .arguments(1) - .asInstanceOf[CallArgument.Specified] - .value - .asInstanceOf[Function.Lambda] - val lamArgArgName = - lamArg.arguments.head.asInstanceOf[DefinitionArgument.Specified].name - - lamArg.body shouldBe an[Application.Prefix] - val lamArgBody = lamArg.body.asInstanceOf[Application.Prefix] - val lamArgBodyArg1Name = lamArgBody.arguments.head - .asInstanceOf[CallArgument.Specified] - .value - .asInstanceOf[Name.Literal] - - lamArgArgName.name shouldEqual lamArgBodyArg1Name.name + |""".stripMargin + val megaIr = desugarWithMegaPass(code) + val miniIr = desugarWithMiniPass(code) + + for ((ir, msg) <- List((miniIr, "Mini IR"), (megaIr, "Mega IR"))) { + withClue("Processed by " + msg) { + ir shouldBe an[Function.Lambda] + ir.asInstanceOf[Function.Lambda].body shouldBe an[Application.Prefix] + val irBody = ir + .asInstanceOf[Function.Lambda] + .body + .asInstanceOf[Application.Prefix] + + irBody + .arguments(1) + .asInstanceOf[CallArgument.Specified] + .value shouldBe an[Function.Lambda] + val lamArg = irBody + .arguments(1) + .asInstanceOf[CallArgument.Specified] + .value + .asInstanceOf[Function.Lambda] + val lamArgArgName = + lamArg.arguments.head + .asInstanceOf[DefinitionArgument.Specified] + .name + + lamArg.body shouldBe an[Application.Prefix] + val lamArgBody = lamArg.body.asInstanceOf[Application.Prefix] + val lamArgBodyArg1Name = lamArgBody.arguments.head + .asInstanceOf[CallArgument.Specified] + .value + .asInstanceOf[Name.Literal] + + lamArgArgName.name shouldEqual lamArgBodyArg1Name.name + } + } } "work in function argument defaults" in { - implicit val ctx: InlineContext = mkInlineContext - - val ir = + val code = """ |\a (b = f _ 1) -> f a - |""".stripMargin.preprocessExpression.get.desugar - - ir shouldBe an[Function.Lambda] - val irFn = ir.asInstanceOf[Function.Lambda] - val bArg = - irFn.arguments.tail.head.asInstanceOf[DefinitionArgument.Specified] - - bArg.defaultValue shouldBe defined - bArg.defaultValue.get shouldBe an[Function.Lambda] - val default = bArg.defaultValue.get.asInstanceOf[Function.Lambda] - val defaultArgName = default.arguments.head - .asInstanceOf[DefinitionArgument.Specified] - .name - - default.body shouldBe an[Application.Prefix] - val defBody = default.body.asInstanceOf[Application.Prefix] - val defBodyArg1Name = defBody.arguments.head - .asInstanceOf[CallArgument.Specified] - .value - .asInstanceOf[Name.Literal] - - defaultArgName.name shouldEqual defBodyArg1Name.name + |""".stripMargin + val megaIr = desugarWithMegaPass(code) + val miniIr = desugarWithMiniPass(code) + + for ((ir, msg) <- List((miniIr, "Mini IR"), (megaIr, "Mega IR"))) { + withClue("Processed by " + msg) { + ir shouldBe an[Function.Lambda] + val irFn = ir.asInstanceOf[Function.Lambda] + val bArg = + irFn.arguments.tail.head.asInstanceOf[DefinitionArgument.Specified] + + bArg.defaultValue shouldBe defined + bArg.defaultValue.get shouldBe an[Function.Lambda] + val default = bArg.defaultValue.get.asInstanceOf[Function.Lambda] + val defaultArgName = default.arguments.head + .asInstanceOf[DefinitionArgument.Specified] + .name + + default.body shouldBe an[Application.Prefix] + val defBody = default.body.asInstanceOf[Application.Prefix] + val defBodyArg1Name = defBody.arguments.head + .asInstanceOf[CallArgument.Specified] + .value + .asInstanceOf[Name.Literal] + + defaultArgName.name shouldEqual defBodyArg1Name.name + } + } } "work for case expressions" in { @@ -575,4 +617,42 @@ class LambdaShorthandToLambdaTest extends CompilerTest { secondLamArgName shouldEqual appArg2Name } } + + "Mini lambda shorthand pass" should { + "Produce same results as mega pass" in { + val codeInputs = List( + "_.length", + "foo a _ b _", + "foo (a = _) b _", + "_ a b", + "if _ then a", + """ + |case _ of + | Nil -> 0 + |""".stripMargin, + "(10 + _)", + "(_ +)", + "(_ + _)", + "(+ _)", + "[1, _, (3 + 4), _]", + """ + |case _ of + | Nil -> f _ b + |""".stripMargin, + "x = _", + "\\ x=_ -> x", + "(_ + 5) 5", + "(f _ _ b) b" + ) + + codeInputs.zipWithIndex.foreach { case (code, idx) => + val testName = "test-" + idx + val msg = s"Code that failed to compile: `$code`" + val megaIr = desugarWithMegaPass(code) + val miniIr = desugarWithMiniPass(code) + CompilerTests.assertEqualsIR(msg, testName, megaIr, miniIr) + } + } + } + } diff --git a/engine/runtime-integration-tests/src/test/scala/org/enso/compiler/test/pass/desugar/OperatorToFunctionTest.scala b/engine/runtime-integration-tests/src/test/scala/org/enso/compiler/test/pass/desugar/OperatorToFunctionTest.scala index adee7d6678d7..780145c77ea5 100644 --- a/engine/runtime-integration-tests/src/test/scala/org/enso/compiler/test/pass/desugar/OperatorToFunctionTest.scala +++ b/engine/runtime-integration-tests/src/test/scala/org/enso/compiler/test/pass/desugar/OperatorToFunctionTest.scala @@ -1,5 +1,8 @@ package org.enso.compiler.test.pass.desugar +import org.enso.compiler.Passes +import org.enso.compiler.context.{FreshNameSupply, InlineContext, ModuleContext} +import org.enso.compiler.core.ir.Module import org.enso.compiler.core.ir.{ CallArgument, Empty, @@ -9,10 +12,38 @@ import org.enso.compiler.core.ir.{ Name } import org.enso.compiler.core.ir.expression.{Application, Operator} -import org.enso.compiler.pass.desugar.OperatorToFunction -import org.enso.compiler.test.CompilerTest +import org.enso.compiler.pass.analyse.{ + AliasAnalysis, + DataflowAnalysis, + DemandAnalysis +} +import org.enso.compiler.pass.{ + IRPass, + IRProcessingPass, + MiniIRPass, + MiniPassFactory, + PassConfiguration, + PassManager +} +import org.enso.compiler.pass.desugar.{ + GenerateMethodBodies, + OperatorToFunction, + SectionsToBinOp +} +import org.enso.compiler.test.MiniPassTest + +class OperatorToFunctionTest extends MiniPassTest { + override def testName: String = "OperatorToFunction" -class OperatorToFunctionTest extends CompilerTest { + override def miniPassFactory: MiniPassFactory = OperatorToFunction + + override def megaPass: IRPass = OperatorToFunctionTestPass + + override def megaPassManager: PassManager = { + val passes = new Passes(defaultConfig) + val precursors = passes.getPrecursors(OperatorToFunction).get + new PassManager(List(precursors), PassConfiguration()) + } // === Utilities ============================================================ @@ -50,6 +81,16 @@ class OperatorToFunctionTest extends CompilerTest { } // === The Tests ============================================================ + val opName = + Name.Literal("=:=", isMethod = true, null) + val left = Empty(null) + val right = Empty(null) + val rightArg = CallArgument.Specified(None, Empty(null), null) + + val (operator, operatorFn) = genOprAndFn(opName, left, right) + + val oprArg = CallArgument.Specified(None, operator, null) + val oprFnArg = CallArgument.Specified(None, operatorFn, null) "Operators" should { val opName = @@ -70,15 +111,59 @@ class OperatorToFunctionTest extends CompilerTest { CallArgument.Specified(None, operatorFn, identifiedLocation = null) "be translated to functions" in { - OperatorToFunction.runExpression(operator, ctx) shouldEqual operatorFn + OperatorToFunctionTestPass.runExpression( + operator, + ctx + ) shouldEqual operatorFn } -// "be translated in module contexts" in { -// val moduleInput = operator.asModuleDefs -// val moduleOutput = operatorFn.asModuleDefs -// -// OperatorToFunction.runModule(moduleInput, modCtx) shouldEqual moduleOutput -// } + "be translated recursively in synthetic IR" in { + val recursiveIR = + Operator.Binary(oprArg, opName, rightArg, null) + val recursiveIRResult = Application.Prefix( + opName, + List(oprFnArg, rightArg), + hasDefaultsSuspended = false, + null + ) + + OperatorToFunctionTestPass.runExpression( + recursiveIR, + ctx + ) shouldEqual recursiveIRResult + } + + "be translated recursively" in { + val code = + """ + |main = + | a = 1 + 2 + | nested_method x y = x + y + | nested_method (3 * 4) a + |""".stripMargin + assertModuleCompilation( + code, + () => + buildModuleContext( + freshNameSupply = Some(new FreshNameSupply()) + ), + ir => { + ir.preorder().foreach { + case _: Operator.Binary => fail("Operator.Binary found") + case _ => + } + } + ) + } + } + + "Operators mini pass" should { + "be translated to functions" in { + val miniPass = OperatorToFunction.createForInlineCompilation(ctx) + val miniRes = + MiniIRPass.compile(classOf[Expression], operator, miniPass) + miniRes shouldEqual operatorFn + } "be translated recursively" in { val recursiveIR = @@ -90,10 +175,87 @@ class OperatorToFunctionTest extends CompilerTest { identifiedLocation = null ) - OperatorToFunction.runExpression( - recursiveIR, - ctx - ) shouldEqual recursiveIRResult + val miniPass = OperatorToFunction.createForInlineCompilation(ctx) + val miniRes = + MiniIRPass.compile(classOf[Expression], recursiveIR, miniPass) + miniRes shouldEqual recursiveIRResult + } + } +} + +/** Copied from the original implementation in `OperatorToFunction` + * This pass converts usages of operators to calls to standard functions. + * + * This pass requires the context to provide: + * + * - Nothing + */ +case object OperatorToFunctionTestPass extends IRPass { + + /** A purely desugaring pass has no analysis output. */ + override type Metadata = IRPass.Metadata.Empty + override type Config = IRPass.Configuration.Default + + override lazy val precursorPasses: Seq[IRProcessingPass] = List( + GenerateMethodBodies, + SectionsToBinOp.INSTANCE + ) + override lazy val invalidatedPasses: Seq[IRProcessingPass] = List( + AliasAnalysis, + DataflowAnalysis, + DemandAnalysis + ) + + /** Executes the conversion pass. + * + * @param ir the Enso IR to process + * @param moduleContext a context object that contains the information needed + * to process a module + * @return `ir`, possibly having made transformations or annotations to that + * IR. + */ + override def runModule( + ir: Module, + moduleContext: ModuleContext + ): Module = { + val new_bindings = ir.bindings.map { a => + a.mapExpressions( + runExpression( + _, + new InlineContext( + moduleContext, + compilerConfig = moduleContext.compilerConfig + ) + ) + ) + } + ir.copy(bindings = new_bindings) + } + + /** Executes the conversion pass in an inline context. + * + * @param ir the Enso IR to process + * @param inlineContext a context object that contains the information needed + * for inline evaluation + * @return `ir`, possibly having made transformations or annotations to that + * IR. + */ + override def runExpression( + ir: Expression, + inlineContext: InlineContext + ): Expression = { + ir.transformExpressions { case operatorBinary: Operator.Binary => + new Application.Prefix( + operatorBinary.operator, + List( + operatorBinary.left.mapExpressions(runExpression(_, inlineContext)), + operatorBinary.right.mapExpressions(runExpression(_, inlineContext)) + ), + hasDefaultsSuspended = false, + operatorBinary.location.orNull, + operatorBinary.passData, + operatorBinary.diagnostics + ) } } } diff --git a/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/desugar/SectionsToBinOp.scala b/engine/runtime-integration-tests/src/test/scala/org/enso/compiler/test/pass/desugar/SectionsToBinOpMegaPass.scala similarity index 96% rename from engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/desugar/SectionsToBinOp.scala rename to engine/runtime-integration-tests/src/test/scala/org/enso/compiler/test/pass/desugar/SectionsToBinOpMegaPass.scala index 6563675b7566..5434577d8faa 100644 --- a/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/desugar/SectionsToBinOp.scala +++ b/engine/runtime-integration-tests/src/test/scala/org/enso/compiler/test/pass/desugar/SectionsToBinOpMegaPass.scala @@ -1,4 +1,4 @@ -package org.enso.compiler.pass.desugar +package org.enso.compiler.test.pass.desugar import org.enso.compiler.context.{FreshNameSupply, InlineContext, ModuleContext} import org.enso.compiler.core.ir.{ @@ -12,6 +12,7 @@ import org.enso.compiler.core.ir.{ import org.enso.compiler.core.CompilerError import org.enso.compiler.core.ir.expression.{Application, Section} import org.enso.compiler.pass.IRPass +import org.enso.compiler.pass.IRProcessingPass import org.enso.compiler.pass.analyse._ import org.enso.compiler.pass.lint.UnusedBindings @@ -23,19 +24,19 @@ import org.enso.compiler.pass.lint.UnusedBindings * * - A [[FreshNameSupply]]. */ -case object SectionsToBinOp extends IRPass { +case object SectionsToBinOpMegaPass extends IRPass { override type Metadata = IRPass.Metadata.Empty override type Config = IRPass.Configuration.Default - override lazy val precursorPasses: Seq[IRPass] = List( - GenerateMethodBodies + override lazy val precursorPasses: Seq[IRProcessingPass] = List( + org.enso.compiler.pass.desugar.GenerateMethodBodies ) - override lazy val invalidatedPasses: Seq[IRPass] = List( + override lazy val invalidatedPasses: Seq[IRProcessingPass] = List( AliasAnalysis, CachePreferenceAnalysis, DataflowAnalysis, DemandAnalysis, - TailCall, + TailCall.INSTANCE, UnusedBindings ) diff --git a/engine/runtime-integration-tests/src/test/scala/org/enso/compiler/test/pass/desugar/SectionsToBinOpTest.scala b/engine/runtime-integration-tests/src/test/scala/org/enso/compiler/test/pass/desugar/SectionsToBinOpTest.scala index e4d68e9a5fe9..d3d4238d7ff5 100644 --- a/engine/runtime-integration-tests/src/test/scala/org/enso/compiler/test/pass/desugar/SectionsToBinOpTest.scala +++ b/engine/runtime-integration-tests/src/test/scala/org/enso/compiler/test/pass/desugar/SectionsToBinOpTest.scala @@ -5,50 +5,43 @@ import org.enso.compiler.context.{FreshNameSupply, InlineContext} import org.enso.compiler.core.ir.{ CallArgument, DefinitionArgument, - Expression, Function, Literal, Name } +import org.enso.compiler.pass.{ + IRPass, + MiniPassFactory, + PassConfiguration, + PassGroup, + PassManager +} + +import org.enso.compiler.core.IR import org.enso.compiler.core.ir.expression.Application import org.enso.compiler.pass.desugar.SectionsToBinOp -import org.enso.compiler.pass.{PassConfiguration, PassGroup, PassManager} -import org.enso.compiler.test.CompilerTest -class SectionsToBinOpTest extends CompilerTest { +import org.enso.compiler.test.MiniPassTest + +class SectionsToBinOpTest extends MiniPassTest { + override def testName: String = "Section To Bin Op" + override def miniPassFactory: MiniPassFactory = SectionsToBinOp.INSTANCE - // === Test Configuration =================================================== + override def megaPass: IRPass = SectionsToBinOpMegaPass - val passes = new Passes(defaultConfig) + override def megaPassManager: PassManager = + new PassManager(List(precursorPasses), passConfiguration) - val precursorPasses: PassGroup = passes.getPrecursors(SectionsToBinOp).get + // === Test Setup =========================================================== + val passes = new Passes(defaultConfig) val passConfiguration: PassConfiguration = PassConfiguration() + val precursorPasses: PassGroup = + passes.getPrecursors(SectionsToBinOp.INSTANCE).get implicit val passManager: PassManager = new PassManager(List(precursorPasses), passConfiguration) - /** Adds an extension method for running desugaring on the input IR. - * - * @param ir the IR to desugar - */ - implicit class DesugarExpression(ir: Expression) { - - /** Runs section desugaring on [[ir]]. - * - * @param inlineContext the inline context in which the desugaring takes - * place - * @return [[ir]], with all sections desugared - */ - def desugar(implicit inlineContext: InlineContext): Expression = { - SectionsToBinOp.runExpression(ir, inlineContext) - } - } - - /** Makes an inline context. - * - * @return a new inline context - */ def mkInlineContext: InlineContext = { buildInlineContext(freshNameSupply = Some(new FreshNameSupply)) } @@ -56,14 +49,21 @@ class SectionsToBinOpTest extends CompilerTest { // === The Tests ============================================================ "Operator section desugaring" should { - "work for left sections" in { - implicit val ctx: InlineContext = mkInlineContext - val ir = - """ - |(1 +) - |""".stripMargin.preprocessExpression.get.desugar + "work for left sections" in { + val code = """ + |(1 +) + |""".stripMargin + + assertInlineCompilation( + code, + () => mkInlineContext, + assertWorkForLeftSection, + true + ) + } + def assertWorkForLeftSection(ir: IR) = { ir shouldBe an[Application.Prefix] ir.location shouldBe defined @@ -81,13 +81,20 @@ class SectionsToBinOpTest extends CompilerTest { } "work for sides sections" in { - implicit val ctx: InlineContext = mkInlineContext - - val ir = + val code = """ |(+) - |""".stripMargin.preprocessExpression.get.desugar + |""".stripMargin + + assertInlineCompilation( + code, + () => mkInlineContext, + assertWorksForSidesSection, + true + ) + } + def assertWorksForSidesSection(ir: IR) = { ir shouldBe an[Function.Lambda] ir.location shouldBe defined @@ -123,13 +130,21 @@ class SectionsToBinOpTest extends CompilerTest { } "work for right sections" in { - implicit val ctx: InlineContext = mkInlineContext - - val ir = + val code = """ |(+ 1) - |""".stripMargin.preprocessExpression.get.desugar + |""".stripMargin + + assertInlineCompilation( + code, + () => mkInlineContext, + assertWorkForRightSections, + true + ) + } + + def assertWorkForRightSections(ir: IR) = { ir shouldBe an[Function.Lambda] ir.location shouldBe defined @@ -152,13 +167,21 @@ class SectionsToBinOpTest extends CompilerTest { } "work when the section is nested" in { - implicit val ctx: InlineContext = mkInlineContext - - val ir = + val code = """ |x -> (x +) - |""".stripMargin.preprocessExpression.get.desugar - .asInstanceOf[Function.Lambda] + |""".stripMargin + + assertInlineCompilation( + code, + () => mkInlineContext, + assertWorkWhenTheSectionIsNested, + true + ) + } + + def assertWorkWhenTheSectionIsNested(x: IR) = { + val ir = x.asInstanceOf[Function.Lambda] ir.body .asInstanceOf[Application.Prefix] @@ -170,13 +193,21 @@ class SectionsToBinOpTest extends CompilerTest { } "flip the arguments when a right section's argument is a blank" in { - implicit val ctx: InlineContext = mkInlineContext - - val ir = + val code = """ |(+ _) - |""".stripMargin.preprocessExpression.get.desugar + |""".stripMargin + + assertInlineCompilation( + code, + () => mkInlineContext, + assertFlipTheArgumentsWhenARightSection, + true + ) + + } + def assertFlipTheArgumentsWhenARightSection(ir: IR) = { ir shouldBe an[Function.Lambda] ir.location shouldBe defined val irFn = ir.asInstanceOf[Function.Lambda] diff --git a/engine/runtime/src/main/scala/org/enso/interpreter/runtime/IrToTruffle.scala b/engine/runtime/src/main/scala/org/enso/interpreter/runtime/IrToTruffle.scala index 968e20996501..f2d24d34da3d 100644 --- a/engine/runtime/src/main/scala/org/enso/interpreter/runtime/IrToTruffle.scala +++ b/engine/runtime/src/main/scala/org/enso/interpreter/runtime/IrToTruffle.scala @@ -985,7 +985,7 @@ class IrToTruffle( expression: Expression ): BaseNode.TailStatus = { val isTailPosition = - expression.getMetadata(TailCall).isDefined + expression.getMetadata(TailCall.INSTANCE).isDefined val isTailAnnotated = TailCall.isTailAnnotated(expression) if (isTailPosition) { if (isTailAnnotated) {