diff --git a/clientconn.go b/clientconn.go index 4f57b55434f9..3a98f0c2ac29 100644 --- a/clientconn.go +++ b/clientconn.go @@ -225,7 +225,7 @@ func Dial(target string, opts ...DialOption) (*ClientConn, error) { func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *ClientConn, err error) { // At the end of this method, we kick the channel out of idle, rather than // waiting for the first rpc. - opts = append([]DialOption{withDefaultScheme("passthrough")}, opts...) + opts = append([]DialOption{withDefaultScheme("passthrough"), WithTargetResolutionEnabled()}, opts...) cc, err := NewClient(target, opts...) if err != nil { return nil, err diff --git a/dialoptions.go b/dialoptions.go index 7494ae591f16..3950cc96853a 100644 --- a/dialoptions.go +++ b/dialoptions.go @@ -94,6 +94,11 @@ type dialOptions struct { idleTimeout time.Duration defaultScheme string maxCallAttempts int + // targetResolutionEnabled indicates whether the client should perform + // target resolution even when a proxy is enabled. + targetResolutionEnabled bool + // useProxy specifies if a proxy should be used. + useProxy bool } // DialOption configures how we set up the connection. @@ -377,7 +382,20 @@ func WithInsecure() DialOption { // later release. func WithNoProxy() DialOption { return newFuncDialOption(func(o *dialOptions) { - o.copts.UseProxy = false + o.useProxy = false + }) +} + +// WithTargetResolutionEnabled returns a DialOption which enables target +// resolution on client. This is ignored if WithNoProxy is used. +// +// # Experimental +// +// Notice: This API is EXPERIMENTAL and may be changed or removed in a +// later release. +func WithTargetResolutionEnabled() DialOption { + return newFuncDialOption(func(o *dialOptions) { + o.targetResolutionEnabled = true }) } @@ -662,14 +680,15 @@ func defaultDialOptions() dialOptions { copts: transport.ConnectOptions{ ReadBufferSize: defaultReadBufSize, WriteBufferSize: defaultWriteBufSize, - UseProxy: true, UserAgent: grpcUA, BufferPool: mem.DefaultBufferPool(), }, - bs: internalbackoff.DefaultExponential, - idleTimeout: 30 * time.Minute, - defaultScheme: "dns", - maxCallAttempts: defaultMaxCallAttempts, + bs: internalbackoff.DefaultExponential, + idleTimeout: 30 * time.Minute, + defaultScheme: "dns", + maxCallAttempts: defaultMaxCallAttempts, + useProxy: true, + targetResolutionEnabled: false, } } diff --git a/internal/internal.go b/internal/internal.go index 7fbfaacde9bf..134c11ea7869 100644 --- a/internal/internal.go +++ b/internal/internal.go @@ -22,6 +22,8 @@ package internal import ( "context" + "net/http" + "net/url" "time" "google.golang.org/grpc/connectivity" @@ -238,6 +240,11 @@ var ( // SetBufferPoolingThresholdForTesting updates the buffer pooling threshold, for // testing purposes. SetBufferPoolingThresholdForTesting any // func(int) + + // HTTPSProxyFromEnvironmentForTesting returns the URL of the proxy to use + // for testing purposes. It is used to override the `http.ProxyFromEnvironment` + // function for testing purposes. + HTTPSProxyFromEnvironmentForTesting func(*http.Request) (*url.URL, error) ) // HealthChecker defines the signature of the client-side LB channel health diff --git a/internal/proxyattributes/proxyattributes.go b/internal/proxyattributes/proxyattributes.go new file mode 100644 index 000000000000..e5e22b287ba4 --- /dev/null +++ b/internal/proxyattributes/proxyattributes.go @@ -0,0 +1,66 @@ +/* + * + * Copyright 2024 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package proxyattributes contains functions for getting and setting proxy +// attributes like the CONNECT address and user info. +package proxyattributes + +import ( + "net/url" + + "google.golang.org/grpc/resolver" +) + +type keyType string + +const proxyOptionsKey = keyType("grpc.resolver.delegatingresolver.proxyOptions") + +// Options holds the proxy connection details needed during the CONNECT +// handshake. It includes the user information and the connect address. +type Options struct { + User *url.Userinfo + ConnectAddr string +} + +// Populate returns a copy of addr with attributes containing the provided user +// and connect address, which are needed during the CONNECT handshake for a +// proxy connection. +func Populate(addr resolver.Address, opts Options) resolver.Address { + addr.Attributes = addr.Attributes.WithValue(proxyOptionsKey, opts) + return addr +} + +// ConnectAddr returns the proxy connect address in resolver.Address, or nil if +// not present. The returned data should not be mutated. +func ConnectAddr(addr resolver.Address) string { + a := addr.Attributes.Value(proxyOptionsKey) + if a != nil { + return a.(Options).ConnectAddr + } + return "" +} + +// User returns the user info in the resolver.Address, or nil if not present. +// The returned data should not be mutated. +func User(addr resolver.Address) *url.Userinfo { + a := addr.Attributes.Value(proxyOptionsKey) + if a != nil { + return a.(Options).User + } + return nil +} diff --git a/internal/proxyattributes/proxyattributes_test.go b/internal/proxyattributes/proxyattributes_test.go new file mode 100644 index 000000000000..e76f3792e39c --- /dev/null +++ b/internal/proxyattributes/proxyattributes_test.go @@ -0,0 +1,126 @@ +/* + * + * Copyright 2024 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package proxyattributes + +import ( + "net/url" + "testing" + + "google.golang.org/grpc/attributes" + "google.golang.org/grpc/internal/grpctest" + "google.golang.org/grpc/resolver" +) + +type s struct { + grpctest.Tester +} + +func Test(t *testing.T) { + grpctest.RunSubTests(t, s{}) +} + +// Tests that ConnectAddr returns the correct connect address in the attribute. +func (s) TestConnectAddr(t *testing.T) { + tests := []struct { + name string + addr resolver.Address + want string + }{ + { + name: "connect address in attribute", + addr: resolver.Address{ + Addr: "test-address", + Attributes: attributes.New(proxyOptionsKey, Options{ + User: nil, + ConnectAddr: "proxy-address", + }), + }, + want: "proxy-address", + }, + { + name: "no attribute", + addr: resolver.Address{Addr: "test-address"}, + want: "", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Validate ConnectAddr returns the correct connect address in the attribute. + if got := ConnectAddr(tt.addr); got != tt.want { + t.Errorf("ConnetAddr(%v) = %v, want %v ", tt.addr, got, tt.want) + } + }) + } +} + +// Tests that User returns the correct user in the attribute. +func (s) TestUser(t *testing.T) { + user := url.UserPassword("username", "password") + tests := []struct { + name string + addr resolver.Address + want *url.Userinfo + }{ + { + name: "user in attribute", + addr: resolver.Address{ + Addr: "test-address", + Attributes: attributes.New(proxyOptionsKey, Options{ + User: user, + ConnectAddr: "proxy-address", + })}, + want: user, + }, + { + name: "no attribute", + addr: resolver.Address{Addr: "test-address"}, + want: nil, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Validate User returns the correct user in the attribute. + if got := User(tt.addr); got != tt.want { + t.Errorf("User(%v) = %v, want %v ", tt.addr, got, tt.want) + } + }) + } +} + +// Tests that Populate returns a copy of addr with attributes containing correct +// user and connect address. +func (s) TestPopulate(t *testing.T) { + addr := resolver.Address{Addr: "test-address"} + pOpts := Options{ + User: url.UserPassword("username", "password"), + ConnectAddr: "proxy-address", + } + + // Call Populate and validate attributes + populatedAddr := Populate(addr, pOpts) + + // Verify that the returned address is updated correctly + if got, want := ConnectAddr(populatedAddr), pOpts.ConnectAddr; got != want { + t.Errorf("Unexpected ConnectAddr proxy atrribute = %v, want %v", got, want) + } + + if got, want := User(populatedAddr), pOpts.User; got != want { + t.Errorf("unexpected User proxy attribute = %v, want %v", got, want) + } +} diff --git a/internal/resolver/delegatingresolver/delegatingresolver.go b/internal/resolver/delegatingresolver/delegatingresolver.go new file mode 100644 index 000000000000..fdb119b97942 --- /dev/null +++ b/internal/resolver/delegatingresolver/delegatingresolver.go @@ -0,0 +1,306 @@ +/* + * + * Copyright 2024 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package delegatingresolver implements a resolver capable of resolving both +// target URIs and proxy addresses. +package delegatingresolver + +import ( + "fmt" + "net/http" + "net/url" + "sync" + + "google.golang.org/grpc/grpclog" + "google.golang.org/grpc/internal" + "google.golang.org/grpc/internal/proxyattributes" + "google.golang.org/grpc/resolver" + "google.golang.org/grpc/serviceconfig" +) + +var ( + // HTTPSProxyFromEnvironment will be used and overwritten in the tests. + httpProxyFromEnvironmentFunc = http.ProxyFromEnvironment + // ProxyScheme will be ovwewritten in tests + ProxyScheme = "dns" + logger = grpclog.Component("delegating-resolver") +) + +// delegatingResolver manages both target URI and proxy address resolution by +// delegating these tasks to separate child resolvers. Essentially, it acts as +// a middleman between the gRPC ClientConn and the child resolvers. +// +// It implements the [resolver.Resolver] interface. +type delegatingResolver struct { + target resolver.Target // parsed target URI to be resolved + cc resolver.ClientConn // gRPC ClientConn + targetResolver resolver.Resolver // resolver for the target URI, based on its scheme + proxyResolver resolver.Resolver // resolver for the proxy URI; nil if no proxy is configured + + mu sync.Mutex // protects access to the resolver state and addresses during updates + targetAddrs []resolver.Address // resolved or unresolved target addresses, depending on proxy configuration + targetEndpt []resolver.Endpoint // resolved target endpoint + proxyAddrs []resolver.Address // resolved proxy addresses; empty if no proxy is configured + proxyURL *url.URL // proxy URL, derived from proxy environment and target + targetResolverReady bool // indicates if an update from the target resolver has been received + proxyResolverReady bool // indicates if an update from the proxy resolver has been received +} + +// parsedURLForProxy determines the proxy URL for the given address based on +// the environment. It can return the following: +// - nil URL, nil error: No proxy is configured or the address is excluded +// using the `NO_PROXY` environment variable or if req.URL.Host is +// "localhost" (with or without // a port number) +// - nil URL, non-nil error: An error occurred while retrieving the proxy URL. +// - non-nil URL, nil error: A proxy is configured, and the proxy URL was +// retrieved successfully without any errors. +func parsedURLForProxy(address string) (*url.URL, error) { + proxyFunc := httpProxyFromEnvironmentFunc + if pf := internal.HTTPSProxyFromEnvironmentForTesting; pf != nil { + proxyFunc = pf + } + + req := &http.Request{URL: &url.URL{ + Scheme: "https", + Host: address, + }} + url, err := proxyFunc(req) + if err != nil { + return nil, err + } + return url, nil +} + +// OnClientResolution is a no-op function in non-test code. In tests, it can +// be overwritten to send a signal to a channel, indicating that client-side +// name resolution was triggered. This enables tests to verify that resolution +// on client side is bypassed when a proxy is in use. +var OnClientResolution = func(int) { /* no-op */ } + +// New creates a new delegating resolver that can create up to two child +// resolvers: +// - one to resolve the proxy address specified using the supported +// environment variables. This uses the registered resolver for the "dns" +// scheme. +// - one to resolve the target URI using the resolver specified by the scheme +// in the target URI or specified by the user using the WithResolvers dial +// option. As a special case, if the target URI's scheme is "dns" and a +// proxy is specified using the supported environment variables, the target +// URI's path portion is used as the resolved address unless target +// resolution is enabled using the dial option. +func New(target resolver.Target, cc resolver.ClientConn, opts resolver.BuildOptions, targetResolverBuilder resolver.Builder, targetResolutionEnabled bool) (resolver.Resolver, error) { + r := &delegatingResolver{ + target: target, + cc: cc, + } + + var err error + r.proxyURL, err = parsedURLForProxy(target.Endpoint()) + if err != nil { + return nil, fmt.Errorf("delegating_resolver: failed to determine proxy URL for target %s: %v", target, err) + } + + // proxy is not configured or proxy address excluded using `NO_PROXY` env + // var, so only target resolver is used. + if r.proxyURL == nil { + OnClientResolution(1) + return targetResolverBuilder.Build(target, cc, opts) + } + + if logger.V(2) { + logger.Info("Proxy URL detected : %s", r.proxyURL) + } + + // When the scheme is 'dns' and target resolution on client is not enabled, + // resolution should be handled by the proxy, not the client. Therefore, we + // bypass the target resolver and store the unresolved target address. + if target.URL.Scheme == "dns" && !targetResolutionEnabled { + r.targetAddrs = []resolver.Address{{Addr: target.Endpoint()}} + r.targetResolverReady = true + } else { + OnClientResolution(1) + wcc := &wrappingClientConn{ + parent: r, + resolverType: targetResolverType, + } + if r.targetResolver, err = targetResolverBuilder.Build(target, wcc, opts); err != nil { + return nil, fmt.Errorf("delegating_resolver: unable to build the resolver for target %s: %v", target, err) + } + } + + if r.proxyResolver, err = r.proxyURIResolver(opts); err != nil { + return nil, fmt.Errorf("delegating_resolver: failed to build resolver for proxy URL %q: %v", r.proxyURL, err) + } + return r, nil +} + +// proxyURIResolver creates a resolver for resolving proxy URIs using the +// "dns" scheme. It adjusts the proxyURL to conform to the "dns:///" format and +// builds a resolver with a wrappingClientConn to capture resolved addresses. +func (r *delegatingResolver) proxyURIResolver(opts resolver.BuildOptions) (resolver.Resolver, error) { + proxyBuilder := resolver.Get(ProxyScheme) + if proxyBuilder == nil { + panic(fmt.Sprintln("delegating_resolver: resolver for proxy not found for scheme dns")) + } + r.proxyURL.Scheme = "dns" + r.proxyURL.Path = "/" + r.proxyURL.Host + r.proxyURL.Host = "" // Clear the Host field to conform to the "dns:///" format + + proxyTarget := resolver.Target{URL: *r.proxyURL} + wcc := &wrappingClientConn{ + parent: r, + resolverType: proxyResolverType, + } + return proxyBuilder.Build(proxyTarget, wcc, opts) +} + +func (r *delegatingResolver) ResolveNow(o resolver.ResolveNowOptions) { + if r.targetResolver != nil { + r.targetResolver.ResolveNow(o) + } + if r.proxyResolver != nil { + r.proxyResolver.ResolveNow(o) + } +} + +func (r *delegatingResolver) Close() { + if r.targetResolver != nil { + r.targetResolver.Close() + } + if r.proxyResolver != nil { + r.proxyResolver.Close() + } +} + +// generateCombinedAddressesLocked creates a list of combined addresses by +// pairing each proxy address with every target address. For each pair, it +// generates a new [resolver.Address] using the proxy address, and adding the +// target address as the attribute along with user info. +func (r *delegatingResolver) generateCombinedAddressesLocked() ([]resolver.Address, []resolver.Endpoint) { + var addresses []resolver.Address + for _, proxyAddr := range r.proxyAddrs { + for _, targetAddr := range r.targetAddrs { + newAddr := resolver.Address{Addr: proxyAddr.Addr} + newAddr = proxyattributes.Populate(newAddr, proxyattributes.Options{ + User: r.proxyURL.User, + ConnectAddr: targetAddr.Addr, + }) + addresses = append(addresses, newAddr) + } + } + + if r.targetEndpt == nil { + return addresses, nil + } + + var endpoints []resolver.Endpoint + for _, proxyAddr := range r.proxyAddrs { + for _, endpt := range r.targetEndpt { + var addrs []resolver.Address + for _, targetAddr := range endpt.Addresses { + newAddr := resolver.Address{Addr: proxyAddr.Addr} + newAddr = proxyattributes.Populate(newAddr, proxyattributes.Options{ + User: r.proxyURL.User, + ConnectAddr: targetAddr.Addr, + }) + addrs = append(addrs, newAddr) + } + endpoints = append(endpoints, resolver.Endpoint{Addresses: addrs}) + } + } + return addresses, endpoints +} + +// resolverType is an enum representing the type of resolver (target or proxy). +type resolverType int + +const ( + targetResolverType resolverType = iota + proxyResolverType +) + +// wrappingClientConn serves as an intermediary between the parent ClientConn +// and the child resolvers created here. It implements the resolver.ClientConn +// interface and is passed in that capacity to the child resolvers. +// +// Its primary function is to aggregate addresses returned by the child +// resolvers before passing them to the parent ClientConn. Any errors returned +// by the child resolvers are propagated verbatim to the parent ClientConn. +type wrappingClientConn struct { + parent *delegatingResolver + resolverType resolverType // represents the type of resolver (target or proxy) +} + +// UpdateState processes updates from the target or proxy resolver. It is called +// twice: once by the target resolver and once by the proxy resolver. It logs +// received addresses, and combines addresses from both resolvers once updates +// from both are received, sending the final state to the parent ClientConn. +func (wcc *wrappingClientConn) UpdateState(state resolver.State) error { + wcc.parent.mu.Lock() + defer wcc.parent.mu.Unlock() + + var curState resolver.State + if wcc.resolverType == targetResolverType { + if logger.V(2) { + logger.Infof("Addresses received from target resolver: %v", state.Addresses) + } + wcc.parent.targetAddrs = state.Addresses + wcc.parent.targetEndpt = state.Endpoints + wcc.parent.targetResolverReady = true + // Update curState to include other state information, such as the + // service config, provided by the target resolver. This ensures + // curState contains all necessary information when passed to + // UpdateState. The state update is only sent after both the target and + // proxy resolvers have sent their updates, and curState has been + // updated with the combined addresses. + curState = state + } + // The proxy resolver is always "dns," so only the address will be + // received, not an endpoint. + if wcc.resolverType == proxyResolverType { + if logger.V(2) { + logger.Infof("Addresses received from proxy resolver: %s", state.Addresses) + } + wcc.parent.proxyAddrs = state.Addresses + wcc.parent.proxyResolverReady = true + } + + // Proceed only if updates from both resolvers have been received. + if !wcc.parent.targetResolverReady || !wcc.parent.proxyResolverReady { + return nil + } + curState.Addresses, curState.Endpoints = wcc.parent.generateCombinedAddressesLocked() + return wcc.parent.cc.UpdateState(curState) +} + +// ReportError intercepts errors from the child resolvers and passes them to ClientConn. +func (wcc *wrappingClientConn) ReportError(err error) { + wcc.parent.cc.ReportError(err) +} + +// NewAddress intercepts the new resolved address from the child resolvers and +// passes them to ClientConn. +func (wcc *wrappingClientConn) NewAddress(addrs []resolver.Address) { + wcc.UpdateState(resolver.State{Addresses: addrs}) +} + +// ParseServiceConfig parses the provided service config and returns an +// object that provides the parsed config. +func (wcc *wrappingClientConn) ParseServiceConfig(serviceConfigJSON string) *serviceconfig.ParseResult { + return wcc.parent.cc.ParseServiceConfig(serviceConfigJSON) +} diff --git a/internal/resolver/delegatingresolver/delegatingresolver_ext_test.go b/internal/resolver/delegatingresolver/delegatingresolver_ext_test.go new file mode 100644 index 000000000000..a450ee36f481 --- /dev/null +++ b/internal/resolver/delegatingresolver/delegatingresolver_ext_test.go @@ -0,0 +1,340 @@ +/* + * + * Copyright 2024 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package delegatingresolver_test + +import ( + "net/http" + "net/url" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "google.golang.org/grpc/internal" + "google.golang.org/grpc/internal/grpctest" + "google.golang.org/grpc/internal/pretty" + "google.golang.org/grpc/internal/proxyattributes" + "google.golang.org/grpc/internal/resolver/delegatingresolver" + "google.golang.org/grpc/internal/testutils" + "google.golang.org/grpc/resolver" + _ "google.golang.org/grpc/resolver/dns" // To register dns resolver. + "google.golang.org/grpc/resolver/manual" + "google.golang.org/grpc/serviceconfig" +) + +type s struct { + grpctest.Tester +} + +func Test(t *testing.T) { + grpctest.RunSubTests(t, s{}) +} + +const ( + targetTestAddr = "test.com" + resolvedTargetTestAddr1 = "1.2.3.4:8080" + resolvedTargetTestAddr2 = "1.2.3.5:8080" + envProxyAddr = "proxytest.com" + resolvedProxyTestAddr1 = "2.3.4.5:7687" + resolvedProxyTestAddr2 = "2.3.4.6:7687" + defaultTestTimeout = 10 * time.Second + defaultTestShortTimeout = 10 * time.Millisecond +) + +// createTestResolverClientConn initializes a [testutils.ResolverClientConn] and +// returns it along with channels for resolver state updates and errors. +func createTestResolverClientConn(t *testing.T) (*testutils.ResolverClientConn, chan resolver.State, chan error) { + stateCh := make(chan resolver.State, 1) + errCh := make(chan error, 1) + + tcc := &testutils.ResolverClientConn{ + Logger: t, + UpdateStateF: func(s resolver.State) error { stateCh <- s; return nil }, + ReportErrorF: func(err error) { errCh <- err }, + } + return tcc, stateCh, errCh +} + +// Tests the scenario where no proxy environment variables are set or proxying +// is disabled by the `NO_PROXY` environment variable. The test verifies that +// the delegating resolver creates only a target resolver and that the +// addresses returned by the delegating resolver exactly match those returned +// by the target resolver. +func (s) TestDelegatingResolverNoProxyEnvVarsSet(t *testing.T) { + hpfe := func(req *http.Request) (*url.URL, error) { return nil, nil } + originalhpfe := internal.HTTPSProxyFromEnvironmentForTesting + internal.HTTPSProxyFromEnvironmentForTesting = hpfe + defer func() { + internal.HTTPSProxyFromEnvironmentForTesting = originalhpfe + }() + + mr := manual.NewBuilderWithScheme("test") // Set up a manual resolver to control the address resolution. + target := mr.Scheme() + ":///" + targetTestAddr + + // Create a delegating resolver with no proxy configuration + tcc, stateCh, _ := createTestResolverClientConn(t) + if _, err := delegatingresolver.New(resolver.Target{URL: *testutils.MustParseURL(target)}, tcc, resolver.BuildOptions{}, mr, false); err != nil { + t.Fatalf("Failed to create delegating resolver: %v", err) + } + + // Update the manual resolver with a test address. + mr.UpdateState(resolver.State{ + Addresses: []resolver.Address{{Addr: resolvedTargetTestAddr1}}, + ServiceConfig: &serviceconfig.ParseResult{}, + }) + + // Verify that the delegating resolver outputs the same addresses, as returned + // by the target resolver. + wantState := resolver.State{ + Addresses: []resolver.Address{{Addr: resolvedTargetTestAddr1}}, + ServiceConfig: &serviceconfig.ParseResult{}, + } + + var gotState resolver.State + select { + case gotState = <-stateCh: + case <-time.After(defaultTestTimeout): + t.Fatal("Timeout when waiting for a state update from the delegating resolver") + } + + if !cmp.Equal(gotState, wantState) { + t.Fatalf("Unexpected state from delegating resolver: %s, want %s", pretty.ToJSON(gotState), pretty.ToJSON(wantState)) + } +} + +// setupDNS registers a new manual resolver for the DNS scheme, effectively +// overwriting the previously registered DNS resolver. This allows the test to +// mock the DNS resolution for the proxy resolver. It also registers the +// original DNS resolver after the test is done. +func setupDNS(t *testing.T) *manual.Resolver { + mr := manual.NewBuilderWithScheme("dns") + + dnsResolverBuilder := resolver.Get("dns") + resolver.Register(mr) + + t.Cleanup(func() { resolver.Register(dnsResolverBuilder) }) + return mr +} + +// proxyAddressWithTargetAttribute creates a resolver.Address for the proxy, +// adding the target address as an attribute. +func proxyAddressWithTargetAttribute(proxyAddr string, targetAddr string) resolver.Address { + addr := resolver.Address{Addr: proxyAddr} + addr = proxyattributes.Populate(addr, proxyattributes.Options{ + User: nil, + ConnectAddr: targetAddr, + }) + return addr +} + +// Tests the scenario where proxy is configured and the target URI contains the +// "dns" scheme and target resolution is enabled. The test verifies that the +// addresses returned by the delegating resolver combines the addresses +// returned by the proxy resolver and the target resolver. +func (s) TestDelegatingResolverwithDNSAndProxyWithTargetResolution(t *testing.T) { + hpfe := func(req *http.Request) (*url.URL, error) { + if req.URL.Host == targetTestAddr { + return &url.URL{ + Scheme: "https", + Host: envProxyAddr, + }, nil + } + return nil, nil + } + originalhpfe := internal.HTTPSProxyFromEnvironmentForTesting + internal.HTTPSProxyFromEnvironmentForTesting = hpfe + defer func() { + internal.HTTPSProxyFromEnvironmentForTesting = originalhpfe + }() + + mrTarget := setupDNS(t) // Manual resolver to control the target resolution. + target := mrTarget.Scheme() + ":///" + targetTestAddr + mrProxy := setupDNS(t) // Set up a manual DNS resolver to control the proxy address resolution. + + tcc, stateCh, _ := createTestResolverClientConn(t) + if _, err := delegatingresolver.New(resolver.Target{URL: *testutils.MustParseURL(target)}, tcc, resolver.BuildOptions{}, mrTarget, true); err != nil { + t.Fatalf("Failed to create delegating resolver: %v", err) + } + + mrProxy.UpdateState(resolver.State{ + Addresses: []resolver.Address{ + {Addr: resolvedProxyTestAddr1}, + {Addr: resolvedProxyTestAddr2}, + }, + ServiceConfig: &serviceconfig.ParseResult{}, + }) + + select { + case <-stateCh: + t.Fatalf("Delegating resolver invoked UpdateState before both the proxy and target resolvers had updated their states.") + case <-time.After(defaultTestShortTimeout): + } + + mrTarget.UpdateState(resolver.State{ + Addresses: []resolver.Address{ + {Addr: resolvedTargetTestAddr1}, + {Addr: resolvedTargetTestAddr2}, + }, + ServiceConfig: &serviceconfig.ParseResult{}, + }) + + // Verify that the delegating resolver outputs the expected address. + wantState := resolver.State{Addresses: []resolver.Address{ + proxyAddressWithTargetAttribute(resolvedProxyTestAddr1, resolvedTargetTestAddr1), + proxyAddressWithTargetAttribute(resolvedProxyTestAddr1, resolvedTargetTestAddr2), + proxyAddressWithTargetAttribute(resolvedProxyTestAddr2, resolvedTargetTestAddr1), + proxyAddressWithTargetAttribute(resolvedProxyTestAddr2, resolvedTargetTestAddr2), + }, + ServiceConfig: &serviceconfig.ParseResult{}, + } + var gotState resolver.State + select { + case gotState = <-stateCh: + case <-time.After(defaultTestTimeout): + t.Fatal("Timeout when waiting for a state update from the delegating resolver") + } + + if !cmp.Equal(gotState, wantState) { + t.Fatalf("Unexpected state from delegating resolver: %s, want %s", pretty.ToJSON(gotState), pretty.ToJSON(wantState)) + } +} + +// Tests the scenario where a proxy is configured, the target URI contains the +// "dns" scheme, and target resolution is disabled(default behavior). The test +// verifies that the addresses returned by the delegating resolver include the +// proxy resolver's addresses, with the unresolved target URI as an attribute +// of the proxy address. +func (s) TestDelegatingResolverwithDNSAndProxyWithNoTargetResolution(t *testing.T) { + hpfe := func(req *http.Request) (*url.URL, error) { + if req.URL.Host == targetTestAddr { + return &url.URL{ + Scheme: "https", + Host: envProxyAddr, + }, nil + } + return nil, nil + } + originalhpfe := internal.HTTPSProxyFromEnvironmentForTesting + internal.HTTPSProxyFromEnvironmentForTesting = hpfe + defer func() { + internal.HTTPSProxyFromEnvironmentForTesting = originalhpfe + }() + + mrTarget := manual.NewBuilderWithScheme("test") // Manual resolver to control the target resolution. + target := "dns:///" + targetTestAddr + mrProxy := setupDNS(t) // Set up a manual DNS resolver to control the proxy address resolution. + + tcc, stateCh, _ := createTestResolverClientConn(t) + if _, err := delegatingresolver.New(resolver.Target{URL: *testutils.MustParseURL(target)}, tcc, resolver.BuildOptions{}, mrTarget, false); err != nil { + t.Fatalf("Failed to create delegating resolver: %v", err) + } + + mrProxy.UpdateState(resolver.State{ + Addresses: []resolver.Address{ + {Addr: resolvedProxyTestAddr1}, + {Addr: resolvedProxyTestAddr2}, + }, + }) + + wantState := resolver.State{Addresses: []resolver.Address{ + proxyAddressWithTargetAttribute(resolvedProxyTestAddr1, targetTestAddr), + proxyAddressWithTargetAttribute(resolvedProxyTestAddr2, targetTestAddr), + }} + + var gotState resolver.State + select { + case gotState = <-stateCh: + case <-time.After(defaultTestTimeout): + t.Fatal("Timeout when waiting for a state update from the delegating resolver") + } + + if !cmp.Equal(gotState, wantState) { + t.Fatalf("Unexpected state from delegating resolver: %s, want %s", pretty.ToJSON(gotState), pretty.ToJSON(wantState)) + } +} + +// Tests the scenario where a proxy is configured, and the target URI scheme is +// not "dns". The test verifies that the addresses returned by the delegating +// resolver include the resolved proxy address and the custom resolved target +// address as attributes of the proxy address. +func (s) TestDelegatingResolverwithCustomResolverAndProxy(t *testing.T) { + hpfe := func(req *http.Request) (*url.URL, error) { + if req.URL.Host == targetTestAddr { + return &url.URL{ + Scheme: "https", + Host: envProxyAddr, + }, nil + } + return nil, nil + } + originalhpfe := internal.HTTPSProxyFromEnvironmentForTesting + internal.HTTPSProxyFromEnvironmentForTesting = hpfe + defer func() { + internal.HTTPSProxyFromEnvironmentForTesting = originalhpfe + }() + + mrTarget := manual.NewBuilderWithScheme("test") // Manual resolver to control the target resolution. + target := mrTarget.Scheme() + ":///" + targetTestAddr + mrProxy := setupDNS(t) // Set up a manual DNS resolver to control the proxy address resolution. + + tcc, stateCh, _ := createTestResolverClientConn(t) + if _, err := delegatingresolver.New(resolver.Target{URL: *testutils.MustParseURL(target)}, tcc, resolver.BuildOptions{}, mrTarget, false); err != nil { + t.Fatalf("Failed to create delegating resolver: %v", err) + } + + mrProxy.UpdateState(resolver.State{ + Addresses: []resolver.Address{ + {Addr: resolvedProxyTestAddr1}, + {Addr: resolvedProxyTestAddr2}, + }, + ServiceConfig: &serviceconfig.ParseResult{}, + }) + + select { + case <-stateCh: + t.Fatalf("Delegating resolver invoked UpdateState before both the proxy and target resolvers had updated their states.") + case <-time.After(defaultTestShortTimeout): + } + + mrTarget.UpdateState(resolver.State{ + Addresses: []resolver.Address{ + {Addr: resolvedTargetTestAddr1}, + {Addr: resolvedTargetTestAddr2}, + }, + ServiceConfig: &serviceconfig.ParseResult{}, + }) + + wantState := resolver.State{Addresses: []resolver.Address{ + proxyAddressWithTargetAttribute(resolvedProxyTestAddr1, resolvedTargetTestAddr1), + proxyAddressWithTargetAttribute(resolvedProxyTestAddr1, resolvedTargetTestAddr2), + proxyAddressWithTargetAttribute(resolvedProxyTestAddr2, resolvedTargetTestAddr1), + proxyAddressWithTargetAttribute(resolvedProxyTestAddr2, resolvedTargetTestAddr2), + }, + ServiceConfig: &serviceconfig.ParseResult{}, + } + var gotState resolver.State + select { + case gotState = <-stateCh: + case <-time.After(defaultTestTimeout): + t.Fatal("Timeout when waiting for a state update from the delegating resolver") + } + + if !cmp.Equal(gotState, wantState) { + t.Fatalf("Unexpected state from delegating resolver: %s, want %s", pretty.ToJSON(gotState), pretty.ToJSON(wantState)) + } +} diff --git a/internal/resolver/delegatingresolver/delegatingresolver_test.go b/internal/resolver/delegatingresolver/delegatingresolver_test.go new file mode 100644 index 000000000000..55eaef79a954 --- /dev/null +++ b/internal/resolver/delegatingresolver/delegatingresolver_test.go @@ -0,0 +1,111 @@ +/* + * + * Copyright 2024 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package delegatingresolver + +import ( + "errors" + "net/http" + "net/url" + "testing" + + "github.com/google/go-cmp/cmp" + "google.golang.org/grpc/internal" + "google.golang.org/grpc/internal/grpctest" + _ "google.golang.org/grpc/resolver/dns" // To register dns resolver. +) + +type s struct { + grpctest.Tester +} + +func Test(t *testing.T) { + grpctest.RunSubTests(t, s{}) +} + +const ( + targetTestAddr = "test.com" + envProxyAddr = "proxytest.com" +) + +// overrideHTTPSProxyFromEnvironment function overwrites HTTPSProxyFromEnvironment and +// returns a function to restore the default values. +func overrideHTTPSProxyFromEnvironment(hpfe func(req *http.Request) (*url.URL, error)) func() { + internal.HTTPSProxyFromEnvironmentForTesting = hpfe + return func() { + internal.HTTPSProxyFromEnvironmentForTesting = nil + } +} + +// Tests that the parsedURLForProxy function correctly resolves the proxy URL +// for a given target address. Tests all the possible output cases. +func (s) TestParsedURLForProxyEnv(t *testing.T) { + err := errors.New("invalid proxy url") + tests := []struct { + name string + hpfeFunc func(req *http.Request) (*url.URL, error) + wantURL *url.URL + wantErr error + }{ + { + name: "valid proxy url and nil error", + hpfeFunc: func(_ *http.Request) (*url.URL, error) { + return &url.URL{ + Scheme: "https", + Host: "proxy.example.com", + }, nil + }, + wantURL: &url.URL{ + Scheme: "https", + Host: "proxy.example.com", + }, + wantErr: nil, + }, + { + name: "invalid proxy url and non-nil error", + hpfeFunc: func(_ *http.Request) (*url.URL, error) { + return &url.URL{ + Scheme: "https", + Host: "notproxy.example.com", + }, err + }, + wantURL: nil, + wantErr: err, + }, + { + name: "nil proxy url and nil error", + hpfeFunc: func(_ *http.Request) (*url.URL, error) { + return nil, nil + }, + wantURL: nil, + wantErr: nil, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + defer overrideHTTPSProxyFromEnvironment(tt.hpfeFunc)() + got, err := parsedURLForProxy(targetTestAddr) + if err != tt.wantErr { + t.Errorf("parsedProxyURLForProxy(%v) failed with error :%v, want %v\n", targetTestAddr, err, tt.wantErr) + } + if !cmp.Equal(got, tt.wantURL) { + t.Fatalf("parsedProxyURLForProxy(%v) = %v, want %v\n", targetTestAddr, got, tt.wantURL) + } + }) + } +} diff --git a/internal/testutils/proxy.go b/internal/testutils/proxy.go new file mode 100644 index 000000000000..3ebf5d3ec3ee --- /dev/null +++ b/internal/testutils/proxy.go @@ -0,0 +1,107 @@ +/* + * + * Copyright 2024 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package testutils + +import ( + "bufio" + "bytes" + "fmt" + "io" + "net" + "net/http" + "time" +) + +const defaultTestTimeout = 10 * time.Second + +// ProxyServer represents a test proxy server. +type ProxyServer struct { + lis net.Listener + in net.Conn // The connection from the client to the proxy. + out net.Conn // The connection from the proxy to the backend. + requestCheck func(*http.Request) error // The function to check the request sent to proxy. +} + +// Stop closes the ProxyServer and its connectionsto client and server. +func (p *ProxyServer) Stop() { + p.lis.Close() + if p.in != nil { + p.in.Close() + } + if p.out != nil { + p.out.Close() + } +} + +// NewProxyServer create and starts a proxy server. +func NewProxyServer(lis net.Listener, reqCheck func(*http.Request) error, errCh chan error, doneCh chan struct{}, backendAddr string, resOnClient bool, proxyStarted func()) *ProxyServer { + p := &ProxyServer{ + lis: lis, + requestCheck: reqCheck, + } + + // Start the proxy server. + go func() { + in, err := p.lis.Accept() + if err != nil { + return + } + p.in = in + if proxyStarted != nil { + proxyStarted() + } + req, err := http.ReadRequest(bufio.NewReader(in)) + if err != nil { + errCh <- fmt.Errorf("failed to read CONNECT req: %v", err) + return + } + if err := p.requestCheck(req); err != nil { + resp := http.Response{StatusCode: http.StatusMethodNotAllowed} + resp.Write(p.in) + p.in.Close() + errCh <- fmt.Errorf("get wrong CONNECT req: %+v, error: %v", req, err) + return + } + var out net.Conn + // if resolution is done on client,connect to address received in + // CONNECT request or else connect to backend address directly. + if resOnClient { + out, err = net.Dial("tcp", req.URL.Host) + } else { + out, err = net.Dial("tcp", backendAddr) + } + + if err != nil { + errCh <- fmt.Errorf("failed to dial to server: %v", err) + return + } + out.SetDeadline(time.Now().Add(defaultTestTimeout)) + //response OK to client + resp := http.Response{StatusCode: http.StatusOK, Proto: "HTTP/1.0"} + var buf bytes.Buffer + resp.Write(&buf) + p.in.Write(buf.Bytes()) + p.out = out + // perform the proxy function, i.e pass the data from client to server and server to client. + go io.Copy(p.in, p.out) + go io.Copy(p.out, p.in) + close(doneCh) + }() + return p +} diff --git a/internal/transport/http2_client.go b/internal/transport/http2_client.go index f0c5cbc47645..bc13dc77594e 100644 --- a/internal/transport/http2_client.go +++ b/internal/transport/http2_client.go @@ -43,6 +43,7 @@ import ( "google.golang.org/grpc/internal/grpcsync" "google.golang.org/grpc/internal/grpcutil" imetadata "google.golang.org/grpc/internal/metadata" + "google.golang.org/grpc/internal/proxyattributes" istatus "google.golang.org/grpc/internal/status" isyscall "google.golang.org/grpc/internal/syscall" "google.golang.org/grpc/internal/transport/networktype" @@ -153,8 +154,14 @@ type http2Client struct { logger *grpclog.PrefixLogger } -func dial(ctx context.Context, fn func(context.Context, string) (net.Conn, error), addr resolver.Address, useProxy bool, grpcUA string) (net.Conn, error) { +func dial(ctx context.Context, fn func(context.Context, string) (net.Conn, error), addr resolver.Address, grpcUA string) (net.Conn, error) { address := addr.Addr + + // A non-empty ConnectAddr attribute indicates that a proxy is configured, + // so initiate a proxy dial. + if proxyattributes.ConnectAddr(addr) != "" { + return proxyDial(ctx, addr, grpcUA) + } networkType, ok := networktype.Get(addr) if fn != nil { // Special handling for unix scheme with custom dialer. Back in the day, @@ -177,9 +184,6 @@ func dial(ctx context.Context, fn func(context.Context, string) (net.Conn, error if !ok { networkType, address = parseDialTarget(address) } - if networkType == "tcp" && useProxy { - return proxyDial(ctx, address, grpcUA) - } return internal.NetDialerWithTCPKeepalive().DialContext(ctx, networkType, address) } @@ -217,7 +221,7 @@ func NewHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts // address specific arbitrary data to reach custom dialers and credential handshakers. connectCtx = icredentials.NewClientHandshakeInfoContext(connectCtx, credentials.ClientHandshakeInfo{Attributes: addr.Attributes}) - conn, err := dial(connectCtx, opts.Dialer, addr, opts.UseProxy, opts.UserAgent) + conn, err := dial(connectCtx, opts.Dialer, addr, opts.UserAgent) if err != nil { if opts.FailOnNonTempDialError { return nil, connectionErrorf(isTemporary(err), err, "transport: error while dialing: %v", err) diff --git a/internal/transport/proxy.go b/internal/transport/proxy.go index 54b224436544..d94d30fe5dba 100644 --- a/internal/transport/proxy.go +++ b/internal/transport/proxy.go @@ -30,34 +30,16 @@ import ( "net/url" "google.golang.org/grpc/internal" + "google.golang.org/grpc/internal/proxyattributes" + "google.golang.org/grpc/resolver" ) const proxyAuthHeaderKey = "Proxy-Authorization" -var ( - // The following variable will be overwritten in the tests. - httpProxyFromEnvironment = http.ProxyFromEnvironment -) - -func mapAddress(address string) (*url.URL, error) { - req := &http.Request{ - URL: &url.URL{ - Scheme: "https", - Host: address, - }, - } - url, err := httpProxyFromEnvironment(req) - if err != nil { - return nil, err - } - return url, nil -} - // To read a response from a net.Conn, http.ReadResponse() takes a bufio.Reader. -// It's possible that this reader reads more than what's need for the response and stores -// those bytes in the buffer. -// bufConn wraps the original net.Conn and the bufio.Reader to make sure we don't lose the -// bytes in the buffer. +// It's possible that this reader reads more than what's need for the response +// and stores those bytes in the buffer. bufConn wraps the original net.Conn +// and the bufio.Reader to make sure we don't lose the bytes in the buffer. type bufConn struct { net.Conn r io.Reader @@ -72,21 +54,25 @@ func basicAuth(username, password string) string { return base64.StdEncoding.EncodeToString([]byte(auth)) } -func doHTTPConnectHandshake(ctx context.Context, conn net.Conn, backendAddr string, proxyURL *url.URL, grpcUA string) (_ net.Conn, err error) { +func doHTTPConnectHandshake(ctx context.Context, conn net.Conn, addr resolver.Address, grpcUA string) (_ net.Conn, err error) { defer func() { if err != nil { conn.Close() } }() + a := proxyattributes.ConnectAddr(addr) req := &http.Request{ Method: http.MethodConnect, - URL: &url.URL{Host: backendAddr}, + URL: &url.URL{Host: a}, Header: map[string][]string{"User-Agent": {grpcUA}}, } - if t := proxyURL.User; t != nil { - u := t.Username() - p, _ := t.Password() + if user := proxyattributes.User(addr); user != nil { + u := user.Username() + p, pSet := user.Password() + if !pSet { + logger.Warningf("password not set for basic authentication for proxy dialing") + } req.Header.Add(proxyAuthHeaderKey, "Basic "+basicAuth(u, p)) } @@ -117,28 +103,13 @@ func doHTTPConnectHandshake(ctx context.Context, conn net.Conn, backendAddr stri return conn, nil } -// proxyDial dials, connecting to a proxy first if necessary. Checks if a proxy -// is necessary, dials, does the HTTP CONNECT handshake, and returns the -// connection. -func proxyDial(ctx context.Context, addr string, grpcUA string) (net.Conn, error) { - newAddr := addr - proxyURL, err := mapAddress(addr) - if err != nil { - return nil, err - } - if proxyURL != nil { - newAddr = proxyURL.Host - } - - conn, err := internal.NetDialerWithTCPKeepalive().DialContext(ctx, "tcp", newAddr) +// proxyDial establishes a TCP connection to the specified address and performs an HTTP CONNECT handshake. +func proxyDial(ctx context.Context, addr resolver.Address, grpcUA string) (net.Conn, error) { + conn, err := internal.NetDialerWithTCPKeepalive().DialContext(ctx, "tcp", addr.Addr) if err != nil { return nil, err } - if proxyURL == nil { - // proxy is disabled if proxyURL is nil. - return conn, err - } - return doHTTPConnectHandshake(ctx, conn, addr, proxyURL, grpcUA) + return doHTTPConnectHandshake(ctx, conn, addr, grpcUA) } func sendHTTPRequest(ctx context.Context, req *http.Request, conn net.Conn) error { diff --git a/internal/transport/proxy_test.go b/internal/transport/proxy_test.go deleted file mode 100644 index 41f0918d1c90..000000000000 --- a/internal/transport/proxy_test.go +++ /dev/null @@ -1,278 +0,0 @@ -//go:build !race -// +build !race - -/* - * - * Copyright 2017 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package transport - -import ( - "bufio" - "bytes" - "context" - "encoding/base64" - "fmt" - "io" - "net" - "net/http" - "net/url" - "testing" - "time" -) - -const ( - envTestAddr = "1.2.3.4:8080" - envProxyAddr = "2.3.4.5:7687" -) - -// overwriteAndRestore overwrite function httpProxyFromEnvironment and -// returns a function to restore the default values. -func overwrite(hpfe func(req *http.Request) (*url.URL, error)) func() { - backHPFE := httpProxyFromEnvironment - httpProxyFromEnvironment = hpfe - return func() { - httpProxyFromEnvironment = backHPFE - } -} - -type proxyServer struct { - t *testing.T - lis net.Listener - in net.Conn - out net.Conn - - requestCheck func(*http.Request) error -} - -func (p *proxyServer) run(waitForServerHello bool) { - in, err := p.lis.Accept() - if err != nil { - return - } - p.in = in - - req, err := http.ReadRequest(bufio.NewReader(in)) - if err != nil { - p.t.Errorf("failed to read CONNECT req: %v", err) - return - } - if err := p.requestCheck(req); err != nil { - resp := http.Response{StatusCode: http.StatusMethodNotAllowed} - resp.Write(p.in) - p.in.Close() - p.t.Errorf("get wrong CONNECT req: %+v, error: %v", req, err) - return - } - - out, err := net.Dial("tcp", req.URL.Host) - if err != nil { - p.t.Errorf("failed to dial to server: %v", err) - return - } - out.SetDeadline(time.Now().Add(defaultTestTimeout)) - resp := http.Response{StatusCode: http.StatusOK, Proto: "HTTP/1.0"} - var buf bytes.Buffer - resp.Write(&buf) - if waitForServerHello { - // Batch the first message from the server with the http connect - // response. This is done to test the cases in which the grpc client has - // the response to the connect request and proxied packets from the - // destination server when it reads the transport. - b := make([]byte, 50) - bytesRead, err := out.Read(b) - if err != nil { - p.t.Errorf("Got error while reading server hello: %v", err) - in.Close() - out.Close() - return - } - buf.Write(b[0:bytesRead]) - } - p.in.Write(buf.Bytes()) - p.out = out - go io.Copy(p.in, p.out) - go io.Copy(p.out, p.in) -} - -func (p *proxyServer) stop() { - p.lis.Close() - if p.in != nil { - p.in.Close() - } - if p.out != nil { - p.out.Close() - } -} - -type testArgs struct { - proxyURLModify func(*url.URL) *url.URL - proxyReqCheck func(*http.Request) error - serverMessage []byte -} - -func testHTTPConnect(t *testing.T, args testArgs) { - plis, err := net.Listen("tcp", "localhost:0") - if err != nil { - t.Fatalf("failed to listen: %v", err) - } - p := &proxyServer{ - t: t, - lis: plis, - requestCheck: args.proxyReqCheck, - } - go p.run(len(args.serverMessage) > 0) - defer p.stop() - - blis, err := net.Listen("tcp", "localhost:0") - if err != nil { - t.Fatalf("failed to listen: %v", err) - } - - msg := []byte{4, 3, 5, 2} - recvBuf := make([]byte, len(msg)) - done := make(chan error, 1) - go func() { - in, err := blis.Accept() - if err != nil { - done <- err - return - } - defer in.Close() - in.Write(args.serverMessage) - in.Read(recvBuf) - done <- nil - }() - - // Overwrite the function in the test and restore them in defer. - hpfe := func(*http.Request) (*url.URL, error) { - return args.proxyURLModify(&url.URL{Host: plis.Addr().String()}), nil - } - defer overwrite(hpfe)() - - // Dial to proxy server. - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - c, err := proxyDial(ctx, blis.Addr().String(), "test") - if err != nil { - t.Fatalf("HTTP connect Dial failed: %v", err) - } - defer c.Close() - c.SetDeadline(time.Now().Add(defaultTestTimeout)) - - // Send msg on the connection. - c.Write(msg) - if err := <-done; err != nil { - t.Fatalf("Failed to accept: %v", err) - } - - // Check received msg. - if string(recvBuf) != string(msg) { - t.Fatalf("Received msg: %v, want %v", recvBuf, msg) - } - - if len(args.serverMessage) > 0 { - gotServerMessage := make([]byte, len(args.serverMessage)) - if _, err := c.Read(gotServerMessage); err != nil { - t.Errorf("Got error while reading message from server: %v", err) - return - } - if string(gotServerMessage) != string(args.serverMessage) { - t.Errorf("Message from server: %v, want %v", gotServerMessage, args.serverMessage) - } - } -} - -func (s) TestHTTPConnect(t *testing.T) { - args := testArgs{ - proxyURLModify: func(in *url.URL) *url.URL { - return in - }, - proxyReqCheck: func(req *http.Request) error { - if req.Method != http.MethodConnect { - return fmt.Errorf("unexpected Method %q, want %q", req.Method, http.MethodConnect) - } - return nil - }, - } - testHTTPConnect(t, args) -} - -func (s) TestHTTPConnectWithServerHello(t *testing.T) { - args := testArgs{ - proxyURLModify: func(in *url.URL) *url.URL { - return in - }, - proxyReqCheck: func(req *http.Request) error { - if req.Method != http.MethodConnect { - return fmt.Errorf("unexpected Method %q, want %q", req.Method, http.MethodConnect) - } - return nil - }, - serverMessage: []byte("server-hello"), - } - testHTTPConnect(t, args) -} - -func (s) TestHTTPConnectBasicAuth(t *testing.T) { - const ( - user = "notAUser" - password = "notAPassword" - ) - args := testArgs{ - proxyURLModify: func(in *url.URL) *url.URL { - in.User = url.UserPassword(user, password) - return in - }, - proxyReqCheck: func(req *http.Request) error { - if req.Method != http.MethodConnect { - return fmt.Errorf("unexpected Method %q, want %q", req.Method, http.MethodConnect) - } - wantProxyAuthStr := "Basic " + base64.StdEncoding.EncodeToString([]byte(user+":"+password)) - if got := req.Header.Get(proxyAuthHeaderKey); got != wantProxyAuthStr { - gotDecoded, _ := base64.StdEncoding.DecodeString(got) - wantDecoded, _ := base64.StdEncoding.DecodeString(wantProxyAuthStr) - return fmt.Errorf("unexpected auth %q (%q), want %q (%q)", got, gotDecoded, wantProxyAuthStr, wantDecoded) - } - return nil - }, - } - testHTTPConnect(t, args) -} - -func (s) TestMapAddressEnv(t *testing.T) { - // Overwrite the function in the test and restore them in defer. - hpfe := func(req *http.Request) (*url.URL, error) { - if req.URL.Host == envTestAddr { - return &url.URL{ - Scheme: "https", - Host: envProxyAddr, - }, nil - } - return nil, nil - } - defer overwrite(hpfe)() - - // envTestAddr should be handled by ProxyFromEnvironment. - got, err := mapAddress(envTestAddr) - if err != nil { - t.Error(err) - } - if got.Host != envProxyAddr { - t.Errorf("want %v, got %v", envProxyAddr, got) - } -} diff --git a/internal/transport/transport.go b/internal/transport/transport.go index f3148e31c5dd..314899b847df 100644 --- a/internal/transport/transport.go +++ b/internal/transport/transport.go @@ -503,8 +503,6 @@ type ConnectOptions struct { ChannelzParent *channelz.SubChannel // MaxHeaderListSize sets the max (uncompressed) size of header list that is prepared to be received. MaxHeaderListSize *uint32 - // UseProxy specifies if a proxy should be used. - UseProxy bool // The mem.BufferPool to use when reading/writing to the wire. BufferPool mem.BufferPool } diff --git a/resolver_wrapper.go b/resolver_wrapper.go index 23bb3fb25824..c987aff6d6d5 100644 --- a/resolver_wrapper.go +++ b/resolver_wrapper.go @@ -26,6 +26,7 @@ import ( "google.golang.org/grpc/internal/channelz" "google.golang.org/grpc/internal/grpcsync" "google.golang.org/grpc/internal/pretty" + "google.golang.org/grpc/internal/resolver/delegatingresolver" "google.golang.org/grpc/resolver" "google.golang.org/grpc/serviceconfig" ) @@ -78,7 +79,17 @@ func (ccr *ccResolverWrapper) start() error { Authority: ccr.cc.authority, } var err error - ccr.resolver, err = ccr.cc.resolverBuilder.Build(ccr.cc.parsedTarget, ccr, opts) + // The delegating resolver is used unless: + // - A custom dialer is provided via WithContextDialer dialoption. + // - Proxy usage is disabled through WithNoProxy dialoption. + // - Client-side resolution is enforced with WithTargetResolutionEnabled. + // In these cases, the resolver is built based on the scheme of target, + // using the appropriate resolver builder. + if ccr.cc.dopts.copts.Dialer != nil || !ccr.cc.dopts.useProxy { + ccr.resolver, err = ccr.cc.resolverBuilder.Build(ccr.cc.parsedTarget, ccr, opts) + } else { + ccr.resolver, err = delegatingresolver.New(ccr.cc.parsedTarget, ccr, opts, ccr.cc.resolverBuilder, ccr.cc.dopts.targetResolutionEnabled) + } errCh <- err }) return <-errCh diff --git a/test/proxy_test.go b/test/proxy_test.go new file mode 100644 index 000000000000..7c567e8d6f21 --- /dev/null +++ b/test/proxy_test.go @@ -0,0 +1,673 @@ +/* + * + * Copyright 2024 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package test + +import ( + "context" + "encoding/base64" + "fmt" + "net" + "net/http" + "net/url" + "testing" + + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/internal" + "google.golang.org/grpc/internal/resolver/delegatingresolver" + "google.golang.org/grpc/internal/stubserver" + "google.golang.org/grpc/internal/testutils" + testgrpc "google.golang.org/grpc/interop/grpc_testing" + "google.golang.org/grpc/resolver" + "google.golang.org/grpc/resolver/manual" +) + +const ( + unresolvedTargetURI = "example.com" + unresolvedProxyURI = "proxyexample.com" +) + +// createAndStartBackendServer creates and starts a test backend server, +// registering a cleanup to stop it. +func createAndStartBackendServer(t *testing.T) string { + backend := &stubserver.StubServer{ + EmptyCallF: func(context.Context, *testgrpc.Empty) (*testgrpc.Empty, error) { return &testgrpc.Empty{}, nil }, + } + if err := backend.StartServer(); err != nil { + t.Fatalf("failed to start backend: %v", err) + } + t.Logf("Started TestService backend at: %q", backend.Address) + t.Cleanup(backend.Stop) + return backend.Address +} + +// setupDNS sets up a manual DNS resolver and registers it, returning the +// builder for test customization. +func setupDNS(t *testing.T) *manual.Resolver { + mr := manual.NewBuilderWithScheme("dns") + origRes := resolver.Get("dns") + resolver.Register(mr) + t.Cleanup(func() { + resolver.Register(origRes) + }) + return mr +} + +// setupProxy initializes and starts a proxy server, registers a cleanup to +// stop it, and returns the proxy's listener and helper channels. +func setupProxy(t *testing.T, backendAddr string, resolutionOnClient bool, reqCheck func(*http.Request) error) (net.Listener, chan error, chan struct{}, chan struct{}) { + pLis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("failed to listen: %v", err) + } + + errCh := make(chan error, 1) + doneCh := make(chan struct{}) + proxyStartedCh := make(chan struct{}) + t.Cleanup(func() { + close(errCh) + }) + + proxyServer := testutils.NewProxyServer(pLis, reqCheck, errCh, doneCh, backendAddr, resolutionOnClient, func() { close(proxyStartedCh) }) + t.Cleanup(proxyServer.Stop) + + return pLis, errCh, doneCh, proxyStartedCh +} + +// requestCheck returns a function that checks the HTTP CONNECT request for the +// correct CONNECT method and address. +func requestCheck(connectAddr string) func(*http.Request) error { + return func(req *http.Request) error { + if req.Method != http.MethodConnect { + return fmt.Errorf("unexpected Method %q, want %q", req.Method, http.MethodConnect) + } + if req.URL.Host != connectAddr { + return fmt.Errorf("unexpected URL.Host in CONNECT req %q, want %q", req.URL.Host, connectAddr) + } + return nil + } +} + +// Tests the scenario where grpc.Dial is performed using a proxy with the +// default resolver in the target URI. The test verifies that the connection is +// established to the proxy server, sends the unresolved target URI in the HTTP +// CONNECT request, and is successfully connected to the backend server. +func (s) TestGRPCDialWithProxy(t *testing.T) { + backendAddr := createAndStartBackendServer(t) + pLis, errCh, doneCh, _ := setupProxy(t, backendAddr, false, requestCheck(unresolvedTargetURI)) + + hpfe := func(req *http.Request) (*url.URL, error) { + if req.URL.Host == unresolvedTargetURI { + return &url.URL{ + Scheme: "https", + Host: pLis.Addr().String(), + }, nil + } + return nil, nil + } + orighpfe := internal.HTTPSProxyFromEnvironmentForTesting + internal.HTTPSProxyFromEnvironmentForTesting = hpfe + defer func() { + internal.HTTPSProxyFromEnvironmentForTesting = orighpfe + }() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + conn, err := grpc.Dial(unresolvedTargetURI, grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + t.Fatalf("grpc.Dial(%s) failed: %v", unresolvedTargetURI, err) + } + defer conn.Close() + + client := testgrpc.NewTestServiceClient(conn) + if _, err := client.EmptyCall(ctx, &testgrpc.Empty{}); err != nil { + t.Errorf("EmptyCall failed: %v", err) + } + + select { + case err := <-errCh: + t.Fatalf("proxy server encountered an error: %v", err) + case <-doneCh: + t.Logf("proxy server succeeded") + } +} + +// Tests the scenario where `grpc.Dial` is performed with a proxy and the "dns" +// scheme for the target. The test verifies that the proxy URI is correctly +// resolved and that the target URI resolution on the client preserves the +// original behavior of `grpc.Dial`. It also ensures that a connection is +// established to the proxy server, with the resolved target URI sent in the +// HTTP CONNECT request, successfully connecting to the backend server. +func (s) TestGRPCDialWithDNSAndProxy(t *testing.T) { + backendAddr := createAndStartBackendServer(t) + pLis, errCh, doneCh, _ := setupProxy(t, backendAddr, false, requestCheck(backendAddr)) + + // Overwrite the default proxy function and restore it after the test. + hpfe := func(req *http.Request) (*url.URL, error) { + if req.URL.Host == unresolvedTargetURI { + return &url.URL{ + Scheme: "https", + Host: unresolvedProxyURI, + }, nil + } + return nil, nil + } + orighpfe := internal.HTTPSProxyFromEnvironmentForTesting + internal.HTTPSProxyFromEnvironmentForTesting = hpfe + defer func() { + internal.HTTPSProxyFromEnvironmentForTesting = orighpfe + }() + + // Configure manual resolvers for both proxy and target backends + mrTarget := setupDNS(t) + mrTarget.InitialState(resolver.State{Addresses: []resolver.Address{{Addr: backendAddr}}}) + // Temporarily modify ProxyScheme for this test. + origScheme := delegatingresolver.ProxyScheme + delegatingresolver.ProxyScheme = "whatever" + t.Cleanup(func() { delegatingresolver.ProxyScheme = origScheme }) + mrProxy := manual.NewBuilderWithScheme("whatever") + resolver.Register(mrProxy) + mrProxy.InitialState(resolver.State{Addresses: []resolver.Address{{Addr: pLis.Addr().String()}}}) + + // Dial to the proxy server. + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + conn, err := grpc.Dial(mrTarget.Scheme()+":///"+unresolvedTargetURI, grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + t.Fatalf("grpc.Dial(%s) failed: %v", mrTarget.Scheme()+":///"+unresolvedTargetURI, err) + } + defer conn.Close() + + // Send an RPC to the backend through the proxy. + client := testgrpc.NewTestServiceClient(conn) + if _, err := client.EmptyCall(ctx, &testgrpc.Empty{}); err != nil { + t.Errorf("EmptyCall failed: %v", err) + } + + // Verify that the proxy server encountered no errors. + select { + case err := <-errCh: + t.Fatalf("proxy server encountered an error: %v", err) + case <-doneCh: + t.Logf("proxy server succeeded") + } +} + +// Tests the scenario where grpc.NewClient is used with the default DNS +// resolver for the target URI and a proxy is configured. The test verifies +// that the client. resolves proxy URI, connects to the proxy server, sends the +// unresolved target URI in the HTTP CONNECT request, and successfully +// establishes a connection to the backend server. +func (s) TestGRPCNewClientWithProxy(t *testing.T) { + // Set up a channel to receive signals from OnClientResolution. + resCh := make(chan bool, 1) + // Overwrite OnClientResolution to send a signal to the channel. + origOnClientResolution := delegatingresolver.OnClientResolution + delegatingresolver.OnClientResolution = func(int) { + resCh <- true + } + t.Cleanup(func() { delegatingresolver.OnClientResolution = origOnClientResolution }) + + backendAddr := createAndStartBackendServer(t) + pLis, errCh, doneCh, _ := setupProxy(t, backendAddr, false, requestCheck(unresolvedTargetURI)) + + // Overwrite the proxy resolution function and restore it afterward. + hpfe := func(req *http.Request) (*url.URL, error) { + if req.URL.Host == unresolvedTargetURI { + return &url.URL{ + Scheme: "https", + Host: unresolvedProxyURI, + }, nil + } + return nil, nil + } + orighpfe := internal.HTTPSProxyFromEnvironmentForTesting + internal.HTTPSProxyFromEnvironmentForTesting = hpfe + defer func() { + internal.HTTPSProxyFromEnvironmentForTesting = orighpfe + }() + + // Set up and update a manual resolver for proxy resolution. + mrProxy := setupDNS(t) + mrProxy.InitialState(resolver.State{ + Addresses: []resolver.Address{ + {Addr: pLis.Addr().String()}, + }, + }) + + // Dial to the proxy server. + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + conn, err := grpc.NewClient(unresolvedTargetURI, grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + t.Fatalf("grpc.NewClient failed: %v", err) + } + defer conn.Close() + + // Send an RPC to the backend through the proxy. + client := testgrpc.NewTestServiceClient(conn) + if _, err := client.EmptyCall(ctx, &testgrpc.Empty{}); err != nil { + t.Errorf("EmptyCall failed: %v", err) + } + + // Verify if the proxy server encountered any errors. + select { + case err := <-errCh: + t.Fatalf("proxy server encountered an error: %v", err) + case <-doneCh: + t.Logf("proxy server succeeded") + } + + // Verify if OnClientResolution was triggered. + select { + case <-resCh: + t.Fatal("target resolution occurred on client unexpectedly") + default: + t.Log("target resolution did not occur on the client") + } +} + +// Tests the scenario where grpc.NewClient is used with a custom target URI +// scheme and a proxy is configured. The test verifies that the client +// successfully connects to the proxy server, resolves the proxy URI correctly, +// includes the resolved target URI in the HTTP CONNECT request, and +// establishes a connection to the backend server. +func (s) TestGRPCNewClientWithProxyAndCustomResolver(t *testing.T) { + // Set up a channel to receive signals from OnClientResolution. + resCh := make(chan bool, 1) + // Overwrite OnClientResolution to send a signal to the channel. + origOnClientResolution := delegatingresolver.OnClientResolution + delegatingresolver.OnClientResolution = func(int) { + resCh <- true + } + t.Cleanup(func() { delegatingresolver.OnClientResolution = origOnClientResolution }) + + backendAddr := createAndStartBackendServer(t) + // Set up and start the proxy server. + pLis, errCh, doneCh, _ := setupProxy(t, backendAddr, true, requestCheck(backendAddr)) + + // Overwrite the proxy resolution function and restore it afterward. + hpfe := func(req *http.Request) (*url.URL, error) { + if req.URL.Host == unresolvedTargetURI { + return &url.URL{ + Scheme: "https", + Host: unresolvedProxyURI, + }, nil + } + return nil, nil + } + orighpfe := internal.HTTPSProxyFromEnvironmentForTesting + internal.HTTPSProxyFromEnvironmentForTesting = hpfe + defer func() { + internal.HTTPSProxyFromEnvironmentForTesting = orighpfe + }() + + // Create and update a custom resolver for target URI. + mrTarget := manual.NewBuilderWithScheme("whatever") + resolver.Register(mrTarget) + mrTarget.InitialState(resolver.State{Addresses: []resolver.Address{{Addr: backendAddr}}}) + // Set up and update a manual resolver for proxy resolution. + mrProxy := setupDNS(t) + mrProxy.InitialState(resolver.State{Addresses: []resolver.Address{{Addr: pLis.Addr().String()}}}) + + // Dial to the proxy server. + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + conn, err := grpc.NewClient(mrTarget.Scheme()+":///"+unresolvedTargetURI, grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + t.Fatalf("grpc.NewClient(%s) failed: %v", mrTarget.Scheme()+":///"+unresolvedTargetURI, err) + } + t.Cleanup(func() { conn.Close() }) + + // Create a test service client and make an RPC call. + client := testgrpc.NewTestServiceClient(conn) + if _, err := client.EmptyCall(ctx, &testgrpc.Empty{}); err != nil { + t.Errorf("EmptyCall() failed: %v", err) + } + + // Check if the proxy encountered any errors. + select { + case err := <-errCh: + t.Fatalf("proxy server encountered an error: %v", err) + case <-doneCh: + t.Logf("proxy server succeeded") + } + + // Check if client-side resolution signal was sent to the channel. + select { + case <-resCh: + t.Log("target resolution occurred on client") + default: + t.Fatal("target resolution did not occur on the client unexpectedly") + } +} + +// Tests the scenario where grpc.NewClient is used with the default "dns" +// resolver and the dial option grpc.WithTargetResolutionEnabled() is set, +// enabling target resolution on the client. The test verifies that target +// resolution happens on the client by sending resolved target URi in HTTP +// CONNECT request, the proxy URI is resolved correctly, and the connection is +// successfully established with the backend server through the proxy. +func (s) TestGRPCNewClientWithProxyAndTargetResolutionEnabled(t *testing.T) { + // Set up a channel to receive signals from OnClientResolution. + resCh := make(chan bool, 1) + // Overwrite OnClientResolution to send a signal to the channel. + origOnClientResolution := delegatingresolver.OnClientResolution + delegatingresolver.OnClientResolution = func(int) { + resCh <- true + } + t.Cleanup(func() { delegatingresolver.OnClientResolution = origOnClientResolution }) + + backendAddr := createAndStartBackendServer(t) + pLis, errCh, doneCh, _ := setupProxy(t, backendAddr, true, requestCheck(backendAddr)) + + // Overwrite the proxy resolution function and restore it afterward. + hpfe := func(req *http.Request) (*url.URL, error) { + if req.URL.Host == unresolvedTargetURI { + return &url.URL{ + Scheme: "https", + Host: unresolvedProxyURI, + }, nil + } + return nil, nil + } + orighpfe := internal.HTTPSProxyFromEnvironmentForTesting + internal.HTTPSProxyFromEnvironmentForTesting = hpfe + defer func() { + internal.HTTPSProxyFromEnvironmentForTesting = orighpfe + }() + + // Configure manual resolvers for both proxy and target backends + mrTarget := setupDNS(t) + mrTarget.InitialState(resolver.State{Addresses: []resolver.Address{{Addr: backendAddr}}}) + // Temporarily modify ProxyScheme for this test. + origScheme := delegatingresolver.ProxyScheme + delegatingresolver.ProxyScheme = "whatever" + t.Cleanup(func() { delegatingresolver.ProxyScheme = origScheme }) + mrProxy := manual.NewBuilderWithScheme("whatever") + resolver.Register(mrProxy) + mrProxy.InitialState(resolver.State{Addresses: []resolver.Address{{Addr: pLis.Addr().String()}}}) + + // Dial options with target resolution enabled. + dopts := []grpc.DialOption{ + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithTargetResolutionEnabled(), + } + // Dial to the proxy server. + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + conn, err := grpc.NewClient(unresolvedTargetURI, dopts...) + if err != nil { + t.Fatalf("grpc.NewClient(%s) failed: %v", unresolvedTargetURI, err) + } + t.Cleanup(func() { conn.Close() }) + + // Create a test service client and make an RPC call. + client := testgrpc.NewTestServiceClient(conn) + if _, err := client.EmptyCall(ctx, &testgrpc.Empty{}); err != nil { + t.Errorf("EmptyCall() failed: %v", err) + } + + // Verify if the proxy encountered any errors. + select { + case err := <-errCh: + t.Fatalf("proxy server encountered an error: %v", err) + case <-doneCh: + t.Logf("proxy server succeeded") + } + + // Check if client-side resolution signal was sent to the channel. + select { + case <-resCh: + t.Log("target resolution occurred on client") + default: + t.Fatal("target resolution did not occur on client unexpectedly") + } +} + +// Tests the scenario where grpc.NewClient is used with grpc.WithNoProxy() set, +// explicitly disabling proxy usage. The test verifies that the client does not +// dial the proxy but directly connects to the backend server. It also checks +// that the proxy resolution function is not called and that the proxy server +// never receives a connection request. +func (s) TestGRPCNewClientWithNoProxy(t *testing.T) { + backendAddr := createAndStartBackendServer(t) + _, _, _, proxyStartedCh := setupProxy(t, backendAddr, true, func(req *http.Request) error { + return fmt.Errorf("proxy server should not have received a Connect request: %v", req) + }) + + proxyCalled := false + hpfe := func(_ *http.Request) (*url.URL, error) { + proxyCalled = true + return &url.URL{ + Scheme: "https", + Host: unresolvedProxyURI, + }, nil + + } + orighpfe := internal.HTTPSProxyFromEnvironmentForTesting + internal.HTTPSProxyFromEnvironmentForTesting = hpfe + defer func() { + internal.HTTPSProxyFromEnvironmentForTesting = orighpfe + }() + + mrTarget := setupDNS(t) + mrTarget.InitialState(resolver.State{Addresses: []resolver.Address{{Addr: backendAddr}}}) + + // Dial options with proxy explicitly disabled. + dopts := []grpc.DialOption{ + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithNoProxy(), // Disable proxy. + } + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + conn, err := grpc.NewClient(unresolvedTargetURI, dopts...) + if err != nil { + t.Fatalf("grpc.NewClient(%s) failed: %v", unresolvedTargetURI, err) + } + defer conn.Close() + + // Create a test service client and make an RPC call. + client := testgrpc.NewTestServiceClient(conn) + if _, err := client.EmptyCall(ctx, &testgrpc.Empty{}); err != nil { + t.Errorf("EmptyCall() failed: %v", err) + } + + if proxyCalled { + t.Error("http.ProxyFromEnvironment function was unexpectedly called") + } + // Verify that the proxy was not dialed. + select { + case <-proxyStartedCh: + t.Fatal("unexpected dial to proxy server") + default: + t.Log("proxy server was not dialed") + } +} + +// Tests the scenario where grpc.NewClient is used with grpc.WithContextDialer() +// set. The test verifies that the client bypasses proxy dialing and uses the +// custom dialer instead. It ensures that the proxy server is never dialed, the +// proxy resolution function is not triggered, and the custom dialer is invoked +// as expected. +func (s) TestGRPCNewClientWithContextDialer(t *testing.T) { + backendAddr := createAndStartBackendServer(t) + _, _, _, proxyStartedCh := setupProxy(t, backendAddr, true, func(req *http.Request) error { + return fmt.Errorf("proxy server should not have received a Connect request: %v", req) + }) + + proxyCalled := false + hpfe := func(_ *http.Request) (*url.URL, error) { + proxyCalled = true + return &url.URL{ + Scheme: "https", + Host: unresolvedProxyURI, + }, nil + + } + orighpfe := internal.HTTPSProxyFromEnvironmentForTesting + internal.HTTPSProxyFromEnvironmentForTesting = hpfe + defer func() { + internal.HTTPSProxyFromEnvironmentForTesting = orighpfe + }() + + // Create a custom dialer that directly dials the backend. We'll use this + // to bypass any proxy logic. + dialerCalled := make(chan bool, 1) + customDialer := func(_ context.Context, backendAddr string) (net.Conn, error) { + dialerCalled <- true + return net.Dial("tcp", backendAddr) + } + + mrTarget := setupDNS(t) + mrTarget.InitialState(resolver.State{Addresses: []resolver.Address{{Addr: backendAddr}}}) + + dopts := []grpc.DialOption{ + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithContextDialer(customDialer), + } + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + conn, err := grpc.NewClient(unresolvedTargetURI, dopts...) + if err != nil { + t.Fatalf("grpc.NewClient(%s) failed: %v", unresolvedTargetURI, err) + } + defer conn.Close() + + client := testgrpc.NewTestServiceClient(conn) + if _, err := client.EmptyCall(ctx, &testgrpc.Empty{}); err != nil { + t.Errorf("EmptyCall() failed: %v", err) + } + + if proxyCalled { + t.Error("http.ProxyFromEnvironment function was unexpectedly called") + } + + select { + case <-dialerCalled: + t.Log("custom dialer was invoked") + default: + t.Error("custom dialer was not invoked") + } + + select { + case <-proxyStartedCh: + t.Fatal("unexpected dial to proxy server") + default: + t.Log("proxy server was not dialled") + } +} + +// Tests the scenario where grpc.NewClient is used with the default DNS resolver +// for targetURI and a proxy. The test verifies that the client connects to the +// proxy server, sends the unresolved target URI in the HTTP CONNECT request, +// and successfully connects to the backend. Additionally, it checks that the +// correct user information is included in the Proxy-Authorization header of +// the CONNECT request. The test also ensures that target resolution does not +// happen on the client. +func (s) TestBasicAuthInGrpcNewClientWithProxy(t *testing.T) { + // Set up a channel to receive signals from OnClientResolution. + resCh := make(chan bool, 1) + // Overwrite OnClientResolution to send a signal to the channel. + origOnClientResolution := delegatingresolver.OnClientResolution + delegatingresolver.OnClientResolution = func(int) { + resCh <- true + } + t.Cleanup(func() { delegatingresolver.OnClientResolution = origOnClientResolution }) + + backendAddr := createAndStartBackendServer(t) + const ( + user = "notAUser" + password = "notAPassword" + ) + pLis, errCh, doneCh, _ := setupProxy(t, backendAddr, false, func(req *http.Request) error { + if req.Method != http.MethodConnect { + return fmt.Errorf("unexpected Method %q, want %q", req.Method, http.MethodConnect) + } + if req.URL.Host != unresolvedTargetURI { + return fmt.Errorf("unexpected URL.Host %q, want %q", req.URL.Host, unresolvedTargetURI) + } + wantProxyAuthStr := "Basic " + base64.StdEncoding.EncodeToString([]byte(user+":"+password)) + if got := req.Header.Get("Proxy-Authorization"); got != wantProxyAuthStr { + gotDecoded, err := base64.StdEncoding.DecodeString(got) + if err != nil { + return fmt.Errorf("failed to decode Proxy-Authorization header: %v", err) + } + wantDecoded, _ := base64.StdEncoding.DecodeString(wantProxyAuthStr) + return fmt.Errorf("unexpected auth %q (%q), want %q (%q)", got, gotDecoded, wantProxyAuthStr, wantDecoded) + } + return nil + }) + + // Overwrite the proxy resolution function and restore it afterward. + hpfe := func(req *http.Request) (*url.URL, error) { + if req.URL.Host == unresolvedTargetURI { + u := url.URL{ + Scheme: "https", + Host: unresolvedProxyURI, + } + u.User = url.UserPassword(user, password) + return &u, nil + } + return nil, nil + } + orighpfe := internal.HTTPSProxyFromEnvironmentForTesting + internal.HTTPSProxyFromEnvironmentForTesting = hpfe + defer func() { + internal.HTTPSProxyFromEnvironmentForTesting = orighpfe + }() + + // Set up and update a manual resolver for proxy resolution. + mrProxy := setupDNS(t) + mrProxy.InitialState(resolver.State{ + Addresses: []resolver.Address{ + {Addr: pLis.Addr().String()}, + }, + }) + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + conn, err := grpc.NewClient(unresolvedTargetURI, grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + t.Fatalf("grpc.NewClient(%s) failed: %v", unresolvedTargetURI, err) + } + defer conn.Close() + + // Send an RPC to the backend through the proxy. + client := testgrpc.NewTestServiceClient(conn) + if _, err := client.EmptyCall(ctx, &testgrpc.Empty{}); err != nil { + t.Errorf("EmptyCall failed: %v", err) + } + + // Verify if the proxy server encountered any errors. + select { + case err := <-errCh: + t.Fatalf("proxy server encountered an error: %v", err) + case <-doneCh: + t.Logf("proxy server succeeded") + } + + // Verify if OnClientResolution was triggered. + select { + case <-resCh: + t.Fatal("target resolution occurred on client unexpectedly") + default: + t.Log("target resolution did not occur on the client") + } +}