Skip to content

Commit

Permalink
Make prefetcher asynchronous.
Browse files Browse the repository at this point in the history
Signed-off-by: Pascal Spörri <[email protected]>
  • Loading branch information
pspoerri committed Aug 30, 2023
1 parent 01068ec commit 494d9a2
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 37 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/**
* Copyright 2023- IBM Inc. All rights reserved
* SPDX-License-Identifier: Apache2.0
*/

package org.apache.spark.storage

import scala.collection.AbstractIterator
import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent.Future
import scala.util.{Failure, Success, Try}

class AsyncPrefetchIterator[A](iter: Iterator[A]) extends AbstractIterator[A] {

private var error: Option[Throwable] = Option.empty
private var value: Option[A] = Option.empty
private var nextValue: Boolean = false

override def hasNext: Boolean = synchronized {
populateNext()
nextValue
}

private def onComplete(result: Try[A]): Unit = synchronized {
result match {
case Failure(err) => error = Some(err)
case Success(v) =>
value = Some(v)
notifyAll()
}
}

private def populateNext(): Unit = synchronized {
if (nextValue) {
return
}
if (iter.hasNext) {
val fut = Future[A] {
iter.next()
}
fut.onComplete(onComplete)
nextValue = true
}
}

override def next(): A = synchronized {
while (value.isEmpty) {
if (error.isDefined) {
throw error.get
}
try {
wait()
}
catch {
case _: InterruptedException =>
Thread.currentThread.interrupt()
}
}
val result = value.get
value = Option.empty
nextValue = false
populateNext()
result
}
}
36 changes: 0 additions & 36 deletions src/main/scala/org/apache/spark/storage/PrefetchIterator.scala

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class S3ShuffleReader[K, C](
readMetrics.incRemoteBlocksFetched(1)
f
})
val recordIter = new S3BufferedPrefetchIterator(new PrefetchIterator(filteredStream), maxBufferSizeTask)
val recordIter = new S3BufferedPrefetchIterator(new AsyncPrefetchIterator(filteredStream), maxBufferSizeTask)
.flatMap(s => {
val stream = s._2
val blockId = s._1
Expand Down

0 comments on commit 494d9a2

Please sign in to comment.