From f5248061e53cb31c08ad5891226815d28670821c Mon Sep 17 00:00:00 2001 From: Erik Geiser Date: Thu, 21 Mar 2024 11:23:18 +0100 Subject: [PATCH] Add DialContext and fix names --- pbtls.go => kbtls.go | 10 +++++++++- pbtls_test.go => kbtls_test.go | 16 ++++++++-------- 2 files changed, 17 insertions(+), 9 deletions(-) rename pbtls.go => kbtls.go (95%) rename pbtls_test.go => kbtls_test.go (95%) diff --git a/pbtls.go b/kbtls.go similarity index 95% rename from pbtls.go rename to kbtls.go index 8b9da5b..24710eb 100644 --- a/pbtls.go +++ b/kbtls.go @@ -2,6 +2,7 @@ package kbtls import ( + "context" "crypto" "crypto/ed25519" "crypto/rand" @@ -252,6 +253,11 @@ func ClientTLSConfigForClientName(key ConnectionKey, clientName string) (*tls.Co // Dial works like tls.Dial with a TLS config based on the provided connection key. func Dial(network string, address string, connectionKey string) (net.Conn, error) { + return DialContext(context.Background(), network, address, connectionKey) +} + +// DialContext works like tls.Dial with a TLS config based on the provided connection key and a context. +func DialContext(ctx context.Context, network string, address string, connectionKey string) (net.Conn, error) { key, err := ParseConnectionKey(connectionKey) if err != nil { return nil, fmt.Errorf("parse connection key: %w", err) @@ -262,7 +268,9 @@ func Dial(network string, address string, connectionKey string) (net.Conn, error return nil, fmt.Errorf("generate client TLS config: %w", err) } - return tls.Dial(network, address, tlsConfig) + dialer := tls.Dialer{Config: tlsConfig} + + return dialer.DialContext(ctx, network, address) } // Listen works like tls.Listen with a TLS config based on the provided connection key. diff --git a/pbtls_test.go b/kbtls_test.go similarity index 95% rename from pbtls_test.go rename to kbtls_test.go index fe1fe54..a3b2340 100644 --- a/pbtls_test.go +++ b/kbtls_test.go @@ -116,8 +116,8 @@ func TestClientServerConfigCompatibility(t *testing.T) { wait := make(chan struct{}) pipeA, pipeB := net.Pipe() - defer pipeA.Close() //nolint:errcheck,gosec - defer pipeB.Close() //nolint:errcheck,gosec + defer pipeA.Close() //nolint:errcheck + defer pipeB.Close() //nolint:errcheck defer func() { <-wait }() @@ -178,8 +178,8 @@ func TestClientServerErrorIfKeyAndNameDiffers(t *testing.T) { wait := make(chan struct{}) pipeA, pipeB := net.Pipe() - defer pipeA.Close() //nolint:errcheck,gosec - defer pipeB.Close() //nolint:errcheck,gosec + defer pipeA.Close() //nolint:errcheck + defer pipeB.Close() //nolint:errcheck defer func() { <-wait }() @@ -236,8 +236,8 @@ func TestClientServerErrorIfKeyDiffers(t *testing.T) { wait := make(chan struct{}) pipeA, pipeB := net.Pipe() - defer pipeA.Close() //nolint:errcheck,gosec - defer pipeB.Close() //nolint:errcheck,gosec + defer pipeA.Close() //nolint:errcheck + defer pipeB.Close() //nolint:errcheck defer func() { <-wait }() @@ -308,7 +308,7 @@ func TestDialListen(t *testing.T) { return } - defer conn.Close() //nolint:errcheck,gosec + defer conn.Close() //nolint:errcheck _, err = conn.Write(testData) if err != nil { @@ -321,7 +321,7 @@ func TestDialListen(t *testing.T) { t.Fatalf("accept: %v", err) } - defer conn.Close() //nolint:errcheck,gosec + defer conn.Close() //nolint:errcheck receivedData := make([]byte, len(testData))