From 7f6db511bb9da47cdbbf7bd33488a23719fe3c41 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pascal=20Spo=CC=88rri?= Date: Wed, 6 Sep 2023 11:38:04 +0200 Subject: [PATCH] Use multiple threads for prefetching. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Pascal SpoĢˆrri --- README.md | 1 + .../shuffle/helper/S3ShuffleDispatcher.scala | 4 +- .../storage/S3BufferedPrefetchIterator.scala | 53 +++++++++++-------- 3 files changed, 36 insertions(+), 22 deletions(-) diff --git a/README.md b/README.md index 72ae637..e08d0f4 100644 --- a/README.md +++ b/README.md @@ -40,6 +40,7 @@ Changing these values might have an impact on performance. - `spark.shuffle.s3.bufferSize`: Default buffer size when writing (default: `8388608`) - `spark.shuffle.s3.maxBufferSizeTask`: Maximum size of the buffered output streams per task (default: `134217728`) +- `spark.shuffle.s3.prefetchConcurrencyTask`: The per-task concurrency when prefetching (default: `2`). - `spark.shuffle.s3.cachePartitionLengths`: Cache partition lengths in memory (default: `true`) - `spark.shuffle.s3.cacheChecksums`: Cache checksums in memory (default: `true`) - `spark.shuffle.s3.cleanup`: Cleanup the shuffle files (default: `true`) diff --git a/src/main/scala/org/apache/spark/shuffle/helper/S3ShuffleDispatcher.scala b/src/main/scala/org/apache/spark/shuffle/helper/S3ShuffleDispatcher.scala index 9d7513e..b23958f 100644 --- a/src/main/scala/org/apache/spark/shuffle/helper/S3ShuffleDispatcher.scala +++ b/src/main/scala/org/apache/spark/shuffle/helper/S3ShuffleDispatcher.scala @@ -33,6 +33,7 @@ class S3ShuffleDispatcher extends Logging { // Optional val bufferSize: Int = conf.getInt("spark.shuffle.s3.bufferSize", defaultValue = 8 * 1024 * 1024) val maxBufferSizeTask: Int = conf.getInt("spark.shuffle.s3.maxBufferSizeTask", defaultValue = 128 * 1024 * 1024) + val prefetchConcurrencyTask: Int = conf.getInt("spark.shuffle.s3.prefetchConcurrencyTask", defaultValue = 2) val cachePartitionLengths: Boolean = conf.getBoolean("spark.shuffle.s3.cachePartitionLengths", defaultValue = true) val cacheChecksums: Boolean = conf.getBoolean("spark.shuffle.s3.cacheChecksums", defaultValue = true) val cleanupShuffleFiles: Boolean = conf.getBoolean("spark.shuffle.s3.cleanup", defaultValue = true) @@ -60,6 +61,7 @@ class S3ShuffleDispatcher extends Logging { // Optional logInfo(s"- spark.shuffle.s3.bufferSize=${bufferSize}") logInfo(s"- spark.shuffle.s3.maxBufferSizeTask=${maxBufferSizeTask}") + logInfo(s"- spark.shuffle.s3.prefetchConcurrencyTask=${prefetchConcurrencyTask}") logInfo(s"- spark.shuffle.s3.cachePartitionLengths=${cachePartitionLengths}") logInfo(s"- spark.shuffle.s3.cacheChecksums=${cacheChecksums}") logInfo(s"- spark.shuffle.s3.cleanup=${cleanupShuffleFiles}") @@ -112,7 +114,7 @@ class S3ShuffleDispatcher extends Logging { def openBlock(blockId: BlockId): FSDataInputStream = { val status = getFileStatusCached(blockId) val builder = fs.openFile(status.getPath).withFileStatus(status) - val stream = builder.build().get() + val stream = builder.build().get() if (canSetReadahead) { stream.setReadahead(0) } diff --git a/src/main/scala/org/apache/spark/storage/S3BufferedPrefetchIterator.scala b/src/main/scala/org/apache/spark/storage/S3BufferedPrefetchIterator.scala index ab624c5..16e9ba3 100644 --- a/src/main/scala/org/apache/spark/storage/S3BufferedPrefetchIterator.scala +++ b/src/main/scala/org/apache/spark/storage/S3BufferedPrefetchIterator.scala @@ -6,35 +6,46 @@ package org.apache.spark.storage import org.apache.spark.internal.Logging +import org.apache.spark.shuffle.helper.S3ShuffleDispatcher import java.io.{BufferedInputStream, InputStream} import java.util class S3BufferedPrefetchIterator(iter: Iterator[(BlockId, S3ShuffleBlockStream)], maxBufferSize: Long) extends Iterator[(BlockId, InputStream)] with Logging { + + private val concurrencyTask = S3ShuffleDispatcher.get.prefetchConcurrencyTask + private val startTime = System.nanoTime() + @volatile private var memoryUsage: Long = 0 @volatile private var hasItem: Boolean = iter.hasNext private var timeWaiting: Long = 0 private var timePrefetching: Long = 0 - private var timeNext: Long = 0 private var numStreams: Long = 0 private var bytesRead: Long = 0 - private var nextElement: (BlockId, S3ShuffleBlockStream) = null + private var activeTasks: Long = 0 private val completed = new util.LinkedList[(InputStream, BlockId, Long)]() private def prefetchThread(): Unit = { - while (iter.hasNext || nextElement != null) { - if (nextElement == null) { - val now = System.nanoTime() - nextElement = iter.next() - timeNext = System.nanoTime() - now + var nextElement: (BlockId, S3ShuffleBlockStream) = null + while (true) { + synchronized { + if (!iter.hasNext && nextElement == null) { + hasItem = false + return + } + if (nextElement == null) { + nextElement = iter.next() + activeTasks += 1 + hasItem = iter.hasNext + } } - val bsize = scala.math.min(maxBufferSize, nextElement._2.maxBytes).toInt var fetchNext = false + val bsize = scala.math.min(maxBufferSize, nextElement._2.maxBytes).toInt synchronized { - if (memoryUsage + math.min(bsize, maxBufferSize) > maxBufferSize) { + if (memoryUsage + bsize > maxBufferSize) { try { wait() } @@ -43,6 +54,7 @@ class S3BufferedPrefetchIterator(iter: Iterator[(BlockId, S3ShuffleBlockStream)] } } else { fetchNext = true + memoryUsage += bsize } } @@ -59,50 +71,49 @@ class S3BufferedPrefetchIterator(iter: Iterator[(BlockId, S3ShuffleBlockStream)] timePrefetching += System.nanoTime() - now bytesRead += bsize synchronized { - memoryUsage += bsize completed.push((stream, block, bsize)) - hasItem = iter.hasNext - notify() + activeTasks -= 1 + notifyAll() } } } } private val self = this - private val thread = new Thread { + private val threads = Array.fill[Thread](concurrencyTask)(new Thread { override def run(): Unit = { self.prefetchThread() } - } - thread.start() + }) + threads.foreach(_.start()) private def printStatistics(): Unit = synchronized { + val totalRuntime = System.nanoTime() - startTime try { + val tR = totalRuntime / 1000000 + val wPer = 100 * timeWaiting / totalRuntime val tW = timeWaiting / 1000000 val tP = timePrefetching / 1000000 - val tN = timeNext / 1000000 val bR = bytesRead val r = numStreams // Average time per prefetch val atP = tP / r // Average time waiting val atW = tW / r - // Average time next - val atN = tN / r // Average read bandwidth val bW = bR.toDouble / (tP.toDouble / 1000) / (1024 * 1024) // Block size val bs = bR / r logInfo(s"Statistics: ${bR} bytes, ${tW} ms waiting (${atW} avg), " + - s"${tP} ms prefetching (avg: ${atP} ms - ${bs} block size - ${bW} MiB/s) " + - s"${tN} ms for next (${atN} avg)") + s"${tP} ms prefetching (avg: ${atP} ms - ${bs} block size - ${bW} MiB/s). " + + s"Total: ${tR} ms - ${wPer}% waiting") } catch { case e: Exception => logError(f"Unable to print statistics: ${e.getMessage}.") } } override def hasNext: Boolean = synchronized { - val result = hasItem || (completed.size() > 0) + val result = hasItem || activeTasks > 0 || (completed.size() > 0) if (!result) { printStatistics() }