diff --git a/third-party/thrift/src/thrift/lib/go/thrift/rocket_client.go b/third-party/thrift/src/thrift/lib/go/thrift/rocket_client.go index ccbc06dad91d4c..881a6491ab9bd2 100644 --- a/third-party/thrift/src/thrift/lib/go/thrift/rocket_client.go +++ b/third-party/thrift/src/thrift/lib/go/thrift/rocket_client.go @@ -105,7 +105,15 @@ func (p *rocketClient) WriteMessageEnd() error { func (p *rocketClient) Flush() (err error) { dataBytes := p.wbuf.Bytes() - if err := p.client.SendSetup(p.onServerMetadataPush); err != nil { + + ctx := context.Background() + if p.ioTimeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, p.ioTimeout) + defer cancel() + } + + if err := p.client.SendSetup(ctx, p.onServerMetadataPush); err != nil { return err } headers := unionMaps(p.reqHeaders, p.persistentHeaders) @@ -115,12 +123,6 @@ func (p *rocketClient) Flush() (err error) { if p.writeType != types.CALL { return nil } - ctx := context.Background() - if p.ioTimeout > 0 { - var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(ctx, p.ioTimeout) - defer cancel() - } p.respHeaders, p.resultData, p.resultErr = p.client.RequestResponse(ctx, p.messageName, p.protoID, p.writeType, headers, p.zstd, dataBytes) clear(p.reqHeaders) return nil diff --git a/third-party/thrift/src/thrift/lib/go/thrift/rocket_rsocket_client.go b/third-party/thrift/src/thrift/lib/go/thrift/rocket_rsocket_client.go index c63aad59b46eac..d32e44a9779874 100644 --- a/third-party/thrift/src/thrift/lib/go/thrift/rocket_rsocket_client.go +++ b/third-party/thrift/src/thrift/lib/go/thrift/rocket_rsocket_client.go @@ -31,7 +31,7 @@ import ( // RSocketClient is a client that uses a rsocket library. type RSocketClient interface { - SendSetup(onServerMetadataPush OnServerMetadataPush) error + SendSetup(ctx context.Context, onServerMetadataPush OnServerMetadataPush) error FireAndForget(messageName string, protoID types.ProtocolID, typeID types.MessageType, headers map[string]string, zstd bool, dataBytes []byte) error RequestResponse(ctx context.Context, messageName string, protoID types.ProtocolID, typeID types.MessageType, headers map[string]string, zstd bool, dataBytes []byte) (map[string]string, []byte, error) Close() error @@ -49,7 +49,7 @@ func newRSocketClient(conn net.Conn) RSocketClient { return &rsocketClient{conn: conn} } -func (r *rsocketClient) SendSetup(onServerMetadataPush OnServerMetadataPush) error { +func (r *rsocketClient) SendSetup(_ context.Context, onServerMetadataPush OnServerMetadataPush) error { if r.client != nil { // already setup return nil