Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PySpark Arrow Stream Serializer #3

Open
wants to merge 3 commits into
base: pandas-udf-integration
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
208 changes: 137 additions & 71 deletions core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,15 @@ object CommandPythonFunctionType extends PythonFunctionType{ override val value
object SqlUdfPythonFunctionType extends PythonFunctionType{ override val value = 1 }
object PandasUdfPythonFunctionType extends PythonFunctionType{ override val value = 2 }

/**
* Interface that can be used when building an iterator to read data from Python
*/
private[spark] trait PythonReadInterface {
def getDataStream: DataInputStream
def readLengthFromPython(): Int
def readFooter(): Unit
}

/**
* A helper class to run Python mapPartition/UDFs in Spark.
*
Expand All @@ -123,10 +132,11 @@ private[spark] class PythonRunner(
// TODO: support accumulator in multiple UDF
private val accumulator = funcs.head.funcs.head.accumulator

def compute(
inputIterator: Iterator[_],
def process[U](
dataWriteBlock: DataOutputStream => Unit,
dataReadBuilder: PythonReadInterface => Iterator[U],
partitionIndex: Int,
context: TaskContext): Iterator[Array[Byte]] = {
context: TaskContext): Iterator[U] = {
val startTime = System.currentTimeMillis
val env = SparkEnv.get
val localdir = env.blockManager.diskBlockManager.localDirs.map(f => f.getPath()).mkString(",")
Expand All @@ -139,7 +149,7 @@ private[spark] class PythonRunner(
@volatile var released = false

// Start a thread to feed the process input from our parent's iterator
val writerThread = new WriterThread(env, worker, inputIterator, partitionIndex, context)
val writerThread = new WriterThread(env, worker, dataWriteBlock, partitionIndex, context)

context.addTaskCompletionListener { context =>
writerThread.shutdownOnTaskCompletion()
Expand All @@ -156,79 +166,29 @@ private[spark] class PythonRunner(
writerThread.start()
new MonitorThread(env, worker, context).start()

// Return an iterator that read lines from the process's stdout
val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize))
val stdoutIterator = new Iterator[Array[Byte]] {
override def next(): Array[Byte] = {
val obj = _nextObj
if (hasNext) {
_nextObj = read()
}
obj
}
// Create stream to read data from process's stdout
val dataIn = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize))

val stdoutIterator = new Iterator[U] with PythonReadInterface {

// Create iterator for reading data blocks
val _dataIterator = dataReadBuilder(this.asInstanceOf[PythonReadInterface])

private def read(): Array[Byte] = {
def safeRead[T](block: => T): T = {
if (writerThread.exception.isDefined) {
throw writerThread.exception.get
}

try {
stream.readInt() match {
case length if length > 0 =>
val obj = new Array[Byte](length)
stream.readFully(obj)
obj
case 0 => Array.empty[Byte]
case SpecialLengths.TIMING_DATA =>
// Timing data from worker
val bootTime = stream.readLong()
val initTime = stream.readLong()
val finishTime = stream.readLong()
val boot = bootTime - startTime
val init = initTime - bootTime
val finish = finishTime - initTime
val total = finishTime - startTime
logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot,
init, finish))
val memoryBytesSpilled = stream.readLong()
val diskBytesSpilled = stream.readLong()
context.taskMetrics.incMemoryBytesSpilled(memoryBytesSpilled)
context.taskMetrics.incDiskBytesSpilled(diskBytesSpilled)
read()
case SpecialLengths.PYTHON_EXCEPTION_THROWN =>
// Signals that an exception has been thrown in python
val exLength = stream.readInt()
val obj = new Array[Byte](exLength)
stream.readFully(obj)
throw new PythonException(new String(obj, StandardCharsets.UTF_8),
writerThread.exception.getOrElse(null))
case SpecialLengths.END_OF_DATA_SECTION =>
// We've finished the data section of the output, but we can still
// read some accumulator updates:
val numAccumulatorUpdates = stream.readInt()
(1 to numAccumulatorUpdates).foreach { _ =>
val updateLen = stream.readInt()
val update = new Array[Byte](updateLen)
stream.readFully(update)
accumulator.add(update)
}
// Check whether the worker is ready to be re-used.
if (stream.readInt() == SpecialLengths.END_OF_STREAM) {
if (reuse_worker) {
env.releasePythonWorker(pythonExec, envVars.asScala.toMap, worker)
released = true
}
}
null
}
block
} catch {

case e: Exception if context.isInterrupted =>
logDebug("Exception thrown after task interruption", e)
throw new TaskKilledException(context.getKillReason().getOrElse("unknown reason"))

case e: Exception if env.isStopped =>
logDebug("Exception thrown after context is stopped", e)
null // exit silently
throw new RuntimeException("TODO: exit silently")// exit silently

case e: Exception if writerThread.exception.isDefined =>
logError("Python worker exited unexpectedly (crashed)", e)
Expand All @@ -240,21 +200,120 @@ private[spark] class PythonRunner(
}
}

var _nextObj = read()
override def next(): U = {
safeRead {
_dataIterator.next()
}
}

override def hasNext: Boolean = {
safeRead {
_dataIterator.hasNext
}
}

override def getDataStream: DataInputStream = dataIn

override def readLengthFromPython(): Int = {
val length = dataIn.readInt()
if (length == SpecialLengths.PYTHON_EXCEPTION_THROWN) {
// Signals that an exception has been thrown in python
val exLength = dataIn.readInt()
val obj = new Array[Byte](exLength)
dataIn.readFully(obj)
throw new PythonException(new String(obj, StandardCharsets.UTF_8), writerThread.exception.orNull)
}
length
}

override def readFooter(): Unit = {
// Timing data from worker
//readLengthFromPython() // == SpecialLengths.TIMING_DATA
val bootTime = dataIn.readLong()
val initTime = dataIn.readLong()
val finishTime = dataIn.readLong()
val boot = bootTime - startTime
val init = initTime - bootTime
val finish = finishTime - initTime
val total = finishTime - startTime
logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot,
init, finish))
val memoryBytesSpilled = dataIn.readLong()
val diskBytesSpilled = dataIn.readLong()
context.taskMetrics.incMemoryBytesSpilled(memoryBytesSpilled)
context.taskMetrics.incDiskBytesSpilled(diskBytesSpilled)

// We've finished the data section of the output, but we can still
// read some accumulator updates:
readLengthFromPython() // == SpecialLengths.END_OF_DATA_SECTION
val numAccumulatorUpdates = readLengthFromPython()
(1 to numAccumulatorUpdates).foreach { _ =>
val updateLen = dataIn.readInt()
val update = new Array[Byte](updateLen)
dataIn.readFully(update)
accumulator.add(update)
}

override def hasNext: Boolean = _nextObj != null
// Check whether the worker is ready to be re-used.
if (readLengthFromPython() == SpecialLengths.END_OF_STREAM) {
if (reuse_worker) {
env.releasePythonWorker(pythonExec, envVars.asScala.toMap, worker)
released = true
}
}
// else == SpecialLengths.END_OF_DATA_SECTION to not reuse worker
}
}

new InterruptibleIterator(context, stdoutIterator)
}

def compute(
inputIterator: Iterator[_],
partitionIndex: Int,
context: TaskContext): Iterator[Array[Byte]] = {

val dataWriteBlock = (out: DataOutputStream) => {
PythonRDD.writeIteratorToStream(inputIterator, out)
}

val dataReadBuilder = (in: PythonReadInterface) => {
new Iterator[Array[Byte]] {
var _lastLength: Int = _

override def hasNext: Boolean = {
_lastLength = in.readLengthFromPython()
val result = _lastLength >= 0
if (!result) {
in.readFooter()
}
result
}

override def next(): Array[Byte] = {
_lastLength match {
case l if l > 0 =>
val obj = new Array[Byte](_lastLength)
in.getDataStream.readFully(obj)
obj
case 0 =>
Array.empty[Byte]
}
}
}
}

process(dataWriteBlock, dataReadBuilder, partitionIndex, context)
}

/**
* The thread responsible for writing the data from the PythonRDD's parent iterator to the
* Python process.
*/
class WriterThread(
env: SparkEnv,
worker: Socket,
inputIterator: Iterator[_],
dataWriteBlock: DataOutputStream => Unit,
partitionIndex: Int,
context: TaskContext)
extends Thread(s"stdout writer for $pythonExec") {
Expand Down Expand Up @@ -340,7 +399,7 @@ private[spark] class PythonRunner(
}

// Data values
PythonRDD.writeIteratorToStream(inputIterator, dataOut)
dataWriteBlock(dataOut)
dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION)
dataOut.writeInt(SpecialLengths.END_OF_STREAM)
dataOut.flush()
Expand Down Expand Up @@ -701,6 +760,13 @@ private[spark] object PythonRDD extends Logging {
* The thread will terminate after all the data are sent or any exceptions happen.
*/
def serveIterator[T](items: Iterator[T], threadName: String): Int = {
serveToStream(threadName) { out =>
writeIteratorToStream(items, out)
}
}

// TODO: scaladoc
def serveToStream(threadName: String)(dataWriteBlock: DataOutputStream => Unit): Int = {
val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost"))
// Close the socket if no connection in 3 seconds
serverSocket.setSoTimeout(3000)
Expand All @@ -712,13 +778,13 @@ private[spark] object PythonRDD extends Logging {
val sock = serverSocket.accept()
val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream))
Utils.tryWithSafeFinally {
writeIteratorToStream(items, out)
dataWriteBlock(out)
} {
out.close()
}
} catch {
case NonFatal(e) =>
logError(s"Error while sending iterator", e)
logError(s"Error while writing to stream", e)
} finally {
serverSocket.close()
}
Expand Down
6 changes: 1 addition & 5 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@
<paranamer.version>2.6</paranamer.version>
<maven-antrun.version>1.8</maven-antrun.version>
<commons-crypto.version>1.0.0</commons-crypto.version>
<arrow.version>0.3.0</arrow.version>
<arrow.version>0.4.0</arrow.version>

<test.java.home>${java.home}</test.java.home>
<test.exclude.tags></test.exclude.tags>
Expand Down Expand Up @@ -1885,10 +1885,6 @@
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
</exclusion>
<exclusion>
<groupId>org.slf4j</groupId>
<artifactId>log4j-over-slf4j</artifactId>
</exclusion>
<exclusion>
<groupId>io.netty</groupId>
<artifactId>netty-handler</artifactId>
Expand Down
52 changes: 51 additions & 1 deletion python/pyspark/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,14 +198,64 @@ def dumps(self, record_batch):

def loads(self, obj):
import pyarrow as pa
reader = pa.FileReader(pa.BufferReader(obj))
reader = pa.RecordBatchFileReader(pa.BufferReader(obj))
assert reader.num_record_batches == 1, "Cannot read more than one record batches"
return reader.get_batch(0)

def __repr__(self):
return "ArrowSerializer"


class ArrowStreamSerializer(Serializer):

def __init__(self, load_to_single_batch=True):
self._load_to_single = load_to_single_batch

def dump_stream(self, iterator, stream):
import pyarrow as pa
write_int(1, stream) # signal start of data block
writer = None
for batch in iterator:
if writer is None:
writer = pa.RecordBatchStreamWriter(stream, batch.schema)
writer.write_batch(batch)
if writer is not None:
writer.close()

def load_stream(self, stream):
import pyarrow as pa
reader = pa.RecordBatchStreamReader(stream)
if self._load_to_single:
return reader.read_all()
else:
return iter(reader)

def __repr__(self):
return "ArrowStreamSerializer"


class ArrowPandasSerializer(ArrowStreamSerializer):

def __init__(self):
super(ArrowPandasSerializer, self).__init__(load_to_single_batch=True)

# dumps a Pandas Series to stream
def dump_stream(self, iterator, stream):
import pyarrow as pa
# TODO: iterator could be a tuple
arr = pa.Array.from_pandas(iterator)
batch = pa.RecordBatch.from_arrays([arr], ["_0"])
super(ArrowPandasSerializer, self).dump_stream([batch], stream)

# loads stream to a list of Pandas Series
def load_stream(self, stream):
table = super(ArrowPandasSerializer, self).load_stream(stream)
return [c.to_pandas() for c in table.itercolumns()]

def __repr__(self):
return "ArrowPandasSerializer"


class BatchedSerializer(Serializer):

"""
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -2522,7 +2522,7 @@ def test_null_conversion(self):

def test_toPandas_arrow_toggle(self):
df = self.spark.createDataFrame(self.data, schema=self.schema)
# NOTE - toPandas(useArrow=False) will infer standard python data types
# NOTE - toPandas() without pyarrow will infer standard python data types
df_sel = df.select("1_str_t", "3_long_t", "5_double_t")
self.spark.conf.set("spark.sql.execution.arrow.enable", "false")
pdf = df_sel.toPandas()
Expand Down
Loading