Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

scalapb-runtime-grpc: optimize Marshaller for InProcessTransport #1615

Merged
merged 6 commits into from
Dec 1, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ val e2eCommonSettings = commonSettings ++ Seq(
publish / skip := true,
javacOptions ++= Seq("-Xlint:deprecation"),
libraryDependencies ++= Seq(
grpcInprocess,
grpcNetty,
grpcProtobuf,
grpcServices,
Expand Down
6 changes: 6 additions & 0 deletions e2e-grpc/src/main/protobuf/service.proto
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ message Res4 {
message Req5 {}
message Res5 {}

message Res6 {
Req1 req = 1;
}

message SealedRequest {
oneof sealed_value {
Req1 req1 = 1;
Expand Down Expand Up @@ -91,6 +95,8 @@ service Service1 {
rpc CustomUnary(XYMessage) returns (Res5) {};

rpc PrimitiveValues(google.protobuf.Int32Value) returns (google.protobuf.StringValue) {};

rpc EchoRequest(Req1) returns (Res6) {}
}

service Issue774 {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ import io.grpc.stub.StreamObserver
import scala.concurrent.Future

class Service1ScalaImpl extends Service1 {

override def echoRequest(request: Req1): Future[Res6] =
Future.successful(Res6(Some(request)))

override def unaryStringLength(request: Req1): Future[Res1] =
Future.successful(Res1(length = request.request.length))

Expand Down
11 changes: 11 additions & 0 deletions e2e-grpc/src/test/scala/GrpcServiceScalaServerSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,17 @@ class GrpcServiceScalaServerSpec extends GrpcServiceSpecBase {
}
}

it("InProcessTransport skips serialization") {
withInMemoryTransportScalaServer { channel =>
val client = Service1GrpcScala.stub(channel)
val req = service.Req1(request = "AmIsraelChai")

val res = Await.result(client.echoRequest(req), 10.seconds)

res.req.get must be theSameInstanceAs(req)
}
}

it("companion object acts as stub factory") {
withScalaServer { channel =>
Service1GrpcScala.Service1Stub
Expand Down
21 changes: 20 additions & 1 deletion e2e-grpc/src/test/scala/GrpcServiceSpecBase.scala
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import java.util.concurrent.TimeUnit

import com.thesamet.pb.{Service1Interceptor, Service1JavaImpl, Service1ScalaImpl}
import com.thesamet.proto.e2e.service.{Service1Grpc => Service1GrpcScala}
import io.grpc.netty.{NegotiationType, NettyChannelBuilder, NettyServerBuilder}
import io.grpc.protobuf.services.ProtoReflectionService
import io.grpc.stub.StreamObserver
import io.grpc.{ManagedChannel, Server}
import io.grpc.inprocess.{InProcessChannelBuilder, InProcessServerBuilder}

import scala.concurrent.{ExecutionContext, Future, Promise}
import scala.util.Random
Expand All @@ -29,6 +29,25 @@ abstract class GrpcServiceSpecBase extends AnyFunSpec with Matchers {
withServer(_.addService(new Service1JavaImpl).intercept(new Service1Interceptor).build())(f)
}

protected[this] def withInMemoryTransportScalaServer[T](f: ManagedChannel => T): T = {
val channelName = "test-in-mem-server"
val server = InProcessServerBuilder.forName(channelName)
.addService(Service1GrpcScala.bindService(new Service1ScalaImpl, singleThreadExecutionContext))
.build()

try {
server.start()
val channel = InProcessChannelBuilder.forName(channelName)
.usePlaintext()
.build()
f(channel)
} finally {
server.shutdown()
server.awaitTermination(3000, TimeUnit.MILLISECONDS)
()
}
}

private[this] def withServer[T](
createServer: NettyServerBuilder => Server
)(f: ManagedChannel => T): T = {
Expand Down
1 change: 1 addition & 0 deletions project/Dependencies.scala
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ object Dependencies {
val grpcNetty = "io.grpc" % "grpc-netty" % versions.grpc
val grpcServices = "io.grpc" % "grpc-services" % versions.grpc
val grpcProtocGen = "io.grpc" % "protoc-gen-grpc-java" % versions.grpc
val grpcInprocess = "io.grpc" % "grpc-inprocess" % versions.grpc
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: align the whitespace?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


// testing
val scalaTest = Def.setting { "org.scalatest" %%% "scalatest" % versions.scalaTest }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@ import scalapb.{GeneratedMessage, GeneratedMessageCompanion, TypeMapper}

class Marshaller[T <: GeneratedMessage](companion: GeneratedMessageCompanion[T])
extends io.grpc.MethodDescriptor.Marshaller[T] {
override def stream(t: T): InputStream = new ByteArrayInputStream(t.toByteArray)
override def stream(t: T): InputStream = new ProtoInputStream[T](t, this)

override def parse(inputStream: InputStream): T =
companion.parseFrom(inputStream)
override def parse(inputStream: InputStream): T = inputStream match {
case pis: ProtoInputStream[T] if pis.marshaller == this => pis.message
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the reason for comparing the marshallers? if not needed, the val marshaller can be removed from the ProtoInputStream class.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed in 8d8256c

introduced it to reproduce the following precaution logic, however seems to be redundant indeed.

case _ => companion.parseFrom(inputStream)
}
}

class TypeMappedMarshaller[T <: GeneratedMessage, Custom](
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package scalapb.grpc

import com.google.protobuf.CodedOutputStream
import scalapb.GeneratedMessage

import java.io.{ByteArrayInputStream, InputStream}

/**
* Allows skipping serialization completely when the io.grpc.inprocess.InProcessTransport is used.
* Inspired by https://github.com/grpc/grpc-java/blob/master/protobuf-lite/src/main/java/io/grpc/protobuf/lite/ProtoInputStream.java
*/
class ProtoInputStream[T <: GeneratedMessage](msg: T, val marshaller: Marshaller[T]) extends InputStream {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add unit tests for this class?

Copy link
Contributor Author

@hugebdu hugebdu Nov 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added in 6c26c6a


private var state: State = Message(msg)

private sealed trait State {
def message: T = throw new IllegalStateException("message not available")
def available: Int
def read(): Int
def read(b: Array[Byte], off: Int, len: Int): Int
}

private object Drained extends State {
override def available: Int = 0
override def read(): Int = -1
override def read(b: Array[Byte], off: Int, len: Int): Int = -1
}

private case class Message(value: T) extends State {
override def available: Int = value.serializedSize
override def message: T = value
override def read(): Int = toStream.read()
override def read(b: Array[Byte], off: Int, len: Int): Int = {
value.serializedSize match {
case 0 => toDrained.read(b, off, len)
case size if size <= len =>
val stream = CodedOutputStream.newInstance(b, off, size)
message.writeTo(stream)
stream.flush()
stream.checkNoSpaceLeft()
toDrained
size
case _ => toStream.read(b, off, len)
}
}
private def toStream: State = {
state = Stream(new ByteArrayInputStream(value.toByteArray))
state
}
private def toDrained: State = {
state = Drained
state
}
}

private case class Stream(value: InputStream) extends State {
override def available: Int = value.available()
override def read(): Int = value.read()
override def read(b: Array[Byte], off: Int, len: Int): Int = value.read(b, off, len)
}

override def read(): Int = state.read()
override def read(b: Array[Byte], off: Int, len: Int): Int = state.read(b, off, len)
override def available(): Int = state.available
def message: T = state.message
}
Loading