From 51f3bb3b644b7306bb1c670a0089ee65c2f07154 Mon Sep 17 00:00:00 2001 From: Brandon Duffany Date: Mon, 13 May 2024 13:12:51 -0400 Subject: [PATCH] Add server streaming support --- .bazelignore | 2 +- BUILD | 10 --- WORKSPACE | 6 +- codegen.go | 105 +++++++++++++++++++++++--- .npmrc => test/.npmrc | 0 test/BUILD | 14 +++- test/defs.bzl | 10 +-- package.json => test/package.json | 7 -- pnpm-lock.yaml => test/pnpm-lock.yaml | 0 test/proto/service.proto | 3 + tsconfig.json => test/tsconfig.json | 0 ts.go | 30 ++++++-- 12 files changed, 140 insertions(+), 47 deletions(-) rename .npmrc => test/.npmrc (100%) rename package.json => test/package.json (70%) rename pnpm-lock.yaml => test/pnpm-lock.yaml (100%) rename tsconfig.json => test/tsconfig.json (100%) diff --git a/.bazelignore b/.bazelignore index 85dcc16..d270388 100644 --- a/.bazelignore +++ b/.bazelignore @@ -1,2 +1,2 @@ .git -node_modules +test/node_modules diff --git a/BUILD b/BUILD index 076e057..93e0889 100644 --- a/BUILD +++ b/BUILD @@ -1,7 +1,5 @@ -load("@aspect_rules_ts//ts:defs.bzl", "ts_config") load("@bazel_gazelle//:def.bzl", "gazelle") load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_library") -load("@npm//:defs.bzl", "npm_link_all_packages") # gazelle:prefix github.com/buildbuddy-io/protoc-gen-protobufjs gazelle(name = "gazelle") @@ -35,11 +33,3 @@ go_binary( embed = [":protoc-gen-protobufjs_lib"], visibility = ["//visibility:public"], ) - -npm_link_all_packages(name = "node_modules") - -ts_config( - name = "tsconfig", - src = "tsconfig.json", - visibility = ["//visibility:public"], -) diff --git a/WORKSPACE b/WORKSPACE index 8f78a30..9298d3e 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -85,8 +85,8 @@ load("@aspect_rules_js//npm:repositories.bzl", "npm_translate_lock") npm_translate_lock( name = "npm", - npmrc = "//:.npmrc", - pnpm_lock = "//:pnpm-lock.yaml", + npmrc = "//test:.npmrc", + pnpm_lock = "//test:pnpm-lock.yaml", verify_node_modules_ignored = "//:.bazelignore", ) @@ -106,7 +106,7 @@ http_archive( load("@aspect_rules_ts//ts:repositories.bzl", "rules_ts_dependencies") rules_ts_dependencies( - ts_version_from = "//:package.json", + ts_version_from = "//test:package.json", ) http_archive( diff --git a/codegen.go b/codegen.go index cff51c4..3bc0b77 100644 --- a/codegen.go +++ b/codegen.go @@ -22,6 +22,35 @@ var ( importPaths = map[string]string{} ) +// Top-level declarations that are duplicated at the top of each generated file. +// +// TODO: maybe publish an NPM package with these declarations so they aren't +// duplicated as much. +const ( + streamTypeDeclarations = ` +export namespace $stream { + export type ServerStream = { + /** Cancels the RPC. */ + cancel(): void; + } + + export type ServerStreamHandler = { + /** Handles a message on the stream. */ + next: (message: T) => void; + /** Handles an error on the stream. */ + error?: (e: any) => void; + /** Called when all messages are done being streamed. Not called on error. */ + complete?: () => void; + } + + export type StreamingRPCParams = { + signal: AbortSignal; + complete: () => void; + } +} +` +) + func generateCode(req *pluginpb.CodeGeneratorRequest) (*pluginpb.CodeGeneratorResponse, error) { for _, p := range *importPathFlag { parts := strings.SplitN(p, "=", 2) @@ -103,6 +132,9 @@ type Codegen struct { // Generated JS implementation (.js content) j *TS + // Whether the server-streaming RPC util needs to be generated. + HasServerStreamingRPC bool + // Paths to the top-level protos to be translated to TS. Paths []string // Out is the file being generated by this codegen. @@ -337,7 +369,7 @@ func (c *Codegen) generate(file *descriptorpb.FileDescriptorProto, sourcePath [] } // Constructor - d.Lf("constructor(properties?: I%s): %s;", messageType.GetName(), messageType.GetName()) + d.Lf("constructor(properties?: I%s);", messageType.GetName()) j.Lf("constructor(properties) {") // All repeated fields and maps are initialized to empty for _, f := range messageType.GetField() { @@ -686,7 +718,7 @@ func (c *Codegen) generate(file *descriptorpb.FileDescriptorProto, sourcePath [] d.L("toJSON(): Record") // getTypeUrl - d.L(`static getTypeUrl(typeUrlPrefix = "type.googleapis.com"): string;`) + d.L(`static getTypeUrl(typeUrlPrefix?: string): string;`) j.L(`static getTypeUrl(typeUrlPrefix = "type.googleapis.com") {`) j.Lf(`return typeUrlPrefix + "/%s.%s";`, ns, messageType.GetName()) j.L("}") @@ -752,9 +784,17 @@ func (c *Codegen) generate(file *descriptorpb.FileDescriptorProto, sourcePath [] for methodIndex, method := range serviceType.GetMethod() { methodPath := append(servicePath, serviceMethodTagNumber, int32(methodIndex)) d.BlockComment(c.Comments[file.GetName()][fmt.Sprint(methodPath)]) - d.Lf( - `%s: { readonly name: "%s" } & ((request: %s) => Promise<%s>);`, - serviceMethodJSName(method), method.GetName(), interfaceTypeName(c.resolveTypeName(method.GetInputType(), "")), c.resolveTypeName(method.GetOutputType(), "")) + if method.GetServerStreaming() { + d.AddTopLevelDeclaration(streamTypeDeclarations) + d.Lf(`%s: { readonly name: "%s"; readonly serverStreaming: true; } & ((request: %s, handler: $stream.ServerStreamHandler<%s>) => $stream.ServerStream<%s>);`, + serviceMethodJSName(method), method.GetName(), interfaceTypeName(c.resolveTypeName(method.GetInputType(), "")), c.resolveTypeName(method.GetOutputType(), ""), c.resolveTypeName(method.GetOutputType(), "")) + } else { + // Note: we don't support the callback style of method calling that + // protobufjs supports - only promises are supported for now. + d.Lf( + `%s: { readonly name: "%s"; readonly serverStreaming: false; } & ((request: %s) => Promise<%s>);`, + serviceMethodJSName(method), method.GetName(), interfaceTypeName(c.resolveTypeName(method.GetInputType(), "")), c.resolveTypeName(method.GetOutputType(), "")) + } } d.L("}") @@ -762,18 +802,18 @@ func (c *Codegen) generate(file *descriptorpb.FileDescriptorProto, sourcePath [] j.Lf("%s.%s = (() => {", curNS, serviceType.GetName()) j.Lf("class %s extends $protobuf.rpc.Service {", serviceType.GetName()) - d.L("constructor(rpcImpl: $protobuf.RPCImpl, requestDelimited = false, responseDelimited = false);") + d.L("constructor(rpcImpl: $protobuf.RPCImpl, requestDelimited?: boolean, responseDelimited?: boolean);") j.L("constructor(rpcImpl, requestDelimited = false, responseDelimited = false) {") j.Lf("super(rpcImpl, requestDelimited, responseDelimited);") j.L("}") - d.L("static create(rpcImpl: $protobuf.RPCImpl, requestDelimited = false, responseDelimited = false);") + d.Lf("static create(rpcImpl: $protobuf.RPCImpl, requestDelimited?: boolean, responseDelimited?: boolean): %s;", serviceType.GetName()) j.L("static create(rpcImpl, requestDelimited = false, responseDelimited = false) {") j.Lf("return new %s(rpcImpl, requestDelimited, responseDelimited);", serviceType.GetName()) j.L("}") for _, method := range serviceType.GetMethod() { - d.Lf(`%s!: I%s["%s"];`, serviceMethodJSName(method), serviceType.GetName(), serviceMethodJSName(method)) + d.Lf(`%s: I%s["%s"];`, serviceMethodJSName(method), serviceType.GetName(), serviceMethodJSName(method)) } // End class definition @@ -782,11 +822,52 @@ func (c *Codegen) generate(file *descriptorpb.FileDescriptorProto, sourcePath [] // Define service methods on class prototype, including readonly "name" prop for _, method := range serviceType.GetMethod() { - j.Lf( - "Object.defineProperty(%s.prototype.%s = function %s(request, callback) {", - serviceType.GetName(), serviceMethodJSName(method), serviceMethodJSName(method)) - j.Lf("return this.rpcCall(%s, %s, %s, request, callback);", serviceMethodJSName(method), c.resolveTypeName(method.GetInputType(), "$root."), c.resolveTypeName(method.GetOutputType(), "$root.")) + if method.GetServerStreaming() { + j.Lf( + "Object.defineProperty(%s.prototype.%s = function %s(request, handler) {", + serviceType.GetName(), serviceMethodJSName(method), serviceMethodJSName(method)) + j.Lf("if (!handler) throw TypeError('stream handler is required for server-streaming RPC');") + j.Lf(`if (!handler.next) throw TypeError("stream handler is missing 'next' callback property");`) + j.Lf("const controller = new AbortController();") + j.Lf("const stream = { cancel: () => controller.abort() };") + j.Lf("const callback = (error, data) => {") + // Ignore AbortError + j.Lf("if (error) {") + j.Lf("if (typeof error === 'object' && error.name === 'AbortError') {") + j.Lf("return; // stream canceled") + j.Lf("}") + j.Lf("if (handler.error) {") + j.Lf("handler.error(error)") + j.Lf("} else {") + j.Lf("console.error('Unhandled error in %s.%s RPC:', error)", serviceType.GetName(), method.GetName()) + j.Lf("}") + j.Lf("} else if (data) {") + j.Lf("handler.next(data);") + j.Lf("}") + j.Lf("}") + // Hack: temporarily patch rpcImpl to receive an extra options parameter + // indicating that this is a server-streaming RPC. This means that we + // will invoke the callback multiple times. + j.Lf(`const originalRPCImpl = this.rpcImpl;`) + j.Lf("const streamParams = { serverStream: true, complete: () => stream.complete(), signal: controller.signal };") + j.Lf(`this.rpcImpl = (method, data, callback) => originalRPCImpl(method, data, callback, streamParams);`) + j.Lf("try {") + j.Lf("this.rpcCall(%s, %s, %s, request, callback);", serviceMethodJSName(method), c.resolveTypeName(method.GetInputType(), "$root."), c.resolveTypeName(method.GetOutputType(), "$root.")) + j.Lf("return stream;") + j.Lf("} finally {") + j.Lf("this.rpcImpl = originalRPCImpl;") + j.Lf("}") + } else { + // Assume unary + j.Lf( + "Object.defineProperty(%s.prototype.%s = function %s(request) {", + serviceType.GetName(), serviceMethodJSName(method), serviceMethodJSName(method)) + j.Lf("return this.rpcCall(%s, %s, %s, request);", serviceMethodJSName(method), c.resolveTypeName(method.GetInputType(), "$root."), c.resolveTypeName(method.GetOutputType(), "$root.")) + } j.Lf(`}, "name", { value: "%s" });`, method.GetName()) + j.Lf( + `Object.defineProperty(%s.prototype.%s, "serverStreaming", { value: %t });`, + serviceType.GetName(), serviceMethodJSName(method), method.GetServerStreaming()) } j.Lf("return %s;", serviceType.GetName()) diff --git a/.npmrc b/test/.npmrc similarity index 100% rename from .npmrc rename to test/.npmrc diff --git a/test/BUILD b/test/BUILD index f76af80..a4981ad 100644 --- a/test/BUILD +++ b/test/BUILD @@ -1,14 +1,24 @@ +load("@aspect_rules_ts//ts:defs.bzl", "ts_config") +load("@npm//:defs.bzl", "npm_link_all_packages") load(":defs.bzl", "ts_jasmine_node_test") +npm_link_all_packages() + ts_jasmine_node_test( name = "encode_test", entry_point = "encode_test.ts", deps = [ - "//:node_modules/long", - "//:node_modules/protobufjs", + "//test:node_modules/long", + "//test:node_modules/protobufjs", "//test/proto:trivial_baseline_pbjs_proto", "//test/proto:trivial_ts_proto", "//test/proto:types_baseline_pbjs_proto", "//test/proto:types_ts_proto", ], ) + +ts_config( + name = "tsconfig", + src = "tsconfig.json", + visibility = ["//visibility:public"], +) diff --git a/test/defs.bzl b/test/defs.bzl index e574cac..772460c 100644 --- a/test/defs.bzl +++ b/test/defs.bzl @@ -2,13 +2,13 @@ load("@aspect_rules_esbuild//esbuild:defs.bzl", "esbuild") load("@aspect_rules_jasmine//jasmine:defs.bzl", "jasmine_test") load("@aspect_rules_js//js:defs.bzl", "js_library") load("@aspect_rules_ts//ts:defs.bzl", "ts_project") -load("@npm//:protobufjs-cli/package_json.bzl", protobufjs_cli_bin = "bin") +load("@npm//test:protobufjs-cli/package_json.bzl", protobufjs_cli_bin = "bin") load("//:rules.bzl", "protoc_gen_protobufjs") def ts_library(name, srcs, **kwargs): ts_project( name = name, - tsconfig = "//:tsconfig", + tsconfig = "//test:tsconfig", transpiler = "tsc", srcs = srcs, **kwargs @@ -77,8 +77,8 @@ def ts_proto_library(name, out, proto, **kwargs): def ts_jasmine_node_test(name, entry_point, deps = [], size = "small", **kwargs): deps = list(deps) deps.extend([ - "//:node_modules/@types/node", - "//:node_modules/@types/jasmine", + "//test:node_modules/@types/node", + "//test:node_modules/@types/jasmine", ]) ts_library( @@ -103,5 +103,5 @@ def ts_jasmine_node_test(name, entry_point, deps = [], size = "small", **kwargs) args = ["*.test.js"], chdir = native.package_name(), data = [":%s__bundle.test.js" % name] + deps, - node_modules = "//:node_modules", + node_modules = "//test:node_modules", ) diff --git a/package.json b/test/package.json similarity index 70% rename from package.json rename to test/package.json index 1a6e3eb..87d7312 100644 --- a/package.json +++ b/test/package.json @@ -1,14 +1,7 @@ { - "name": "protoc-gen-protobufjs", - "version": "0.0.1", - "description": "", - "main": "index.js", "scripts": { "test": "bazel test //..." }, - "keywords": [], - "author": "", - "license": "MIT", "dependencies": { "@bazel/jasmine": "^5.8.1", "@types/jasmine": "^5.1.4", diff --git a/pnpm-lock.yaml b/test/pnpm-lock.yaml similarity index 100% rename from pnpm-lock.yaml rename to test/pnpm-lock.yaml diff --git a/test/proto/service.proto b/test/proto/service.proto index 149e560..4d766bb 100644 --- a/test/proto/service.proto +++ b/test/proto/service.proto @@ -16,4 +16,7 @@ service Things { rpc GetThing(GetThingRequest) returns (GetThingResponse); rpc GetOtherThing(GetThingRequest) returns (GetThingResponse); + + // Streams things. + rpc GetThings(GetThingRequest) returns (stream GetThingResponse); } \ No newline at end of file diff --git a/tsconfig.json b/test/tsconfig.json similarity index 100% rename from tsconfig.json rename to test/tsconfig.json diff --git a/ts.go b/ts.go index d172582..b7f3d49 100644 --- a/ts.go +++ b/ts.go @@ -23,12 +23,13 @@ var ( type TS struct { JS bool - indentation string - buf string - lastLine string - imports map[string]*tsImport - scope []string - methodScopeIndex int + indentation string + buf string + lastLine string + imports map[string]*tsImport + topLevelDeclarations map[string]struct{} + scope []string + methodScopeIndex int } func (t *TS) Lf(format string, args ...any) { @@ -168,6 +169,13 @@ func (t *TS) BlockComment(comment string) { t.Lf(" */") } +func (t *TS) AddTopLevelDeclaration(code string) { + if t.topLevelDeclarations == nil { + t.topLevelDeclarations = map[string]struct{}{} + } + t.topLevelDeclarations[strings.TrimSpace(code)] = struct{}{} +} + func (t *TS) DefaultImport(pkg, nameAndAlias string) { t.addImport(&tsImport{pkg: pkg, defaultImport: nameAndAlias}) } @@ -202,7 +210,15 @@ func (g *TS) String() string { jsHeader += "\n" } - return jsHeader + importSection + g.buf + topLevelDeclarations := "" + for code := range g.topLevelDeclarations { + topLevelDeclarations += "\n" + code + "\n" + } + if len(g.topLevelDeclarations) != 0 { + topLevelDeclarations += "\n" + } + + return jsHeader + importSection + topLevelDeclarations + g.buf } func isMethodOrConstructorDefinitionStart(line string) bool {