Skip to content

Commit

Permalink
Add ability to run Postgres proxy on separate listener (#8323)
Browse files Browse the repository at this point in the history
  • Loading branch information
smallinsky authored Dec 10, 2021
1 parent c3dee23 commit d24ae5b
Show file tree
Hide file tree
Showing 16 changed files with 200 additions and 54 deletions.
2 changes: 2 additions & 0 deletions api/client/webclient/webclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,8 @@ type SSHProxySettings struct {

// DBProxySettings contains database access specific proxy settings.
type DBProxySettings struct {
// PostgresListenAddr is Postgres proxy listen address.
PostgresListenAddr string `json:"postgres_listen_addr,omitempty"`
// PostgresPublicAddr is advertised to Postgres clients.
PostgresPublicAddr string `json:"postgres_public_addr,omitempty"`
// MySQLListenAddr is MySQL proxy listen address.
Expand Down
34 changes: 34 additions & 0 deletions integration/db_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,40 @@ func TestDatabaseAccessUnspecifiedHostname(t *testing.T) {
require.NoError(t, err)
}

// TestDatabaseAccessPostgresSeparateListener tests postgres proxy listener running on separate port.
func TestDatabaseAccessPostgresSeparateListener(t *testing.T) {
pack := setupDatabaseTest(t,
withPortSetupDatabaseTest(separatePostgresPortSetup),
)

// Connect to the database service in root cluster.
client, err := postgres.MakeTestClient(context.Background(), common.TestClientConfig{
AuthClient: pack.root.cluster.GetSiteAPI(pack.root.cluster.Secrets.SiteName),
AuthServer: pack.root.cluster.Process.GetAuthServer(),
Address: net.JoinHostPort(Loopback, pack.root.cluster.GetPortPostgres()),
Cluster: pack.root.cluster.Secrets.SiteName,
Username: pack.root.user.GetName(),
RouteToDatabase: tlsca.RouteToDatabase{
ServiceName: pack.root.postgresService.Name,
Protocol: pack.root.postgresService.Protocol,
Username: "postgres",
Database: "test",
},
})
require.NoError(t, err)

// Execute a query.
result, err := client.Exec(context.Background(), "select 1").ReadAll()
require.NoError(t, err)
require.Equal(t, []*pgconn.Result{postgres.TestQueryResponse}, result)
require.Equal(t, uint32(1), pack.root.postgres.QueryCount())
require.Equal(t, uint32(0), pack.leaf.postgres.QueryCount())

// Disconnect.
err = client.Close(context.Background())
require.NoError(t, err)
}

func waitForAuditEventTypeWithBackoff(t *testing.T, cli *auth.Server, startTime time.Time, eventType string) []apievents.AuditEvent {
max := time.Second
timeout := time.After(max)
Expand Down
4 changes: 4 additions & 0 deletions integration/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -595,6 +595,10 @@ func (i *TeleInstance) GenerateConfig(t *testing.T, trustedSecrets []*InstanceSe
tconf.Proxy.SSHAddr.Addr = net.JoinHostPort(i.Hostname, i.GetPortProxy())
tconf.Proxy.WebAddr.Addr = net.JoinHostPort(i.Hostname, i.GetPortWeb())
tconf.Proxy.MySQLAddr.Addr = net.JoinHostPort(i.Hostname, i.GetPortMySQL())
if i.Postgres != nil {
// Postgres proxy port was configured on a separate listener.
tconf.Proxy.PostgresAddr.Addr = net.JoinHostPort(i.Hostname, i.GetPortPostgres())
}
}
tconf.AuthServers = append(tconf.AuthServers, tconf.Auth.SSHAddr)
tconf.Auth.StorageConfig = backend.Config{
Expand Down
14 changes: 14 additions & 0 deletions integration/ports.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,18 @@ func webReverseTunnelMuxPortSetup() *InstancePorts {
}
}

func separatePostgresPortSetup() *InstancePorts {
return &InstancePorts{
Web: newInstancePort(),
SSH: newInstancePort(),
Auth: newInstancePort(),
SSHProxy: newInstancePort(),
ReverseTunnel: newInstancePort(),
MySQL: newInstancePort(),
Postgres: newInstancePort(),
}
}

type InstancePorts struct {
Host string
Web *InstancePort
Expand All @@ -97,6 +109,7 @@ type InstancePorts struct {
Auth *InstancePort
ReverseTunnel *InstancePort
MySQL *InstancePort
Postgres *InstancePort

isSinglePortSetup bool
}
Expand All @@ -107,6 +120,7 @@ func (i *InstancePorts) GetPortAuth() string { return i.Auth.String() }
func (i *InstancePorts) GetPortProxy() string { return i.SSHProxy.String() }
func (i *InstancePorts) GetPortWeb() string { return i.Web.String() }
func (i *InstancePorts) GetPortMySQL() string { return i.MySQL.String() }
func (i *InstancePorts) GetPortPostgres() string { return i.Postgres.String() }
func (i *InstancePorts) GetPortReverseTunnel() string { return i.ReverseTunnel.String() }

func (i *InstancePorts) GetSSHAddr() string {
Expand Down
7 changes: 7 additions & 0 deletions lib/client/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -2651,6 +2651,13 @@ func (tc *TeleportClient) applyProxySettings(proxySettings webclient.ProxySettin
proxySettings.DB.PostgresPublicAddr)
}
tc.PostgresProxyAddr = net.JoinHostPort(addr.Host(), strconv.Itoa(addr.Port(tc.WebProxyPort())))
case proxySettings.DB.PostgresListenAddr != "":
addr, err := utils.ParseAddr(proxySettings.DB.PostgresListenAddr)
if err != nil {
return trace.BadParameter("failed to parse Postgres listen address received from server: %q, contact your administrator for help",
proxySettings.DB.PostgresListenAddr)
}
tc.PostgresProxyAddr = net.JoinHostPort(tc.WebProxyHost(), strconv.Itoa(addr.Port(defaults.PostgresListenPort)))
default:
webProxyHost, webProxyPort := tc.WebProxyHostPort()
tc.PostgresProxyAddr = net.JoinHostPort(webProxyHost, strconv.Itoa(webProxyPort))
Expand Down
37 changes: 27 additions & 10 deletions lib/config/configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,13 @@ func applyProxyConfig(fc *FileConfig, cfg *service.Config) error {
}
cfg.Proxy.MySQLAddr = *addr
}
if fc.Proxy.PostgresAddr != "" {
addr, err := utils.ParseHostPortAddr(fc.Proxy.PostgresAddr, int(defaults.PostgresListenPort))
if err != nil {
return trace.Wrap(err)
}
cfg.Proxy.PostgresAddr = *addr
}

// This is the legacy format. Continue to support it forever, but ideally
// users now use the list format below.
Expand Down Expand Up @@ -788,22 +795,14 @@ func applyProxyConfig(fc *FileConfig, cfg *service.Config) error {
cfg.Proxy.TunnelPublicAddrs = addrs
}
if len(fc.Proxy.PostgresPublicAddr) != 0 {
// Postgres proxy is multiplexed on the web proxy port. If the port is
// not specified here explicitly, prefer defaults in the following
// order, depending on what's set:
// 1. Web proxy public port
// 2. Web proxy listen port
// 3. Web proxy default listen port
defaultPort := cfg.Proxy.WebAddr.Port(defaults.HTTPListenPort)
if len(cfg.Proxy.PublicAddrs) != 0 {
defaultPort = cfg.Proxy.PublicAddrs[0].Port(defaults.HTTPListenPort)
}
defaultPort := getPostgresDefaultPort(cfg)
addrs, err := utils.AddrsFromStrings(fc.Proxy.PostgresPublicAddr, defaultPort)
if err != nil {
return trace.Wrap(err)
}
cfg.Proxy.PostgresPublicAddrs = addrs
}

if len(fc.Proxy.MySQLPublicAddr) != 0 {
if fc.Proxy.MySQLAddr == "" {
return trace.BadParameter("mysql_listen_addr must be set when mysql_public_addr is set")
Expand All @@ -827,6 +826,24 @@ func applyProxyConfig(fc *FileConfig, cfg *service.Config) error {
return nil
}

func getPostgresDefaultPort(cfg *service.Config) int {
if !cfg.Proxy.PostgresAddr.IsEmpty() {
// If the proxy.PostgresAddr flag was provided return port
// from PostgresAddr address or default PostgresListenPort.
return cfg.Proxy.PostgresAddr.Port(defaults.PostgresListenPort)
}
// Postgres proxy is multiplexed on the web proxy port. If the proxy is
// not specified here explicitly, prefer defaults in the following
// order, depending on what's set:
// 1. Web proxy public port
// 2. Web proxy listen port
// 3. Web proxy default listen port
if len(cfg.Proxy.PublicAddrs) != 0 {
return cfg.Proxy.PublicAddrs[0].Port(defaults.HTTPListenPort)
}
return cfg.Proxy.WebAddr.Port(defaults.HTTPListenPort)
}

func applyDefaultProxyListenerAddresses(cfg *service.Config) {
if cfg.Version == defaults.TeleportConfigVersionV2 {
// For v2 configuration if an address is not provided don't fallback to the default values.
Expand Down
22 changes: 22 additions & 0 deletions lib/config/configuration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -812,6 +812,28 @@ func TestPostgresPublicAddr(t *testing.T) {
},
out: []string{net.JoinHostPort("postgres.example.com", strconv.Itoa(defaults.HTTPListenPort))},
},
{
desc: "when PostgresAddr is provided with port, the explicitly provided port should be use",
fc: &FileConfig{
Proxy: Proxy{
WebAddr: "0.0.0.0:8080",
PostgresAddr: "0.0.0.0:12345",
PostgresPublicAddr: []string{"postgres.example.com"},
},
},
out: []string{"postgres.example.com:12345"},
},
{
desc: "when PostgresAddr is provided without port, defaults PostgresPort should be used",
fc: &FileConfig{
Proxy: Proxy{
WebAddr: "0.0.0.0:8080",
PostgresAddr: "0.0.0.0",
PostgresPublicAddr: []string{"postgres.example.com"},
},
},
out: []string{net.JoinHostPort("postgres.example.com", strconv.Itoa(defaults.PostgresListenPort))},
},
}
for _, test := range tests {
t.Run(test.desc, func(t *testing.T) {
Expand Down
3 changes: 3 additions & 0 deletions lib/config/fileconf.go
Original file line number Diff line number Diff line change
Expand Up @@ -1115,6 +1115,9 @@ type Proxy struct {
// MySQLPublicAddr is the hostport the proxy advertises for MySQL
// client connections.
MySQLPublicAddr apiutils.Strings `yaml:"mysql_public_addr,omitempty"`

// PostgresAddr is Postgres proxy listen address.
PostgresAddr string `yaml:"postgres_listen_addr,omitempty"`
// PostgresPublicAddr is the hostport the proxy advertises for Postgres
// client connections.
PostgresPublicAddr apiutils.Strings `yaml:"postgres_public_addr,omitempty"`
Expand Down
3 changes: 3 additions & 0 deletions lib/defaults/defaults.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ const (
// MySQLListenPort is the default listen port for MySQL proxy.
MySQLListenPort = 3036

// PostgresListenPort is the default listen port for PostgreSQL proxy.
PostgresListenPort = 5432

// MetricsListenPort is the default listen port for the metrics service.
MetricsListenPort = 3081

Expand Down
6 changes: 3 additions & 3 deletions lib/multiplexer/multiplexer.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ type Config struct {
DisableSSH bool
// DisableTLS disables TLS socket
DisableTLS bool
// DisableDB disables database access proxy listener
DisableDB bool
// DisablePostgres disables Postgres access proxy listener
DisablePostgres bool
// ID is an identifier used for debugging purposes
ID string
}
Expand Down Expand Up @@ -249,7 +249,7 @@ func (m *Mux) detectAndForward(conn net.Conn) {
conn.Close()
case ProtoPostgres:
m.WithField("protocol", connWrapper.protocol).Debug("Detected Postgres client connection.")
if m.DisableDB {
if m.DisablePostgres {
m.Debug("Closing Postgres client connection: db proxy listener is disabled.")
conn.Close()
return
Expand Down
3 changes: 3 additions & 0 deletions lib/service/cfg.go
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,9 @@ type ProxyConfig struct {
// MySQLAddr is address of MySQL proxy.
MySQLAddr utils.NetAddr

// PostgresAddr is address of Postgres proxy.
PostgresAddr utils.NetAddr

Limiter limiter.Config

// PublicAddrs is a list of the public addresses the proxy advertises
Expand Down
1 change: 1 addition & 0 deletions lib/service/listeners.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ var (
listenerProxyWeb = listenerType(teleport.Component(teleport.ComponentProxy, "web"))
listenerProxyTunnel = listenerType(teleport.Component(teleport.ComponentProxy, "tunnel"))
listenerProxyMySQL = listenerType(teleport.Component(teleport.ComponentProxy, "mysql"))
listenerProxyPostgres = listenerType(teleport.Component(teleport.ComponentProxy, "postgres"))
listenerMetrics = listenerType(teleport.ComponentMetrics)
listenerWindowsDesktop = listenerType(teleport.ComponentWindowsDesktop)
)
Expand Down
36 changes: 33 additions & 3 deletions lib/service/proxy_settings.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package service

import (
"context"
"fmt"

"github.com/gravitational/trace"

Expand Down Expand Up @@ -75,6 +76,11 @@ func (p *proxySettings) buildProxySettings(proxyListenerMode types.ProxyListener
if !p.cfg.Proxy.MySQLAddr.IsEmpty() {
proxySettings.DB.MySQLListenAddr = p.cfg.Proxy.MySQLAddr.String()
}

if !p.cfg.Proxy.PostgresAddr.IsEmpty() {
proxySettings.DB.PostgresListenAddr = p.cfg.Proxy.PostgresAddr.String()
}

if p.cfg.Proxy.Kube.Enabled {
proxySettings.Kube.ListenAddr = p.cfg.Proxy.Kube.ListenAddr.String()
}
Expand All @@ -90,6 +96,7 @@ func (p *proxySettings) buildProxySettingsV2(proxyListenerMode types.ProxyListen
settings.SSH.TunnelListenAddr = multiplexAddr
settings.Kube.ListenAddr = multiplexAddr
settings.DB.MySQLListenAddr = multiplexAddr
settings.DB.PostgresListenAddr = multiplexAddr
}
return settings
}
Expand All @@ -107,10 +114,33 @@ func (p *proxySettings) setProxyPublicAddressesSettings(settings *webclient.Prox
if len(p.cfg.Proxy.Kube.PublicAddrs) > 0 {
settings.Kube.PublicAddr = p.cfg.Proxy.Kube.PublicAddrs[0].String()
}
if len(p.cfg.Proxy.PostgresPublicAddrs) > 0 {
settings.DB.PostgresPublicAddr = p.cfg.Proxy.PostgresPublicAddrs[0].String()
}
if len(p.cfg.Proxy.MySQLPublicAddrs) > 0 {
settings.DB.MySQLPublicAddr = p.cfg.Proxy.MySQLPublicAddrs[0].String()
}
settings.DB.PostgresPublicAddr = p.getPostgresPublicAddr()
}

// getPostgresPublicAddr returns the proxy PostgresPublicAddrs based on whether the Postgres proxy service
// was configured on separate listener. For backward compatibility if PostgresPublicAddrs was not provided.
// Proxy will reuse the PostgresPublicAddrs field to propagate postgres service address to legacy tsh clients.
func (p *proxySettings) getPostgresPublicAddr() string {
if len(p.cfg.Proxy.PostgresPublicAddrs) > 0 {
return p.cfg.Proxy.PostgresPublicAddrs[0].String()
}

if p.cfg.Proxy.PostgresAddr.IsEmpty() {
return ""
}

// DELETE IN 9.0.0
// If the PostgresPublicAddrs address was not set propagate separate postgres service listener address
// to legacy tsh clients reusing PostgresPublicAddrs field.
var host string
if len(p.cfg.Proxy.PublicAddrs) > 0 {
// Get proxy host address from public address.
host = p.cfg.Proxy.PublicAddrs[0].Host()
} else {
host = p.cfg.Proxy.WebAddr.Host()
}
return fmt.Sprintf("%s:%d", host, p.cfg.Proxy.PostgresAddr.Port(defaults.PostgresListenPort))
}
Loading

0 comments on commit d24ae5b

Please sign in to comment.