Skip to content

Commit

Permalink
refactor: refactor load tls and dtls using generics
Browse files Browse the repository at this point in the history
Signed-off-by: 1998-felix <[email protected]>
  • Loading branch information
felixgateru committed Apr 29, 2024
1 parent e2586a5 commit 7a11601
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 79 deletions.
6 changes: 2 additions & 4 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,11 @@ func NewConfig(opts env.Options) (Config, error) {
if err != nil {
return Config{}, err
}

c.TLSConfig, err = mptls.LoadTLS(&cfg)
c.TLSConfig, err = mptls.LoadSecConfig(&cfg, &tls.Config{})
if err != nil {
return Config{}, err
}

c.DTLSConfig, err = mptls.LoadDTLS(&cfg)
c.DTLSConfig, err = mptls.LoadSecConfig(&cfg, &dtls.Config{})
if err != nil {
return Config{}, err
}
Expand Down
134 changes: 59 additions & 75 deletions pkg/tls/tls.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,111 +14,95 @@ import (
)

var (
errTLSdetails = errors.New("failed to get TLS details of connection")
errLoadCerts = errors.New("failed to load certificates")
errLoadServerCA = errors.New("failed to load Server CA")
errLoadClientCA = errors.New("failed to load Client CA")
errAppendCA = errors.New("failed to append root ca tls.Config")
errTLSdetails = errors.New("failed to get TLS details of connection")
errLoadCerts = errors.New("failed to load certificates")
errLoadServerCA = errors.New("failed to load Server CA")
errLoadClientCA = errors.New("failed to load Client CA")
errAppendCA = errors.New("failed to append root ca tls.Config")
errUnsupportedSec = errors.New("unsupported security configuration")
)

// LoadTLS returns a TLS configuration that can be used in TLS servers.
func LoadTLS(c *Config) (*tls.Config, error) {
type SecConfig interface {
*tls.Config | *dtls.Config
}

// LoadSecConfig returns a TLS or DTLS configuration that can be used for TLS or DTLS servers.
func LoadSecConfig[sc SecConfig](c *Config, s sc) (sc, error) {
if c.CertFile == "" || c.KeyFile == "" {
return nil, nil
}

tlsConfig := &tls.Config{}

certificate, err := tls.LoadX509KeyPair(c.CertFile, c.KeyFile)
if err != nil {
return nil, errors.Join(errLoadCerts, err)
}
tlsConfig = &tls.Config{
Certificates: []tls.Certificate{certificate},
}

// Loading Server CA file
rootCA, err := loadCertFile(c.ServerCAFile)
if err != nil {
return nil, errors.Join(errLoadServerCA, err)
}
if len(rootCA) > 0 {
if tlsConfig.RootCAs == nil {
tlsConfig.RootCAs = x509.NewCertPool()
}
if !tlsConfig.RootCAs.AppendCertsFromPEM(rootCA) {
return nil, errAppendCA
}
}

// Loading Client CA File
clientCA, err := loadCertFile(c.ClientCAFile)
if err != nil {
return nil, errors.Join(errLoadClientCA, err)
}
if len(clientCA) > 0 {
if tlsConfig.ClientCAs == nil {
tlsConfig.ClientCAs = x509.NewCertPool()
}
if !tlsConfig.ClientCAs.AppendCertsFromPEM(clientCA) {
return nil, errAppendCA
}
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
if c.Validator != nil {
tlsConfig.VerifyPeerCertificate = c.Validator
}
}
return tlsConfig, nil
}

// LoadDTLS returns a DTLS configuration that can be used in DTLS servers.
func LoadDTLS(c *Config) (*dtls.Config, error) {
if c.CertFile == "" || c.KeyFile == "" {
return nil, nil
}

dtlsConfig := &dtls.Config{}

certificate, err := tls.LoadX509KeyPair(c.CertFile, c.KeyFile)
if err != nil {
return nil, errors.Join(errLoadCerts, err)
}
dtlsConfig = &dtls.Config{
Certificates: []tls.Certificate{certificate},
}
switch config := any(s).(type) {
case *tls.Config:
config.Certificates = []tls.Certificate{certificate}

// Loading Server CA file
rootCA, err := loadCertFile(c.ServerCAFile)
if err != nil {
return nil, errors.Join(errLoadServerCA, err)
}
if len(rootCA) > 0 {
if dtlsConfig.RootCAs == nil {
dtlsConfig.RootCAs = x509.NewCertPool()
if len(rootCA) > 0 {
if config.RootCAs == nil {
config.RootCAs = x509.NewCertPool()
}
if !config.RootCAs.AppendCertsFromPEM(rootCA) {
return nil, errAppendCA
}
}
if !dtlsConfig.RootCAs.AppendCertsFromPEM(rootCA) {
return nil, errAppendCA
}
}

// Loading Client CA File
clientCA, err := loadCertFile(c.ClientCAFile)
if err != nil {
return nil, errors.Join(errLoadClientCA, err)
}
if len(clientCA) > 0 {
if dtlsConfig.ClientCAs == nil {
dtlsConfig.ClientCAs = x509.NewCertPool()
if len(clientCA) > 0 {
if config.ClientCAs == nil {
config.ClientCAs = x509.NewCertPool()
}
if !config.ClientCAs.AppendCertsFromPEM(clientCA) {
return nil, errAppendCA
}
config.ClientAuth = tls.RequireAndVerifyClientCert
if c.Validator != nil {
config.VerifyPeerCertificate = c.Validator
}
}
if !dtlsConfig.ClientCAs.AppendCertsFromPEM(clientCA) {
return nil, errAppendCA
return s, nil
case *dtls.Config:
config.Certificates = []tls.Certificate{certificate}

if len(rootCA) > 0 {
if config.RootCAs == nil {
config.RootCAs = x509.NewCertPool()
}
if !config.RootCAs.AppendCertsFromPEM(rootCA) {
return nil, errAppendCA
}
}
dtlsConfig.ClientAuth = dtls.RequireAndVerifyClientCert
if c.Validator != nil {
dtlsConfig.VerifyPeerCertificate = c.Validator

if len(clientCA) > 0 {
if config.ClientCAs == nil {
config.ClientCAs = x509.NewCertPool()
}
if !config.ClientCAs.AppendCertsFromPEM(clientCA) {
return nil, errAppendCA
}
config.ClientAuth = dtls.RequireAndVerifyClientCert
if c.Validator != nil {
config.VerifyPeerCertificate = c.Validator
}
}
return s, nil
default:
return nil, errUnsupportedSec
}
return dtlsConfig, nil
}

// ClientCert returns client certificate.
Expand Down

0 comments on commit 7a11601

Please sign in to comment.