diff --git a/build.sbt b/build.sbt index 53b64ad..1c2f6c9 100644 --- a/build.sbt +++ b/build.sbt @@ -1,10 +1,10 @@ -scalaVersion := "2.13.14" +scalaVersion := "3.3.3" name := "struct-tensor" organization := "uk.ac.ed.dal" version := "0.1" -libraryDependencies += "org.scala-lang.modules" %% "scala-parser-combinators" % "1.1.2" +libraryDependencies += "org.scala-lang.modules" %% "scala-parser-combinators" % "2.0.0" libraryDependencies += "com.lihaoyi" %% "fastparse" % "3.0.2" libraryDependencies += "com.github.scopt" %% "scopt" % "4.1.0" libraryDependencies += "org.scalatest" %% "scalatest" % "3.2.11" % Test diff --git a/src/main/scala/uk/ac/ed/dal/structtensor/Main.scala b/src/main/scala/uk/ac/ed/dal/structtensor/Main.scala index f1c6077..ddb75ef 100644 --- a/src/main/scala/uk/ac/ed/dal/structtensor/Main.scala +++ b/src/main/scala/uk/ac/ed/dal/structtensor/Main.scala @@ -10,7 +10,8 @@ import codegen._ import java.io.File import scopt.OParser -object Main extends App { +object Main { + def main(args: Array[String]) = { import Optimizer._ import Utils._ @@ -201,7 +202,7 @@ object Main extends App { val rcRule = Rule( ccRule.head, SoPTimesSoP( - SoP(Seq(Prod(Seq(ccRule.head.vars2RedundancyVars)))), + SoP(Seq(Prod(Seq(ccRule.head.vars2RedundancyVars())))), rmRule.body ) ) @@ -331,3 +332,4 @@ object Main extends App { } } } +} diff --git a/src/main/scala/uk/ac/ed/dal/structtensor/codegen/Codegen.scala b/src/main/scala/uk/ac/ed/dal/structtensor/codegen/Codegen.scala index 4677d7f..0b2d669 100644 --- a/src/main/scala/uk/ac/ed/dal/structtensor/codegen/Codegen.scala +++ b/src/main/scala/uk/ac/ed/dal/structtensor/codegen/Codegen.scala @@ -25,7 +25,6 @@ object Codegen { case c: ConstantDouble => c.value.toString case a @ Arithmetic(op, i1, i2) => s"(${CPPFormat(i1)} ${op} ${CPPFormat(i2)})" - case _ => "" } case a: Access => if (a.vars.isEmpty) CPPFormat(a.name) @@ -81,7 +80,6 @@ object Codegen { } case _ => None } - case _ => None } val (begin, end, equals, usedIndices) = ( begin_end_equals_usedIndices.map(_._1), diff --git a/src/main/scala/uk/ac/ed/dal/structtensor/compiler/Compiler.scala b/src/main/scala/uk/ac/ed/dal/structtensor/compiler/Compiler.scala index 149e7a5..af85411 100644 --- a/src/main/scala/uk/ac/ed/dal/structtensor/compiler/Compiler.scala +++ b/src/main/scala/uk/ac/ed/dal/structtensor/compiler/Compiler.scala @@ -3,7 +3,7 @@ package structtensor package compiler import utils._ -import sourcecode.Macros.Chunk.Var + import uk.ac.ed.dal.structtensor.parser.Parser.variableSeq object Compiler { @@ -235,8 +235,8 @@ object Compiler { } def shift(lhs: Access, rhs: Access, exps: Seq[Exp]): (Rule, Rule, Rule) = { - val usBody = SoP(Seq(Prod(Seq(rhs.uniqueHead) ++ exps))) - val us = Rule(lhs.uniqueHead, usBody) + val usBody = SoP(Seq(Prod(Seq(rhs.uniqueHead()) ++ exps))) + val us = Rule(lhs.uniqueHead(), usBody) val redExps = exps.map { case Comparison( @@ -246,16 +246,16 @@ object Compiler { ) => Comparison( "=", - Arithmetic("-", index1.vars2RedundancyVars, index2), - variable.vars2RedundancyVars + Arithmetic("-", index1.vars2RedundancyVars(), index2), + variable.vars2RedundancyVars() ) case exp => exp } - val rmBody = SoP(Seq(Prod(Seq(rhs.redundancyHead) ++ exps ++ redExps))) - val rm = Rule(lhs.redundancyHead, rmBody) + val rmBody = SoP(Seq(Prod(Seq(rhs.redundancyHead()) ++ exps ++ redExps))) + val rm = Rule(lhs.redundancyHead(), rmBody) - val cBody = SoP(Seq(Prod(Seq(rhs.compressedHead) ++ exps))) - val c = Rule(lhs.compressedHead, cBody) + val cBody = SoP(Seq(Prod(Seq(rhs.compressedHead()) ++ exps))) + val c = Rule(lhs.compressedHead(), cBody) (us, rm, c) } @@ -289,23 +289,23 @@ object Compiler { ctx ) ) { - val us = Rule(lhs.uniqueHead, SoP(Seq(Prod(Seq(rhs.uniqueHead))))) - val rm = Rule(lhs.redundancyHead, SoP(Seq(Prod(Seq(rhs.redundancyHead))))) - val c = Rule(lhs.compressedHead, SoP(Seq(Prod(Seq(rhs.compressedHead))))) + val us = Rule(lhs.uniqueHead(), SoP(Seq(Prod(Seq(rhs.uniqueHead()))))) + val rm = Rule(lhs.redundancyHead(), SoP(Seq(Prod(Seq(rhs.redundancyHead()))))) + val c = Rule(lhs.compressedHead(), SoP(Seq(Prod(Seq(rhs.compressedHead()))))) (us, rm, c) } else { val us = Rule( - lhs.uniqueHead, - SoP(Seq(Prod(Seq(rhs.uniqueHead)), Prod(Seq(rhs.redundancyHead)))) + lhs.uniqueHead(), + SoP(Seq(Prod(Seq(rhs.uniqueHead())), Prod(Seq(rhs.redundancyHead())))) ) - val rm = Rule(lhs.redundancyHead, emptySoP()) + val rm = Rule(lhs.redundancyHead(), emptySoP()) val c = Rule( - lhs.compressedHead, + lhs.compressedHead(), SoP( Seq( - Prod(Seq(rhs.compressedHead)), + Prod(Seq(rhs.compressedHead())), Prod( - Seq(rhs.redundancyHead, rhs.vars2RedundancyVars.compressedHead()) + Seq(rhs.redundancyHead(), rhs.vars2RedundancyVars().compressedHead()) ) ) ) @@ -347,16 +347,16 @@ object Compiler { vars2 ) ) { - val usBody = SoP(Seq(Prod(Seq(acc1.uniqueHead, acc2.uniqueHead)))) - val us = Rule(lhs.uniqueHead, usBody) + val usBody = SoP(Seq(Prod(Seq(acc1.uniqueHead(), acc2.uniqueHead())))) + val us = Rule(lhs.uniqueHead(), usBody) val rmBody = SoP( Seq( - Prod(Seq(acc1.redundancyHead, acc2.redundancyHead)), + Prod(Seq(acc1.redundancyHead(), acc2.redundancyHead())), Prod( Seq( - acc1.uniqueHead, - acc2.redundancyHead + acc1.uniqueHead(), + acc2.redundancyHead() ) ++ vectorizeComparisonMultiplication( "=", vars1, @@ -365,8 +365,8 @@ object Compiler { ), Prod( Seq( - acc1.redundancyHead, - acc2.uniqueHead + acc1.redundancyHead(), + acc2.uniqueHead() ) ++ vectorizeComparisonMultiplication( "=", vars2, @@ -375,49 +375,49 @@ object Compiler { ) ) ) - val rm = Rule(lhs.redundancyHead, rmBody) + val rm = Rule(lhs.redundancyHead(), rmBody) val cBody = SoP( - Seq(Prod(Seq(acc1.compressedHead, acc2.compressedHead))) + Seq(Prod(Seq(acc1.compressedHead(), acc2.compressedHead()))) ) - val c = Rule(lhs.compressedHead, cBody) + val c = Rule(lhs.compressedHead(), cBody) (us, rm, c) } else if (lhs.vars.toSet == vars1.union(vars2).toSet) { val usBody = SoP( Seq( - Prod(Seq(acc1.uniqueHead, acc2.uniqueHead)), - Prod(Seq(acc1.uniqueHead, acc2.redundancyHead)), - Prod(Seq(acc1.redundancyHead, acc2.uniqueHead)) + Prod(Seq(acc1.uniqueHead(), acc2.uniqueHead())), + Prod(Seq(acc1.uniqueHead(), acc2.redundancyHead())), + Prod(Seq(acc1.redundancyHead(), acc2.uniqueHead())) ) ) - val us = Rule(lhs.uniqueHead, usBody) + val us = Rule(lhs.uniqueHead(), usBody) val rmBody = SoP( - Seq(Prod(Seq(acc1.redundancyHead, acc2.redundancyHead))) + Seq(Prod(Seq(acc1.redundancyHead(), acc2.redundancyHead()))) ) - val rm: Rule = Rule(lhs.redundancyHead, rmBody) + val rm: Rule = Rule(lhs.redundancyHead(), rmBody) val cBody = SoP( Seq( - Prod(Seq(acc1.compressedHead, acc2.compressedHead)), + Prod(Seq(acc1.compressedHead(), acc2.compressedHead())), Prod( Seq( - acc1.compressedHead, - acc2.redundancyHead, - acc2.vars2RedundancyVars.compressedHead() + acc1.compressedHead(), + acc2.redundancyHead(), + acc2.vars2RedundancyVars().compressedHead() ) ), Prod( Seq( - acc1.redundancyHead, - acc2.compressedHead, - acc1.vars2RedundancyVars.compressedHead() + acc1.redundancyHead(), + acc2.compressedHead(), + acc1.vars2RedundancyVars().compressedHead() ) ) ) ) - val c = Rule(lhs.compressedHead, cBody) + val c = Rule(lhs.compressedHead(), cBody) (us, rm, c) } else @@ -429,14 +429,14 @@ object Compiler { acc1 @ Access(name1, vars1, Tensor), Comparison(op2, index2, variable2) ) => { - val usBody = SoP(Seq(Prod(Seq(acc1.uniqueHead, e2)))) - val us = Rule(lhs.uniqueHead, usBody) + val usBody = SoP(Seq(Prod(Seq(acc1.uniqueHead(), e2)))) + val us = Rule(lhs.uniqueHead(), usBody) - val rmBody = SoP(Seq(Prod(Seq(acc1.redundancyHead, e2)))) - val rm = Rule(lhs.redundancyHead, rmBody) + val rmBody = SoP(Seq(Prod(Seq(acc1.redundancyHead(), e2)))) + val rm = Rule(lhs.redundancyHead(), rmBody) - val cBody = SoP(Seq(Prod(Seq(acc1.compressedHead, e2)))) - val c = Rule(lhs.compressedHead, cBody) + val cBody = SoP(Seq(Prod(Seq(acc1.compressedHead(), e2)))) + val c = Rule(lhs.compressedHead(), cBody) (us, rm, c) } @@ -501,7 +501,7 @@ object Compiler { else index1 val usBody = SoP(Seq(Prod(Seq(Comparison("=", indexEq, variable))))) - val us = Rule(lhs.uniqueHead, usBody) + val us = Rule(lhs.uniqueHead(), usBody) val rmBody = SoP( Seq( @@ -514,21 +514,21 @@ object Compiler { ) ) ) - val rm = Rule(lhs.redundancyHead, rmBody) + val rm = Rule(lhs.redundancyHead(), rmBody) val cBody = SoP(Seq(Prod(Seq(Comparison("=", indexEq, variable))))) - val c = Rule(lhs.compressedHead, cBody) + val c = Rule(lhs.compressedHead(), cBody) (us, rm, c) } else { val usBody = SoP(Seq(Prod(Seq(e1, e2)))) - val us = Rule(lhs.uniqueHead, usBody) + val us = Rule(lhs.uniqueHead(), usBody) val rmBody = emptySoP() - val rm = Rule(lhs.redundancyHead, rmBody) + val rm = Rule(lhs.redundancyHead(), rmBody) val cBody = SoP(Seq(Prod(Seq(e1, e2)))) - val c = Rule(lhs.compressedHead, cBody) + val c = Rule(lhs.compressedHead(), cBody) (us, rm, c) } @@ -676,7 +676,7 @@ object Compiler { us_minus_last.body, SoP(Seq(Prod(comparisonSeq))) ) - val us = Rule(lhs.uniqueHead, usBody) + val us = Rule(lhs.uniqueHead(), usBody) val all_vars_minus_last = init.map(_.vars) val rmBody1 = @@ -699,7 +699,7 @@ object Compiler { c_minus_last.body, SoP(Seq(Prod(comparisonSeq))) ) - val c = Rule(lhs.compressedHead, cBody) + val c = Rule(lhs.compressedHead(), cBody) (us, rm, c) } @@ -869,9 +869,9 @@ object Compiler { case (acc, ind) if indexList.contains(ind) => vectorizedMultVarEqualsRedundancyVarSeq( ind - ) :+ acc.uniqueHead + ) :+ acc.uniqueHead() case (acc, ind) if !indexList.contains(ind) => - Seq(acc.redundancyHead) + Seq(acc.redundancyHead()) } .flatten .toSeq @@ -885,7 +885,7 @@ object Compiler { prodTimesSoP(Prod(allUniqueHeads), injectedMapRM) val rmBody = concatSoP(Seq(SoP(bodyRMProdSeq1), uniqueHeadIncludedInRM)) - val rm = Rule(lhs.redundancyHead, rmBody) + val rm = Rule(lhs.redundancyHead(), rmBody) val cBody = prodTimesSoP(Prod(accSeq.map(_.compressedHead())), c2.body) val c = Rule(lhs.compressedHead(), cBody) @@ -910,9 +910,9 @@ object Compiler { all_intersect, all_intersect.redundancyVars ) :+ - acc.uniqueHead + acc.uniqueHead() case (acc, ind) if indexList.contains(ind) => - Seq(acc.redundancyHead) + Seq(acc.redundancyHead()) } .flatten .toSeq @@ -991,12 +991,12 @@ object Compiler { }) .unzip3 - val us = Rule(lhs.uniqueHead, concatSoP(usSoPSeq)) + val us = Rule(lhs.uniqueHead(), concatSoP(usSoPSeq)) val rm = Rule( - lhs.redundancyHead, + lhs.redundancyHead(), concatSoP(rmSoPSeq :+ SoP(Seq(Prod(accSeq.map(_.redundancyHead()))))) ) - val c = Rule(lhs.compressedHead, concatSoP(cSoPSeq)) + val c = Rule(lhs.compressedHead(), concatSoP(cSoPSeq)) (us, rm, c) } @@ -1050,13 +1050,13 @@ object Compiler { case (acc1 @ Access(_, vars1, _), acc2 @ Access(_, vars2, _)) if (vars1.toSet == vars2.toSet) => { val acc1UniqueSetBody = - locallyDenormalizeAndReturnBody(acc1.uniqueHead, ctx) + locallyDenormalizeAndReturnBody(acc1.uniqueHead(), ctx) val acc2UniqueSetBody = - locallyDenormalizeAndReturnBody(acc2.uniqueHead, ctx) + locallyDenormalizeAndReturnBody(acc2.uniqueHead(), ctx) val acc1RedundancyMapBody = - locallyDenormalizeAndReturnBody(acc1.redundancyHead, ctx) + locallyDenormalizeAndReturnBody(acc1.redundancyHead(), ctx) val acc2RedundancyMapBody = - locallyDenormalizeAndReturnBody(acc2.redundancyHead, ctx) + locallyDenormalizeAndReturnBody(acc2.redundancyHead(), ctx) if ( isSoPEquals(acc1UniqueSetBody, acc2UniqueSetBody) && isSoPEquals( @@ -1064,16 +1064,16 @@ object Compiler { acc2RedundancyMapBody ) ) { - val usBody = SoP(Seq(Prod(Seq(acc1.uniqueHead)))) - val us = Rule(lhs.uniqueHead, usBody) + val usBody = SoP(Seq(Prod(Seq(acc1.uniqueHead())))) + val us = Rule(lhs.uniqueHead(), usBody) - val rmBody = SoP(Seq(Prod(Seq(acc1.redundancyHead)))) - val rm = Rule(lhs.redundancyHead, rmBody) + val rmBody = SoP(Seq(Prod(Seq(acc1.redundancyHead())))) + val rm = Rule(lhs.redundancyHead(), rmBody) val cBody = SoP( - Seq(Prod(Seq(acc1.compressedHead)), Prod(Seq(acc2.compressedHead))) + Seq(Prod(Seq(acc1.compressedHead())), Prod(Seq(acc2.compressedHead()))) ) - val c = Rule(lhs.compressedHead, cBody) + val c = Rule(lhs.compressedHead(), cBody) (us, rm, c) } else if ( @@ -1083,53 +1083,53 @@ object Compiler { ) ) { val usBody = SoP( - Seq(Prod(Seq(acc1.uniqueHead)), Prod(Seq(acc2.uniqueHead))) + Seq(Prod(Seq(acc1.uniqueHead())), Prod(Seq(acc2.uniqueHead()))) ) - val us = Rule(lhs.uniqueHead, usBody) + val us = Rule(lhs.uniqueHead(), usBody) val rmBody = SoP( - Seq(Prod(Seq(acc1.redundancyHead)), Prod(Seq(acc2.redundancyHead))) + Seq(Prod(Seq(acc1.redundancyHead())), Prod(Seq(acc2.redundancyHead()))) ) - val rm = Rule(lhs.redundancyHead, rmBody) + val rm = Rule(lhs.redundancyHead(), rmBody) val cBody = SoP( - Seq(Prod(Seq(acc1.compressedHead)), Prod(Seq(acc2.compressedHead))) + Seq(Prod(Seq(acc1.compressedHead())), Prod(Seq(acc2.compressedHead()))) ) - val c = Rule(lhs.compressedHead, cBody) + val c = Rule(lhs.compressedHead(), cBody) (us, rm, c) } else { val usBody = SoP( Seq( - Prod(Seq(acc1.uniqueHead)), - Prod(Seq(acc2.uniqueHead)), - Prod(Seq(acc1.redundancyHead)), - Prod(Seq(acc2.redundancyHead)) + Prod(Seq(acc1.uniqueHead())), + Prod(Seq(acc2.uniqueHead())), + Prod(Seq(acc1.redundancyHead())), + Prod(Seq(acc2.redundancyHead())) ) ) - val us = Rule(lhs.uniqueHead, usBody) + val us = Rule(lhs.uniqueHead(), usBody) - val rm = Rule(lhs.redundancyHead, emptySoP()) + val rm = Rule(lhs.redundancyHead(), emptySoP()) val cBody = SoP( Seq( - Prod(Seq(acc1.compressedHead)), - Prod(Seq(acc2.compressedHead)), + Prod(Seq(acc1.compressedHead())), + Prod(Seq(acc2.compressedHead())), Prod( Seq( - acc1.redundancyHead, - acc1.vars2RedundancyVars.compressedHead() + acc1.redundancyHead(), + acc1.vars2RedundancyVars().compressedHead() ) ), Prod( Seq( - acc2.redundancyHead, - acc2.vars2RedundancyVars.compressedHead() + acc2.redundancyHead(), + acc2.vars2RedundancyVars().compressedHead() ) ) ) ) - val c = Rule(lhs.compressedHead, cBody) + val c = Rule(lhs.compressedHead(), cBody) (us, rm, c) } diff --git a/src/main/scala/uk/ac/ed/dal/structtensor/compiler/Optimizer.scala b/src/main/scala/uk/ac/ed/dal/structtensor/compiler/Optimizer.scala index 3384cca..c681764 100644 --- a/src/main/scala/uk/ac/ed/dal/structtensor/compiler/Optimizer.scala +++ b/src/main/scala/uk/ac/ed/dal/structtensor/compiler/Optimizer.scala @@ -130,12 +130,12 @@ object Optimizer { acc + (us.head -> usBody) + (rm.head -> rmBody) + (cc.head -> ccBody) + (tc.head -> tcBody) }) - val denormUS = Rule(head, denormMap(head.uniqueHead)) + val denormUS = Rule(head, denormMap(head.uniqueHead())) val denormRM = Rule( Access(head.name, head.vars.redundancyVarsInplace, head.kind), - denormMap(head.redundancyHead) + denormMap(head.redundancyHead()) ) - val denormCC = Rule(head, denormMap(head.compressedHead)) + val denormCC = Rule(head, denormMap(head.compressedHead())) val denormTC = Rule(head, denormMap(head)) (denormUS, denormRM, denormCC, denormTC) diff --git a/src/main/scala/uk/ac/ed/dal/structtensor/compiler/STUR.scala b/src/main/scala/uk/ac/ed/dal/structtensor/compiler/STUR.scala index 86320fd..0433f1a 100644 --- a/src/main/scala/uk/ac/ed/dal/structtensor/compiler/STUR.scala +++ b/src/main/scala/uk/ac/ed/dal/structtensor/compiler/STUR.scala @@ -61,21 +61,21 @@ case class Arithmetic(op: String, index1: Index, index2: Index) extends Index with Dim { def prettyFormat(): String = - s"(${index1.prettyFormat} $op ${index2.prettyFormat})" + s"(${index1.prettyFormat()} $op ${index2.prettyFormat()})" def vars2RedundancyVars(): Arithmetic = - Arithmetic(op, index1.vars2RedundancyVars, index2.vars2RedundancyVars) + Arithmetic(op, index1.vars2RedundancyVars(), index2.vars2RedundancyVars()) } case class Access(name: String, vars: Seq[Variable], kind: AccessType) extends Exp { import Utils._ def prettyFormat(): String = { - val pr = vars.map(_.prettyFormat).mkString(", ") + val pr = vars.map(_.prettyFormat()).mkString(", ") if (pr.isEmpty) name else s"$name($pr)" } def vars2RedundancyVars(): Access = - Access(name, vars.map(_.vars2RedundancyVars), kind) + Access(name, vars.map(_.vars2RedundancyVars()), kind) def uniqueHead(): Access = Access(name.uniqueName, vars, UniqueSet) def redundancyHead(): Access = Access(name.redundancyName, vars.redundancyVarsInplace, RedundancyMap) @@ -90,16 +90,16 @@ case class Comparison(op: String, index: Index, variable: Variable) extends Exp { def prettyFormat(): String = index match { case _: Variable | _ if op != "=" => - s"(${index.prettyFormat} $op ${variable.prettyFormat})" - case _ => s"(${variable.prettyFormat} $op ${index.prettyFormat})" + s"(${index.prettyFormat()} $op ${variable.prettyFormat()})" + case _ => s"(${variable.prettyFormat()} $op ${index.prettyFormat()})" } def vars2RedundancyVars(): Comparison = - Comparison(op, index.vars2RedundancyVars, variable.vars2RedundancyVars) + Comparison(op, index.vars2RedundancyVars(), variable.vars2RedundancyVars()) } case class Prod(exps: Seq[Exp]) { def prettyFormat(): String = { - val pr = exps.map(_.prettyFormat).mkString(" * ") + val pr = exps.map(_.prettyFormat()).mkString(" * ") if (pr.isEmpty) "∅" else pr } def inverse(): Prod = Prod( @@ -114,13 +114,13 @@ case class Prod(exps: Seq[Exp]) { case class SoP(prods: Seq[Prod]) extends RuleOrSoP { def prettyFormat(): String = { - val pr = prods.map(_.prettyFormat).mkString(" + ") + val pr = prods.map(_.prettyFormat()).mkString(" + ") if (pr.isEmpty) "∅" else pr } def vars2RedundancyVars(): SoP = { SoP(prods.map { prod => - Prod(prod.exps.map(_.vars2RedundancyVars)) + Prod(prod.exps.map(_.vars2RedundancyVars())) }) } @@ -128,7 +128,7 @@ case class SoP(prods: Seq[Prod]) extends RuleOrSoP { } case class Rule(head: Access, body: SoP) extends RuleOrSoP { - def prettyFormat(): String = s"${head.prettyFormat} := ${body.prettyFormat}" + def prettyFormat(): String = s"${head.prettyFormat()} := ${body.prettyFormat()}" def inverse(): Rule = Rule(head.inverseHead(), body.inverse()) }