Skip to content

Commit

Permalink
UPSTREAM: <carry>: add tls support for apiserver http/grpc
Browse files Browse the repository at this point in the history
make mlpipeline server url scheme configurable
add tls handling for PA and ui
remove local grpc client tls.

Signed-off-by: Humair Khan <[email protected]>
  • Loading branch information
HumairAK committed Jul 14, 2024
1 parent ed99781 commit cd5eb38
Show file tree
Hide file tree
Showing 15 changed files with 352 additions and 152 deletions.
11 changes: 8 additions & 3 deletions backend/src/agent/persistence/client/pipeline_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,20 @@ func NewPipelineClient(
basePath string,
mlPipelineServiceName string,
mlPipelineServiceHttpPort string,
mlPipelineServiceGRPCPort string) (*PipelineClient, error) {
mlPipelineServiceGRPCPort string,
mlPipelineServiceTLSEnabled bool) (*PipelineClient, error) {
httpAddress := fmt.Sprintf(addressTemp, mlPipelineServiceName, mlPipelineServiceHttpPort)
grpcAddress := fmt.Sprintf(addressTemp, mlPipelineServiceName, mlPipelineServiceGRPCPort)
err := util.WaitForAPIAvailable(initializeTimeout, basePath, httpAddress)
scheme := "http"
if mlPipelineServiceTLSEnabled {
scheme = "https"
}
err := util.WaitForAPIAvailable(initializeTimeout, basePath, httpAddress, scheme)
if err != nil {
return nil, errors.Wrapf(err,
"Failed to initialize pipeline client. Error: %s", err.Error())
}
connection, err := util.GetRpcConnection(grpcAddress)
connection, err := util.GetRpcConnection(grpcAddress, mlPipelineServiceTLSEnabled)
if err != nil {
return nil, errors.Wrapf(err,
"Failed to get RPC connection. Error: %s", err.Error())
Expand Down
6 changes: 5 additions & 1 deletion backend/src/agent/persistence/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ var (
mlPipelineAPIServerBasePath string
mlPipelineServiceHttpPort string
mlPipelineServiceGRPCPort string
mlPipelineServiceTLSEnabled bool
namespace string
ttlSecondsAfterWorkflowFinish int64
numWorker int
Expand All @@ -55,6 +56,7 @@ const (
mlPipelineAPIServerNameFlagName = "mlPipelineAPIServerName"
mlPipelineAPIServerHttpPortFlagName = "mlPipelineServiceHttpPort"
mlPipelineAPIServerGRPCPortFlagName = "mlPipelineServiceGRPCPort"
mlPipelineAPIServerTLSEnabled = "mlPipelineServiceTLSEnabled"
namespaceFlagName = "namespace"
ttlSecondsAfterWorkflowFinishFlagName = "ttlSecondsAfterWorkflowFinish"
numWorkerName = "numWorker"
Expand Down Expand Up @@ -109,7 +111,8 @@ func main() {
mlPipelineAPIServerBasePath,
mlPipelineAPIServerName,
mlPipelineServiceHttpPort,
mlPipelineServiceGRPCPort)
mlPipelineServiceGRPCPort,
mlPipelineServiceTLSEnabled)
if err != nil {
log.Fatalf("Error creating ML pipeline API Server client: %v", err)
}
Expand All @@ -136,6 +139,7 @@ func init() {
flag.StringVar(&mlPipelineAPIServerName, mlPipelineAPIServerNameFlagName, "ml-pipeline", "Name of the ML pipeline API server.")
flag.StringVar(&mlPipelineServiceHttpPort, mlPipelineAPIServerHttpPortFlagName, "8888", "Http Port of the ML pipeline API server.")
flag.StringVar(&mlPipelineServiceGRPCPort, mlPipelineAPIServerGRPCPortFlagName, "8887", "GRPC Port of the ML pipeline API server.")
flag.BoolVar(&mlPipelineServiceTLSEnabled, mlPipelineAPIServerTLSEnabled, false, "Set to 'true' if mlpipeline api server serves over TLS (default: 'false').")
flag.StringVar(&mlPipelineAPIServerBasePath, mlPipelineAPIServerBasePathFlagName,
"/apis/v1beta1", "The base path for the ML pipeline API server.")
flag.StringVar(&namespace, namespaceFlagName, "", "The namespace name used for Kubernetes informers to obtain the listers.")
Expand Down
108 changes: 85 additions & 23 deletions backend/src/apiserver/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@ package main

import (
"context"
"crypto/tls"
"encoding/json"
"flag"
"fmt"
"github.com/kubeflow/pipelines/backend/src/apiserver/client"
"google.golang.org/grpc/credentials"
"io"
"io/ioutil"
"math"
Expand Down Expand Up @@ -52,21 +54,48 @@ var (
httpPortFlag = flag.String("httpPortFlag", ":8888", "Http Proxy Port")
configPath = flag.String("config", "", "Path to JSON file containing config")
sampleConfigPath = flag.String("sampleconfig", "", "Path to samples")
tlsCertPath = flag.String("tlsCertPath", "", "Path to the public tls cert.")
tlsCertKeyPath = flag.String("tlsCertKeyPath", "", "Path to the private tls key cert.")
collectMetricsFlag = flag.Bool("collectMetricsFlag", true, "Whether to collect Prometheus metrics in API server.")
)

type RegisterHttpHandlerFromEndpoint func(ctx context.Context, mux *runtime.ServeMux, endpoint string, opts []grpc.DialOption) error

func initCerts() (*tls.Config, error) {
if *tlsCertPath == "" && *tlsCertKeyPath == "" {
// User can choose not to provide certs
return nil, nil
} else if *tlsCertPath == "" {
return nil, fmt.Errorf("Missing tlsCertPath when specifying cert paths, both tlsCertPath and tlsCertKeyPath are required.")
} else if *tlsCertKeyPath == "" {
return nil, fmt.Errorf("Missing tlsCertKeyPath when specifying cert paths, both tlsCertPath and tlsCertKeyPath are required.")
}
serverCert, err := tls.LoadX509KeyPair(*tlsCertPath, *tlsCertKeyPath)
if err != nil {
return nil, err
}
config := &tls.Config{
Certificates: []tls.Certificate{serverCert},
}
return config, err
}

func main() {
flag.Parse()

initConfig()
clientManager := cm.NewClientManager()

tlsConfig, err := initCerts()
if err != nil {
glog.Fatalf("Failed to parse Cert paths. Err: %v", err)
}

resourceManager := resource.NewResourceManager(
&clientManager,
&resource.ResourceManagerOptions{CollectMetrics: *collectMetricsFlag},
)
err := loadSamples(resourceManager)
err = loadSamples(resourceManager)
if err != nil {
glog.Fatalf("Failed to load samples. Err: %v", err)
}
Expand All @@ -78,8 +107,8 @@ func main() {
}
}

go startRpcServer(resourceManager)
startHttpProxy(resourceManager)
go startRpcServer(resourceManager, tlsConfig)
startHttpProxy(resourceManager, tlsConfig)

clientManager.Close()
}
Expand All @@ -93,13 +122,25 @@ func grpcCustomMatcher(key string) (string, bool) {
return strings.ToLower(key), false
}

func startRpcServer(resourceManager *resource.ResourceManager) {
glog.Info("Starting RPC server")
func startRpcServer(resourceManager *resource.ResourceManager, tlsConfig *tls.Config) {
var s *grpc.Server
if tlsConfig != nil {
glog.Info("Starting RPC server (TLS enabled)")
tlsCredentials := credentials.NewTLS(tlsConfig)
s = grpc.NewServer(
grpc.Creds(tlsCredentials),
grpc.UnaryInterceptor(apiServerInterceptor),
grpc.MaxRecvMsgSize(math.MaxInt32),
)
} else {
glog.Info("Starting RPC server")
s = grpc.NewServer(grpc.UnaryInterceptor(apiServerInterceptor), grpc.MaxRecvMsgSize(math.MaxInt32))
}

listener, err := net.Listen("tcp", *rpcPortFlag)
if err != nil {
glog.Fatalf("Failed to start RPC server: %v", err)
}
s := grpc.NewServer(grpc.UnaryInterceptor(apiServerInterceptor), grpc.MaxRecvMsgSize(math.MaxInt32))

sharedExperimentServer := server.NewExperimentServer(resourceManager, &server.ExperimentServerOptions{CollectMetrics: *collectMetricsFlag})
sharedPipelineServer := server.NewPipelineServer(
Expand Down Expand Up @@ -140,29 +181,28 @@ func startRpcServer(resourceManager *resource.ResourceManager) {
glog.Info("RPC server started")
}

func startHttpProxy(resourceManager *resource.ResourceManager) {
glog.Info("Starting Http Proxy")
func startHttpProxy(resourceManager *resource.ResourceManager, tlsConfig *tls.Config) {

ctx := context.Background()
ctx, cancel := context.WithCancel(ctx)
defer cancel()

// Create gRPC HTTP MUX and register services for v1beta1 api.
runtimeMux := runtime.NewServeMux(runtime.WithIncomingHeaderMatcher(grpcCustomMatcher))
registerHttpHandlerFromEndpoint(apiv1beta1.RegisterPipelineServiceHandlerFromEndpoint, "PipelineService", ctx, runtimeMux)
registerHttpHandlerFromEndpoint(apiv1beta1.RegisterExperimentServiceHandlerFromEndpoint, "ExperimentService", ctx, runtimeMux)
registerHttpHandlerFromEndpoint(apiv1beta1.RegisterJobServiceHandlerFromEndpoint, "JobService", ctx, runtimeMux)
registerHttpHandlerFromEndpoint(apiv1beta1.RegisterRunServiceHandlerFromEndpoint, "RunService", ctx, runtimeMux)
registerHttpHandlerFromEndpoint(apiv1beta1.RegisterTaskServiceHandlerFromEndpoint, "TaskService", ctx, runtimeMux)
registerHttpHandlerFromEndpoint(apiv1beta1.RegisterReportServiceHandlerFromEndpoint, "ReportService", ctx, runtimeMux)
registerHttpHandlerFromEndpoint(apiv1beta1.RegisterVisualizationServiceHandlerFromEndpoint, "Visualization", ctx, runtimeMux)
registerHttpHandlerFromEndpoint(apiv1beta1.RegisterAuthServiceHandlerFromEndpoint, "AuthService", ctx, runtimeMux)
registerHttpHandlerFromEndpoint(apiv1beta1.RegisterPipelineServiceHandlerFromEndpoint, "PipelineService", ctx, runtimeMux, tlsConfig)
registerHttpHandlerFromEndpoint(apiv1beta1.RegisterExperimentServiceHandlerFromEndpoint, "ExperimentService", ctx, runtimeMux, tlsConfig)
registerHttpHandlerFromEndpoint(apiv1beta1.RegisterJobServiceHandlerFromEndpoint, "JobService", ctx, runtimeMux, tlsConfig)
registerHttpHandlerFromEndpoint(apiv1beta1.RegisterRunServiceHandlerFromEndpoint, "RunService", ctx, runtimeMux, tlsConfig)
registerHttpHandlerFromEndpoint(apiv1beta1.RegisterTaskServiceHandlerFromEndpoint, "TaskService", ctx, runtimeMux, tlsConfig)
registerHttpHandlerFromEndpoint(apiv1beta1.RegisterReportServiceHandlerFromEndpoint, "ReportService", ctx, runtimeMux, tlsConfig)
registerHttpHandlerFromEndpoint(apiv1beta1.RegisterVisualizationServiceHandlerFromEndpoint, "Visualization", ctx, runtimeMux, tlsConfig)
registerHttpHandlerFromEndpoint(apiv1beta1.RegisterAuthServiceHandlerFromEndpoint, "AuthService", ctx, runtimeMux, tlsConfig)

// Create gRPC HTTP MUX and register services for v2beta1 api.
registerHttpHandlerFromEndpoint(apiv2beta1.RegisterExperimentServiceHandlerFromEndpoint, "ExperimentService", ctx, runtimeMux)
registerHttpHandlerFromEndpoint(apiv2beta1.RegisterPipelineServiceHandlerFromEndpoint, "PipelineService", ctx, runtimeMux)
registerHttpHandlerFromEndpoint(apiv2beta1.RegisterRecurringRunServiceHandlerFromEndpoint, "RecurringRunService", ctx, runtimeMux)
registerHttpHandlerFromEndpoint(apiv2beta1.RegisterRunServiceHandlerFromEndpoint, "RunService", ctx, runtimeMux)
registerHttpHandlerFromEndpoint(apiv2beta1.RegisterExperimentServiceHandlerFromEndpoint, "ExperimentService", ctx, runtimeMux, tlsConfig)
registerHttpHandlerFromEndpoint(apiv2beta1.RegisterPipelineServiceHandlerFromEndpoint, "PipelineService", ctx, runtimeMux, tlsConfig)
registerHttpHandlerFromEndpoint(apiv2beta1.RegisterRecurringRunServiceHandlerFromEndpoint, "RecurringRunService", ctx, runtimeMux, tlsConfig)
registerHttpHandlerFromEndpoint(apiv2beta1.RegisterRunServiceHandlerFromEndpoint, "RunService", ctx, runtimeMux, tlsConfig)

// Create a top level mux to include both pipeline upload server and gRPC servers.
topMux := mux.NewRouter()
Expand Down Expand Up @@ -195,13 +235,35 @@ func startHttpProxy(resourceManager *resource.ResourceManager) {
// Register a handler for Prometheus to poll.
topMux.Handle("/metrics", promhttp.Handler())

http.ListenAndServe(*httpPortFlag, topMux)
if tlsConfig != nil {
glog.Info("Starting Https Proxy")
https := http.Server{
TLSConfig: tlsConfig,
Addr: *httpPortFlag,
Handler: topMux,
}
https.ListenAndServeTLS("", "")
} else {
glog.Info("Starting Http Proxy")
http.ListenAndServe(*httpPortFlag, topMux)
}

glog.Info("Http Proxy started")
}

func registerHttpHandlerFromEndpoint(handler RegisterHttpHandlerFromEndpoint, serviceName string, ctx context.Context, mux *runtime.ServeMux) {
func registerHttpHandlerFromEndpoint(handler RegisterHttpHandlerFromEndpoint, serviceName string, ctx context.Context, mux *runtime.ServeMux, tlsConfig *tls.Config) {
endpoint := "localhost" + *rpcPortFlag
opts := []grpc.DialOption{grpc.WithInsecure(), grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(math.MaxInt32))}
var opts []grpc.DialOption
if tlsConfig != nil {
// local client connections via http proxy to grpc should not require tls
tlsConfig.InsecureSkipVerify = true
opts = []grpc.DialOption{
grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)),
grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(math.MaxInt32)),
}
} else {
opts = []grpc.DialOption{grpc.WithInsecure(), grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(math.MaxInt32))}
}

if err := handler(ctx, mux, endpoint, opts); err != nil {
glog.Fatalf("Failed to register %v handler: %v", serviceName, err)
Expand Down
22 changes: 18 additions & 4 deletions backend/src/common/util/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
package util

import (
"crypto/tls"
"fmt"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"net/http"
"strings"
"time"
Expand All @@ -28,9 +31,9 @@ import (
"k8s.io/client-go/tools/clientcmd"
)

func WaitForAPIAvailable(initializeTimeout time.Duration, basePath string, apiAddress string) error {
func WaitForAPIAvailable(initializeTimeout time.Duration, basePath string, apiAddress string, scheme string) error {
operation := func() error {
response, err := http.Get(fmt.Sprintf("http://%s%s/healthz", apiAddress, basePath))
response, err := http.Get(fmt.Sprintf("%s://%s%s/healthz", scheme, apiAddress, basePath))
if err != nil {
return err
}
Expand Down Expand Up @@ -74,8 +77,19 @@ func GetKubernetesClientFromClientConfig(clientConfig clientcmd.ClientConfig) (
return clientSet, config, namespace, nil
}

func GetRpcConnection(address string) (*grpc.ClientConn, error) {
conn, err := grpc.Dial(address, grpc.WithInsecure())
func GetRpcConnection(address string, tlsEnabled bool) (*grpc.ClientConn, error) {
creds := insecure.NewCredentials()
if tlsEnabled {
config := &tls.Config{
InsecureSkipVerify: false,
}
creds = credentials.NewTLS(config)
}

conn, err := grpc.Dial(
address,
grpc.WithTransportCredentials(creds),
)
if err != nil {
return nil, errors.Wrapf(err, "Failed to create gRPC connection")
}
Expand Down
18 changes: 16 additions & 2 deletions backend/src/v2/cacheutils/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@ package cacheutils
import (
"context"
"crypto/sha256"
"crypto/tls"
"encoding/hex"
"encoding/json"
"fmt"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"os"

"google.golang.org/grpc"
Expand Down Expand Up @@ -111,10 +114,21 @@ type Client struct {
}

// NewClient creates a Client.
func NewClient() (*Client, error) {
func NewClient(mlPipelineServiceTLSEnabled bool) (*Client, error) {
creds := insecure.NewCredentials()
if mlPipelineServiceTLSEnabled {
config := &tls.Config{
InsecureSkipVerify: false,
}
creds = credentials.NewTLS(config)
}
cacheEndPoint := cacheDefaultEndpoint()
glog.Infof("Connecting to cache endpoint %s", cacheEndPoint)
conn, err := grpc.Dial(cacheEndPoint, grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(MaxClientGRPCMessageSize)), grpc.WithInsecure())
conn, err := grpc.Dial(
cacheEndPoint,
grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(MaxClientGRPCMessageSize)),
grpc.WithTransportCredentials(creds),
)
if err != nil {
return nil, fmt.Errorf("metadata.NewClient() failed: %w", err)
}
Expand Down
24 changes: 16 additions & 8 deletions backend/src/v2/cmd/driver/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ var (
// the value stored in the paths will be either 'true' or 'false'
cachedDecisionPath = flag.String("cached_decision_path", "", "Cached Decision output path")
conditionPath = flag.String("condition_path", "", "Condition output path")

mlPipelineServiceTLSEnabledStr = flag.String("mlPipelineServiceTLSEnabled", "false", "Set to 'true' if mlpipeline api server serves over TLS (default: 'false').")
)

// func RootDAG(pipelineName string, runID string, component *pipelinespec.ComponentSpec, task *pipelinespec.PipelineTaskSpec, mlmd *metadata.Client) (*Execution, error) {
Expand Down Expand Up @@ -147,18 +149,24 @@ func drive() (err error) {
if err != nil {
return err
}
cacheClient, err := cacheutils.NewClient()
mlPipelineServiceTLSEnabled, err := strconv.ParseBool(*mlPipelineServiceTLSEnabledStr)
if err != nil {
return err
}

cacheClient, err := cacheutils.NewClient(mlPipelineServiceTLSEnabled)
if err != nil {
return err
}
options := driver.Options{
PipelineName: *pipelineName,
RunID: *runID,
Namespace: namespace,
Component: componentSpec,
Task: taskSpec,
DAGExecutionID: *dagExecutionID,
IterationIndex: *iterationIndex,
PipelineName: *pipelineName,
RunID: *runID,
Namespace: namespace,
Component: componentSpec,
Task: taskSpec,
DAGExecutionID: *dagExecutionID,
IterationIndex: *iterationIndex,
MLPipelineTLSEnabled: mlPipelineServiceTLSEnabled,
}
var execution *driver.Execution
var driverErr error
Expand Down
Loading

0 comments on commit cd5eb38

Please sign in to comment.