diff --git a/src/main/scala/edu/berkeley/cs/rise/quartz/PayableExtractor.scala b/src/main/scala/edu/berkeley/cs/rise/quartz/PayableExtractor.scala new file mode 100644 index 0000000..23ed1f1 --- /dev/null +++ b/src/main/scala/edu/berkeley/cs/rise/quartz/PayableExtractor.scala @@ -0,0 +1,50 @@ +package edu.berkeley.cs.rise.quartz + +object PayableExtractor { + + def extractPayableVars(stateMachine: StateMachine): (Set[String], Set[(String, String)], Set[(String, String)]) = { + var fields: Set[String] = Set.empty[String] + var params: Set[(String, String)] = Set.empty[(String, String)] + val structFields: Set[(String, String)] = Set.empty[(String, String)] + + var previousSize = 0 + do { + previousSize = fields.size + params.size + structFields.size + stateMachine.transitions foreach { transition => + val transParams = transition.parameters.getOrElse(Seq.empty[Variable]).map(row => row.name) + transition.body.getOrElse(Seq.empty[Statement]) foreach { + case Send(destination, _, _) => { + val destVar = identifyVariable(destination) + if (transParams.contains(destVar)) { + params += ((transition.name, destVar)) + } else { + fields += destVar + } + } + case Assignment(left, right) if fields.contains(left.rootName) || + params.contains((transition.name, left.rootName)) => { + val rightVar = identifyVariable(right) + if (transParams.contains(rightVar)) { + params += ((transition.name, rightVar)) + } else { + fields += rightVar + } + } + case _ => "default" + } + } + } while (previousSize != fields.size + params.size + structFields.size) + + println("Fields:", fields) + println("Params:", params) + println("Struct:", structFields) + return (fields, params, structFields) + } + + def identifyVariable(expression: Expression): String = expression match { + case MappingRef(map, key) => identifyVariable(map) + case VarRef(name) => name + case SequenceSize(sequence) => identifyVariable(sequence) + case _ => "" + } +} diff --git a/src/main/scala/edu/berkeley/cs/rise/quartz/Solidity.scala b/src/main/scala/edu/berkeley/cs/rise/quartz/Solidity.scala index 976d25a..20c1383 100644 --- a/src/main/scala/edu/berkeley/cs/rise/quartz/Solidity.scala +++ b/src/main/scala/edu/berkeley/cs/rise/quartz/Solidity.scala @@ -34,19 +34,17 @@ object Solidity { case Bool => "bool" case Timespan => "uint" case HashValue(_) => "bytes32" - case Mapping(keyType, valueType) => s"mapping(${writeType(keyType, payable)} => ${writeType(valueType, payable)})" + case Mapping(keyType, valueType) => s"mapping(${writeType(keyType, false)} => ${writeType(valueType, payable)})" case Sequence(elementType) => s"${writeType(elementType, payable)}[]" case Struct(name) => name } - private def writeStructDefinition(name: String, fields: Map[String, DataType]): String = { + private def writeStructDefinition(name: String, fields: Map[String, DataType], payableFields: Set[(String, String)]): String = { val builder = new StringBuilder() appendLine(builder, s"struct $name {") - indentationLevel += 1 fields.foreach { case (fName, fTy) => - // TODO Determine when field must be marked payable - appendLine(builder, s"${writeType(fTy, payable = false)} $fName;") + appendLine(builder, s"${writeType(fTy, payable = payableFields.contains((name, fName)))} $fName;") } indentationLevel -= 1 appendLine(builder, "}") @@ -200,14 +198,14 @@ object Solidity { builder.toString() } - private def writeTransition(transition: Transition, useCall: Boolean = false): String = { + private def writeTransition(transition: Transition, useCall: Boolean = false, payableParams: Set[(String, String)]): String = { val builder = new StringBuilder() val paramsRepr = transition.parameters.fold("") { params => // Remove parameters that are used in the original source but are built in to Solidity val effectiveParams = params.filter(p => !BUILTIN_PARAMS.contains(p.name)) - val payableParams = extractPayableVars(transition.body.getOrElse(Seq.empty[Statement]), effectiveParams.map(_.name).toSet) - writeParameters(effectiveParams.zip(effectiveParams.map(p => payableParams.contains(p.name)))) + val payables = payableParams.filter(_._1.equals(transition.name)).map(_._2) + writeParameters(effectiveParams.zip(effectiveParams.map(p => payables.contains(p.name)))) } val payable = if (transition.parameters.getOrElse(Seq.empty[Variable]).exists(_.name == "tokens")) { @@ -485,9 +483,9 @@ object Solidity { appendLine(builder, s"contract $name {") indentationLevel += 1 - val payableFields = extractPayableVars(stateMachine.flattenStatements, stateMachine.fields.map(_.name).toSet) + var (fields, params, structFields) = PayableExtractor.extractPayableVars(stateMachine) - stateMachine.structs.foreach { case (name, fields) => builder.append(writeStructDefinition(name, fields)) } + stateMachine.structs.foreach { case (name, fields) => builder.append(writeStructDefinition(name, fields, structFields)) } appendLine(builder, "enum State {") indentationLevel += 1 @@ -497,12 +495,12 @@ object Solidity { indentationLevel -= 1 appendLine(builder, "}") - stateMachine.fields.foreach(f => appendLine(builder, writeField(f, payableFields.contains(f.name)) + ";")) + stateMachine.fields.foreach(f => appendLine(builder, writeField(f, fields.contains(f.name)) + ";")) appendLine(builder, s"State public $CURRENT_STATE_VAR;") builder.append(writeAuthorizationFields(stateMachine)) builder.append("\n") - stateMachine.transitions foreach { t => builder.append(writeTransition(t, useCall)) } + stateMachine.transitions foreach { t => builder.append(writeTransition(t, useCall, params)) } extractAllMembershipTypes(stateMachine).foreach(ty => builder.append(writeSequenceContainsTest(ty) + "\n")) builder.append("\n") @@ -557,28 +555,4 @@ object Solidity { expressionChecks } } - - private def extractVarNames(expression: Expression): Set[String] = expression match { - case MappingRef(map, key) => extractVarNames(map) ++ extractVarNames(key) - case VarRef(name) => Set(name) - case LogicalOperation(left, _, right) => extractVarNames(left) ++ extractVarNames(right) - case ArithmeticOperation(left, _, right) => extractVarNames(left) ++ extractVarNames(right) - case SequenceSize(sequence) => extractVarNames(sequence) - case _ => Set.empty[String] - } - - private def extractPayableVars(statements: Seq[Statement], scope: Set[String] = Set.empty[String]): Set[String] = { - val names = statements.foldLeft(Set.empty[String]) { (current, statement) => - statement match { - case Send(destination, _, _) => current.union(extractVarNames(destination)) - case _ => current - } - } - - if (scope.nonEmpty) { - names.intersect(scope) - } else { - names - } - } } diff --git a/src/main/scala/edu/berkeley/cs/rise/quartz/StateMachine.scala b/src/main/scala/edu/berkeley/cs/rise/quartz/StateMachine.scala index 2413991..92d6bf3 100644 --- a/src/main/scala/edu/berkeley/cs/rise/quartz/StateMachine.scala +++ b/src/main/scala/edu/berkeley/cs/rise/quartz/StateMachine.scala @@ -154,8 +154,6 @@ case class StateMachine(structs: Map[String, Map[String, DataType]], fields: Seq } def flattenExpressions: Seq[Expression] = transitions.flatMap(_.flattenExpressions()) - - def flattenStatements: Seq[Statement] = transitions.flatMap(_.body.getOrElse(Seq.empty[Statement])) } object StateMachine { diff --git a/src/test/resources/payable/field.qtz b/src/test/resources/payable/field.qtz new file mode 100644 index 0000000..1b66a86 --- /dev/null +++ b/src/test/resources/payable/field.qtz @@ -0,0 +1,19 @@ +contract Field { + data { + A: Identity + B: Identity + C: Identity + } + + initialize: -> open { + B = A + } + + test1: open -> open { + send 0 to A + } + + test2: open -> open { + C = B + } +} diff --git a/src/test/resources/payable/field2.qtz b/src/test/resources/payable/field2.qtz new file mode 100644 index 0000000..678b7c9 --- /dev/null +++ b/src/test/resources/payable/field2.qtz @@ -0,0 +1,19 @@ +contract Field { + data { + A: Identity + B: Identity + C: Identity + } + + initialize: -> open { + A = B + } + + test1: open -> open { + send 0 to A + } + + test2: open -> open { + B = C + } +} diff --git a/src/test/resources/payable/mapping.qtz b/src/test/resources/payable/mapping.qtz new file mode 100644 index 0000000..8955c1a --- /dev/null +++ b/src/test/resources/payable/mapping.qtz @@ -0,0 +1,14 @@ +contract Mapping { + data { + Map1: Mapping[Uint, Identity] + Map2: Mapping[Identity, Identity] + } + + initialize: ->(id: Uint) open { + send 0 to Map1[id] + } + + test1: open ->(id: Identity) open { + send 0 to Map2[id] + } +} diff --git a/src/test/resources/payable/mapping2.qtz b/src/test/resources/payable/mapping2.qtz new file mode 100644 index 0000000..d348168 --- /dev/null +++ b/src/test/resources/payable/mapping2.qtz @@ -0,0 +1,11 @@ +contract Mapping { + data { + Map1: Mapping[Uint, Mapping[Identity, Identity]] + Map2: Mapping[Identity, Mapping[Uint, Identity]] + } + + initialize: ->(id: Uint, id2: Identity) open { + send 0 to Map1[id][id2] + send 0 to Map2[id2][id] + } +} diff --git a/src/test/resources/payable/param.qtz b/src/test/resources/payable/param.qtz new file mode 100644 index 0000000..0826825 --- /dev/null +++ b/src/test/resources/payable/param.qtz @@ -0,0 +1,19 @@ +contract Param { + data { + A: Identity + B: Identity + } + + initialize: ->(id: Identity) open { + A = id + A = B + } + + test1: open -> open { + send 0 to A + } + + test2: open ->(id: Identity) open { + B = id + } +} diff --git a/src/test/resources/payable/param2.qtz b/src/test/resources/payable/param2.qtz new file mode 100644 index 0000000..b33ba5d --- /dev/null +++ b/src/test/resources/payable/param2.qtz @@ -0,0 +1,15 @@ +contract Param { + data { + A: Identity + } + + initialize: ->(id: Identity, id2: Identity, id3: Identity) open { + A = id + id = id2 + id3 = id2 + } + + test1: open -> open { + send 0 to A + } +} diff --git a/src/test/resources/payable/struct.qtz b/src/test/resources/payable/struct.qtz new file mode 100644 index 0000000..ddeb27f --- /dev/null +++ b/src/test/resources/payable/struct.qtz @@ -0,0 +1,33 @@ +contract Struct { + struct Wrapper1 { + payee: Identity + } + + struct Wrapper2 { + payee2: Identity + wrap: Wrapper1 + } + + struct Wrapper3 { + payee3: Identity + wrap: Wrapper2 + } + + data { + wrap1: Wrapper1 + wrap2: Wrapper2 + wrap3: Wrapper3 + } + + initialize: -> open { + send 0 to wrap1.payee + } + + test1: open -> open { + send 0 to wrap2.wrap.payee + } + + test2: open -> open { + send 0 to wrap3.wrap.wrap.payee + } +} \ No newline at end of file diff --git a/src/test/resources/payable/struct2.qtz b/src/test/resources/payable/struct2.qtz new file mode 100644 index 0000000..aa3a08f --- /dev/null +++ b/src/test/resources/payable/struct2.qtz @@ -0,0 +1,31 @@ +contract Struct { + struct Wrapper1 { + payee: Identity + } + + struct Wrapper2 { + map: Mapping[Identity, Wrapper1] + map2: Mapping[Identity, Identity] + } + + struct Wrapper3 { + seq: Sequence[Wrapper1] + } + + struct Wrapper4 { + payeeMap: Wrapper2 + payeeSeq: Wrapper3 + } + + data { + wrap2: Wrapper2 + wrap3: Wrapper3 + wrap4: Wrapper4 + } + + initialize: ->(id: Identity) open { + send 0 to wrap2.map[id].payee + send 0 to wrap2.map2[id] + send 0 to wrap4.payeeMap.map[id].payee + } +} \ No newline at end of file diff --git a/src/test/resources/payable/struct3.qtz b/src/test/resources/payable/struct3.qtz new file mode 100644 index 0000000..a379dab --- /dev/null +++ b/src/test/resources/payable/struct3.qtz @@ -0,0 +1,24 @@ +contract Struct { + struct Wrapper { + payee: Identity + payee2: Identity + } + + data { + wrap: Wrapper + A: Identity + } + + initialize: -> open { + A = wrap.payee + } + + test1: open -> open { + send 0 to A + } + + test2: open ->(id: Identity) open { + wrap.payee2 = id + send 0 to wrap.payee2 + } +} \ No newline at end of file