Skip to content

Commit

Permalink
Add DialContext and fix names
Browse files Browse the repository at this point in the history
  • Loading branch information
rtpt-erikgeiser committed Mar 21, 2024
1 parent dbd65a8 commit f524806
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 9 deletions.
10 changes: 9 additions & 1 deletion pbtls.go → kbtls.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
package kbtls

import (
"context"
"crypto"
"crypto/ed25519"
"crypto/rand"
Expand Down Expand Up @@ -252,6 +253,11 @@ func ClientTLSConfigForClientName(key ConnectionKey, clientName string) (*tls.Co

// Dial works like tls.Dial with a TLS config based on the provided connection key.
func Dial(network string, address string, connectionKey string) (net.Conn, error) {
return DialContext(context.Background(), network, address, connectionKey)
}

// DialContext works like tls.Dial with a TLS config based on the provided connection key and a context.
func DialContext(ctx context.Context, network string, address string, connectionKey string) (net.Conn, error) {
key, err := ParseConnectionKey(connectionKey)
if err != nil {
return nil, fmt.Errorf("parse connection key: %w", err)
Expand All @@ -262,7 +268,9 @@ func Dial(network string, address string, connectionKey string) (net.Conn, error
return nil, fmt.Errorf("generate client TLS config: %w", err)
}

return tls.Dial(network, address, tlsConfig)
dialer := tls.Dialer{Config: tlsConfig}

return dialer.DialContext(ctx, network, address)
}

// Listen works like tls.Listen with a TLS config based on the provided connection key.
Expand Down
16 changes: 8 additions & 8 deletions pbtls_test.go → kbtls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,8 @@ func TestClientServerConfigCompatibility(t *testing.T) {
wait := make(chan struct{})

pipeA, pipeB := net.Pipe()
defer pipeA.Close() //nolint:errcheck,gosec
defer pipeB.Close() //nolint:errcheck,gosec
defer pipeA.Close() //nolint:errcheck
defer pipeB.Close() //nolint:errcheck
defer func() {
<-wait
}()
Expand Down Expand Up @@ -178,8 +178,8 @@ func TestClientServerErrorIfKeyAndNameDiffers(t *testing.T) {
wait := make(chan struct{})

pipeA, pipeB := net.Pipe()
defer pipeA.Close() //nolint:errcheck,gosec
defer pipeB.Close() //nolint:errcheck,gosec
defer pipeA.Close() //nolint:errcheck
defer pipeB.Close() //nolint:errcheck
defer func() {
<-wait
}()
Expand Down Expand Up @@ -236,8 +236,8 @@ func TestClientServerErrorIfKeyDiffers(t *testing.T) {
wait := make(chan struct{})

pipeA, pipeB := net.Pipe()
defer pipeA.Close() //nolint:errcheck,gosec
defer pipeB.Close() //nolint:errcheck,gosec
defer pipeA.Close() //nolint:errcheck
defer pipeB.Close() //nolint:errcheck
defer func() {
<-wait
}()
Expand Down Expand Up @@ -308,7 +308,7 @@ func TestDialListen(t *testing.T) {
return
}

defer conn.Close() //nolint:errcheck,gosec
defer conn.Close() //nolint:errcheck

_, err = conn.Write(testData)
if err != nil {
Expand All @@ -321,7 +321,7 @@ func TestDialListen(t *testing.T) {
t.Fatalf("accept: %v", err)
}

defer conn.Close() //nolint:errcheck,gosec
defer conn.Close() //nolint:errcheck

receivedData := make([]byte, len(testData))

Expand Down

0 comments on commit f524806

Please sign in to comment.