From 85162ae2ae0c4bc7b67b9dc2a1ec85db2d9ec90c Mon Sep 17 00:00:00 2001 From: Bruno De Bus Date: Mon, 30 Sep 2024 20:24:09 +0200 Subject: [PATCH] Fixes => force handshake, working dshdev ACLs, add topic IDs --- proxy/client.go | 7 +++ proxy/processor.go | 7 +++ proxy/processor_default.go | 59 ++++++++++++++++++-- proxy/processor_default_test.go | 7 ++- proxy/protocol/responses.go | 99 +++++++++++++++++++++++++-------- 5 files changed, 148 insertions(+), 31 deletions(-) diff --git a/proxy/client.go b/proxy/client.go index 42352beb..ab69c6e4 100644 --- a/proxy/client.go +++ b/proxy/client.go @@ -1,6 +1,7 @@ package proxy import ( + "context" "crypto/tls" "crypto/x509" "fmt" @@ -310,6 +311,12 @@ func (c *Client) handleConn(conn Conn) { tlsConn, ok := conn.LocalConnection.(*tls.Conn) if ok { + err := tlsConn.HandshakeContext(context.TODO()) + if err != nil { + logrus.Info(err) + return + } + if len(tlsConn.ConnectionState().PeerCertificates) > 0 { commonName := tlsConn.ConnectionState().PeerCertificates[0].Subject.CommonName id = &commonName diff --git a/proxy/processor.go b/proxy/processor.go index becd2788..fcc4eb38 100644 --- a/proxy/processor.go +++ b/proxy/processor.go @@ -66,6 +66,8 @@ type processor struct { clientID *string + topicIDMap *protocol.TopicIDMap + // producer will never send request with acks=0 producerAcks0Disabled bool } @@ -109,6 +111,7 @@ func newProcessor(cfg ProcessorConfig, brokerAddress string, id *string) *proces writeTimeout: writeTimeout, brokerAddress: brokerAddress, clientID: id, + topicIDMap: protocol.NewTopicIDMap(), localSasl: cfg.LocalSasl, authServer: cfg.AuthServer, forbiddenApiKeys: cfg.ForbiddenApiKeys, @@ -132,6 +135,7 @@ func (p *processor) RequestsLoop(dst DeadlineWriter, src DeadlineReaderWriter) ( timeout: p.writeTimeout, brokerAddress: p.brokerAddress, clientID: p.clientID, + topicIDMap: p.topicIDMap, forbiddenApiKeys: p.forbiddenApiKeys, buf: make([]byte, p.requestBufferSize), localSasl: p.localSasl, @@ -150,6 +154,7 @@ type RequestsLoopContext struct { timeout time.Duration brokerAddress string clientID *string + topicIDMap *protocol.TopicIDMap forbiddenApiKeys map[int16]struct{} buf []byte // bufSize @@ -225,6 +230,7 @@ func (p *processor) ResponsesLoop(dst DeadlineWriter, src DeadlineReader) (readE netAddressMappingFunc: p.netAddressMappingFunc, timeout: p.readTimeout, brokerAddress: p.brokerAddress, + topicIDMap: p.topicIDMap, buf: make([]byte, p.responseBufferSize), } return ctx.responsesLoop(dst, src) @@ -236,6 +242,7 @@ type ResponsesLoopContext struct { netAddressMappingFunc config.NetAddressMappingFunc timeout time.Duration brokerAddress string + topicIDMap *protocol.TopicIDMap buf []byte // bufSize } diff --git a/proxy/processor_default.go b/proxy/processor_default.go index cef37c8d..c715e2c7 100644 --- a/proxy/processor_default.go +++ b/proxy/processor_default.go @@ -8,6 +8,7 @@ import ( "strconv" "time" + "github.com/google/uuid" "github.com/grepplabs/kafka-proxy/proxy/protocol" "github.com/sirupsen/logrus" "github.com/twmb/franz-go/pkg/kbin" @@ -122,12 +123,29 @@ func (handler *DefaultRequestHandler) handleRequest(dst DeadlineWriter, src Dead if ctx.clientID != nil { for _, topic := range request.Topics { - if *ctx.clientID == "test1" { - if topic.Topic != "allowed" { + switch *ctx.clientID { + case "test1": + switch topic.Topic { + case "scratch.bdbtest1.dshdev": + default: return true, errors.New(fmt.Sprintf("Client %s is not allowed to produce to %s", *ctx.clientID, topic.Topic)) } + + case "test2": + switch topic.Topic { + case "scratch.bdbtest2.dshdev": + default: + return true, errors.New(fmt.Sprintf("Client %s is not allowed to produce to %s", *ctx.clientID, topic.Topic)) + } + case "unit-test": + // TODO: make sure to remove this + default: + return true, errors.New(fmt.Sprintf("Client %s is not allowed to produce to %s", *ctx.clientID, topic.Topic)) } + } + } else { + return true, errors.New("Client not set") } case apiKeyFetch: request := kmsg.NewPtrFetchRequest() @@ -157,12 +175,41 @@ func (handler *DefaultRequestHandler) handleRequest(dst DeadlineWriter, src Dead if ctx.clientID != nil { for _, topic := range request.Topics { - if *ctx.clientID == "test2" { - if topic.Topic != "allowed" { - return true, errors.New(fmt.Sprintf("Client %s is not allowed to produce to %s", *ctx.clientID, topic.Topic)) + var name string + if requestKeyVersion.ApiVersion >= 13 { + asUUID, err := uuid.FromBytes(topic.TopicID[:]) + if err != nil { + return true, errors.New(fmt.Sprintf("Could not get topic ID as UUID: %v", topic.TopicID)) + } + + name = ctx.topicIDMap.Get(asUUID.String()) + } else { + name = topic.Topic + } + + switch *ctx.clientID { + case "test1": + switch name { + case "scratch.bdbtest1.dshdev": + default: + return true, errors.New(fmt.Sprintf("Client %s is not allowed to consume from %s", *ctx.clientID, name)) + } + + case "test2": + switch name { + case "scratch.bdbtest2.dshdev": + default: + return true, errors.New(fmt.Sprintf("Client %s is not allowed to consume from %s", *ctx.clientID, name)) } + case "unit-test": + // TODO: make sure to remove this + default: + return true, errors.New(fmt.Sprintf("Client %s is not allowed to consume from %s", *ctx.clientID, name)) } + } + } else { + return true, errors.New("Client not set") } } @@ -313,7 +360,7 @@ func (handler *DefaultResponseHandler) handleResponse(dst DeadlineWriter, src De if _, err = io.ReadFull(src, resp); err != nil { return true, err } - newResponseBuf, err := responseModifier.Apply(resp) + newResponseBuf, err := responseModifier.Apply(resp, ctx.topicIDMap) if err != nil { return true, err } diff --git a/proxy/processor_default_test.go b/proxy/processor_default_test.go index f2720f4a..d7b499b2 100644 --- a/proxy/processor_default_test.go +++ b/proxy/processor_default_test.go @@ -3,11 +3,12 @@ package proxy import ( "bytes" "encoding/hex" + "testing" + "time" + "github.com/grepplabs/kafka-proxy/proxy/protocol" "github.com/pkg/errors" "github.com/stretchr/testify/assert" - "testing" - "time" ) func TestHandleRequest(t *testing.T) { @@ -86,6 +87,7 @@ func TestHandleRequest(t *testing.T) { nextRequestHandlerChannel := make(chan RequestHandler, 1) nextResponseHandlerChannel := make(chan ResponseHandler, 1) + clientID := "unit-test" ctx := &RequestsLoopContext{ openRequestsChannel: openRequestsChannel, nextRequestHandlerChannel: nextRequestHandlerChannel, @@ -93,6 +95,7 @@ func TestHandleRequest(t *testing.T) { timeout: 1 * time.Second, buf: buf, localSasl: &LocalSasl{}, + clientID: &clientID, } a := assert.New(t) diff --git a/proxy/protocol/responses.go b/proxy/protocol/responses.go index 3309719b..ffa9da9c 100644 --- a/proxy/protocol/responses.go +++ b/proxy/protocol/responses.go @@ -3,7 +3,9 @@ package protocol import ( "errors" "fmt" + "sync" + "github.com/google/uuid" "github.com/grepplabs/kafka-proxy/config" ) @@ -12,13 +14,47 @@ const ( apiKeyFindCoordinator = 10 brokersKeyName = "brokers" - hostKeyName = "host" - portKeyName = "port" + topicsKeyName = "topic_metadata" + + topicIDKeyName = "topic_id" + topicNameKeyName = "name" + + hostKeyName = "host" + portKeyName = "port" coordinatorKeyName = "coordinator" coordinatorsKeyName = "coordinators" ) +type TopicIDMap struct { + mu sync.Mutex + idMap map[string]string +} + +func (tim *TopicIDMap) Add(uuid string, name string) { + tim.mu.Lock() + + //fmt.Printf("ADD %p: %s %s\n", tim, uuid, name) + tim.idMap[uuid] = name + tim.mu.Unlock() +} + +func (tim *TopicIDMap) Get(uuid string) string { + tim.mu.Lock() + defer tim.mu.Unlock() + + //fmt.Printf("GET %p: %s\n", tim, uuid) + return tim.idMap[uuid] +} + +var globalIDMap = &TopicIDMap{ + idMap: make(map[string]string), +} + +func NewTopicIDMap() *TopicIDMap { + return globalIDMap +} + var ( metadataResponseSchemaVersions = createMetadataResponseSchemaVersions() findCoordinatorResponseSchemaVersions = createFindCoordinatorResponseSchemaVersions() @@ -47,7 +83,7 @@ func createMetadataResponseSchemaVersions() []Schema { metadataResponseV0 := NewSchema("metadata_response_v0", &Array{Name: brokersKeyName, Ty: metadataBrokerV0}, - &Array{Name: "topic_metadata", Ty: topicMetadataV0}, + &Array{Name: topicsKeyName, Ty: topicMetadataV0}, ) metadataBrokerV1 := NewSchema("metadata_broker_v1", @@ -137,8 +173,8 @@ func createMetadataResponseSchemaVersions() []Schema { topicMetadataSchema10 := NewSchema("topic_metadata_schema10", &Mfield{Name: "error_code", Ty: TypeInt16}, - &Mfield{Name: "name", Ty: TypeCompactStr}, - &Mfield{Name: "topic_id", Ty: TypeUuid}, + &Mfield{Name: topicNameKeyName, Ty: TypeCompactStr}, + &Mfield{Name: topicIDKeyName, Ty: TypeUuid}, &Mfield{Name: "is_internal", Ty: TypeBool}, &CompactArray{Name: "partition_metadata", Ty: partitionMetadataSchema9}, &Mfield{Name: "topic_authorized_operations", Ty: TypeInt32}, @@ -147,8 +183,8 @@ func createMetadataResponseSchemaVersions() []Schema { topicMetadataSchema12 := NewSchema("topic_metadata_schema12", &Mfield{Name: "error_code", Ty: TypeInt16}, - &Mfield{Name: "name", Ty: TypeCompactNullableStr}, - &Mfield{Name: "topic_id", Ty: TypeUuid}, + &Mfield{Name: topicNameKeyName, Ty: TypeCompactNullableStr}, + &Mfield{Name: topicIDKeyName, Ty: TypeUuid}, &Mfield{Name: "is_internal", Ty: TypeBool}, &CompactArray{Name: "partition_metadata", Ty: partitionMetadataSchema9}, &Mfield{Name: "topic_authorized_operations", Ty: TypeInt32}, @@ -158,14 +194,14 @@ func createMetadataResponseSchemaVersions() []Schema { metadataResponseV1 := NewSchema("metadata_response_v1", &Array{Name: brokersKeyName, Ty: metadataBrokerV1}, &Mfield{Name: "controller_id", Ty: TypeInt32}, - &Array{Name: "topic_metadata", Ty: topicMetadataV1}, + &Array{Name: topicsKeyName, Ty: topicMetadataV1}, ) metadataResponseV2 := NewSchema("metadata_response_v2", &Array{Name: brokersKeyName, Ty: metadataBrokerV1}, &Mfield{Name: "cluster_id", Ty: TypeNullableStr}, &Mfield{Name: "controller_id", Ty: TypeInt32}, - &Array{Name: "topic_metadata", Ty: topicMetadataV1}, + &Array{Name: topicsKeyName, Ty: topicMetadataV1}, ) metadataResponseV3 := NewSchema("metadata_response_v3", @@ -173,7 +209,7 @@ func createMetadataResponseSchemaVersions() []Schema { &Array{Name: brokersKeyName, Ty: metadataBrokerV1}, &Mfield{Name: "cluster_id", Ty: TypeNullableStr}, &Mfield{Name: "controller_id", Ty: TypeInt32}, - &Array{Name: "topic_metadata", Ty: topicMetadataV1}, + &Array{Name: topicsKeyName, Ty: topicMetadataV1}, ) metadataResponseV4 := metadataResponseV3 @@ -183,7 +219,7 @@ func createMetadataResponseSchemaVersions() []Schema { &Array{Name: brokersKeyName, Ty: metadataBrokerV1}, &Mfield{Name: "cluster_id", Ty: TypeNullableStr}, &Mfield{Name: "controller_id", Ty: TypeInt32}, - &Array{Name: "topic_metadata", Ty: topicMetadataV2}, + &Array{Name: topicsKeyName, Ty: topicMetadataV2}, ) metadataResponseV6 := metadataResponseV5 @@ -193,7 +229,7 @@ func createMetadataResponseSchemaVersions() []Schema { &Array{Name: brokersKeyName, Ty: metadataBrokerV1}, &Mfield{Name: "cluster_id", Ty: TypeNullableStr}, &Mfield{Name: "controller_id", Ty: TypeInt32}, - &Array{Name: "topic_metadata", Ty: topicMetadataV7}, + &Array{Name: topicsKeyName, Ty: topicMetadataV7}, ) metadataResponseV8 := NewSchema("metadata_response_v8", @@ -201,7 +237,7 @@ func createMetadataResponseSchemaVersions() []Schema { &Array{Name: brokersKeyName, Ty: metadataBrokerV1}, &Mfield{Name: "cluster_id", Ty: TypeNullableStr}, &Mfield{Name: "controller_id", Ty: TypeInt32}, - &Array{Name: "topic_metadata", Ty: topicMetadataV8}, + &Array{Name: topicsKeyName, Ty: topicMetadataV8}, &Mfield{Name: "cluster_authorized_operations", Ty: TypeInt32}, ) @@ -210,7 +246,7 @@ func createMetadataResponseSchemaVersions() []Schema { &CompactArray{Name: brokersKeyName, Ty: metadataBrokerSchema9}, &Mfield{Name: "cluster_id", Ty: TypeCompactNullableStr}, &Mfield{Name: "controller_id", Ty: TypeInt32}, - &CompactArray{Name: "topic_metadata", Ty: topicMetadataSchema9}, + &CompactArray{Name: topicsKeyName, Ty: topicMetadataSchema9}, &Mfield{Name: "cluster_authorized_operations", Ty: TypeInt32}, &SchemaTaggedFields{Name: "response_tagged_fields"}, ) @@ -220,7 +256,7 @@ func createMetadataResponseSchemaVersions() []Schema { &CompactArray{Name: brokersKeyName, Ty: metadataBrokerSchema9}, &Mfield{Name: "cluster_id", Ty: TypeCompactNullableStr}, &Mfield{Name: "controller_id", Ty: TypeInt32}, - &CompactArray{Name: "topic_metadata", Ty: topicMetadataSchema10}, + &CompactArray{Name: topicsKeyName, Ty: topicMetadataSchema10}, &Mfield{Name: "cluster_authorized_operations", Ty: TypeInt32}, &SchemaTaggedFields{Name: "response_tagged_fields"}, ) @@ -230,7 +266,7 @@ func createMetadataResponseSchemaVersions() []Schema { &CompactArray{Name: brokersKeyName, Ty: metadataBrokerSchema9}, &Mfield{Name: "cluster_id", Ty: TypeCompactNullableStr}, &Mfield{Name: "controller_id", Ty: TypeInt32}, - &CompactArray{Name: "topic_metadata", Ty: topicMetadataSchema10}, + &CompactArray{Name: topicsKeyName, Ty: topicMetadataSchema10}, &SchemaTaggedFields{Name: "response_tagged_fields"}, ) @@ -239,7 +275,7 @@ func createMetadataResponseSchemaVersions() []Schema { &CompactArray{Name: brokersKeyName, Ty: metadataBrokerSchema9}, &Mfield{Name: "cluster_id", Ty: TypeCompactNullableStr}, &Mfield{Name: "controller_id", Ty: TypeInt32}, - &CompactArray{Name: "topic_metadata", Ty: topicMetadataSchema12}, + &CompactArray{Name: topicsKeyName, Ty: topicMetadataSchema12}, &SchemaTaggedFields{Name: "response_tagged_fields"}, ) @@ -296,13 +332,30 @@ func createFindCoordinatorResponseSchemaVersions() []Schema { return []Schema{findCoordinatorResponseV0, findCoordinatorResponseV1, findCoordinatorResponseV2, findCoordinatorResponseV3, findCoordinatorResponseV4} } -func modifyMetadataResponse(decodedStruct *Struct, fn config.NetAddressMappingFunc) error { +func modifyMetadataResponse(decodedStruct *Struct, fn config.NetAddressMappingFunc, topicIDMap *TopicIDMap) error { if decodedStruct == nil { return errors.New("decoded struct must not be nil") } if fn == nil { return errors.New("net address mapper must not be nil") } + + topicsArray, ok := decodedStruct.Get(topicsKeyName).([]interface{}) + if !ok { + return errors.New("topics list not found") + } + for _, topicElement := range topicsArray { + topic := topicElement.(*Struct) + + id, ok := topic.Get(topicIDKeyName).(uuid.UUID) + if ok { + name, ok := topic.Get(topicNameKeyName).(*string) + if ok { + topicIDMap.Add(id.String(), *name) + } + } + } + brokersArray, ok := decodedStruct.Get(brokersKeyName).([]interface{}) if !ok { return errors.New("brokers list not found") @@ -342,7 +395,7 @@ func modifyMetadataResponse(decodedStruct *Struct, fn config.NetAddressMappingFu return nil } -func modifyFindCoordinatorResponse(decodedStruct *Struct, fn config.NetAddressMappingFunc) error { +func modifyFindCoordinatorResponse(decodedStruct *Struct, fn config.NetAddressMappingFunc, topicIDMap *TopicIDMap) error { if decodedStruct == nil { return errors.New("decoded struct must not be nil") } @@ -405,10 +458,10 @@ func modifyCoordinator(decodedStruct *Struct, fn config.NetAddressMappingFunc) e } type ResponseModifier interface { - Apply(resp []byte) ([]byte, error) + Apply(resp []byte, topicIDMap *TopicIDMap) ([]byte, error) } -type modifyResponseFunc func(decodedStruct *Struct, fn config.NetAddressMappingFunc) error +type modifyResponseFunc func(decodedStruct *Struct, fn config.NetAddressMappingFunc, topicIDMap *TopicIDMap) error type responseModifier struct { schema Schema @@ -416,12 +469,12 @@ type responseModifier struct { netAddressMappingFunc config.NetAddressMappingFunc } -func (f *responseModifier) Apply(resp []byte) ([]byte, error) { +func (f *responseModifier) Apply(resp []byte, topicIDMap *TopicIDMap) ([]byte, error) { decodedStruct, err := DecodeSchema(resp, f.schema) if err != nil { return nil, err } - err = f.modifyResponseFunc(decodedStruct, f.netAddressMappingFunc) + err = f.modifyResponseFunc(decodedStruct, f.netAddressMappingFunc, topicIDMap) if err != nil { return nil, err }