diff --git a/broadcastclient/broadcastclient.go b/broadcastclient/broadcastclient.go index 7519b71477..8d88114fab 100644 --- a/broadcastclient/broadcastclient.go +++ b/broadcastclient/broadcastclient.go @@ -5,7 +5,6 @@ package broadcastclient import ( "context" - "crypto/tls" "encoding/json" "errors" "fmt" @@ -63,16 +62,15 @@ var FeedConfigDefault = FeedConfig{ } type Config struct { - ReconnectInitialBackoff time.Duration `koanf:"reconnect-initial-backoff" reload:"hot"` - ReconnectMaximumBackoff time.Duration `koanf:"reconnect-maximum-backoff" reload:"hot"` - RequireChainId bool `koanf:"require-chain-id" reload:"hot"` - RequireFeedVersion bool `koanf:"require-feed-version" reload:"hot"` - Timeout time.Duration `koanf:"timeout" reload:"hot"` - URL []string `koanf:"url"` - Verify signature.VerifierConfig `koanf:"verify"` - CertFile string `koanf:"cert-file"` - KeyFile string `koanf:"key-file"` - EnableCompression bool `koanf:"enable-compression" reload:"hot"` + ReconnectInitialBackoff time.Duration `koanf:"reconnect-initial-backoff" reload:"hot"` + ReconnectMaximumBackoff time.Duration `koanf:"reconnect-maximum-backoff" reload:"hot"` + RequireChainId bool `koanf:"require-chain-id" reload:"hot"` + RequireFeedVersion bool `koanf:"require-feed-version" reload:"hot"` + Timeout time.Duration `koanf:"timeout" reload:"hot"` + URL []string `koanf:"url"` + Verify signature.VerifierConfig `koanf:"verify"` + EnableCompression bool `koanf:"enable-compression" reload:"hot"` + Tls wsbroadcastserver.TlsConfig `koanf:"tls"` } func (c *Config) Enable() bool { @@ -89,9 +87,8 @@ func ConfigAddOptions(prefix string, f *flag.FlagSet) { f.Duration(prefix+".timeout", DefaultConfig.Timeout, "duration to wait before timing out connection to sequencer feed") f.StringSlice(prefix+".url", DefaultConfig.URL, "URL of sequencer feed source") signature.FeedVerifierConfigAddOptions(prefix+".verify", f) - f.String(prefix+".cert-file", DefaultConfig.CertFile, "X509 client public certificate file") - f.String(prefix+".key-file", DefaultConfig.KeyFile, "X509 client private key file") f.Bool(prefix+".enable-compression", DefaultConfig.EnableCompression, "enable per message deflate compression support") + wsbroadcastserver.TlsConfigAddOptions(prefix, f, "client", false) } var DefaultConfig = Config{ @@ -103,6 +100,7 @@ var DefaultConfig = Config{ URL: []string{""}, Timeout: 20 * time.Second, EnableCompression: true, + Tls: wsbroadcastserver.DefaultTlsConfig, } var DefaultTestConfig = Config{ @@ -114,6 +112,7 @@ var DefaultTestConfig = Config{ URL: []string{""}, Timeout: 200 * time.Millisecond, EnableCompression: true, + Tls: wsbroadcastserver.DefaultTlsConfig, } type TransactionStreamerInterface interface { @@ -236,15 +235,9 @@ func (bc *BroadcastClient) connect(ctx context.Context, nextSeqNum arbutil.Messa if config.EnableCompression { extensions = []httphead.Option{deflateExt} } - tlsConfig := tls.Config{ - MinVersion: tls.VersionTLS12, - } - if config.CertFile != "" && config.KeyFile != "" { - clientCert, err := tls.LoadX509KeyPair(config.CertFile, config.KeyFile) - if err != nil { - return nil, err - } - tlsConfig.Certificates = []tls.Certificate{clientCert} + tlsConfig, err := config.Tls.GetConfig() + if err != nil { + return nil, err } timeoutDialer := ws.Dialer{ Header: header, @@ -287,7 +280,7 @@ func (bc *BroadcastClient) connect(ctx context.Context, nextSeqNum arbutil.Messa return nil }, Timeout: 10 * time.Second, - TLSConfig: &tlsConfig, + TLSConfig: tlsConfig, Extensions: extensions, } diff --git a/wsbroadcastserver/wsbroadcastserver.go b/wsbroadcastserver/wsbroadcastserver.go index cd277387a0..3bd146eb1d 100644 --- a/wsbroadcastserver/wsbroadcastserver.go +++ b/wsbroadcastserver/wsbroadcastserver.go @@ -5,11 +5,14 @@ package wsbroadcastserver import ( "context" + "crypto/tls" + "crypto/x509" "errors" "fmt" "net" "net/http" "net/textproto" + "os" "strconv" "strings" "sync" @@ -40,6 +43,75 @@ const ( LivenessProbeURI = "livenessprobe" ) +type TlsConfig struct { + CertFile string `koanf:"cert-file"` + KeyFile string `koanf:"key-file"` + RootCAs []string `koanf:"root-cas"` + ClientCAs []string `koanf:"client-cas"` +} + +func (tc *TlsConfig) Validate() error { + if len(tc.CertFile) == 0 && len(tc.KeyFile) != 0 { + return errors.New("cert-file must be provided if key-file present") + } + if len(tc.CertFile) != 0 && len(tc.KeyFile) == 0 { + return errors.New("key-file must be provided if cert-file present") + } + + return nil +} + +func (tc *TlsConfig) GetConfig() (*tls.Config, error) { + config := &tls.Config{ + MinVersion: tls.VersionTLS12, + } + + if len(tc.CertFile) != 0 { + cert, err := tls.LoadX509KeyPair(tc.CertFile, tc.KeyFile) + if err != nil { + return nil, fmt.Errorf("unable to read key/cert pair %s/%s: %s", tc.KeyFile, tc.CertFile, err.Error()) + } + config.Certificates = []tls.Certificate{cert} + } + + if len(tc.RootCAs) != 0 { + rootCACerts := x509.NewCertPool() + for _, rootCACertFilename := range tc.RootCAs { + rootCACert, err := os.ReadFile(rootCACertFilename) + if err != nil { + return nil, fmt.Errorf("unable to read root certificate authority %s: %s", rootCACertFilename, err.Error()) + } + rootCACerts.AppendCertsFromPEM(rootCACert) + } + config.RootCAs = rootCACerts + } + + if len(tc.ClientCAs) != 0 { + clientCAs := x509.NewCertPool() + for _, clientCAFilename := range tc.ClientCAs { + clientCA, err := os.ReadFile(clientCAFilename) + if err != nil { + return nil, fmt.Errorf("unable to read client certificate authority %s: %s", clientCAFilename, err.Error()) + } + clientCAs.AppendCertsFromPEM(clientCA) + } + config.ClientCAs = clientCAs + } + + return config, nil +} + +var DefaultTlsConfig = TlsConfig{} + +func TlsConfigAddOptions(prefix string, f *flag.FlagSet, description string, includeClientCert bool) { + f.String(prefix+".cert-file", DefaultTlsConfig.CertFile, description+" certificate file") + f.String(prefix+".key-file", DefaultTlsConfig.KeyFile, description+" key file") + f.StringSlice(prefix+".root-cas", DefaultTlsConfig.RootCAs, "root certificate authority files") + if includeClientCert { + f.StringSlice(prefix+".client-cas", DefaultTlsConfig.ClientCAs, "client mTLS certificate authority files") + } +} + type BroadcasterConfig struct { Enable bool `koanf:"enable"` Signed bool `koanf:"signed"` @@ -63,13 +135,14 @@ type BroadcasterConfig struct { MaxCatchup int `koanf:"max-catchup" reload:"hot"` ConnectionLimits ConnectionLimiterConfig `koanf:"connection-limits" reload:"hot"` ClientDelay time.Duration `koanf:"client-delay" reload:"hot"` + Tls TlsConfig `koanf:"tls"` } func (bc *BroadcasterConfig) Validate() error { if !bc.EnableCompression && bc.RequireCompression { return errors.New("require-compression cannot be true while enable-compression is false") } - return nil + return bc.Tls.Validate() } type BroadcasterConfigFetcher func() *BroadcasterConfig @@ -97,6 +170,7 @@ func BroadcasterConfigAddOptions(prefix string, f *flag.FlagSet) { f.Int(prefix+".max-catchup", DefaultBroadcasterConfig.MaxCatchup, "the maximum size of the catchup buffer (-1 means unlimited)") ConnectionLimiterConfigAddOptions(prefix+".connection-limits", f) f.Duration(prefix+".client-delay", DefaultBroadcasterConfig.ClientDelay, "delay the first messages sent to each client by this amount") + TlsConfigAddOptions(prefix+".tls", f, "server", true) } var DefaultBroadcasterConfig = BroadcasterConfig{ @@ -122,6 +196,7 @@ var DefaultBroadcasterConfig = BroadcasterConfig{ MaxCatchup: -1, ConnectionLimits: DefaultConnectionLimiterConfig, ClientDelay: 0, + Tls: DefaultTlsConfig, } var DefaultTestBroadcasterConfig = BroadcasterConfig{ @@ -147,6 +222,7 @@ var DefaultTestBroadcasterConfig = BroadcasterConfig{ MaxCatchup: -1, ConnectionLimits: DefaultConnectionLimiterConfig, ClientDelay: 0, + Tls: DefaultTlsConfig, } type WSBroadcastServer struct { @@ -395,19 +471,33 @@ func (s *WSBroadcastServer) StartWithHeader(ctx context.Context, header ws.Hands // Create tcp server for relay connections config := s.config() - ln, err := net.Listen("tcp", config.Addr+":"+config.Port) - if err != nil { - log.Error("error calling net.Listen", "err", err) - return err - } + if len(config.Tls.CertFile) != 0 { + tlsConfig, err := config.Tls.GetConfig() + if err != nil { + return err + } + ln, err := tls.Listen("tcp", config.Addr+":"+config.Port, tlsConfig) + if err != nil { + log.Error("error calling net.Listen", "err", err) + return err + } - s.listener = ln + s.listener = ln + } else { + ln, err := net.Listen("tcp", config.Addr+":"+config.Port) + if err != nil { + log.Error("error calling net.Listen", "err", err) + return err + } + + s.listener = ln + } - log.Info("arbitrum websocket broadcast server is listening", "address", ln.Addr().String()) + log.Info("arbitrum websocket broadcast server is listening", "address", s.listener.Addr().String()) // Create netpoll descriptor for the listener. // We use OneShot here to synchronously manage the rate that new connections are accepted - acceptDesc, err := netpoll.HandleListener(ln, netpoll.EventRead|netpoll.EventOneShot) + acceptDesc, err := netpoll.HandleListener(s.listener, netpoll.EventRead|netpoll.EventOneShot) if err != nil { log.Error("error calling HandleListener", "err", err) return err @@ -430,7 +520,7 @@ func (s *WSBroadcastServer) StartWithHeader(ctx context.Context, header ws.Hands // cooldown the server and do not receive connection for some short // time. err := s.clientManager.pool.ScheduleTimeout(time.Millisecond, func() { - conn, err := ln.Accept() + conn, err := s.listener.Accept() if err != nil { acceptErrChan <- err return