diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index 39af92baa06c7..4122c71a22073 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -99,6 +99,15 @@ object CommandPythonFunctionType extends PythonFunctionType{ override val value
object SqlUdfPythonFunctionType extends PythonFunctionType{ override val value = 1 }
object PandasUdfPythonFunctionType extends PythonFunctionType{ override val value = 2 }
+/**
+ * Interface that can be used when building an iterator to read data from Python
+ */
+private[spark] trait PythonReadInterface {
+ def getDataStream: DataInputStream
+ def readLengthFromPython(): Int
+ def readFooter(): Unit
+}
+
/**
* A helper class to run Python mapPartition/UDFs in Spark.
*
@@ -123,10 +132,11 @@ private[spark] class PythonRunner(
// TODO: support accumulator in multiple UDF
private val accumulator = funcs.head.funcs.head.accumulator
- def compute(
- inputIterator: Iterator[_],
+ def process[U](
+ dataWriteBlock: DataOutputStream => Unit,
+ dataReadBuilder: PythonReadInterface => Iterator[U],
partitionIndex: Int,
- context: TaskContext): Iterator[Array[Byte]] = {
+ context: TaskContext): Iterator[U] = {
val startTime = System.currentTimeMillis
val env = SparkEnv.get
val localdir = env.blockManager.diskBlockManager.localDirs.map(f => f.getPath()).mkString(",")
@@ -139,7 +149,7 @@ private[spark] class PythonRunner(
@volatile var released = false
// Start a thread to feed the process input from our parent's iterator
- val writerThread = new WriterThread(env, worker, inputIterator, partitionIndex, context)
+ val writerThread = new WriterThread(env, worker, dataWriteBlock, partitionIndex, context)
context.addTaskCompletionListener { context =>
writerThread.shutdownOnTaskCompletion()
@@ -156,79 +166,29 @@ private[spark] class PythonRunner(
writerThread.start()
new MonitorThread(env, worker, context).start()
- // Return an iterator that read lines from the process's stdout
- val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize))
- val stdoutIterator = new Iterator[Array[Byte]] {
- override def next(): Array[Byte] = {
- val obj = _nextObj
- if (hasNext) {
- _nextObj = read()
- }
- obj
- }
+ // Create stream to read data from process's stdout
+ val dataIn = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize))
+
+ val stdoutIterator = new Iterator[U] with PythonReadInterface {
+
+ // Create iterator for reading data blocks
+ val _dataIterator = dataReadBuilder(this.asInstanceOf[PythonReadInterface])
- private def read(): Array[Byte] = {
+ def safeRead[T](block: => T): T = {
if (writerThread.exception.isDefined) {
throw writerThread.exception.get
}
+
try {
- stream.readInt() match {
- case length if length > 0 =>
- val obj = new Array[Byte](length)
- stream.readFully(obj)
- obj
- case 0 => Array.empty[Byte]
- case SpecialLengths.TIMING_DATA =>
- // Timing data from worker
- val bootTime = stream.readLong()
- val initTime = stream.readLong()
- val finishTime = stream.readLong()
- val boot = bootTime - startTime
- val init = initTime - bootTime
- val finish = finishTime - initTime
- val total = finishTime - startTime
- logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot,
- init, finish))
- val memoryBytesSpilled = stream.readLong()
- val diskBytesSpilled = stream.readLong()
- context.taskMetrics.incMemoryBytesSpilled(memoryBytesSpilled)
- context.taskMetrics.incDiskBytesSpilled(diskBytesSpilled)
- read()
- case SpecialLengths.PYTHON_EXCEPTION_THROWN =>
- // Signals that an exception has been thrown in python
- val exLength = stream.readInt()
- val obj = new Array[Byte](exLength)
- stream.readFully(obj)
- throw new PythonException(new String(obj, StandardCharsets.UTF_8),
- writerThread.exception.getOrElse(null))
- case SpecialLengths.END_OF_DATA_SECTION =>
- // We've finished the data section of the output, but we can still
- // read some accumulator updates:
- val numAccumulatorUpdates = stream.readInt()
- (1 to numAccumulatorUpdates).foreach { _ =>
- val updateLen = stream.readInt()
- val update = new Array[Byte](updateLen)
- stream.readFully(update)
- accumulator.add(update)
- }
- // Check whether the worker is ready to be re-used.
- if (stream.readInt() == SpecialLengths.END_OF_STREAM) {
- if (reuse_worker) {
- env.releasePythonWorker(pythonExec, envVars.asScala.toMap, worker)
- released = true
- }
- }
- null
- }
+ block
} catch {
-
case e: Exception if context.isInterrupted =>
logDebug("Exception thrown after task interruption", e)
throw new TaskKilledException(context.getKillReason().getOrElse("unknown reason"))
case e: Exception if env.isStopped =>
logDebug("Exception thrown after context is stopped", e)
- null // exit silently
+ throw new RuntimeException("TODO: exit silently")// exit silently
case e: Exception if writerThread.exception.isDefined =>
logError("Python worker exited unexpectedly (crashed)", e)
@@ -240,13 +200,112 @@ private[spark] class PythonRunner(
}
}
- var _nextObj = read()
+ override def next(): U = {
+ safeRead {
+ _dataIterator.next()
+ }
+ }
+
+ override def hasNext: Boolean = {
+ safeRead {
+ _dataIterator.hasNext
+ }
+ }
+
+ override def getDataStream: DataInputStream = dataIn
+
+ override def readLengthFromPython(): Int = {
+ val length = dataIn.readInt()
+ if (length == SpecialLengths.PYTHON_EXCEPTION_THROWN) {
+ // Signals that an exception has been thrown in python
+ val exLength = dataIn.readInt()
+ val obj = new Array[Byte](exLength)
+ dataIn.readFully(obj)
+ throw new PythonException(new String(obj, StandardCharsets.UTF_8), writerThread.exception.orNull)
+ }
+ length
+ }
+
+ override def readFooter(): Unit = {
+ // Timing data from worker
+ //readLengthFromPython() // == SpecialLengths.TIMING_DATA
+ val bootTime = dataIn.readLong()
+ val initTime = dataIn.readLong()
+ val finishTime = dataIn.readLong()
+ val boot = bootTime - startTime
+ val init = initTime - bootTime
+ val finish = finishTime - initTime
+ val total = finishTime - startTime
+ logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot,
+ init, finish))
+ val memoryBytesSpilled = dataIn.readLong()
+ val diskBytesSpilled = dataIn.readLong()
+ context.taskMetrics.incMemoryBytesSpilled(memoryBytesSpilled)
+ context.taskMetrics.incDiskBytesSpilled(diskBytesSpilled)
+
+ // We've finished the data section of the output, but we can still
+ // read some accumulator updates:
+ readLengthFromPython() // == SpecialLengths.END_OF_DATA_SECTION
+ val numAccumulatorUpdates = readLengthFromPython()
+ (1 to numAccumulatorUpdates).foreach { _ =>
+ val updateLen = dataIn.readInt()
+ val update = new Array[Byte](updateLen)
+ dataIn.readFully(update)
+ accumulator.add(update)
+ }
- override def hasNext: Boolean = _nextObj != null
+ // Check whether the worker is ready to be re-used.
+ if (readLengthFromPython() == SpecialLengths.END_OF_STREAM) {
+ if (reuse_worker) {
+ env.releasePythonWorker(pythonExec, envVars.asScala.toMap, worker)
+ released = true
+ }
+ }
+ // else == SpecialLengths.END_OF_DATA_SECTION to not reuse worker
+ }
}
+
new InterruptibleIterator(context, stdoutIterator)
}
+ def compute(
+ inputIterator: Iterator[_],
+ partitionIndex: Int,
+ context: TaskContext): Iterator[Array[Byte]] = {
+
+ val dataWriteBlock = (out: DataOutputStream) => {
+ PythonRDD.writeIteratorToStream(inputIterator, out)
+ }
+
+ val dataReadBuilder = (in: PythonReadInterface) => {
+ new Iterator[Array[Byte]] {
+ var _lastLength: Int = _
+
+ override def hasNext: Boolean = {
+ _lastLength = in.readLengthFromPython()
+ val result = _lastLength >= 0
+ if (!result) {
+ in.readFooter()
+ }
+ result
+ }
+
+ override def next(): Array[Byte] = {
+ _lastLength match {
+ case l if l > 0 =>
+ val obj = new Array[Byte](_lastLength)
+ in.getDataStream.readFully(obj)
+ obj
+ case 0 =>
+ Array.empty[Byte]
+ }
+ }
+ }
+ }
+
+ process(dataWriteBlock, dataReadBuilder, partitionIndex, context)
+ }
+
/**
* The thread responsible for writing the data from the PythonRDD's parent iterator to the
* Python process.
@@ -254,7 +313,7 @@ private[spark] class PythonRunner(
class WriterThread(
env: SparkEnv,
worker: Socket,
- inputIterator: Iterator[_],
+ dataWriteBlock: DataOutputStream => Unit,
partitionIndex: Int,
context: TaskContext)
extends Thread(s"stdout writer for $pythonExec") {
@@ -340,7 +399,7 @@ private[spark] class PythonRunner(
}
// Data values
- PythonRDD.writeIteratorToStream(inputIterator, dataOut)
+ dataWriteBlock(dataOut)
dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION)
dataOut.writeInt(SpecialLengths.END_OF_STREAM)
dataOut.flush()
@@ -701,6 +760,13 @@ private[spark] object PythonRDD extends Logging {
* The thread will terminate after all the data are sent or any exceptions happen.
*/
def serveIterator[T](items: Iterator[T], threadName: String): Int = {
+ serveToStream(threadName) { out =>
+ writeIteratorToStream(items, out)
+ }
+ }
+
+ // TODO: scaladoc
+ def serveToStream(threadName: String)(dataWriteBlock: DataOutputStream => Unit): Int = {
val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost"))
// Close the socket if no connection in 3 seconds
serverSocket.setSoTimeout(3000)
@@ -712,13 +778,13 @@ private[spark] object PythonRDD extends Logging {
val sock = serverSocket.accept()
val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream))
Utils.tryWithSafeFinally {
- writeIteratorToStream(items, out)
+ dataWriteBlock(out)
} {
out.close()
}
} catch {
case NonFatal(e) =>
- logError(s"Error while sending iterator", e)
+ logError(s"Error while writing to stream", e)
} finally {
serverSocket.close()
}
diff --git a/pom.xml b/pom.xml
index dc967e224f987..8811ffcae0c93 100644
--- a/pom.xml
+++ b/pom.xml
@@ -184,7 +184,7 @@
2.6
1.8
1.0.0
- 0.3.0
+ 0.4.0
${java.home}
@@ -1885,10 +1885,6 @@
com.fasterxml.jackson.core
jackson-databind
-
- org.slf4j
- log4j-over-slf4j
-
io.netty
netty-handler
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index 04b4dcf2d30a7..b3e8acf3c7a3b 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -198,7 +198,7 @@ def dumps(self, record_batch):
def loads(self, obj):
import pyarrow as pa
- reader = pa.FileReader(pa.BufferReader(obj))
+ reader = pa.RecordBatchFileReader(pa.BufferReader(obj))
assert reader.num_record_batches == 1, "Cannot read more than one record batches"
return reader.get_batch(0)
@@ -206,6 +206,56 @@ def __repr__(self):
return "ArrowSerializer"
+class ArrowStreamSerializer(Serializer):
+
+ def __init__(self, load_to_single_batch=True):
+ self._load_to_single = load_to_single_batch
+
+ def dump_stream(self, iterator, stream):
+ import pyarrow as pa
+ write_int(1, stream) # signal start of data block
+ writer = None
+ for batch in iterator:
+ if writer is None:
+ writer = pa.RecordBatchStreamWriter(stream, batch.schema)
+ writer.write_batch(batch)
+ if writer is not None:
+ writer.close()
+
+ def load_stream(self, stream):
+ import pyarrow as pa
+ reader = pa.RecordBatchStreamReader(stream)
+ if self._load_to_single:
+ return reader.read_all()
+ else:
+ return iter(reader)
+
+ def __repr__(self):
+ return "ArrowStreamSerializer"
+
+
+class ArrowPandasSerializer(ArrowStreamSerializer):
+
+ def __init__(self):
+ super(ArrowPandasSerializer, self).__init__(load_to_single_batch=True)
+
+ # dumps a Pandas Series to stream
+ def dump_stream(self, iterator, stream):
+ import pyarrow as pa
+ # TODO: iterator could be a tuple
+ arr = pa.Array.from_pandas(iterator)
+ batch = pa.RecordBatch.from_arrays([arr], ["_0"])
+ super(ArrowPandasSerializer, self).dump_stream([batch], stream)
+
+ # loads stream to a list of Pandas Series
+ def load_stream(self, stream):
+ table = super(ArrowPandasSerializer, self).load_stream(stream)
+ return [c.to_pandas() for c in table.itercolumns()]
+
+ def __repr__(self):
+ return "ArrowPandasSerializer"
+
+
class BatchedSerializer(Serializer):
"""
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 184aa0821d3ee..45d7d198fadf3 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -2522,7 +2522,7 @@ def test_null_conversion(self):
def test_toPandas_arrow_toggle(self):
df = self.spark.createDataFrame(self.data, schema=self.schema)
- # NOTE - toPandas(useArrow=False) will infer standard python data types
+ # NOTE - toPandas() without pyarrow will infer standard python data types
df_sel = df.select("1_str_t", "3_long_t", "5_double_t")
self.spark.conf.set("spark.sql.execution.arrow.enable", "false")
pdf = df_sel.toPandas()
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index ac609d9438017..f885c79bacc70 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -31,7 +31,7 @@
from pyspark.files import SparkFiles
from pyspark.serializers import write_with_length, write_int, read_long, \
write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, BatchedSerializer, \
- ArrowSerializer
+ ArrowSerializer, ArrowPandasSerializer
from pyspark import shuffle
pickleSer = PickleSerializer()
@@ -118,8 +118,14 @@ def read_udfs(pickleSer, infile):
mapper_str = "lambda a: (%s)" % (", ".join(call_udf))
mapper = eval(mapper_str, udfs)
- func = lambda _, it: map(mapper, it)
- ser = BatchedSerializer(PickleSerializer(), 100)
+ # These lines enable UDF evaluation with Arrow
+ ser = ArrowPandasSerializer()
+ func = lambda _, series_list: mapper(series_list) # TODO: what if not vectorizable
+
+ # Uncomment out for default UDF evaluation
+ #func = lambda _, it: map(mapper, it)
+ #ser = BatchedSerializer(PickleSerializer(), 100)
+
# profiling is not supported for UDF
return func, None, ser, ser
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
index dee2671ed7cad..95eedf90098fd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
@@ -34,7 +34,7 @@ import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema}
import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow}
import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRowWriter}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@@ -262,6 +262,72 @@ private[sql] object ArrowConverters {
reader.close()
}
}
+
+ private[arrow] def writeRowsAsArrow(
+ rowIter: Iterator[InternalRow],
+ schema: StructType,
+ out: DataOutputStream): Unit = {
+ val allocator = new RootAllocator(Long.MaxValue)
+ val arrowSchema = ArrowConverters.schemaToArrowSchema(schema)
+ val root = VectorSchemaRoot.create(arrowSchema, allocator)
+ val loader = new VectorLoader(root)
+ val writer = new ArrowStreamWriter(root, null, Channels.newChannel(out))
+
+ val batch = internalRowIterToArrowBatch(rowIter, schema, allocator)
+
+ // TODO: catch exceptions
+ loader.load(batch)
+ writer.writeBatch()
+ writer.end()
+
+ batch.close()
+ root.close()
+ allocator.close()
+ }
+
+ private[arrow] def readArrowAsRows(in: DataInputStream): Iterator[InternalRow] = {
+ new Iterator[InternalRow] {
+ val _allocator = new RootAllocator(Long.MaxValue)
+ private val _reader = new ArrowStreamReader(Channels.newChannel(in), _allocator)
+ private val _root = _reader.getVectorSchemaRoot
+ private var _index = 0
+ val mutableRow = new GenericInternalRow(1)
+
+ _reader.loadNextBatch()
+
+ override def hasNext: Boolean = _index < _root.getRowCount
+
+ override def next(): InternalRow = {
+ val fieldVecs = _root.getFieldVectors
+
+ if (fieldVecs.size() == 1) {
+ mutableRow(0) = fieldVecs.get(0).getAccessor.getObject(_index)
+ _index += 1
+ if (_index >= _root.getRowCount) {
+ _index = 0
+ _reader.loadNextBatch()
+ }
+ mutableRow
+ } else {
+ val fields = _root.getFieldVectors.asScala
+
+ val genericRowData = fields.map { field =>
+ val obj: Any = field.getAccessor.getObject(_index)
+ obj
+ }.toArray
+
+ _index += 1
+
+ if (_index >= _root.getRowCount) {
+ _index = 0
+ _reader.loadNextBatch()
+ }
+
+ new GenericInternalRow(genericRowData)
+ }
+ }
+ }
+ }
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowEvalPythonExec.scala
new file mode 100644
index 0000000000000..5ac134fdd1d10
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowEvalPythonExec.scala
@@ -0,0 +1,143 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql.execution.arrow
+
+import java.io.{DataOutputStream, File}
+
+import org.apache.spark.api.python.{ChainedPythonFunctions, PandasUdfPythonFunctionType, PythonReadInterface, PythonRunner}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.execution.python.{HybridRowQueue, PythonUDF}
+
+//import org.apache.spark.sql.ArrowConverters
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.types.{DataType, StructField, StructType}
+import org.apache.spark.util.Utils
+import org.apache.spark.{SparkEnv, TaskContext}
+
+import scala.collection.mutable.ArrayBuffer
+
+
+/**
+ * A physical plan that evaluates a [[PythonUDF]],
+ */
+case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], child: SparkPlan)
+ extends SparkPlan {
+
+ def children: Seq[SparkPlan] = child :: Nil
+
+ override def producedAttributes: AttributeSet = AttributeSet(output.drop(child.output.length))
+
+ private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, Seq[Expression]) = {
+ udf.children match {
+ case Seq(u: PythonUDF) =>
+ val (chained, children) = collectFunctions(u)
+ (ChainedPythonFunctions(chained.funcs ++ Seq(udf.func)), children)
+ case children =>
+ // There should not be any other UDFs, or the children can't be evaluated directly.
+ assert(children.forall(_.find(_.isInstanceOf[PythonUDF]).isEmpty))
+ (ChainedPythonFunctions(Seq(udf.func)), udf.children)
+ }
+ }
+
+ protected override def doExecute(): RDD[InternalRow] = {
+ val inputRDD = child.execute().map(_.copy())
+ val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536)
+ val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true)
+
+ inputRDD.mapPartitions { iter =>
+
+ // The queue used to buffer input rows so we can drain it to
+ // combine input with output from Python.
+ val queue = HybridRowQueue(TaskContext.get().taskMemoryManager(),
+ new File(Utils.getLocalDir(SparkEnv.get.conf)), child.output.length)
+ TaskContext.get().addTaskCompletionListener({ ctx =>
+ queue.close()
+ })
+
+ val (pyFuncs, inputs) = udfs.map(collectFunctions).unzip
+
+ // flatten all the arguments
+ val allInputs = new ArrayBuffer[Expression]
+ val dataTypes = new ArrayBuffer[DataType]
+ val argOffsets = inputs.map { input =>
+ input.map { e =>
+ if (allInputs.exists(_.semanticEquals(e))) {
+ allInputs.indexWhere(_.semanticEquals(e))
+ } else {
+ allInputs += e
+ dataTypes += e.dataType
+ allInputs.length - 1
+ }
+ }.toArray
+ }.toArray
+ val projection = newMutableProjection(allInputs, child.output)
+ val schema = StructType(dataTypes.map(dt => StructField("", dt)))
+
+ // enable memo iff we serialize the row with schema (schema and class should be memorized)
+
+ // Input iterator to Python: input rows are grouped so we send them in batches to Python.
+ // For each row, add it to the queue.
+ val projectedRowIter = iter.map { inputRow =>
+ queue.add(inputRow.asInstanceOf[UnsafeRow])
+ projection(inputRow)
+ }
+
+ val dataWriteBlock = (out: DataOutputStream) => {
+ ArrowConverters.writeRowsAsArrow(projectedRowIter, schema, out)
+ }
+
+ val dataReadBuilder = (in: PythonReadInterface) => {
+ new Iterator[InternalRow] {
+
+ // Check for initial error
+ in.readLengthFromPython()
+
+ val iter = ArrowConverters.readArrowAsRows(in.getDataStream)
+
+ override def hasNext: Boolean = {
+ val result = iter.hasNext
+ if (!result) {
+ in.readLengthFromPython() // == SpecialLengths.TIMING_DATA, marks end of data
+ in.readFooter()
+ }
+ result
+ }
+
+ override def next(): InternalRow = {
+ iter.next()
+ }
+ }
+ }
+
+ val context = TaskContext.get()
+
+ // Output iterator for results from Python.
+ val outputIterator = new PythonRunner(pyFuncs, bufferSize, reuseWorker, PandasUdfPythonFunctionType, argOffsets)
+ .process(dataWriteBlock, dataReadBuilder, context.partitionId(), context)
+
+ val joined = new JoinedRow
+ val resultProj = UnsafeProjection.create(output, output)
+
+ outputIterator.map { outputRow =>
+ resultProj(joined(queue.remove(), outputRow))
+ }
+ }
+ }
+}
\ No newline at end of file
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
index 69b4b7bb07de6..dec46dfdc09bb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
@@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Proj
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution
import org.apache.spark.sql.execution.{FilterExec, SparkPlan}
+import org.apache.spark.sql.execution.arrow.ArrowEvalPythonExec
/**
@@ -138,7 +139,13 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {
val resultAttrs = udfs.zipWithIndex.map { case (u, i) =>
AttributeReference(s"pythonUDF$i", u.dataType)()
}
- val evaluation = BatchEvalPythonExec(validUdfs, child.output ++ resultAttrs, child)
+
+ // This line enables UDF evaluation with Arrow
+ val evaluation = ArrowEvalPythonExec(validUdfs, child.output ++ resultAttrs, child)
+
+ // Uncomment for default UDF evaluation
+ //val evaluation = BatchEvalPythonExec(validUdfs, child.output ++ resultAttrs, child)
+
attributeMap ++= validUdfs.zip(resultAttrs)
evaluation
} else {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala
index cd1e77f524afd..723b54e2e2a67 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala
@@ -163,7 +163,7 @@ private[python] case class DiskRowQueue(file: File, fields: Int) extends RowQueu
* HybridRowQueue could be safely appended in one thread, and pulled in another thread in the same
* time.
*/
-private[python] case class HybridRowQueue(
+private[execution] case class HybridRowQueue(
memManager: TaskMemoryManager,
tempDir: File,
numFields: Int)