Skip to content

Commit

Permalink
Implemented a buffered stream using a elastic byte buffer pool.
Browse files Browse the repository at this point in the history
Signed-off-by: Pascal Spörri <[email protected]>
  • Loading branch information
pspoerri committed Aug 22, 2023
1 parent ed2404e commit 502cacf
Show file tree
Hide file tree
Showing 3 changed files with 179 additions and 48 deletions.
130 changes: 130 additions & 0 deletions src/main/scala/org/apache/spark/storage/S3BufferedBlockStream.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
/**
* Copyright 2023- IBM Inc. All rights reserved
* SPDX-License-Identifier: Apache-2.0
*/

/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.storage

import org.apache.hadoop.io.ElasticByteBufferPool
import org.apache.spark.internal.Logging

import java.io.InputStream
import java.nio.ByteBuffer
import org.apache.spark.shuffle.helper.S3ShuffleDispatcher
import org.apache.spark.storage.S3BufferedBlockStream.{releaseBuffer, getBuffer}


/**
* Adapted class from org.apache.spark.io.NioBufferedFileInputStream
*/
class S3BufferedBlockStream(stream: S3ShuffleBlockStream) extends InputStream with Logging {
private lazy val dispatcher = S3ShuffleDispatcher.get

private var buffer: ByteBuffer = {
val bufsize = scala.math.min(stream.maxBytes, dispatcher.bufferInputSize).toInt
val buf = getBuffer(bufsize)
buf.clear()
buf
}

private def refillBuffer(force: Boolean = false): Boolean = synchronized {
if (buffer == null) {
return false
}
if (!buffer.hasRemaining() || force) {
buffer.clear()
val array = buffer.array()
val nRead = stream.readNBytes(array, 0, array.length)
buffer.position(nRead)
buffer.flip()
if (nRead < 0) {
close()
return false
}
}
true
}

override def close(): Unit = synchronized {
if (buffer == null) {
return
}
releaseBuffer(buffer)
buffer = null
stream.close()
super.close()
}

override def read(): Int = synchronized {
if (!refillBuffer()) {
return -1
}
buffer.get & 0xFF
}

override def read(b: Array[Byte], off: Int, len: Int): Int = synchronized {
if (off < 0 || len < 0 || off + len < 0 || off + len > b.length) {
throw new IndexOutOfBoundsException
}
if (!refillBuffer()) {
return -1
}
val length = scala.math.min(len, buffer.remaining())
buffer.get(b, off, length)
length
}

override def available(): Int = synchronized {
if (buffer == null) {
return 0
}
return buffer.remaining()
}

override def skip(n: Long): Long = synchronized {
if (buffer == null) {
return 0
}
if (n <= 0) {
return 0
}
if (buffer.remaining() >= n) {
// The buffered content is enough to skip
buffer.position(buffer.position() + n.toInt)
return n
}
val skippedFromBuffer = buffer.remaining()
val toSkipFromStream = n - skippedFromBuffer
// Discard everything we have read in the buffer.
buffer.position(0)
buffer.flip()
return skippedFromBuffer + stream.skip(toSkipFromStream)
}
}

object S3BufferedBlockStream {
private lazy val bufferPool = new ElasticByteBufferPool()

def getBuffer(size: Int): ByteBuffer = {
val bufferSize = scala.math.min(size, 4096)
bufferPool.getBuffer(false, bufferSize)
}

def releaseBuffer(buffer: ByteBuffer): Unit = {
bufferPool.putBuffer(buffer)
}
}
87 changes: 47 additions & 40 deletions src/main/scala/org/apache/spark/storage/S3ShuffleBlockStream.scala
Original file line number Diff line number Diff line change
Expand Up @@ -45,64 +45,71 @@ class S3ShuffleBlockStream(

private val singleByteBuffer = new Array[Byte](1)

override def close(): Unit = {
override def close(): Unit = synchronized {
if (streamClosed) {
return
}
this.synchronized {
if (dispatcher.supportsUnbuffer) {
stream.unbuffer()
streamClosed = true
} else {
stream.close()
streamClosed = true
}
if (dispatcher.supportsUnbuffer) {
stream.unbuffer()
streamClosed = true
} else {
stream.close()
streamClosed = true
}
super.close()
}

override def read(): Int = {
override def read(): Int = synchronized {
if (streamClosed || numBytes >= maxBytes) {
return -1
}
this.synchronized {
try {
stream.readFully(startPosition + numBytes, singleByteBuffer)
numBytes += 1
if (numBytes >= maxBytes) {
close()
}
return singleByteBuffer(0)
} catch {
case e: IOException =>
logError(f"Encountered an unexpected IOException: ${e.toString}")
close()
return -1
try {
stream.readFully(startPosition + numBytes, singleByteBuffer)
numBytes += 1
if (numBytes >= maxBytes) {
close()
}
return singleByteBuffer(0)
} catch {
case e: IOException =>
logError(f"Encountered an unexpected IOException: ${e.toString}")
close()
return -1
}
}

override def read(b: Array[Byte], off: Int, len: Int): Int = {
override def read(b: Array[Byte], off: Int, len: Int): Int = synchronized {
if (streamClosed || numBytes >= maxBytes) {
return -1
}
this.synchronized {
val maxLength = (maxBytes - numBytes).toInt
assert(maxLength >= 0)
val length = math.min(maxLength, len)
try {
stream.readFully(startPosition + numBytes, b, off, length)
numBytes += length
if (numBytes >= maxBytes) {
close()
}
return length
} catch {
case e: IOException =>
logError(f"Encountered an unexpected IOException: ${e.toString}")
close()
return -1
val maxLength = (maxBytes - numBytes).toInt
assert(maxLength >= 0)
val length = math.min(maxLength, len)
try {
stream.readFully(startPosition + numBytes, b, off, length)
numBytes += length
if (numBytes >= maxBytes) {
close()
}
return length
} catch {
case e: IOException =>
logError(f"Encountered an unexpected IOException: ${e.toString}")
close()
return -1
}
}

override def skip(n: Long): Long = synchronized {
if (streamClosed) {
return 0
}
val maxLength = maxBytes - numBytes
val skipped = math.min(n, maxLength)
numBytes += skipped
if (numBytes >= maxBytes) {
close()
}
return skipped
}
}
10 changes: 2 additions & 8 deletions src/main/scala/org/apache/spark/storage/S3ShuffleReader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

package org.apache.spark.storage

import org.apache.hadoop.io.ElasticByteBufferPool
import org.apache.spark.internal.{Logging, config}
import org.apache.spark.io.CompressionCodec
import org.apache.spark.serializer.SerializerManager
Expand Down Expand Up @@ -106,14 +107,7 @@ class S3ShuffleReader[K, C](
// NextIterator. The NextIterator makes sure that close() is called on the
// underlying InputStream when all records have been read.
Future {
val bufferSize = scala.math.min(wrappedStream.maxBytes, bufferInputSize).toInt
val stream = new BufferedInputStream(wrappedStream, bufferSize)

// Fill the buffered input stream by reading and then resetting the stream.
stream.mark(bufferSize)
stream.read()
stream.reset()

val stream = new S3BufferedBlockStream(wrappedStream)
val checkedStream = if (dispatcher.checksumEnabled) {
new S3ChecksumValidationStream(blockId, stream, dispatcher.checksumAlgorithm)
} else {
Expand Down

0 comments on commit 502cacf

Please sign in to comment.