Skip to content

Commit

Permalink
UPSTREAM: <carry>: add tls support for apiserver http/grpc
Browse files Browse the repository at this point in the history
make mlpipeline server url scheme configurable
add tls handling for PA and ui
remove local grpc client tls.

Signed-off-by: Humair Khan <[email protected]>
  • Loading branch information
HumairAK committed Jul 14, 2024
1 parent ed99781 commit 5aee619
Show file tree
Hide file tree
Showing 19 changed files with 419 additions and 171 deletions.
11 changes: 8 additions & 3 deletions backend/src/agent/persistence/client/pipeline_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,20 @@ func NewPipelineClient(
basePath string,
mlPipelineServiceName string,
mlPipelineServiceHttpPort string,
mlPipelineServiceGRPCPort string) (*PipelineClient, error) {
mlPipelineServiceGRPCPort string,
mlPipelineServiceTLSEnabled bool) (*PipelineClient, error) {
httpAddress := fmt.Sprintf(addressTemp, mlPipelineServiceName, mlPipelineServiceHttpPort)
grpcAddress := fmt.Sprintf(addressTemp, mlPipelineServiceName, mlPipelineServiceGRPCPort)
err := util.WaitForAPIAvailable(initializeTimeout, basePath, httpAddress)
scheme := "http"
if mlPipelineServiceTLSEnabled {
scheme = "https"
}
err := util.WaitForAPIAvailable(initializeTimeout, basePath, httpAddress, scheme)
if err != nil {
return nil, errors.Wrapf(err,
"Failed to initialize pipeline client. Error: %s", err.Error())
}
connection, err := util.GetRpcConnection(grpcAddress)
connection, err := util.GetRpcConnection(grpcAddress, mlPipelineServiceTLSEnabled)
if err != nil {
return nil, errors.Wrapf(err,
"Failed to get RPC connection. Error: %s", err.Error())
Expand Down
42 changes: 26 additions & 16 deletions backend/src/agent/persistence/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package main

import (
"flag"
"strconv"
"time"

"github.com/kubeflow/pipelines/backend/src/agent/persistence/client"
Expand All @@ -29,21 +30,22 @@ import (
)

var (
masterURL string
kubeconfig string
initializeTimeout time.Duration
timeout time.Duration
mlPipelineAPIServerName string
mlPipelineAPIServerPort string
mlPipelineAPIServerBasePath string
mlPipelineServiceHttpPort string
mlPipelineServiceGRPCPort string
namespace string
ttlSecondsAfterWorkflowFinish int64
numWorker int
clientQPS float64
clientBurst int
saTokenRefreshIntervalInSecs int64
masterURL string
kubeconfig string
initializeTimeout time.Duration
timeout time.Duration
mlPipelineAPIServerName string
mlPipelineAPIServerPort string
mlPipelineAPIServerBasePath string
mlPipelineServiceHttpPort string
mlPipelineServiceGRPCPort string
mlPipelineServiceTLSEnabledStr string
namespace string
ttlSecondsAfterWorkflowFinish int64
numWorker int
clientQPS float64
clientBurst int
saTokenRefreshIntervalInSecs int64
)

const (
Expand All @@ -55,6 +57,7 @@ const (
mlPipelineAPIServerNameFlagName = "mlPipelineAPIServerName"
mlPipelineAPIServerHttpPortFlagName = "mlPipelineServiceHttpPort"
mlPipelineAPIServerGRPCPortFlagName = "mlPipelineServiceGRPCPort"
mlPipelineAPIServerTLSEnabled = "mlPipelineServiceTLSEnabled"
namespaceFlagName = "namespace"
ttlSecondsAfterWorkflowFinishFlagName = "ttlSecondsAfterWorkflowFinish"
numWorkerName = "numWorker"
Expand Down Expand Up @@ -102,14 +105,20 @@ func main() {
log.Fatalf("Error starting Service Account Token Refresh Ticker due to: %v", err)
}

mlPipelineServiceTLSEnabled, err := strconv.ParseBool(mlPipelineServiceTLSEnabledStr)
if err != nil {
log.Fatalf("Error parsing boolean flag %s, please provide a valid bool value (true/false). %v", mlPipelineAPIServerTLSEnabled, err)
}

pipelineClient, err := client.NewPipelineClient(
initializeTimeout,
timeout,
tokenRefresher,
mlPipelineAPIServerBasePath,
mlPipelineAPIServerName,
mlPipelineServiceHttpPort,
mlPipelineServiceGRPCPort)
mlPipelineServiceGRPCPort,
mlPipelineServiceTLSEnabled)
if err != nil {
log.Fatalf("Error creating ML pipeline API Server client: %v", err)
}
Expand All @@ -136,6 +145,7 @@ func init() {
flag.StringVar(&mlPipelineAPIServerName, mlPipelineAPIServerNameFlagName, "ml-pipeline", "Name of the ML pipeline API server.")
flag.StringVar(&mlPipelineServiceHttpPort, mlPipelineAPIServerHttpPortFlagName, "8888", "Http Port of the ML pipeline API server.")
flag.StringVar(&mlPipelineServiceGRPCPort, mlPipelineAPIServerGRPCPortFlagName, "8887", "GRPC Port of the ML pipeline API server.")
flag.StringVar(&mlPipelineServiceTLSEnabledStr, mlPipelineAPIServerTLSEnabled, "false", "Set to 'true' if mlpipeline api server serves over TLS (default: 'false').")
flag.StringVar(&mlPipelineAPIServerBasePath, mlPipelineAPIServerBasePathFlagName,
"/apis/v1beta1", "The base path for the ML pipeline API server.")
flag.StringVar(&namespace, namespaceFlagName, "", "The namespace name used for Kubernetes informers to obtain the listers.")
Expand Down
109 changes: 86 additions & 23 deletions backend/src/apiserver/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@ package main

import (
"context"
"crypto/tls"
"encoding/json"
"flag"
"fmt"
"github.com/kubeflow/pipelines/backend/src/apiserver/client"
"google.golang.org/grpc/credentials"
"io"
"io/ioutil"
"math"
Expand Down Expand Up @@ -52,21 +54,49 @@ var (
httpPortFlag = flag.String("httpPortFlag", ":8888", "Http Proxy Port")
configPath = flag.String("config", "", "Path to JSON file containing config")
sampleConfigPath = flag.String("sampleconfig", "", "Path to samples")
tlsCertPath = flag.String("tlsCertPath", "", "Path to the public tls cert.")
tlsCertKeyPath = flag.String("tlsCertKeyPath", "", "Path to the private tls key cert.")
collectMetricsFlag = flag.Bool("collectMetricsFlag", true, "Whether to collect Prometheus metrics in API server.")
)

type RegisterHttpHandlerFromEndpoint func(ctx context.Context, mux *runtime.ServeMux, endpoint string, opts []grpc.DialOption) error

func initCerts() (*tls.Config, error) {
if *tlsCertPath == "" && *tlsCertKeyPath == "" {
// User can choose not to provide certs
return nil, nil
} else if *tlsCertPath == "" {
return nil, fmt.Errorf("Missing tlsCertPath when specifying cert paths, both tlsCertPath and tlsCertKeyPath are required.")
} else if *tlsCertKeyPath == "" {
return nil, fmt.Errorf("Missing tlsCertKeyPath when specifying cert paths, both tlsCertPath and tlsCertKeyPath are required.")
}
serverCert, err := tls.LoadX509KeyPair(*tlsCertPath, *tlsCertKeyPath)
if err != nil {
return nil, err
}
config := &tls.Config{
Certificates: []tls.Certificate{serverCert},
}
glog.Info("TLS cert key/pair loaded.")
return config, err
}

func main() {
flag.Parse()

initConfig()
clientManager := cm.NewClientManager()

tlsConfig, err := initCerts()
if err != nil {
glog.Fatalf("Failed to parse Cert paths. Err: %v", err)
}

resourceManager := resource.NewResourceManager(
&clientManager,
&resource.ResourceManagerOptions{CollectMetrics: *collectMetricsFlag},
)
err := loadSamples(resourceManager)
err = loadSamples(resourceManager)
if err != nil {
glog.Fatalf("Failed to load samples. Err: %v", err)
}
Expand All @@ -78,8 +108,8 @@ func main() {
}
}

go startRpcServer(resourceManager)
startHttpProxy(resourceManager)
go startRpcServer(resourceManager, tlsConfig)
startHttpProxy(resourceManager, tlsConfig)

clientManager.Close()
}
Expand All @@ -93,13 +123,25 @@ func grpcCustomMatcher(key string) (string, bool) {
return strings.ToLower(key), false
}

func startRpcServer(resourceManager *resource.ResourceManager) {
glog.Info("Starting RPC server")
func startRpcServer(resourceManager *resource.ResourceManager, tlsConfig *tls.Config) {
var s *grpc.Server
if tlsConfig != nil {
glog.Info("Starting RPC server (TLS enabled)")
tlsCredentials := credentials.NewTLS(tlsConfig)
s = grpc.NewServer(
grpc.Creds(tlsCredentials),
grpc.UnaryInterceptor(apiServerInterceptor),
grpc.MaxRecvMsgSize(math.MaxInt32),
)
} else {
glog.Info("Starting RPC server")
s = grpc.NewServer(grpc.UnaryInterceptor(apiServerInterceptor), grpc.MaxRecvMsgSize(math.MaxInt32))
}

listener, err := net.Listen("tcp", *rpcPortFlag)
if err != nil {
glog.Fatalf("Failed to start RPC server: %v", err)
}
s := grpc.NewServer(grpc.UnaryInterceptor(apiServerInterceptor), grpc.MaxRecvMsgSize(math.MaxInt32))

sharedExperimentServer := server.NewExperimentServer(resourceManager, &server.ExperimentServerOptions{CollectMetrics: *collectMetricsFlag})
sharedPipelineServer := server.NewPipelineServer(
Expand Down Expand Up @@ -140,29 +182,28 @@ func startRpcServer(resourceManager *resource.ResourceManager) {
glog.Info("RPC server started")
}

func startHttpProxy(resourceManager *resource.ResourceManager) {
glog.Info("Starting Http Proxy")
func startHttpProxy(resourceManager *resource.ResourceManager, tlsConfig *tls.Config) {

ctx := context.Background()
ctx, cancel := context.WithCancel(ctx)
defer cancel()

// Create gRPC HTTP MUX and register services for v1beta1 api.
runtimeMux := runtime.NewServeMux(runtime.WithIncomingHeaderMatcher(grpcCustomMatcher))
registerHttpHandlerFromEndpoint(apiv1beta1.RegisterPipelineServiceHandlerFromEndpoint, "PipelineService", ctx, runtimeMux)
registerHttpHandlerFromEndpoint(apiv1beta1.RegisterExperimentServiceHandlerFromEndpoint, "ExperimentService", ctx, runtimeMux)
registerHttpHandlerFromEndpoint(apiv1beta1.RegisterJobServiceHandlerFromEndpoint, "JobService", ctx, runtimeMux)
registerHttpHandlerFromEndpoint(apiv1beta1.RegisterRunServiceHandlerFromEndpoint, "RunService", ctx, runtimeMux)
registerHttpHandlerFromEndpoint(apiv1beta1.RegisterTaskServiceHandlerFromEndpoint, "TaskService", ctx, runtimeMux)
registerHttpHandlerFromEndpoint(apiv1beta1.RegisterReportServiceHandlerFromEndpoint, "ReportService", ctx, runtimeMux)
registerHttpHandlerFromEndpoint(apiv1beta1.RegisterVisualizationServiceHandlerFromEndpoint, "Visualization", ctx, runtimeMux)
registerHttpHandlerFromEndpoint(apiv1beta1.RegisterAuthServiceHandlerFromEndpoint, "AuthService", ctx, runtimeMux)
registerHttpHandlerFromEndpoint(apiv1beta1.RegisterPipelineServiceHandlerFromEndpoint, "PipelineService", ctx, runtimeMux, tlsConfig)
registerHttpHandlerFromEndpoint(apiv1beta1.RegisterExperimentServiceHandlerFromEndpoint, "ExperimentService", ctx, runtimeMux, tlsConfig)
registerHttpHandlerFromEndpoint(apiv1beta1.RegisterJobServiceHandlerFromEndpoint, "JobService", ctx, runtimeMux, tlsConfig)
registerHttpHandlerFromEndpoint(apiv1beta1.RegisterRunServiceHandlerFromEndpoint, "RunService", ctx, runtimeMux, tlsConfig)
registerHttpHandlerFromEndpoint(apiv1beta1.RegisterTaskServiceHandlerFromEndpoint, "TaskService", ctx, runtimeMux, tlsConfig)
registerHttpHandlerFromEndpoint(apiv1beta1.RegisterReportServiceHandlerFromEndpoint, "ReportService", ctx, runtimeMux, tlsConfig)
registerHttpHandlerFromEndpoint(apiv1beta1.RegisterVisualizationServiceHandlerFromEndpoint, "Visualization", ctx, runtimeMux, tlsConfig)
registerHttpHandlerFromEndpoint(apiv1beta1.RegisterAuthServiceHandlerFromEndpoint, "AuthService", ctx, runtimeMux, tlsConfig)

// Create gRPC HTTP MUX and register services for v2beta1 api.
registerHttpHandlerFromEndpoint(apiv2beta1.RegisterExperimentServiceHandlerFromEndpoint, "ExperimentService", ctx, runtimeMux)
registerHttpHandlerFromEndpoint(apiv2beta1.RegisterPipelineServiceHandlerFromEndpoint, "PipelineService", ctx, runtimeMux)
registerHttpHandlerFromEndpoint(apiv2beta1.RegisterRecurringRunServiceHandlerFromEndpoint, "RecurringRunService", ctx, runtimeMux)
registerHttpHandlerFromEndpoint(apiv2beta1.RegisterRunServiceHandlerFromEndpoint, "RunService", ctx, runtimeMux)
registerHttpHandlerFromEndpoint(apiv2beta1.RegisterExperimentServiceHandlerFromEndpoint, "ExperimentService", ctx, runtimeMux, tlsConfig)
registerHttpHandlerFromEndpoint(apiv2beta1.RegisterPipelineServiceHandlerFromEndpoint, "PipelineService", ctx, runtimeMux, tlsConfig)
registerHttpHandlerFromEndpoint(apiv2beta1.RegisterRecurringRunServiceHandlerFromEndpoint, "RecurringRunService", ctx, runtimeMux, tlsConfig)
registerHttpHandlerFromEndpoint(apiv2beta1.RegisterRunServiceHandlerFromEndpoint, "RunService", ctx, runtimeMux, tlsConfig)

// Create a top level mux to include both pipeline upload server and gRPC servers.
topMux := mux.NewRouter()
Expand Down Expand Up @@ -195,13 +236,35 @@ func startHttpProxy(resourceManager *resource.ResourceManager) {
// Register a handler for Prometheus to poll.
topMux.Handle("/metrics", promhttp.Handler())

http.ListenAndServe(*httpPortFlag, topMux)
if tlsConfig != nil {
glog.Info("Starting Https Proxy")
https := http.Server{
TLSConfig: tlsConfig,
Addr: *httpPortFlag,
Handler: topMux,
}
https.ListenAndServeTLS("", "")
} else {
glog.Info("Starting Http Proxy")
http.ListenAndServe(*httpPortFlag, topMux)
}

glog.Info("Http Proxy started")
}

func registerHttpHandlerFromEndpoint(handler RegisterHttpHandlerFromEndpoint, serviceName string, ctx context.Context, mux *runtime.ServeMux) {
func registerHttpHandlerFromEndpoint(handler RegisterHttpHandlerFromEndpoint, serviceName string, ctx context.Context, mux *runtime.ServeMux, tlsConfig *tls.Config) {
endpoint := "localhost" + *rpcPortFlag
opts := []grpc.DialOption{grpc.WithInsecure(), grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(math.MaxInt32))}
var opts []grpc.DialOption
if tlsConfig != nil {
// local client connections via http proxy to grpc should not require tls
tlsConfig.InsecureSkipVerify = true
opts = []grpc.DialOption{
grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)),
grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(math.MaxInt32)),
}
} else {
opts = []grpc.DialOption{grpc.WithInsecure(), grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(math.MaxInt32))}
}

if err := handler(ctx, mux, endpoint, opts); err != nil {
glog.Fatalf("Failed to register %v handler: %v", serviceName, err)
Expand Down
20 changes: 16 additions & 4 deletions backend/src/common/util/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
package util

import (
"crypto/tls"
"fmt"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"net/http"
"strings"
"time"
Expand All @@ -28,9 +31,9 @@ import (
"k8s.io/client-go/tools/clientcmd"
)

func WaitForAPIAvailable(initializeTimeout time.Duration, basePath string, apiAddress string) error {
func WaitForAPIAvailable(initializeTimeout time.Duration, basePath string, apiAddress string, scheme string) error {
operation := func() error {
response, err := http.Get(fmt.Sprintf("http://%s%s/healthz", apiAddress, basePath))
response, err := http.Get(fmt.Sprintf("%s://%s%s/healthz", scheme, apiAddress, basePath))
if err != nil {
return err
}
Expand Down Expand Up @@ -74,8 +77,17 @@ func GetKubernetesClientFromClientConfig(clientConfig clientcmd.ClientConfig) (
return clientSet, config, namespace, nil
}

func GetRpcConnection(address string) (*grpc.ClientConn, error) {
conn, err := grpc.Dial(address, grpc.WithInsecure())
func GetRpcConnection(address string, tlsEnabled bool) (*grpc.ClientConn, error) {
creds := insecure.NewCredentials()
if tlsEnabled {
config := &tls.Config{}
creds = credentials.NewTLS(config)
}

conn, err := grpc.Dial(
address,
grpc.WithTransportCredentials(creds),
)
if err != nil {
return nil, errors.Wrapf(err, "Failed to create gRPC connection")
}
Expand Down
18 changes: 16 additions & 2 deletions backend/src/v2/cacheutils/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@ package cacheutils
import (
"context"
"crypto/sha256"
"crypto/tls"
"encoding/hex"
"encoding/json"
"fmt"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"os"

"google.golang.org/grpc"
Expand Down Expand Up @@ -111,10 +114,21 @@ type Client struct {
}

// NewClient creates a Client.
func NewClient() (*Client, error) {
func NewClient(mlPipelineServiceTLSEnabled bool) (*Client, error) {
creds := insecure.NewCredentials()
if mlPipelineServiceTLSEnabled {
config := &tls.Config{
InsecureSkipVerify: false,
}
creds = credentials.NewTLS(config)
}
cacheEndPoint := cacheDefaultEndpoint()
glog.Infof("Connecting to cache endpoint %s", cacheEndPoint)
conn, err := grpc.Dial(cacheEndPoint, grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(MaxClientGRPCMessageSize)), grpc.WithInsecure())
conn, err := grpc.Dial(
cacheEndPoint,
grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(MaxClientGRPCMessageSize)),
grpc.WithTransportCredentials(creds),
)
if err != nil {
return nil, fmt.Errorf("metadata.NewClient() failed: %w", err)
}
Expand Down
Loading

0 comments on commit 5aee619

Please sign in to comment.