Skip to content

Commit

Permalink
scalapb-runtime-grpc: optimize Marshaller for InProcessTransport (#1615)
Browse files Browse the repository at this point in the history
* scalapb-runtime-grpc: optimize Marshaller for InProcessTransport

* post-review: drop redundant check; fix 2.13 compilation

* post-review: formatting

* post-review: added unit test for ProtoInputStream

* remove code duplication

* post-review: format code
  • Loading branch information
hugebdu authored Dec 1, 2023
1 parent 8b57da3 commit 427db01
Show file tree
Hide file tree
Showing 9 changed files with 262 additions and 12 deletions.
1 change: 1 addition & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ val e2eCommonSettings = commonSettings ++ Seq(
publish / skip := true,
javacOptions ++= Seq("-Xlint:deprecation"),
libraryDependencies ++= Seq(
grpcInprocess,
grpcNetty,
grpcProtobuf,
grpcServices,
Expand Down
6 changes: 6 additions & 0 deletions e2e-grpc/src/main/protobuf/service.proto
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ message Res4 {
message Req5 {}
message Res5 {}

message Res6 {
Req1 req = 1;
}

message SealedRequest {
oneof sealed_value {
Req1 req1 = 1;
Expand Down Expand Up @@ -91,6 +95,8 @@ service Service1 {
rpc CustomUnary(XYMessage) returns (Res5) {};

rpc PrimitiveValues(google.protobuf.Int32Value) returns (google.protobuf.StringValue) {};

rpc EchoRequest(Req1) returns (Res6) {}
}

service Issue774 {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ import io.grpc.stub.StreamObserver
import scala.concurrent.Future

class Service1ScalaImpl extends Service1 {

override def echoRequest(request: Req1): Future[Res6] =
Future.successful(Res6(Some(request)))

override def unaryStringLength(request: Req1): Future[Res1] =
Future.successful(Res1(length = request.request.length))

Expand Down
11 changes: 11 additions & 0 deletions e2e-grpc/src/test/scala/GrpcServiceScalaServerSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,17 @@ class GrpcServiceScalaServerSpec extends GrpcServiceSpecBase {
}
}

it("InProcessTransport skips serialization") {
withInMemoryTransportScalaServer { channel =>
val client = Service1GrpcScala.stub(channel)
val req = service.Req1(request = "AmIsraelChai")

val res = Await.result(client.echoRequest(req), 10.seconds)

res.req.get must be theSameInstanceAs(req)
}
}

it("companion object acts as stub factory") {
withScalaServer { channel =>
Service1GrpcScala.Service1Stub
Expand Down
37 changes: 28 additions & 9 deletions e2e-grpc/src/test/scala/GrpcServiceSpecBase.scala
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import java.util.concurrent.TimeUnit

import com.thesamet.pb.{Service1Interceptor, Service1JavaImpl, Service1ScalaImpl}
import com.thesamet.proto.e2e.service.{Service1Grpc => Service1GrpcScala}
import io.grpc.netty.{NegotiationType, NettyChannelBuilder, NettyServerBuilder}
import io.grpc.protobuf.services.ProtoReflectionService
import io.grpc.stub.StreamObserver
import io.grpc.{ManagedChannel, Server}
import io.grpc.inprocess.{InProcessChannelBuilder, InProcessServerBuilder}

import scala.concurrent.{ExecutionContext, Future, Promise}
import scala.util.Random
Expand All @@ -29,25 +29,44 @@ abstract class GrpcServiceSpecBase extends AnyFunSpec with Matchers {
withServer(_.addService(new Service1JavaImpl).intercept(new Service1Interceptor).build())(f)
}

protected[this] def withInMemoryTransportScalaServer[T](f: ManagedChannel => T): T = {
val channelName = "test-in-mem-server"
withManagedServer(
InProcessServerBuilder
.forName(channelName)
.addService(Service1GrpcScala.bindService(new Service1ScalaImpl, singleThreadExecutionContext))
.build()
) {
val channel = InProcessChannelBuilder
.forName(channelName)
.usePlaintext()
.build()
f(channel)
}
}

private[this] def withServer[T](
createServer: NettyServerBuilder => Server
)(f: ManagedChannel => T): T = {
val port = UniquePortGenerator.get()
val server = createServer(NettyServerBuilder.forPort(port))
try {
server.start()
val port = UniquePortGenerator.get()
withManagedServer(createServer(NettyServerBuilder.forPort(port))) {
val channel = NettyChannelBuilder
.forAddress("localhost", port)
.negotiationType(NegotiationType.PLAINTEXT)
.build()
f(channel)
} finally {
server.shutdown()
server.awaitTermination(3000, TimeUnit.MILLISECONDS)
()
}
}

private[this] def withManagedServer[T](server: Server)(f: => T): T = try {
server.start()
f
} finally {
server.shutdown()
server.awaitTermination(3000, TimeUnit.MILLISECONDS)
()
}

private[this] val singleThreadExecutionContext = new ExecutionContext {
override def reportFailure(cause: Throwable): Unit = cause.printStackTrace()

Expand Down
140 changes: 140 additions & 0 deletions e2e-grpc/src/test/scala/ProtoInputStreamSpec.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import com.thesamet.proto.e2e.service.SealedResponseMessage
import org.scalatest.funspec.AnyFunSpec
import org.scalatest.matchers.must.Matchers
import scalapb.grpc.ProtoInputStream
import com.thesamet.proto.e2e.service.Res1

import scala.util.Random

class ProtoInputStreamSpec extends AnyFunSpec with Matchers {

trait Setup {
// message.serializedSize == 4
val message = SealedResponseMessage(
SealedResponseMessage.SealedValue.Res1(Res1(42))
)
val stream = new ProtoInputStream(message)
def newBuffer = Array.fill[Byte](message.serializedSize * 2)(0)

def fullyDrainStream() = {
stream.read(newBuffer, 0, message.serializedSize + 1)
}

def partiallyDrainStream() = {
stream.read(newBuffer, 0, message.serializedSize - 1)
}
}

describe("#available()") {
it("returns full length for a fresh stream") {
new Setup {
stream.available() must be(message.serializedSize)
}
}

it("returns zero for drained stream") {
new Setup {
fullyDrainStream()

stream.available() must be(0)
}
}

it("returns remaining length for partially drained stream") {
new Setup {
partiallyDrainStream()

stream.available() must be(1)
}
}
}

describe("#read(buffer, offset, length)") {

it("returns -1 for a fully drained stream") {
new Setup {
fullyDrainStream()

stream.read(newBuffer, 0, 10) must be(-1)
}
}

it("returns requested length and fills the buffer") {
new Setup {
val buf1 = newBuffer
stream.read(buf1, 0, 2) must be(2)
buf1.take(2) must be(message.toByteArray.take(2))
}
}

it("fully readable in chunks") {
new Setup {
var offset = 0
var count = 0
var buf = newBuffer
val res = Array.newBuilder[Byte]
do {
res ++= buf.slice(offset, offset + count)
buf = newBuffer
offset += count
count = stream.read(buf, offset, Random.nextInt(3))
} while (count !== -1)

res.result() must be(message.toByteArray)
}
}
}

describe("#read()") {

it("returns bytes for a fresh stream") {
new Setup {
val bytes = message.toByteArray

stream.read() must be(bytes(0))
stream.read() must be(bytes(1))
stream.read() must be(bytes(2))
}
}

it("returns -1 when fully drained") {
new Setup {
fullyDrainStream()

stream.read() must be(-1)
}
}

it("returns next byte when partially drained") {
new Setup {
partiallyDrainStream()

stream.read() must be(message.toByteArray.last)
}
}
}

describe("#message") {
it("returns the same instance passed in the constructor") {
new Setup {
stream.message must be theSameInstanceAs (message)
}
}

it("throws when fully drained") {
new Setup {
fullyDrainStream()

an[IllegalStateException] should be thrownBy stream.message
}
}

it("throws when partially drained") {
new Setup {
partiallyDrainStream()

an[IllegalStateException] should be thrownBy stream.message
}
}
}
}
1 change: 1 addition & 0 deletions project/Dependencies.scala
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ object Dependencies {
val grpcNetty = "io.grpc" % "grpc-netty" % versions.grpc
val grpcServices = "io.grpc" % "grpc-services" % versions.grpc
val grpcProtocGen = "io.grpc" % "protoc-gen-grpc-java" % versions.grpc
val grpcInprocess = "io.grpc" % "grpc-inprocess" % versions.grpc

// testing
val scalaTest = Def.setting { "org.scalatest" %%% "scalatest" % versions.scalaTest }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@ import scalapb.{GeneratedMessage, GeneratedMessageCompanion, TypeMapper}

class Marshaller[T <: GeneratedMessage](companion: GeneratedMessageCompanion[T])
extends io.grpc.MethodDescriptor.Marshaller[T] {
override def stream(t: T): InputStream = new ByteArrayInputStream(t.toByteArray)
override def stream(t: T): InputStream = new ProtoInputStream[T](t)

override def parse(inputStream: InputStream): T =
companion.parseFrom(inputStream)
override def parse(inputStream: InputStream): T = inputStream match {
case pis: ProtoInputStream[_] => pis.message.asInstanceOf[T]
case _ => companion.parseFrom(inputStream)
}
}

class TypeMappedMarshaller[T <: GeneratedMessage, Custom](
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package scalapb.grpc

import com.google.protobuf.CodedOutputStream
import scalapb.GeneratedMessage

import java.io.{ByteArrayInputStream, InputStream}

/** Allows skipping serialization completely when the io.grpc.inprocess.InProcessTransport is used.
* Inspired by
* https://github.com/grpc/grpc-java/blob/master/protobuf-lite/src/main/java/io/grpc/protobuf/lite/ProtoInputStream.java
*/
class ProtoInputStream[T <: GeneratedMessage](msg: T) extends InputStream {

private var state: State = Message(msg)

private sealed trait State {
def message: T = throw new IllegalStateException("message not available")
def available: Int
def read(): Int
def read(b: Array[Byte], off: Int, len: Int): Int
}

private object Drained extends State {
override def available: Int = 0
override def read(): Int = -1
override def read(b: Array[Byte], off: Int, len: Int): Int = -1
}

private case class Message(value: T) extends State {
override def available: Int = value.serializedSize
override def message: T = value
override def read(): Int = toStream.read()
override def read(b: Array[Byte], off: Int, len: Int): Int = {
value.serializedSize match {
case 0 => toDrained.read(b, off, len)
case size if size <= len =>
val stream = CodedOutputStream.newInstance(b, off, size)
message.writeTo(stream)
stream.flush()
stream.checkNoSpaceLeft()
toDrained
size
case _ => toStream.read(b, off, len)
}
}
private def toStream: State = {
state = Stream(new ByteArrayInputStream(value.toByteArray))
state
}
private def toDrained: State = {
state = Drained
state
}
}

private case class Stream(value: InputStream) extends State {
override def available: Int = value.available()
override def read(): Int = value.read()
override def read(b: Array[Byte], off: Int, len: Int): Int = value.read(b, off, len)
}

override def read(): Int = state.read()
override def read(b: Array[Byte], off: Int, len: Int): Int = state.read(b, off, len)
override def available(): Int = state.available
def message: T = state.message
}

0 comments on commit 427db01

Please sign in to comment.