Skip to content

Commit

Permalink
Add wss support for feed broadcast
Browse files Browse the repository at this point in the history
  • Loading branch information
joshuacolvin0 committed Sep 22, 2023
1 parent f273f2a commit 290a119
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 33 deletions.
39 changes: 16 additions & 23 deletions broadcastclient/broadcastclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ package broadcastclient

import (
"context"
"crypto/tls"
"encoding/json"
"errors"
"fmt"
Expand Down Expand Up @@ -63,16 +62,15 @@ var FeedConfigDefault = FeedConfig{
}

type Config struct {
ReconnectInitialBackoff time.Duration `koanf:"reconnect-initial-backoff" reload:"hot"`
ReconnectMaximumBackoff time.Duration `koanf:"reconnect-maximum-backoff" reload:"hot"`
RequireChainId bool `koanf:"require-chain-id" reload:"hot"`
RequireFeedVersion bool `koanf:"require-feed-version" reload:"hot"`
Timeout time.Duration `koanf:"timeout" reload:"hot"`
URL []string `koanf:"url"`
Verify signature.VerifierConfig `koanf:"verify"`
CertFile string `koanf:"cert-file"`
KeyFile string `koanf:"key-file"`
EnableCompression bool `koanf:"enable-compression" reload:"hot"`
ReconnectInitialBackoff time.Duration `koanf:"reconnect-initial-backoff" reload:"hot"`
ReconnectMaximumBackoff time.Duration `koanf:"reconnect-maximum-backoff" reload:"hot"`
RequireChainId bool `koanf:"require-chain-id" reload:"hot"`
RequireFeedVersion bool `koanf:"require-feed-version" reload:"hot"`
Timeout time.Duration `koanf:"timeout" reload:"hot"`
URL []string `koanf:"url"`
Verify signature.VerifierConfig `koanf:"verify"`
EnableCompression bool `koanf:"enable-compression" reload:"hot"`
Tls wsbroadcastserver.TlsConfig `koanf:"tls"`
}

func (c *Config) Enable() bool {
Expand All @@ -89,9 +87,8 @@ func ConfigAddOptions(prefix string, f *flag.FlagSet) {
f.Duration(prefix+".timeout", DefaultConfig.Timeout, "duration to wait before timing out connection to sequencer feed")
f.StringSlice(prefix+".url", DefaultConfig.URL, "URL of sequencer feed source")
signature.FeedVerifierConfigAddOptions(prefix+".verify", f)
f.String(prefix+".cert-file", DefaultConfig.CertFile, "X509 client public certificate file")
f.String(prefix+".key-file", DefaultConfig.KeyFile, "X509 client private key file")
f.Bool(prefix+".enable-compression", DefaultConfig.EnableCompression, "enable per message deflate compression support")
wsbroadcastserver.TlsConfigAddOptions(prefix, f, "client", false)
}

var DefaultConfig = Config{
Expand All @@ -103,6 +100,7 @@ var DefaultConfig = Config{
URL: []string{""},
Timeout: 20 * time.Second,
EnableCompression: true,
Tls: wsbroadcastserver.DefaultTlsConfig,
}

var DefaultTestConfig = Config{
Expand All @@ -114,6 +112,7 @@ var DefaultTestConfig = Config{
URL: []string{""},
Timeout: 200 * time.Millisecond,
EnableCompression: true,
Tls: wsbroadcastserver.DefaultTlsConfig,
}

type TransactionStreamerInterface interface {
Expand Down Expand Up @@ -236,15 +235,9 @@ func (bc *BroadcastClient) connect(ctx context.Context, nextSeqNum arbutil.Messa
if config.EnableCompression {
extensions = []httphead.Option{deflateExt}
}
tlsConfig := tls.Config{
MinVersion: tls.VersionTLS12,
}
if config.CertFile != "" && config.KeyFile != "" {
clientCert, err := tls.LoadX509KeyPair(config.CertFile, config.KeyFile)
if err != nil {
return nil, err
}
tlsConfig.Certificates = []tls.Certificate{clientCert}
tlsConfig, err := config.Tls.GetConfig()
if err != nil {
return nil, err
}
timeoutDialer := ws.Dialer{
Header: header,
Expand Down Expand Up @@ -287,7 +280,7 @@ func (bc *BroadcastClient) connect(ctx context.Context, nextSeqNum arbutil.Messa
return nil
},
Timeout: 10 * time.Second,
TLSConfig: &tlsConfig,
TLSConfig: tlsConfig,
Extensions: extensions,
}

Expand Down
110 changes: 100 additions & 10 deletions wsbroadcastserver/wsbroadcastserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@ package wsbroadcastserver

import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"net"
"net/http"
"net/textproto"
"os"
"strconv"
"strings"
"sync"
Expand Down Expand Up @@ -40,6 +43,75 @@ const (
LivenessProbeURI = "livenessprobe"
)

type TlsConfig struct {
CertFile string `koanf:"cert-file"`
KeyFile string `koanf:"key-file"`
RootCAs []string `koanf:"root-cas"`
ClientCAs []string `koanf:"client-cas"`
}

func (tc *TlsConfig) Validate() error {
if len(tc.CertFile) == 0 && len(tc.KeyFile) != 0 {
return errors.New("cert-file must be provided if key-file present")
}
if len(tc.CertFile) != 0 && len(tc.KeyFile) == 0 {
return errors.New("key-file must be provided if cert-file present")
}

return nil
}

func (tc *TlsConfig) GetConfig() (*tls.Config, error) {
config := &tls.Config{
MinVersion: tls.VersionTLS12,
}

if len(tc.CertFile) != 0 {
cert, err := tls.LoadX509KeyPair(tc.CertFile, tc.KeyFile)
if err != nil {
return nil, fmt.Errorf("unable to read key/cert pair %s/%s: %s", tc.KeyFile, tc.CertFile, err.Error())
}
config.Certificates = []tls.Certificate{cert}
}

if len(tc.RootCAs) != 0 {
rootCACerts := x509.NewCertPool()
for _, rootCACertFilename := range tc.RootCAs {
rootCACert, err := os.ReadFile(rootCACertFilename)
if err != nil {
return nil, fmt.Errorf("unable to read root certificate authority %s: %s", rootCACertFilename, err.Error())
}
rootCACerts.AppendCertsFromPEM(rootCACert)
}
config.RootCAs = rootCACerts
}

if len(tc.ClientCAs) != 0 {
clientCAs := x509.NewCertPool()
for _, clientCAFilename := range tc.ClientCAs {
clientCA, err := os.ReadFile(clientCAFilename)
if err != nil {
return nil, fmt.Errorf("unable to read client certificate authority %s: %s", clientCAFilename, err.Error())
}
clientCAs.AppendCertsFromPEM(clientCA)
}
config.ClientCAs = clientCAs
}

return config, nil
}

var DefaultTlsConfig = TlsConfig{}

func TlsConfigAddOptions(prefix string, f *flag.FlagSet, description string, includeClientCert bool) {
f.String(prefix+".cert-file", DefaultTlsConfig.CertFile, description+" certificate file")
f.String(prefix+".key-file", DefaultTlsConfig.KeyFile, description+" key file")
f.StringSlice(prefix+".root-cas", DefaultTlsConfig.RootCAs, "root certificate authority files")
if includeClientCert {
f.StringSlice(prefix+".client-cas", DefaultTlsConfig.ClientCAs, "client mTLS certificate authority files")
}
}

type BroadcasterConfig struct {
Enable bool `koanf:"enable"`
Signed bool `koanf:"signed"`
Expand All @@ -63,13 +135,14 @@ type BroadcasterConfig struct {
MaxCatchup int `koanf:"max-catchup" reload:"hot"`
ConnectionLimits ConnectionLimiterConfig `koanf:"connection-limits" reload:"hot"`
ClientDelay time.Duration `koanf:"client-delay" reload:"hot"`
Tls TlsConfig `koanf:"tls"`
}

func (bc *BroadcasterConfig) Validate() error {
if !bc.EnableCompression && bc.RequireCompression {
return errors.New("require-compression cannot be true while enable-compression is false")
}
return nil
return bc.Tls.Validate()
}

type BroadcasterConfigFetcher func() *BroadcasterConfig
Expand Down Expand Up @@ -97,6 +170,7 @@ func BroadcasterConfigAddOptions(prefix string, f *flag.FlagSet) {
f.Int(prefix+".max-catchup", DefaultBroadcasterConfig.MaxCatchup, "the maximum size of the catchup buffer (-1 means unlimited)")
ConnectionLimiterConfigAddOptions(prefix+".connection-limits", f)
f.Duration(prefix+".client-delay", DefaultBroadcasterConfig.ClientDelay, "delay the first messages sent to each client by this amount")
TlsConfigAddOptions(prefix+".tls", f, "server", true)
}

var DefaultBroadcasterConfig = BroadcasterConfig{
Expand All @@ -122,6 +196,7 @@ var DefaultBroadcasterConfig = BroadcasterConfig{
MaxCatchup: -1,
ConnectionLimits: DefaultConnectionLimiterConfig,
ClientDelay: 0,
Tls: DefaultTlsConfig,
}

var DefaultTestBroadcasterConfig = BroadcasterConfig{
Expand All @@ -147,6 +222,7 @@ var DefaultTestBroadcasterConfig = BroadcasterConfig{
MaxCatchup: -1,
ConnectionLimits: DefaultConnectionLimiterConfig,
ClientDelay: 0,
Tls: DefaultTlsConfig,
}

type WSBroadcastServer struct {
Expand Down Expand Up @@ -395,19 +471,33 @@ func (s *WSBroadcastServer) StartWithHeader(ctx context.Context, header ws.Hands

// Create tcp server for relay connections
config := s.config()
ln, err := net.Listen("tcp", config.Addr+":"+config.Port)
if err != nil {
log.Error("error calling net.Listen", "err", err)
return err
}
if len(config.Tls.CertFile) != 0 {
tlsConfig, err := config.Tls.GetConfig()
if err != nil {
return err
}
ln, err := tls.Listen("tcp", config.Addr+":"+config.Port, tlsConfig)
if err != nil {
log.Error("error calling net.Listen", "err", err)
return err
}

s.listener = ln
s.listener = ln
} else {
ln, err := net.Listen("tcp", config.Addr+":"+config.Port)
if err != nil {
log.Error("error calling net.Listen", "err", err)
return err
}

s.listener = ln
}

log.Info("arbitrum websocket broadcast server is listening", "address", ln.Addr().String())
log.Info("arbitrum websocket broadcast server is listening", "address", s.listener.Addr().String())

// Create netpoll descriptor for the listener.
// We use OneShot here to synchronously manage the rate that new connections are accepted
acceptDesc, err := netpoll.HandleListener(ln, netpoll.EventRead|netpoll.EventOneShot)
acceptDesc, err := netpoll.HandleListener(s.listener, netpoll.EventRead|netpoll.EventOneShot)
if err != nil {
log.Error("error calling HandleListener", "err", err)
return err
Expand All @@ -430,7 +520,7 @@ func (s *WSBroadcastServer) StartWithHeader(ctx context.Context, header ws.Hands
// cooldown the server and do not receive connection for some short
// time.
err := s.clientManager.pool.ScheduleTimeout(time.Millisecond, func() {
conn, err := ln.Accept()
conn, err := s.listener.Accept()
if err != nil {
acceptErrChan <- err
return
Expand Down

0 comments on commit 290a119

Please sign in to comment.