Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prefetch using multiple threads. Optimize buffers when accessing index and checksum files. #66

Merged
merged 2 commits into from
Sep 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ Changing these values might have an impact on performance.

- `spark.shuffle.s3.bufferSize`: Default buffer size when writing (default: `8388608`)
- `spark.shuffle.s3.maxBufferSizeTask`: Maximum size of the buffered output streams per task (default: `134217728`)
- `spark.shuffle.s3.prefetchConcurrencyTask`: The per-task concurrency when prefetching (default: `2`).
- `spark.shuffle.s3.cachePartitionLengths`: Cache partition lengths in memory (default: `true`)
- `spark.shuffle.s3.cacheChecksums`: Cache checksums in memory (default: `true`)
- `spark.shuffle.s3.cleanup`: Cleanup the shuffle files (default: `true`)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class S3ShuffleDispatcher extends Logging {
// Optional
val bufferSize: Int = conf.getInt("spark.shuffle.s3.bufferSize", defaultValue = 8 * 1024 * 1024)
val maxBufferSizeTask: Int = conf.getInt("spark.shuffle.s3.maxBufferSizeTask", defaultValue = 128 * 1024 * 1024)
val prefetchConcurrencyTask: Int = conf.getInt("spark.shuffle.s3.prefetchConcurrencyTask", defaultValue = 2)
val cachePartitionLengths: Boolean = conf.getBoolean("spark.shuffle.s3.cachePartitionLengths", defaultValue = true)
val cacheChecksums: Boolean = conf.getBoolean("spark.shuffle.s3.cacheChecksums", defaultValue = true)
val cleanupShuffleFiles: Boolean = conf.getBoolean("spark.shuffle.s3.cleanup", defaultValue = true)
Expand Down Expand Up @@ -60,6 +61,7 @@ class S3ShuffleDispatcher extends Logging {
// Optional
logInfo(s"- spark.shuffle.s3.bufferSize=${bufferSize}")
logInfo(s"- spark.shuffle.s3.maxBufferSizeTask=${maxBufferSizeTask}")
logInfo(s"- spark.shuffle.s3.prefetchConcurrencyTask=${prefetchConcurrencyTask}")
logInfo(s"- spark.shuffle.s3.cachePartitionLengths=${cachePartitionLengths}")
logInfo(s"- spark.shuffle.s3.cacheChecksums=${cacheChecksums}")
logInfo(s"- spark.shuffle.s3.cleanup=${cleanupShuffleFiles}")
Expand Down Expand Up @@ -112,7 +114,7 @@ class S3ShuffleDispatcher extends Logging {
def openBlock(blockId: BlockId): FSDataInputStream = {
val status = getFileStatusCached(blockId)
val builder = fs.openFile(status.getPath).withFileStatus(status)
val stream = builder.build().get()
val stream = builder.build().get()
if (canSetReadahead) {
stream.setReadahead(0)
}
Expand All @@ -121,7 +123,7 @@ class S3ShuffleDispatcher extends Logging {

private val cachedFileStatus = new ConcurrentObjectMap[BlockId, FileStatus]()

private def getFileStatusCached(blockId: BlockId): FileStatus = {
def getFileStatusCached(blockId: BlockId): FileStatus = {
cachedFileStatus.getOrElsePut(blockId, (value: BlockId) => {
fs.getFileStatus(getPath(value))
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ object S3ShuffleHelper extends Logging {
def writeArrayAsBlock(blockId: BlockId, array: Array[Long]): Unit = {
val serializerInstance = serializer.newInstance()
val buffer = serializerInstance.serialize[Array[Long]](array)
val file = new BufferedOutputStream(dispatcher.createBlock(blockId), dispatcher.bufferSize)
val file = dispatcher.createBlock(blockId)
file.write(buffer.array(), buffer.arrayOffset(), buffer.limit())
file.flush()
file.close()
Expand Down Expand Up @@ -132,11 +132,13 @@ object S3ShuffleHelper extends Logging {
}

private def readBlockAsArray(blockId: BlockId) = {
val file = new BufferedInputStream(dispatcher.openBlock(blockId), dispatcher.bufferSize)
var buffer = new Array[Byte](1024)
val stat = dispatcher.getFileStatusCached(blockId)
val fsize = scala.math.min(stat.getLen.toInt, dispatcher.bufferSize)
val file = new BufferedInputStream(dispatcher.openBlock(blockId), fsize)
var buffer = new Array[Byte](fsize)
var numBytes = 0
var done = false
do {
while (!done) {
val c = file.read(buffer, numBytes, buffer.length - numBytes)
if (c >= 0) {
numBytes += c
Expand All @@ -146,7 +148,7 @@ object S3ShuffleHelper extends Logging {
} else {
done = true
}
} while (!done)
}
val serializerInstance = serializer.newInstance()
try {
val result = serializerInstance.deserialize[Array[Long]](ByteBuffer.wrap(buffer, 0, numBytes))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,35 +6,46 @@
package org.apache.spark.storage

import org.apache.spark.internal.Logging
import org.apache.spark.shuffle.helper.S3ShuffleDispatcher

import java.io.{BufferedInputStream, InputStream}
import java.util

class S3BufferedPrefetchIterator(iter: Iterator[(BlockId, S3ShuffleBlockStream)], maxBufferSize: Long) extends Iterator[(BlockId, InputStream)] with Logging {

private val concurrencyTask = S3ShuffleDispatcher.get.prefetchConcurrencyTask
private val startTime = System.nanoTime()

@volatile private var memoryUsage: Long = 0
@volatile private var hasItem: Boolean = iter.hasNext
private var timeWaiting: Long = 0
private var timePrefetching: Long = 0
private var timeNext: Long = 0
private var numStreams: Long = 0
private var bytesRead: Long = 0

private var nextElement: (BlockId, S3ShuffleBlockStream) = null
private var activeTasks: Long = 0

private val completed = new util.LinkedList[(InputStream, BlockId, Long)]()

private def prefetchThread(): Unit = {
while (iter.hasNext || nextElement != null) {
if (nextElement == null) {
val now = System.nanoTime()
nextElement = iter.next()
timeNext = System.nanoTime() - now
var nextElement: (BlockId, S3ShuffleBlockStream) = null
while (true) {
synchronized {
if (!iter.hasNext && nextElement == null) {
hasItem = false
return
}
if (nextElement == null) {
nextElement = iter.next()
activeTasks += 1
hasItem = iter.hasNext
}
}
val bsize = scala.math.min(maxBufferSize, nextElement._2.maxBytes).toInt

var fetchNext = false
val bsize = scala.math.min(maxBufferSize, nextElement._2.maxBytes).toInt
synchronized {
if (memoryUsage + math.min(bsize, maxBufferSize) > maxBufferSize) {
if (memoryUsage + bsize > maxBufferSize) {
try {
wait()
}
Expand All @@ -43,6 +54,7 @@ class S3BufferedPrefetchIterator(iter: Iterator[(BlockId, S3ShuffleBlockStream)]
}
} else {
fetchNext = true
memoryUsage += bsize
}
}

Expand All @@ -59,50 +71,49 @@ class S3BufferedPrefetchIterator(iter: Iterator[(BlockId, S3ShuffleBlockStream)]
timePrefetching += System.nanoTime() - now
bytesRead += bsize
synchronized {
memoryUsage += bsize
completed.push((stream, block, bsize))
hasItem = iter.hasNext
notify()
activeTasks -= 1
notifyAll()
}
}
}
}

private val self = this
private val thread = new Thread {
private val threads = Array.fill[Thread](concurrencyTask)(new Thread {
override def run(): Unit = {
self.prefetchThread()
}
}
thread.start()
})
threads.foreach(_.start())

private def printStatistics(): Unit = synchronized {
val totalRuntime = System.nanoTime() - startTime
try {
val tR = totalRuntime / 1000000
val wPer = 100 * timeWaiting / totalRuntime
val tW = timeWaiting / 1000000
val tP = timePrefetching / 1000000
val tN = timeNext / 1000000
val bR = bytesRead
val r = numStreams
// Average time per prefetch
val atP = tP / r
// Average time waiting
val atW = tW / r
// Average time next
val atN = tN / r
// Average read bandwidth
val bW = bR.toDouble / (tP.toDouble / 1000) / (1024 * 1024)
// Block size
val bs = bR / r
logInfo(s"Statistics: ${bR} bytes, ${tW} ms waiting (${atW} avg), " +
s"${tP} ms prefetching (avg: ${atP} ms - ${bs} block size - ${bW} MiB/s) " +
s"${tN} ms for next (${atN} avg)")
s"${tP} ms prefetching (avg: ${atP} ms - ${bs} block size - ${bW} MiB/s). " +
s"Total: ${tR} ms - ${wPer}% waiting")
} catch {
case e: Exception => logError(f"Unable to print statistics: ${e.getMessage}.")
}
}

override def hasNext: Boolean = synchronized {
val result = hasItem || (completed.size() > 0)
val result = hasItem || activeTasks > 0 || (completed.size() > 0)
if (!result) {
printStatistics()
}
Expand Down