Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

scalapb-runtime-grpc: optimize Marshaller for InProcessTransport #1615

Merged
merged 6 commits into from
Dec 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ val e2eCommonSettings = commonSettings ++ Seq(
publish / skip := true,
javacOptions ++= Seq("-Xlint:deprecation"),
libraryDependencies ++= Seq(
grpcInprocess,
grpcNetty,
grpcProtobuf,
grpcServices,
Expand Down
6 changes: 6 additions & 0 deletions e2e-grpc/src/main/protobuf/service.proto
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ message Res4 {
message Req5 {}
message Res5 {}

message Res6 {
Req1 req = 1;
}

message SealedRequest {
oneof sealed_value {
Req1 req1 = 1;
Expand Down Expand Up @@ -91,6 +95,8 @@ service Service1 {
rpc CustomUnary(XYMessage) returns (Res5) {};

rpc PrimitiveValues(google.protobuf.Int32Value) returns (google.protobuf.StringValue) {};

rpc EchoRequest(Req1) returns (Res6) {}
}

service Issue774 {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ import io.grpc.stub.StreamObserver
import scala.concurrent.Future

class Service1ScalaImpl extends Service1 {

override def echoRequest(request: Req1): Future[Res6] =
Future.successful(Res6(Some(request)))

override def unaryStringLength(request: Req1): Future[Res1] =
Future.successful(Res1(length = request.request.length))

Expand Down
11 changes: 11 additions & 0 deletions e2e-grpc/src/test/scala/GrpcServiceScalaServerSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,17 @@ class GrpcServiceScalaServerSpec extends GrpcServiceSpecBase {
}
}

it("InProcessTransport skips serialization") {
withInMemoryTransportScalaServer { channel =>
val client = Service1GrpcScala.stub(channel)
val req = service.Req1(request = "AmIsraelChai")

val res = Await.result(client.echoRequest(req), 10.seconds)

res.req.get must be theSameInstanceAs(req)
}
}

it("companion object acts as stub factory") {
withScalaServer { channel =>
Service1GrpcScala.Service1Stub
Expand Down
37 changes: 28 additions & 9 deletions e2e-grpc/src/test/scala/GrpcServiceSpecBase.scala
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import java.util.concurrent.TimeUnit

import com.thesamet.pb.{Service1Interceptor, Service1JavaImpl, Service1ScalaImpl}
import com.thesamet.proto.e2e.service.{Service1Grpc => Service1GrpcScala}
import io.grpc.netty.{NegotiationType, NettyChannelBuilder, NettyServerBuilder}
import io.grpc.protobuf.services.ProtoReflectionService
import io.grpc.stub.StreamObserver
import io.grpc.{ManagedChannel, Server}
import io.grpc.inprocess.{InProcessChannelBuilder, InProcessServerBuilder}

import scala.concurrent.{ExecutionContext, Future, Promise}
import scala.util.Random
Expand All @@ -29,25 +29,44 @@ abstract class GrpcServiceSpecBase extends AnyFunSpec with Matchers {
withServer(_.addService(new Service1JavaImpl).intercept(new Service1Interceptor).build())(f)
}

protected[this] def withInMemoryTransportScalaServer[T](f: ManagedChannel => T): T = {
val channelName = "test-in-mem-server"
withManagedServer(
InProcessServerBuilder
.forName(channelName)
.addService(Service1GrpcScala.bindService(new Service1ScalaImpl, singleThreadExecutionContext))
.build()
) {
val channel = InProcessChannelBuilder
.forName(channelName)
.usePlaintext()
.build()
f(channel)
}
}

private[this] def withServer[T](
createServer: NettyServerBuilder => Server
)(f: ManagedChannel => T): T = {
val port = UniquePortGenerator.get()
val server = createServer(NettyServerBuilder.forPort(port))
try {
server.start()
val port = UniquePortGenerator.get()
withManagedServer(createServer(NettyServerBuilder.forPort(port))) {
val channel = NettyChannelBuilder
.forAddress("localhost", port)
.negotiationType(NegotiationType.PLAINTEXT)
.build()
f(channel)
} finally {
server.shutdown()
server.awaitTermination(3000, TimeUnit.MILLISECONDS)
()
}
}

private[this] def withManagedServer[T](server: Server)(f: => T): T = try {
server.start()
f
} finally {
server.shutdown()
server.awaitTermination(3000, TimeUnit.MILLISECONDS)
()
}

private[this] val singleThreadExecutionContext = new ExecutionContext {
override def reportFailure(cause: Throwable): Unit = cause.printStackTrace()

Expand Down
140 changes: 140 additions & 0 deletions e2e-grpc/src/test/scala/ProtoInputStreamSpec.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import com.thesamet.proto.e2e.service.SealedResponseMessage
import org.scalatest.funspec.AnyFunSpec
import org.scalatest.matchers.must.Matchers
import scalapb.grpc.ProtoInputStream
import com.thesamet.proto.e2e.service.Res1

import scala.util.Random

class ProtoInputStreamSpec extends AnyFunSpec with Matchers {

trait Setup {
// message.serializedSize == 4
val message = SealedResponseMessage(
SealedResponseMessage.SealedValue.Res1(Res1(42))
)
val stream = new ProtoInputStream(message)
def newBuffer = Array.fill[Byte](message.serializedSize * 2)(0)

def fullyDrainStream() = {
stream.read(newBuffer, 0, message.serializedSize + 1)
}

def partiallyDrainStream() = {
stream.read(newBuffer, 0, message.serializedSize - 1)
}
}

describe("#available()") {
it("returns full length for a fresh stream") {
new Setup {
stream.available() must be(message.serializedSize)
}
}

it("returns zero for drained stream") {
new Setup {
fullyDrainStream()

stream.available() must be(0)
}
}

it("returns remaining length for partially drained stream") {
new Setup {
partiallyDrainStream()

stream.available() must be(1)
}
}
}

describe("#read(buffer, offset, length)") {

it("returns -1 for a fully drained stream") {
new Setup {
fullyDrainStream()

stream.read(newBuffer, 0, 10) must be(-1)
}
}

it("returns requested length and fills the buffer") {
new Setup {
val buf1 = newBuffer
stream.read(buf1, 0, 2) must be(2)
buf1.take(2) must be(message.toByteArray.take(2))
}
}

it("fully readable in chunks") {
new Setup {
var offset = 0
var count = 0
var buf = newBuffer
val res = Array.newBuilder[Byte]
do {
res ++= buf.slice(offset, offset + count)
buf = newBuffer
offset += count
count = stream.read(buf, offset, Random.nextInt(3))
} while (count !== -1)

res.result() must be(message.toByteArray)
}
}
}

describe("#read()") {

it("returns bytes for a fresh stream") {
new Setup {
val bytes = message.toByteArray

stream.read() must be(bytes(0))
stream.read() must be(bytes(1))
stream.read() must be(bytes(2))
}
}

it("returns -1 when fully drained") {
new Setup {
fullyDrainStream()

stream.read() must be(-1)
}
}

it("returns next byte when partially drained") {
new Setup {
partiallyDrainStream()

stream.read() must be(message.toByteArray.last)
}
}
}

describe("#message") {
it("returns the same instance passed in the constructor") {
new Setup {
stream.message must be theSameInstanceAs (message)
}
}

it("throws when fully drained") {
new Setup {
fullyDrainStream()

an[IllegalStateException] should be thrownBy stream.message
}
}

it("throws when partially drained") {
new Setup {
partiallyDrainStream()

an[IllegalStateException] should be thrownBy stream.message
}
}
}
}
1 change: 1 addition & 0 deletions project/Dependencies.scala
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ object Dependencies {
val grpcNetty = "io.grpc" % "grpc-netty" % versions.grpc
val grpcServices = "io.grpc" % "grpc-services" % versions.grpc
val grpcProtocGen = "io.grpc" % "protoc-gen-grpc-java" % versions.grpc
val grpcInprocess = "io.grpc" % "grpc-inprocess" % versions.grpc

// testing
val scalaTest = Def.setting { "org.scalatest" %%% "scalatest" % versions.scalaTest }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@ import scalapb.{GeneratedMessage, GeneratedMessageCompanion, TypeMapper}

class Marshaller[T <: GeneratedMessage](companion: GeneratedMessageCompanion[T])
extends io.grpc.MethodDescriptor.Marshaller[T] {
override def stream(t: T): InputStream = new ByteArrayInputStream(t.toByteArray)
override def stream(t: T): InputStream = new ProtoInputStream[T](t)

override def parse(inputStream: InputStream): T =
companion.parseFrom(inputStream)
override def parse(inputStream: InputStream): T = inputStream match {
case pis: ProtoInputStream[_] => pis.message.asInstanceOf[T]
case _ => companion.parseFrom(inputStream)
}
}

class TypeMappedMarshaller[T <: GeneratedMessage, Custom](
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package scalapb.grpc

import com.google.protobuf.CodedOutputStream
import scalapb.GeneratedMessage

import java.io.{ByteArrayInputStream, InputStream}

/** Allows skipping serialization completely when the io.grpc.inprocess.InProcessTransport is used.
* Inspired by
* https://github.com/grpc/grpc-java/blob/master/protobuf-lite/src/main/java/io/grpc/protobuf/lite/ProtoInputStream.java
*/
class ProtoInputStream[T <: GeneratedMessage](msg: T) extends InputStream {

private var state: State = Message(msg)

private sealed trait State {
def message: T = throw new IllegalStateException("message not available")
def available: Int
def read(): Int
def read(b: Array[Byte], off: Int, len: Int): Int
}

private object Drained extends State {
override def available: Int = 0
override def read(): Int = -1
override def read(b: Array[Byte], off: Int, len: Int): Int = -1
}

private case class Message(value: T) extends State {
override def available: Int = value.serializedSize
override def message: T = value
override def read(): Int = toStream.read()
override def read(b: Array[Byte], off: Int, len: Int): Int = {
value.serializedSize match {
case 0 => toDrained.read(b, off, len)
case size if size <= len =>
val stream = CodedOutputStream.newInstance(b, off, size)
message.writeTo(stream)
stream.flush()
stream.checkNoSpaceLeft()
toDrained
size
case _ => toStream.read(b, off, len)
}
}
private def toStream: State = {
state = Stream(new ByteArrayInputStream(value.toByteArray))
state
}
private def toDrained: State = {
state = Drained
state
}
}

private case class Stream(value: InputStream) extends State {
override def available: Int = value.available()
override def read(): Int = value.read()
override def read(b: Array[Byte], off: Int, len: Int): Int = value.read(b, off, len)
}

override def read(): Int = state.read()
override def read(b: Array[Byte], off: Int, len: Int): Int = state.read(b, off, len)
override def available(): Int = state.available
def message: T = state.message
}