diff --git a/README.md b/README.md index 72ae637..e08d0f4 100644 --- a/README.md +++ b/README.md @@ -40,6 +40,7 @@ Changing these values might have an impact on performance. - `spark.shuffle.s3.bufferSize`: Default buffer size when writing (default: `8388608`) - `spark.shuffle.s3.maxBufferSizeTask`: Maximum size of the buffered output streams per task (default: `134217728`) +- `spark.shuffle.s3.prefetchConcurrencyTask`: The per-task concurrency when prefetching (default: `2`). - `spark.shuffle.s3.cachePartitionLengths`: Cache partition lengths in memory (default: `true`) - `spark.shuffle.s3.cacheChecksums`: Cache checksums in memory (default: `true`) - `spark.shuffle.s3.cleanup`: Cleanup the shuffle files (default: `true`) diff --git a/src/main/scala/org/apache/spark/shuffle/helper/S3ShuffleDispatcher.scala b/src/main/scala/org/apache/spark/shuffle/helper/S3ShuffleDispatcher.scala index b730eb7..b23958f 100644 --- a/src/main/scala/org/apache/spark/shuffle/helper/S3ShuffleDispatcher.scala +++ b/src/main/scala/org/apache/spark/shuffle/helper/S3ShuffleDispatcher.scala @@ -33,6 +33,7 @@ class S3ShuffleDispatcher extends Logging { // Optional val bufferSize: Int = conf.getInt("spark.shuffle.s3.bufferSize", defaultValue = 8 * 1024 * 1024) val maxBufferSizeTask: Int = conf.getInt("spark.shuffle.s3.maxBufferSizeTask", defaultValue = 128 * 1024 * 1024) + val prefetchConcurrencyTask: Int = conf.getInt("spark.shuffle.s3.prefetchConcurrencyTask", defaultValue = 2) val cachePartitionLengths: Boolean = conf.getBoolean("spark.shuffle.s3.cachePartitionLengths", defaultValue = true) val cacheChecksums: Boolean = conf.getBoolean("spark.shuffle.s3.cacheChecksums", defaultValue = true) val cleanupShuffleFiles: Boolean = conf.getBoolean("spark.shuffle.s3.cleanup", defaultValue = true) @@ -60,6 +61,7 @@ class S3ShuffleDispatcher extends Logging { // Optional logInfo(s"- spark.shuffle.s3.bufferSize=${bufferSize}") logInfo(s"- spark.shuffle.s3.maxBufferSizeTask=${maxBufferSizeTask}") + logInfo(s"- spark.shuffle.s3.prefetchConcurrencyTask=${prefetchConcurrencyTask}") logInfo(s"- spark.shuffle.s3.cachePartitionLengths=${cachePartitionLengths}") logInfo(s"- spark.shuffle.s3.cacheChecksums=${cacheChecksums}") logInfo(s"- spark.shuffle.s3.cleanup=${cleanupShuffleFiles}") @@ -112,7 +114,7 @@ class S3ShuffleDispatcher extends Logging { def openBlock(blockId: BlockId): FSDataInputStream = { val status = getFileStatusCached(blockId) val builder = fs.openFile(status.getPath).withFileStatus(status) - val stream = builder.build().get() + val stream = builder.build().get() if (canSetReadahead) { stream.setReadahead(0) } @@ -121,7 +123,7 @@ class S3ShuffleDispatcher extends Logging { private val cachedFileStatus = new ConcurrentObjectMap[BlockId, FileStatus]() - private def getFileStatusCached(blockId: BlockId): FileStatus = { + def getFileStatusCached(blockId: BlockId): FileStatus = { cachedFileStatus.getOrElsePut(blockId, (value: BlockId) => { fs.getFileStatus(getPath(value)) }) diff --git a/src/main/scala/org/apache/spark/shuffle/helper/S3ShuffleHelper.scala b/src/main/scala/org/apache/spark/shuffle/helper/S3ShuffleHelper.scala index f70f03c..9b2f2fc 100644 --- a/src/main/scala/org/apache/spark/shuffle/helper/S3ShuffleHelper.scala +++ b/src/main/scala/org/apache/spark/shuffle/helper/S3ShuffleHelper.scala @@ -56,7 +56,7 @@ object S3ShuffleHelper extends Logging { def writeArrayAsBlock(blockId: BlockId, array: Array[Long]): Unit = { val serializerInstance = serializer.newInstance() val buffer = serializerInstance.serialize[Array[Long]](array) - val file = new BufferedOutputStream(dispatcher.createBlock(blockId), dispatcher.bufferSize) + val file = dispatcher.createBlock(blockId) file.write(buffer.array(), buffer.arrayOffset(), buffer.limit()) file.flush() file.close() @@ -132,11 +132,13 @@ object S3ShuffleHelper extends Logging { } private def readBlockAsArray(blockId: BlockId) = { - val file = new BufferedInputStream(dispatcher.openBlock(blockId), dispatcher.bufferSize) - var buffer = new Array[Byte](1024) + val stat = dispatcher.getFileStatusCached(blockId) + val fsize = scala.math.min(stat.getLen.toInt, dispatcher.bufferSize) + val file = new BufferedInputStream(dispatcher.openBlock(blockId), fsize) + var buffer = new Array[Byte](fsize) var numBytes = 0 var done = false - do { + while (!done) { val c = file.read(buffer, numBytes, buffer.length - numBytes) if (c >= 0) { numBytes += c @@ -146,7 +148,7 @@ object S3ShuffleHelper extends Logging { } else { done = true } - } while (!done) + } val serializerInstance = serializer.newInstance() try { val result = serializerInstance.deserialize[Array[Long]](ByteBuffer.wrap(buffer, 0, numBytes)) diff --git a/src/main/scala/org/apache/spark/storage/S3BufferedPrefetchIterator.scala b/src/main/scala/org/apache/spark/storage/S3BufferedPrefetchIterator.scala index ab624c5..16e9ba3 100644 --- a/src/main/scala/org/apache/spark/storage/S3BufferedPrefetchIterator.scala +++ b/src/main/scala/org/apache/spark/storage/S3BufferedPrefetchIterator.scala @@ -6,35 +6,46 @@ package org.apache.spark.storage import org.apache.spark.internal.Logging +import org.apache.spark.shuffle.helper.S3ShuffleDispatcher import java.io.{BufferedInputStream, InputStream} import java.util class S3BufferedPrefetchIterator(iter: Iterator[(BlockId, S3ShuffleBlockStream)], maxBufferSize: Long) extends Iterator[(BlockId, InputStream)] with Logging { + + private val concurrencyTask = S3ShuffleDispatcher.get.prefetchConcurrencyTask + private val startTime = System.nanoTime() + @volatile private var memoryUsage: Long = 0 @volatile private var hasItem: Boolean = iter.hasNext private var timeWaiting: Long = 0 private var timePrefetching: Long = 0 - private var timeNext: Long = 0 private var numStreams: Long = 0 private var bytesRead: Long = 0 - private var nextElement: (BlockId, S3ShuffleBlockStream) = null + private var activeTasks: Long = 0 private val completed = new util.LinkedList[(InputStream, BlockId, Long)]() private def prefetchThread(): Unit = { - while (iter.hasNext || nextElement != null) { - if (nextElement == null) { - val now = System.nanoTime() - nextElement = iter.next() - timeNext = System.nanoTime() - now + var nextElement: (BlockId, S3ShuffleBlockStream) = null + while (true) { + synchronized { + if (!iter.hasNext && nextElement == null) { + hasItem = false + return + } + if (nextElement == null) { + nextElement = iter.next() + activeTasks += 1 + hasItem = iter.hasNext + } } - val bsize = scala.math.min(maxBufferSize, nextElement._2.maxBytes).toInt var fetchNext = false + val bsize = scala.math.min(maxBufferSize, nextElement._2.maxBytes).toInt synchronized { - if (memoryUsage + math.min(bsize, maxBufferSize) > maxBufferSize) { + if (memoryUsage + bsize > maxBufferSize) { try { wait() } @@ -43,6 +54,7 @@ class S3BufferedPrefetchIterator(iter: Iterator[(BlockId, S3ShuffleBlockStream)] } } else { fetchNext = true + memoryUsage += bsize } } @@ -59,50 +71,49 @@ class S3BufferedPrefetchIterator(iter: Iterator[(BlockId, S3ShuffleBlockStream)] timePrefetching += System.nanoTime() - now bytesRead += bsize synchronized { - memoryUsage += bsize completed.push((stream, block, bsize)) - hasItem = iter.hasNext - notify() + activeTasks -= 1 + notifyAll() } } } } private val self = this - private val thread = new Thread { + private val threads = Array.fill[Thread](concurrencyTask)(new Thread { override def run(): Unit = { self.prefetchThread() } - } - thread.start() + }) + threads.foreach(_.start()) private def printStatistics(): Unit = synchronized { + val totalRuntime = System.nanoTime() - startTime try { + val tR = totalRuntime / 1000000 + val wPer = 100 * timeWaiting / totalRuntime val tW = timeWaiting / 1000000 val tP = timePrefetching / 1000000 - val tN = timeNext / 1000000 val bR = bytesRead val r = numStreams // Average time per prefetch val atP = tP / r // Average time waiting val atW = tW / r - // Average time next - val atN = tN / r // Average read bandwidth val bW = bR.toDouble / (tP.toDouble / 1000) / (1024 * 1024) // Block size val bs = bR / r logInfo(s"Statistics: ${bR} bytes, ${tW} ms waiting (${atW} avg), " + - s"${tP} ms prefetching (avg: ${atP} ms - ${bs} block size - ${bW} MiB/s) " + - s"${tN} ms for next (${atN} avg)") + s"${tP} ms prefetching (avg: ${atP} ms - ${bs} block size - ${bW} MiB/s). " + + s"Total: ${tR} ms - ${wPer}% waiting") } catch { case e: Exception => logError(f"Unable to print statistics: ${e.getMessage}.") } } override def hasNext: Boolean = synchronized { - val result = hasItem || (completed.size() > 0) + val result = hasItem || activeTasks > 0 || (completed.size() > 0) if (!result) { printStatistics() }