Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TPCH testing for grpc #222

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 99 additions & 0 deletions src/test/scala/edu/berkeley/cs/rise/opaque/OpaqueSuiteBase.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,17 @@
package edu.berkeley.cs.rise.opaque

import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.Row
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.types._
import org.apache.spark.sql.functions._
import org.apache.spark.unsafe.types.UTF8String

import java.sql.Date
import java.util.concurrent.TimeUnit

import org.scalatest.BeforeAndAfterAll
import org.scalactic.Equality

trait OpaqueSuiteBase extends OpaqueFunSuite with BeforeAndAfterAll with SQLTestData {

Expand All @@ -35,6 +43,61 @@ trait OpaqueSuiteBase extends OpaqueFunSuite with BeforeAndAfterAll with SQLTest
Utils.cleanup(spark)
}

// TODO: Override checkAnswer for new functionality. Uncomment after merging with
// key-gen, re-encryption, sp, and python pull requests
// override def checkAnswer[A: Equality](
// ignore: Boolean = false,
// isOrdered: Boolean = false,
// verbose: Boolean = false,
// printPlan: Boolean = false,
// shouldLogOperators: Boolean = false
// )(f: SecurityLevel => A): Unit = {
// if (ignore) {
// return
// }
//
// Utils.setOperatorLoggingLevel(shouldLogOperators)
// val (insecure, encrypted) = (f(Insecure), f(Encrypted))
// (insecure, encrypted) match {
// case (insecure: DataFrame, encrypted: DataFrame) =>
// val insecureSeq = insecure.collect
// val schema = insecure.schema
//
// encrypted.collect
// val encryptedRows = SPHelper.obtainRows(encrypted)
// val encryptedRows2 = encryptedRows.map(x => prepareRowWithSchema(x, schema))
// val encryptedDF = spark.createDataFrame(
// spark.sparkContext.makeRDD(encryptedRows2, numPartitions),
// schema)
//
// val encryptedSeq = encryptedDF.collect
//
// // Unable to test any compound filtering functionality as results returned from any
// // operation are not returned immediately but stored in file on disk.
// // See benchmark/KMeans.scala and benchmark/LogisticRegression.scala for example
// val equal = insecureSeq.toSet == encryptedSeq.toSet
// // if (isOrdered) insecureSeq === encryptedSeq
// // else insecureSeq.toSet === encryptedSeq.toSet
// if (!equal) {
// if (printPlan) {
// println("**************** Spark Plan ****************")
// insecure.explain()
// println("**************** Opaque Plan ****************")
// encrypted.explain()
// }
// println(genError(insecureSeq, encryptedSeq, isOrdered, verbose))
// }
// assert(equal)
// case (insecure: Array[Array[Double]], encrypted: Array[Array[Double]]) =>
// for ((x, y) <- insecure.zip(encrypted)) {
//
// assert(x === y)
// }
// case _ =>
// assert(insecure === encrypted)
// }
// }

def makeDF[A <: Product: scala.reflect.ClassTag: scala.reflect.runtime.universe.TypeTag](
data: Seq[A],
sl: SecurityLevel,
Expand All @@ -46,4 +109,40 @@ trait OpaqueSuiteBase extends OpaqueFunSuite with BeforeAndAfterAll with SQLTest
.toDF(columnNames: _*)
)
}

// Using a schema, convert the row into the relevant datatypes
// Only certain types are implemented currently. Add as needed
private def prepareRowWithSchema(row: Row, schema: StructType): Row = {
val fields = schema.fields

if (row == null) {
println("null row")
return null
}

assert(row.length == fields.size)

Row.fromSeq(for (i <- 0 until row.length) yield {
val rowValue = row(i)
val fieldDataType = fields(i).dataType
val converted = fieldDataType match {
case ArrayType(_,_) => rowValue
case BinaryType => rowValue
case BooleanType => rowValue
case CalendarIntervalType => rowValue

// TPCH dates are calculated in days for some reason
case DateType => new Date(TimeUnit.DAYS.toMillis(rowValue.asInstanceOf[Integer].toLong))
case MapType(_,_,_) => rowValue
case NullType => rowValue
// case IntegerType => rowValue.asInstanceOf[Integer]
// case DoubleType => rowValue.asInstanceOf[Double]
case StringType => rowValue.asInstanceOf[UTF8String].toString()
case StructType(_) => rowValue
case TimestampType => rowValue
case _ => rowValue
}
converted
})
}
}
101 changes: 101 additions & 0 deletions src/test/scala/edu/berkeley/cs/rise/opaque/SPHelper.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package edu.berkeley.cs.rise.opaque

import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.Row

import org.apache.spark.sql.catalyst.util.GenericArrayData

import java.util.Base64

import java.nio.ByteBuffer

import edu.berkeley.cs.rise.opaque.execution.SP

object SPHelper {

val sp = new SP()

// Create a SP for decryption
// TODO: Change hard-coded path for user_cert
val userCert = scala.io.Source.fromFile("/home/opaque/opaque/user1.crt").mkString

// Change empty key to key used in RA.initRA to allow for decryption
final val GCM_KEY_LENGTH = 32
val sharedKey: Array[Byte] = Array.fill[Byte](GCM_KEY_LENGTH)(0)
sp.Init(sharedKey, userCert)

def convertGenericArrayData(rowData: Row, index: Int): Row = {

val data = rowData(index).asInstanceOf[GenericArrayData]

val dataArray = new Array[Double](data.numElements)
for (i <- 0 until dataArray.size) {
dataArray(i) = data.getDouble(i)
}
return Row(dataArray)
}

def convertGenericArrayDataKMeans(rowData: Row, index: Int): Row = {

val data = rowData(index).asInstanceOf[GenericArrayData]

val dataArray = new Array[Double](data.numElements)
for (i <- 0 until dataArray.size) {
dataArray(i) = data.getDouble(i)
}

val dataTwo = rowData(1).asInstanceOf[GenericArrayData]

val dataArrayTwo = new Array[Double](dataTwo.numElements)
for (i <- 0 until dataArrayTwo.size) {
dataArrayTwo(i) = dataTwo.getDouble(i)
}

return Row(dataArray, dataArrayTwo, rowData(2))
}

def obtainRows(df: DataFrame) : Seq[Row] = {

// Hardcoded user1 for driver
val ciphers = Utils.postVerifyAndReturn(df, "user1")

val internalRow = (for (cipher <- ciphers) yield {

val plain = sp.Decrypt(Base64.getEncoder().encodeToString(cipher))
val rows = tuix.Rows.getRootAsRows(ByteBuffer.wrap(plain))

for (j <- 0 until rows.rowsLength) yield {
val row = rows.rows(j)
assert(!row.isDummy)
Row.fromSeq(for (k <- 0 until row.fieldValuesLength) yield {
val field: Any =
if (!row.fieldValues(k).isNull()) {
Utils.flatbuffersExtractFieldValue(row.fieldValues(k))
} else {
null
}
field
})
}
}).flatten

return internalRow
}
}
147 changes: 118 additions & 29 deletions src/test/scala/edu/berkeley/cs/rise/opaque/benchmark/KMeans.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import edu.berkeley.cs.rise.opaque.expressions.ClosestPoint.closestPoint
import edu.berkeley.cs.rise.opaque.expressions.VectorMultiply.vectormultiply
import edu.berkeley.cs.rise.opaque.expressions.VectorSum
import edu.berkeley.cs.rise.opaque.SecurityLevel
import edu.berkeley.cs.rise.opaque.SPHelper

import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.Row
Expand Down Expand Up @@ -77,42 +78,130 @@ object KMeans {
"system" -> securityLevel.name,
"N" -> N
) {
if (securityLevel.name == "spark sql") {

// Sample k random points.
// TODO: Assumes points are already permuted randomly.
var centroids = points.take(K).map(_.getSeq[Double](0).toArray)

var tempDist = 1.0

while (tempDist > convergeDist) {
val newCentroids = points
.select(
closestPoint($"p", lit(centroids)).as("oldCentroid"),
$"p".as("centroidPartialSum"),
lit(1).as("centroidPartialCount")
)
.groupBy($"oldCentroid")
.agg(
vectorsum($"centroidPartialSum").as("centroidSum"),
sum($"centroidPartialCount").as("centroidCount")
)
.select(
$"oldCentroid",
vectormultiply($"centroidSum", (lit(1.0) / $"centroidCount")).as("newCentroid")
)
.collect

tempDist = 0.0
for (row <- newCentroids) {
tempDist += squaredDistance(
new DenseVector(row.getSeq[Double](0).toArray),
new DenseVector(row.getSeq[Double](1).toArray)
)
}

centroids = newCentroids.map(_.getSeq[Double](1).toArray)
}

// Sample k random points.
// TODO: Assumes points are already permuted randomly.
var centroids = points.take(K).map(_.getSeq[Double](0).toArray)
var tempDist = 1.0

while (tempDist > convergeDist) {
val newCentroids = points
.select(
closestPoint($"p", lit(centroids)).as("oldCentroid"),
$"p".as("centroidPartialSum"),
lit(1).as("centroidPartialCount")
)
.groupBy($"oldCentroid")
.agg(
vectorsum($"centroidPartialSum").as("centroidSum"),
sum($"centroidPartialCount").as("centroidCount")
)
.select(
$"oldCentroid",
vectormultiply($"centroidSum", (lit(1.0) / $"centroidCount")).as("newCentroid")
centroids
} else {

// First operation block. Instead of using take, use collect for simplicity
// points.take(K)
points.collect

var centroids = SPHelper.obtainRows(points).map(x => SPHelper.convertGenericArrayData(x, 0))
.map(x => x(0).asInstanceOf[Array[Double]])
.toArray.slice(0, 3)

var tempDist = 1.0

while (tempDist > convergeDist) {

// Second operation block
val df_2 = points
.select(
closestPoint($"p", lit(centroids)).as("oldCentroid"),
$"p".as("centroidPartialSum"),
lit(1).as("centroidPartialCount")
).collect

// Of form Seq[Row[GenericArrayData, GenericArrayData, Int]]
val rows_2 = SPHelper.obtainRows(points)
.map(x => SPHelper.convertGenericArrayDataKMeans(x, 0))

// Third operation block
val schema = StructType(
Seq(StructField("oldCentroid", DataTypes.createArrayType(DoubleType)),
StructField("centroidPartialSum", DataTypes.createArrayType(DoubleType)),
StructField("centroidPartialCount", DataTypes.IntegerType))
)
.collect

tempDist = 0.0
for (row <- newCentroids) {
tempDist += squaredDistance(
new DenseVector(row.getSeq[Double](0).toArray),
new DenseVector(row.getSeq[Double](1).toArray)
val df_3 = securityLevel.applyTo(
spark.createDataFrame(
spark.sparkContext.makeRDD(rows_2, numPartitions),
schema))

df_3.groupBy($"oldCentroid")
.agg(
vectorsum($"centroidPartialSum").as("centroidSum"),
sum($"centroidPartialCount").as("centroidCount")
).collect

// Of form Seq[Row[GenericArrayData, GenericArrayData, Int]]
val rows_3 = SPHelper.obtainRows(df_3)
.map(x => SPHelper.convertGenericArrayDataKMeans(x, 0))

// Fourth operation block
val schema_2 = StructType(
Seq(StructField("oldCentroid", DataTypes.createArrayType(DoubleType)),
StructField("centroidSum", DataTypes.createArrayType(DoubleType)),
StructField("centroidCount", DataTypes.LongType))
)

val df_4 = securityLevel.applyTo(
spark.createDataFrame(
spark.sparkContext.makeRDD(rows_3, numPartitions),
schema_2))

df_4.select(
$"oldCentroid",
vectormultiply($"centroidSum", (lit(1.0) / $"centroidCount")).as("newCentroid"),
lit(1).as("filler")
)
.collect

val rows_4 = SPHelper.obtainRows(df_3)
.map(x => SPHelper.convertGenericArrayDataKMeans(x, 0))

// Final operation block

tempDist = 0.0
for (row <- rows_4) {
tempDist += squaredDistance(
new DenseVector(row(0).asInstanceOf[Array[Double]]),
new DenseVector(row(1).asInstanceOf[Array[Double]])
)
}

centroids = rows_4.map(x => x(1).asInstanceOf[Array[Double]]).toArray
}

centroids = newCentroids.map(_.getSeq[Double](1).toArray)
}

centroids
centroids
}
}
}
}
Loading