Skip to content

Commit

Permalink
Merge pull request #2269 from same-id/fix-non-root-hosts
Browse files Browse the repository at this point in the history
Fix non-root hosts failing on resolving DNS
  • Loading branch information
Roasbeef authored Dec 5, 2024
2 parents 25c804f + 42d6eba commit ec0b90d
Show file tree
Hide file tree
Showing 2 changed files with 179 additions and 58 deletions.
127 changes: 69 additions & 58 deletions rpcclient/infrastructure.go
Original file line number Diff line number Diff line change
Expand Up @@ -759,41 +759,26 @@ out:
// result, unmarshalling it, and delivering the unmarshalled result to the
// provided response channel.
func (c *Client) handleSendPostMessage(jReq *jsonRequest) {
protocol := "http"
if !c.config.DisableTLS {
protocol = "https"
}

var (
err, lastErr error
lastErr error
backoff time.Duration
httpResponse *http.Response
)

parsedAddr, err := ParseAddressString(c.config.Host)
httpURL, err := c.config.httpURL()
if err != nil {
jReq.responseChan <- &Response{
err: fmt.Errorf("failed to parse address %v", err),
}
return
}

var url string
switch parsedAddr.Network() {
case "unix", "unixpacket":
// Using a placeholder URL because a non-empty URL is required.
// The Unix domain socket is specified in the DialContext.
url = protocol + "://unix"
default:
url = protocol + "://" + c.config.Host
}

tries := 10
for i := 0; i < tries; i++ {
var httpReq *http.Request

bodyReader := bytes.NewReader(jReq.marshalledJSON)
httpReq, err = http.NewRequest("POST", url, bodyReader)
httpReq, err = http.NewRequest("POST", httpURL, bodyReader)
if err != nil {
jReq.responseChan <- &Response{result: nil, err: err}
return
Expand Down Expand Up @@ -1355,16 +1340,21 @@ func newHTTPClient(config *ConnConfig) (*http.Client, error) {
}
}

parsedAddr, err := ParseAddressString(config.Host)
parsedDialAddr, err := ParseAddressString(config.Host)
if err != nil {
return nil, err
}
client := http.Client{
Transport: &http.Transport{
Proxy: proxyFunc,
TLSClientConfig: tlsConfig,
DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
return net.Dial(parsedAddr.Network(), parsedAddr.String())
DialContext: func(_ context.Context, _,
_ string) (net.Conn, error) {

return net.Dial(
parsedDialAddr.Network(),
parsedDialAddr.String(),
)
},
},
Timeout: defaultHTTPTimeout,
Expand All @@ -1373,6 +1363,32 @@ func newHTTPClient(config *ConnConfig) (*http.Client, error) {
return &client, nil
}

// httpURL returns the URL to use for HTTP POST requests.
func (config *ConnConfig) httpURL() (string, error) {
protocol := "http"
if !config.DisableTLS {
protocol = "https"
}

parsedAddr, err := ParseAddressString(config.Host)
if err != nil {
return "", fmt.Errorf("error parsing host '%v': %v",
config.Host, err)
}

var httpURL string
switch parsedAddr.Network() {
case "unix", "unixpacket":
// Using a placeholder URL because a non-empty URL is required.
// The Unix domain socket is specified in the DialContext.
httpURL = protocol + "://unix"
default:
httpURL = protocol + "://" + config.Host
}

return httpURL, nil
}

// dial opens a websocket connection using the passed connection configuration
// details.
func dial(config *ConnConfig) (*websocket.Conn, error) {
Expand Down Expand Up @@ -1733,53 +1749,48 @@ func (c *Client) Send() error {
return nil
}

// cutPrefix returns s without the provided leading prefix string
// and reports whether it found the prefix.
// If s doesn't start with prefix, cutPrefix returns s, false.
// If prefix is the empty string, cutPrefix returns s, true.
// Copied from go1.20 version.
func cutPrefix(s, prefix string) (after string, found bool) {
if !strings.HasPrefix(s, prefix) {
return s, false
}
return s[len(prefix):], true
}

// ParseAddressString converts an address in string format to a net.Addr that is
// compatible with btcd. UDP is not supported because btcd needs reliable
// connections. We accept a custom function to resolve any TCP addresses so
// that caller is able control exactly how resolution is performed.
// connections.
func ParseAddressString(strAddress string) (net.Addr, error) {
var parsedNetwork, parsedAddr string
// Addresses can either be in unix://address, unixpacket://address URL
// format, or just address:port host format for tcp.
if after, ok := cutPrefix(strAddress, "unix://"); ok {
return net.ResolveUnixAddr("unix", after)
}
if after, ok := cutPrefix(strAddress, "unixpacket://"); ok {
return net.ResolveUnixAddr("unixpacket", after)
}

// Addresses can either be in network://address:port format,
// network:address:port, address:port, or just port. We want to support
// all possible types.
if strings.Contains(strAddress, "://") {
parts := strings.Split(strAddress, "://")
parsedNetwork, parsedAddr = parts[0], parts[1]
} else if strings.Contains(strAddress, ":") {
parts := strings.Split(strAddress, ":")
parsedNetwork = parts[0]
parsedAddr = strings.Join(parts[1:], ":")
} else {
parsedAddr = strAddress
// Not supporting :// anywhere in the host or path.
return nil, fmt.Errorf("unsupported protocol in address: %s",
strAddress)
}

// Only TCP and Unix socket addresses are valid. We can't use IP or
// UDP only connections for anything we do in lnd.
switch parsedNetwork {
case "unix", "unixpacket":
return net.ResolveUnixAddr(parsedNetwork, parsedAddr)

case "tcp", "tcp4", "tcp6":
return net.ResolveTCPAddr(parsedNetwork, verifyPort(parsedAddr))

case "ip", "ip4", "ip6", "udp", "udp4", "udp6", "unixgram":
return nil, fmt.Errorf("only TCP or unix socket "+
"addresses are supported: %s", parsedAddr)

default:
// We'll now possibly use the local host short circuit
// or parse out an all interfaces listen.
addrWithPort := verifyPort(strAddress)

// Otherwise, we'll attempt to resolve the host.
return net.ResolveTCPAddr("tcp", addrWithPort)
// Parse it as a dummy URL to get the host and port.
u, err := url.Parse("dummy://" + strAddress)
if err != nil {
return nil, err
}
return net.ResolveTCPAddr("tcp", verifyPort(u.Host))
}

// verifyPort makes sure that an address string has both a host and a port.
// If the address is just a port, then we'll assume that the user is using the
// short cut to specify a localhost:port address.
// shortcut to specify a localhost:port address.
func verifyPort(address string) string {
host, port, err := net.SplitHostPort(address)
if err != nil {
Expand All @@ -1801,8 +1812,8 @@ func verifyPort(address string) string {
return net.JoinHostPort(address, "")
}

// In the case that both the host and port are empty, we'll use the
// an empty port.
// In the case that both the host and port are empty, we'll use an empty
// port.
if host == "" && port == "" {
return ":"
}
Expand Down
110 changes: 110 additions & 0 deletions rpcclient/infrastructure_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
package rpcclient

import (
"testing"

"github.com/stretchr/testify/require"
)

// TestParseAddressString checks different variation of supported and
// unsupported addresses.
func TestParseAddressString(t *testing.T) {
t.Parallel()

// Using localhost only to avoid network calls.
testCases := []struct {
name string
addressString string
expNetwork string
expAddress string
expErrStr string
}{
{
name: "localhost",
addressString: "localhost",
expNetwork: "tcp",
expAddress: "127.0.0.1:0",
},
{
name: "localhost ip",
addressString: "127.0.0.1",
expNetwork: "tcp",
expAddress: "127.0.0.1:0",
},
{
name: "localhost ipv6",
addressString: "::1",
expNetwork: "tcp",
expAddress: "[::1]:0",
},
{
name: "localhost and port",
addressString: "localhost:80",
expNetwork: "tcp",
expAddress: "127.0.0.1:80",
},
{
name: "localhost ipv6 and port",
addressString: "[::1]:80",
expNetwork: "tcp",
expAddress: "[::1]:80",
},
{
name: "colon and port",
addressString: ":80",
expNetwork: "tcp",
expAddress: ":80",
},
{
name: "colon only",
addressString: ":",
expNetwork: "tcp",
expAddress: ":0",
},
{
name: "localhost and path",
addressString: "localhost/path",
expNetwork: "tcp",
expAddress: "127.0.0.1:0",
},
{
name: "localhost port and path",
addressString: "localhost:80/path",
expNetwork: "tcp",
expAddress: "127.0.0.1:80",
},
{
name: "unix prefix",
addressString: "unix://the/rest/of/the/path",
expNetwork: "unix",
expAddress: "the/rest/of/the/path",
},
{
name: "unix prefix",
addressString: "unixpacket://the/rest/of/the/path",
expNetwork: "unixpacket",
expAddress: "the/rest/of/the/path",
},
{
name: "error http prefix",
addressString: "http://localhost:1010",
expErrStr: "unsupported protocol in address",
},
}

for _, tc := range testCases {
tc := tc

t.Run(tc.name, func(t *testing.T) {
addr, err := ParseAddressString(tc.addressString)
if tc.expErrStr != "" {
require.Error(t, err)
require.Contains(t, err.Error(), tc.expErrStr)
return
}
require.NoError(t, err)
require.Equal(t, tc.expNetwork, addr.Network())
require.Equal(t, tc.expAddress, addr.String())
})
}
}

0 comments on commit ec0b90d

Please sign in to comment.