From 427db01d7f5c5ceeadd1fe8316f79d989d81d753 Mon Sep 17 00:00:00 2001 From: Daniel Shmuglin Date: Fri, 1 Dec 2023 06:34:47 +0200 Subject: [PATCH] scalapb-runtime-grpc: optimize Marshaller for InProcessTransport (#1615) * scalapb-runtime-grpc: optimize Marshaller for InProcessTransport * post-review: drop redundant check; fix 2.13 compilation * post-review: formatting * post-review: added unit test for ProtoInputStream * remove code duplication * post-review: format code --- build.sbt | 1 + e2e-grpc/src/main/protobuf/service.proto | 6 + .../com/thesamet/pb/Service1ScalaImpl.scala | 4 + .../scala/GrpcServiceScalaServerSpec.scala | 11 ++ .../src/test/scala/GrpcServiceSpecBase.scala | 37 +++-- .../src/test/scala/ProtoInputStreamSpec.scala | 140 ++++++++++++++++++ project/Dependencies.scala | 1 + .../main/scala/scalapb/grpc/Marshaller.scala | 8 +- .../scala/scalapb/grpc/ProtoInputStream.scala | 66 +++++++++ 9 files changed, 262 insertions(+), 12 deletions(-) create mode 100644 e2e-grpc/src/test/scala/ProtoInputStreamSpec.scala create mode 100644 scalapb-runtime-grpc/src/main/scala/scalapb/grpc/ProtoInputStream.scala diff --git a/build.sbt b/build.sbt index 204f17aa1..073ba8aef 100644 --- a/build.sbt +++ b/build.sbt @@ -303,6 +303,7 @@ val e2eCommonSettings = commonSettings ++ Seq( publish / skip := true, javacOptions ++= Seq("-Xlint:deprecation"), libraryDependencies ++= Seq( + grpcInprocess, grpcNetty, grpcProtobuf, grpcServices, diff --git a/e2e-grpc/src/main/protobuf/service.proto b/e2e-grpc/src/main/protobuf/service.proto index 94f216444..8dc3f72c4 100644 --- a/e2e-grpc/src/main/protobuf/service.proto +++ b/e2e-grpc/src/main/protobuf/service.proto @@ -33,6 +33,10 @@ message Res4 { message Req5 {} message Res5 {} +message Res6 { + Req1 req = 1; +} + message SealedRequest { oneof sealed_value { Req1 req1 = 1; @@ -91,6 +95,8 @@ service Service1 { rpc CustomUnary(XYMessage) returns (Res5) {}; rpc PrimitiveValues(google.protobuf.Int32Value) returns (google.protobuf.StringValue) {}; + + rpc EchoRequest(Req1) returns (Res6) {} } service Issue774 { diff --git a/e2e-grpc/src/main/scala/com/thesamet/pb/Service1ScalaImpl.scala b/e2e-grpc/src/main/scala/com/thesamet/pb/Service1ScalaImpl.scala index 3c1c01b9a..ebacfcdc5 100644 --- a/e2e-grpc/src/main/scala/com/thesamet/pb/Service1ScalaImpl.scala +++ b/e2e-grpc/src/main/scala/com/thesamet/pb/Service1ScalaImpl.scala @@ -10,6 +10,10 @@ import io.grpc.stub.StreamObserver import scala.concurrent.Future class Service1ScalaImpl extends Service1 { + + override def echoRequest(request: Req1): Future[Res6] = + Future.successful(Res6(Some(request))) + override def unaryStringLength(request: Req1): Future[Res1] = Future.successful(Res1(length = request.request.length)) diff --git a/e2e-grpc/src/test/scala/GrpcServiceScalaServerSpec.scala b/e2e-grpc/src/test/scala/GrpcServiceScalaServerSpec.scala index d31b5b11b..4ad4002d8 100644 --- a/e2e-grpc/src/test/scala/GrpcServiceScalaServerSpec.scala +++ b/e2e-grpc/src/test/scala/GrpcServiceScalaServerSpec.scala @@ -211,6 +211,17 @@ class GrpcServiceScalaServerSpec extends GrpcServiceSpecBase { } } + it("InProcessTransport skips serialization") { + withInMemoryTransportScalaServer { channel => + val client = Service1GrpcScala.stub(channel) + val req = service.Req1(request = "AmIsraelChai") + + val res = Await.result(client.echoRequest(req), 10.seconds) + + res.req.get must be theSameInstanceAs(req) + } + } + it("companion object acts as stub factory") { withScalaServer { channel => Service1GrpcScala.Service1Stub diff --git a/e2e-grpc/src/test/scala/GrpcServiceSpecBase.scala b/e2e-grpc/src/test/scala/GrpcServiceSpecBase.scala index d8affea18..38400f037 100644 --- a/e2e-grpc/src/test/scala/GrpcServiceSpecBase.scala +++ b/e2e-grpc/src/test/scala/GrpcServiceSpecBase.scala @@ -1,11 +1,11 @@ import java.util.concurrent.TimeUnit - import com.thesamet.pb.{Service1Interceptor, Service1JavaImpl, Service1ScalaImpl} import com.thesamet.proto.e2e.service.{Service1Grpc => Service1GrpcScala} import io.grpc.netty.{NegotiationType, NettyChannelBuilder, NettyServerBuilder} import io.grpc.protobuf.services.ProtoReflectionService import io.grpc.stub.StreamObserver import io.grpc.{ManagedChannel, Server} +import io.grpc.inprocess.{InProcessChannelBuilder, InProcessServerBuilder} import scala.concurrent.{ExecutionContext, Future, Promise} import scala.util.Random @@ -29,25 +29,44 @@ abstract class GrpcServiceSpecBase extends AnyFunSpec with Matchers { withServer(_.addService(new Service1JavaImpl).intercept(new Service1Interceptor).build())(f) } + protected[this] def withInMemoryTransportScalaServer[T](f: ManagedChannel => T): T = { + val channelName = "test-in-mem-server" + withManagedServer( + InProcessServerBuilder + .forName(channelName) + .addService(Service1GrpcScala.bindService(new Service1ScalaImpl, singleThreadExecutionContext)) + .build() + ) { + val channel = InProcessChannelBuilder + .forName(channelName) + .usePlaintext() + .build() + f(channel) + } + } + private[this] def withServer[T]( createServer: NettyServerBuilder => Server )(f: ManagedChannel => T): T = { - val port = UniquePortGenerator.get() - val server = createServer(NettyServerBuilder.forPort(port)) - try { - server.start() + val port = UniquePortGenerator.get() + withManagedServer(createServer(NettyServerBuilder.forPort(port))) { val channel = NettyChannelBuilder .forAddress("localhost", port) .negotiationType(NegotiationType.PLAINTEXT) .build() f(channel) - } finally { - server.shutdown() - server.awaitTermination(3000, TimeUnit.MILLISECONDS) - () } } + private[this] def withManagedServer[T](server: Server)(f: => T): T = try { + server.start() + f + } finally { + server.shutdown() + server.awaitTermination(3000, TimeUnit.MILLISECONDS) + () + } + private[this] val singleThreadExecutionContext = new ExecutionContext { override def reportFailure(cause: Throwable): Unit = cause.printStackTrace() diff --git a/e2e-grpc/src/test/scala/ProtoInputStreamSpec.scala b/e2e-grpc/src/test/scala/ProtoInputStreamSpec.scala new file mode 100644 index 000000000..56c2d4ff0 --- /dev/null +++ b/e2e-grpc/src/test/scala/ProtoInputStreamSpec.scala @@ -0,0 +1,140 @@ +import com.thesamet.proto.e2e.service.SealedResponseMessage +import org.scalatest.funspec.AnyFunSpec +import org.scalatest.matchers.must.Matchers +import scalapb.grpc.ProtoInputStream +import com.thesamet.proto.e2e.service.Res1 + +import scala.util.Random + +class ProtoInputStreamSpec extends AnyFunSpec with Matchers { + + trait Setup { + // message.serializedSize == 4 + val message = SealedResponseMessage( + SealedResponseMessage.SealedValue.Res1(Res1(42)) + ) + val stream = new ProtoInputStream(message) + def newBuffer = Array.fill[Byte](message.serializedSize * 2)(0) + + def fullyDrainStream() = { + stream.read(newBuffer, 0, message.serializedSize + 1) + } + + def partiallyDrainStream() = { + stream.read(newBuffer, 0, message.serializedSize - 1) + } + } + + describe("#available()") { + it("returns full length for a fresh stream") { + new Setup { + stream.available() must be(message.serializedSize) + } + } + + it("returns zero for drained stream") { + new Setup { + fullyDrainStream() + + stream.available() must be(0) + } + } + + it("returns remaining length for partially drained stream") { + new Setup { + partiallyDrainStream() + + stream.available() must be(1) + } + } + } + + describe("#read(buffer, offset, length)") { + + it("returns -1 for a fully drained stream") { + new Setup { + fullyDrainStream() + + stream.read(newBuffer, 0, 10) must be(-1) + } + } + + it("returns requested length and fills the buffer") { + new Setup { + val buf1 = newBuffer + stream.read(buf1, 0, 2) must be(2) + buf1.take(2) must be(message.toByteArray.take(2)) + } + } + + it("fully readable in chunks") { + new Setup { + var offset = 0 + var count = 0 + var buf = newBuffer + val res = Array.newBuilder[Byte] + do { + res ++= buf.slice(offset, offset + count) + buf = newBuffer + offset += count + count = stream.read(buf, offset, Random.nextInt(3)) + } while (count !== -1) + + res.result() must be(message.toByteArray) + } + } + } + + describe("#read()") { + + it("returns bytes for a fresh stream") { + new Setup { + val bytes = message.toByteArray + + stream.read() must be(bytes(0)) + stream.read() must be(bytes(1)) + stream.read() must be(bytes(2)) + } + } + + it("returns -1 when fully drained") { + new Setup { + fullyDrainStream() + + stream.read() must be(-1) + } + } + + it("returns next byte when partially drained") { + new Setup { + partiallyDrainStream() + + stream.read() must be(message.toByteArray.last) + } + } + } + + describe("#message") { + it("returns the same instance passed in the constructor") { + new Setup { + stream.message must be theSameInstanceAs (message) + } + } + + it("throws when fully drained") { + new Setup { + fullyDrainStream() + + an[IllegalStateException] should be thrownBy stream.message + } + } + + it("throws when partially drained") { + new Setup { + partiallyDrainStream() + + an[IllegalStateException] should be thrownBy stream.message + } + } + } +} diff --git a/project/Dependencies.scala b/project/Dependencies.scala index 33ece56f1..69b978a1b 100644 --- a/project/Dependencies.scala +++ b/project/Dependencies.scala @@ -60,6 +60,7 @@ object Dependencies { val grpcNetty = "io.grpc" % "grpc-netty" % versions.grpc val grpcServices = "io.grpc" % "grpc-services" % versions.grpc val grpcProtocGen = "io.grpc" % "protoc-gen-grpc-java" % versions.grpc + val grpcInprocess = "io.grpc" % "grpc-inprocess" % versions.grpc // testing val scalaTest = Def.setting { "org.scalatest" %%% "scalatest" % versions.scalaTest } diff --git a/scalapb-runtime-grpc/src/main/scala/scalapb/grpc/Marshaller.scala b/scalapb-runtime-grpc/src/main/scala/scalapb/grpc/Marshaller.scala index f06a7f817..4e6d1547d 100644 --- a/scalapb-runtime-grpc/src/main/scala/scalapb/grpc/Marshaller.scala +++ b/scalapb-runtime-grpc/src/main/scala/scalapb/grpc/Marshaller.scala @@ -6,10 +6,12 @@ import scalapb.{GeneratedMessage, GeneratedMessageCompanion, TypeMapper} class Marshaller[T <: GeneratedMessage](companion: GeneratedMessageCompanion[T]) extends io.grpc.MethodDescriptor.Marshaller[T] { - override def stream(t: T): InputStream = new ByteArrayInputStream(t.toByteArray) + override def stream(t: T): InputStream = new ProtoInputStream[T](t) - override def parse(inputStream: InputStream): T = - companion.parseFrom(inputStream) + override def parse(inputStream: InputStream): T = inputStream match { + case pis: ProtoInputStream[_] => pis.message.asInstanceOf[T] + case _ => companion.parseFrom(inputStream) + } } class TypeMappedMarshaller[T <: GeneratedMessage, Custom]( diff --git a/scalapb-runtime-grpc/src/main/scala/scalapb/grpc/ProtoInputStream.scala b/scalapb-runtime-grpc/src/main/scala/scalapb/grpc/ProtoInputStream.scala new file mode 100644 index 000000000..50363642d --- /dev/null +++ b/scalapb-runtime-grpc/src/main/scala/scalapb/grpc/ProtoInputStream.scala @@ -0,0 +1,66 @@ +package scalapb.grpc + +import com.google.protobuf.CodedOutputStream +import scalapb.GeneratedMessage + +import java.io.{ByteArrayInputStream, InputStream} + +/** Allows skipping serialization completely when the io.grpc.inprocess.InProcessTransport is used. + * Inspired by + * https://github.com/grpc/grpc-java/blob/master/protobuf-lite/src/main/java/io/grpc/protobuf/lite/ProtoInputStream.java + */ +class ProtoInputStream[T <: GeneratedMessage](msg: T) extends InputStream { + + private var state: State = Message(msg) + + private sealed trait State { + def message: T = throw new IllegalStateException("message not available") + def available: Int + def read(): Int + def read(b: Array[Byte], off: Int, len: Int): Int + } + + private object Drained extends State { + override def available: Int = 0 + override def read(): Int = -1 + override def read(b: Array[Byte], off: Int, len: Int): Int = -1 + } + + private case class Message(value: T) extends State { + override def available: Int = value.serializedSize + override def message: T = value + override def read(): Int = toStream.read() + override def read(b: Array[Byte], off: Int, len: Int): Int = { + value.serializedSize match { + case 0 => toDrained.read(b, off, len) + case size if size <= len => + val stream = CodedOutputStream.newInstance(b, off, size) + message.writeTo(stream) + stream.flush() + stream.checkNoSpaceLeft() + toDrained + size + case _ => toStream.read(b, off, len) + } + } + private def toStream: State = { + state = Stream(new ByteArrayInputStream(value.toByteArray)) + state + } + private def toDrained: State = { + state = Drained + state + } + } + + private case class Stream(value: InputStream) extends State { + override def available: Int = value.available() + override def read(): Int = value.read() + override def read(b: Array[Byte], off: Int, len: Int): Int = value.read(b, off, len) + } + + override def read(): Int = state.read() + override def read(b: Array[Byte], off: Int, len: Int): Int = state.read(b, off, len) + override def available(): Int = state.available + def message: T = state.message +}