Skip to content

Commit

Permalink
Fixes => force handshake, working dshdev ACLs, add topic IDs
Browse files Browse the repository at this point in the history
  • Loading branch information
brunodebus committed Sep 30, 2024
1 parent 41e319a commit 85162ae
Show file tree
Hide file tree
Showing 5 changed files with 148 additions and 31 deletions.
7 changes: 7 additions & 0 deletions proxy/client.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package proxy

import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
Expand Down Expand Up @@ -310,6 +311,12 @@ func (c *Client) handleConn(conn Conn) {

tlsConn, ok := conn.LocalConnection.(*tls.Conn)
if ok {
err := tlsConn.HandshakeContext(context.TODO())
if err != nil {
logrus.Info(err)
return
}

if len(tlsConn.ConnectionState().PeerCertificates) > 0 {
commonName := tlsConn.ConnectionState().PeerCertificates[0].Subject.CommonName
id = &commonName
Expand Down
7 changes: 7 additions & 0 deletions proxy/processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ type processor struct {

clientID *string

topicIDMap *protocol.TopicIDMap

// producer will never send request with acks=0
producerAcks0Disabled bool
}
Expand Down Expand Up @@ -109,6 +111,7 @@ func newProcessor(cfg ProcessorConfig, brokerAddress string, id *string) *proces
writeTimeout: writeTimeout,
brokerAddress: brokerAddress,
clientID: id,
topicIDMap: protocol.NewTopicIDMap(),
localSasl: cfg.LocalSasl,
authServer: cfg.AuthServer,
forbiddenApiKeys: cfg.ForbiddenApiKeys,
Expand All @@ -132,6 +135,7 @@ func (p *processor) RequestsLoop(dst DeadlineWriter, src DeadlineReaderWriter) (
timeout: p.writeTimeout,
brokerAddress: p.brokerAddress,
clientID: p.clientID,
topicIDMap: p.topicIDMap,
forbiddenApiKeys: p.forbiddenApiKeys,
buf: make([]byte, p.requestBufferSize),
localSasl: p.localSasl,
Expand All @@ -150,6 +154,7 @@ type RequestsLoopContext struct {
timeout time.Duration
brokerAddress string
clientID *string
topicIDMap *protocol.TopicIDMap
forbiddenApiKeys map[int16]struct{}
buf []byte // bufSize

Expand Down Expand Up @@ -225,6 +230,7 @@ func (p *processor) ResponsesLoop(dst DeadlineWriter, src DeadlineReader) (readE
netAddressMappingFunc: p.netAddressMappingFunc,
timeout: p.readTimeout,
brokerAddress: p.brokerAddress,
topicIDMap: p.topicIDMap,
buf: make([]byte, p.responseBufferSize),
}
return ctx.responsesLoop(dst, src)
Expand All @@ -236,6 +242,7 @@ type ResponsesLoopContext struct {
netAddressMappingFunc config.NetAddressMappingFunc
timeout time.Duration
brokerAddress string
topicIDMap *protocol.TopicIDMap
buf []byte // bufSize
}

Expand Down
59 changes: 53 additions & 6 deletions proxy/processor_default.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"strconv"
"time"

"github.com/google/uuid"
"github.com/grepplabs/kafka-proxy/proxy/protocol"
"github.com/sirupsen/logrus"
"github.com/twmb/franz-go/pkg/kbin"
Expand Down Expand Up @@ -122,12 +123,29 @@ func (handler *DefaultRequestHandler) handleRequest(dst DeadlineWriter, src Dead

if ctx.clientID != nil {
for _, topic := range request.Topics {
if *ctx.clientID == "test1" {
if topic.Topic != "allowed" {
switch *ctx.clientID {
case "test1":
switch topic.Topic {
case "scratch.bdbtest1.dshdev":
default:
return true, errors.New(fmt.Sprintf("Client %s is not allowed to produce to %s", *ctx.clientID, topic.Topic))
}

case "test2":
switch topic.Topic {
case "scratch.bdbtest2.dshdev":
default:
return true, errors.New(fmt.Sprintf("Client %s is not allowed to produce to %s", *ctx.clientID, topic.Topic))
}
case "unit-test":
// TODO: make sure to remove this
default:
return true, errors.New(fmt.Sprintf("Client %s is not allowed to produce to %s", *ctx.clientID, topic.Topic))
}

}
} else {
return true, errors.New("Client not set")
}
case apiKeyFetch:
request := kmsg.NewPtrFetchRequest()
Expand Down Expand Up @@ -157,12 +175,41 @@ func (handler *DefaultRequestHandler) handleRequest(dst DeadlineWriter, src Dead

if ctx.clientID != nil {
for _, topic := range request.Topics {
if *ctx.clientID == "test2" {
if topic.Topic != "allowed" {
return true, errors.New(fmt.Sprintf("Client %s is not allowed to produce to %s", *ctx.clientID, topic.Topic))
var name string
if requestKeyVersion.ApiVersion >= 13 {
asUUID, err := uuid.FromBytes(topic.TopicID[:])
if err != nil {
return true, errors.New(fmt.Sprintf("Could not get topic ID as UUID: %v", topic.TopicID))
}

name = ctx.topicIDMap.Get(asUUID.String())
} else {
name = topic.Topic
}

switch *ctx.clientID {
case "test1":
switch name {
case "scratch.bdbtest1.dshdev":
default:
return true, errors.New(fmt.Sprintf("Client %s is not allowed to consume from %s", *ctx.clientID, name))
}

case "test2":
switch name {
case "scratch.bdbtest2.dshdev":
default:
return true, errors.New(fmt.Sprintf("Client %s is not allowed to consume from %s", *ctx.clientID, name))
}
case "unit-test":
// TODO: make sure to remove this
default:
return true, errors.New(fmt.Sprintf("Client %s is not allowed to consume from %s", *ctx.clientID, name))
}

}
} else {
return true, errors.New("Client not set")
}
}

Expand Down Expand Up @@ -313,7 +360,7 @@ func (handler *DefaultResponseHandler) handleResponse(dst DeadlineWriter, src De
if _, err = io.ReadFull(src, resp); err != nil {
return true, err
}
newResponseBuf, err := responseModifier.Apply(resp)
newResponseBuf, err := responseModifier.Apply(resp, ctx.topicIDMap)
if err != nil {
return true, err
}
Expand Down
7 changes: 5 additions & 2 deletions proxy/processor_default_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@ package proxy
import (
"bytes"
"encoding/hex"
"testing"
"time"

"github.com/grepplabs/kafka-proxy/proxy/protocol"
"github.com/pkg/errors"
"github.com/stretchr/testify/assert"
"testing"
"time"
)

func TestHandleRequest(t *testing.T) {
Expand Down Expand Up @@ -86,13 +87,15 @@ func TestHandleRequest(t *testing.T) {
nextRequestHandlerChannel := make(chan RequestHandler, 1)
nextResponseHandlerChannel := make(chan ResponseHandler, 1)

clientID := "unit-test"
ctx := &RequestsLoopContext{
openRequestsChannel: openRequestsChannel,
nextRequestHandlerChannel: nextRequestHandlerChannel,
nextResponseHandlerChannel: nextResponseHandlerChannel,
timeout: 1 * time.Second,
buf: buf,
localSasl: &LocalSasl{},
clientID: &clientID,
}

a := assert.New(t)
Expand Down
Loading

0 comments on commit 85162ae

Please sign in to comment.