Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

UPSTREAM: <carry>: Added support for TLS to MLMD GRPC Server #72

Merged
merged 1 commit into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion backend/src/apiserver/client_manager/client_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,8 @@ func (c *ClientManager) init() {

c.k8sCoreClient = client.CreateKubernetesCoreOrFatal(common.GetDurationConfig(initConnectionTimeout), clientParams)

newClient, err := metadata.NewClient(common.GetMetadataGrpcServiceServiceHost(), common.GetMetadataGrpcServiceServicePort())
newClient, err := metadata.NewClient(common.GetMetadataGrpcServiceServiceHost(), common.GetMetadataGrpcServiceServicePort(), common.GetMetadataTLSEnabled())

if err != nil {
glog.Fatalf("Failed to create metadata client. Error: %v", err)
}
Expand Down
5 changes: 5 additions & 0 deletions backend/src/apiserver/common/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ const (
TokenReviewAudience string = "TOKEN_REVIEW_AUDIENCE"
MetadataGrpcServiceServiceHost string = "METADATA_GRPC_SERVICE_SERVICE_HOST"
MetadataGrpcServiceServicePort string = "METADATA_GRPC_SERVICE_SERVICE_PORT"
MetadataTLSEnabled string = "METADATA_TLS_ENABLED"
SignedURLExpiryTimeSeconds string = "SIGNED_URL_EXPIRY_TIME_SECONDS"
)

Expand Down Expand Up @@ -142,3 +143,7 @@ func GetMetadataGrpcServiceServicePort() string {
func GetSignedURLExpiryTimeSeconds() int {
return GetIntConfigWithDefault(SignedURLExpiryTimeSeconds, DefaultSignedURLExpiryTimeSeconds)
}

func GetMetadataTLSEnabled() bool {
return GetBoolConfigWithDefault(MetadataTLSEnabled, DefaultMetadataTLSEnabled)
}
1 change: 1 addition & 0 deletions backend/src/apiserver/common/const.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ const DefaultTokenReviewAudience string = "pipelines.kubeflow.org"
const (
DefaultMetadataGrpcServiceServiceHost = "metadata-grpc-service"
DefaultMetadataGrpcServiceServicePort = "8080"
DefaultMetadataTLSEnabled = false
)

const (
Expand Down
17 changes: 16 additions & 1 deletion backend/src/v2/cmd/driver/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ var (
conditionPath = flag.String("condition_path", "", "Condition output path")

mlPipelineServiceTLSEnabledStr = flag.String("mlPipelineServiceTLSEnabled", "false", "Set to 'true' if mlpipeline api server serves over TLS (default: 'false').")
metadataTLSEnabledStr = flag.String("metadataTLSEnabled", "false", "Set to 'true' if metadata server serves over TLS (default: 'false').")
)

// func RootDAG(pipelineName string, runID string, component *pipelinespec.ComponentSpec, task *pipelinespec.PipelineTaskSpec, mlmd *metadata.Client) (*Execution, error) {
Expand Down Expand Up @@ -154,6 +155,11 @@ func drive() (err error) {
return err
}

metadataTLSEnabled, err := strconv.ParseBool(*metadataTLSEnabledStr)
if err != nil {
return err
}

cacheClient, err := cacheutils.NewClient(mlPipelineServiceTLSEnabled)
if err != nil {
return err
Expand All @@ -167,6 +173,9 @@ func drive() (err error) {
DAGExecutionID: *dagExecutionID,
IterationIndex: *iterationIndex,
MLPipelineTLSEnabled: mlPipelineServiceTLSEnabled,
MLMDServerAddress: *mlmdServerAddress,
MLMDServerPort: *mlmdServerPort,
MLMDTLSEnabled: metadataTLSEnabled,
}
var execution *driver.Execution
var driverErr error
Expand Down Expand Up @@ -292,5 +301,11 @@ func newMlmdClient() (*metadata.Client, error) {
mlmdConfig.Address = *mlmdServerAddress
mlmdConfig.Port = *mlmdServerPort
}
return metadata.NewClient(mlmdConfig.Address, mlmdConfig.Port)

tlsEnabled, err := strconv.ParseBool(*metadataTLSEnabledStr)
if err != nil {
return nil, err
}

return metadata.NewClient(mlmdConfig.Address, mlmdConfig.Port, tlsEnabled)
}
8 changes: 8 additions & 0 deletions backend/src/v2/cmd/launcher-v2/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ var (
mlmdServerAddress = flag.String("mlmd_server_address", "", "The MLMD gRPC server address.")
mlmdServerPort = flag.String("mlmd_server_port", "8080", "The MLMD gRPC server port.")
mlPipelineServiceTLSEnabledStr = flag.String("mlPipelineServiceTLSEnabled", "false", "Set to 'true' if mlpipeline api server serves over TLS (default: 'false').")
metadataTLSEnabledStr = flag.String("metadataTLSEnabled", "false", "Set to 'true' if metadata server serves over TLS (default: 'false').")
)

func main() {
Expand Down Expand Up @@ -71,6 +72,12 @@ func run() error {
if err != nil {
return err
}

metadataServiceTLSEnabled, err := strconv.ParseBool(*metadataTLSEnabledStr)
if err != nil {
return err
}

launcherV2Opts := &component.LauncherV2Options{
Namespace: namespace,
PodName: *podName,
Expand All @@ -80,6 +87,7 @@ func run() error {
PipelineName: *pipelineName,
RunID: *runID,
MLPipelineTLSEnabled: mlPipelineServiceTLSEnabled,
MetadataTLSEnabled: metadataServiceTLSEnabled,
}

switch *executorType {
Expand Down
4 changes: 4 additions & 0 deletions backend/src/v2/compiler/argocompiler/container.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package argocompiler
import (
wfapi "github.com/argoproj/argo-workflows/v3/pkg/apis/workflow/v1alpha1"
"github.com/kubeflow/pipelines/api/v2alpha1/go/pipelinespec"
"github.com/kubeflow/pipelines/backend/src/apiserver/common"
"github.com/kubeflow/pipelines/backend/src/v2/component"
k8score "k8s.io/api/core/v1"
"os"
Expand Down Expand Up @@ -163,6 +164,9 @@ func (c *workflowCompiler) addContainerDriverTemplate() string {
"--condition_path", outputPath(paramCondition),
"--kubernetes_config", inputValue(paramKubernetesConfig),
"--mlPipelineServiceTLSEnabled", strconv.FormatBool(c.mlPipelineServiceTLSEnabled),
"--mlmd_server_address", common.GetMetadataGrpcServiceServiceHost(),
"--mlmd_server_port", common.GetMetadataGrpcServiceServicePort(),
"--metadataTLSEnabled", strconv.FormatBool(common.GetMetadataTLSEnabled()),
},
Resources: driverResources,
},
Expand Down
4 changes: 4 additions & 0 deletions backend/src/v2/compiler/argocompiler/dag.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ package argocompiler

import (
"fmt"
"github.com/kubeflow/pipelines/backend/src/apiserver/common"
"sort"
"strconv"
"strings"
Expand Down Expand Up @@ -443,6 +444,9 @@ func (c *workflowCompiler) addDAGDriverTemplate() string {
"--iteration_count_path", outputPath(paramIterationCount),
"--condition_path", outputPath(paramCondition),
"--mlPipelineServiceTLSEnabled", strconv.FormatBool(c.mlPipelineServiceTLSEnabled),
"--mlmd_server_address", common.GetMetadataGrpcServiceServiceHost(),
"--mlmd_server_port", common.GetMetadataGrpcServiceServicePort(),
"--metadataTLSEnabled", strconv.FormatBool(common.GetMetadataTLSEnabled()),
},
Resources: driverResources,
},
Expand Down
7 changes: 3 additions & 4 deletions backend/src/v2/compiler/argocompiler/importer.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package argocompiler

import (
"fmt"
"github.com/kubeflow/pipelines/backend/src/apiserver/common"

wfapi "github.com/argoproj/argo-workflows/v3/pkg/apis/workflow/v1alpha1"
"github.com/kubeflow/pipelines/api/v2alpha1/go/pipelinespec"
Expand Down Expand Up @@ -76,10 +77,8 @@ func (c *workflowCompiler) addImporterTemplate() string {
fmt.Sprintf("$(%s)", component.EnvPodName),
"--pod_uid",
fmt.Sprintf("$(%s)", component.EnvPodUID),
"--mlmd_server_address",
fmt.Sprintf("$(%s)", component.EnvMetadataHost),
"--mlmd_server_port",
fmt.Sprintf("$(%s)", component.EnvMetadataPort),
"--mlmd_server_address", common.GetMetadataGrpcServiceServiceHost(),
"--mlmd_server_port", common.GetMetadataGrpcServiceServicePort(),
HumairAK marked this conversation as resolved.
Show resolved Hide resolved
}
importerTemplate := &wfapi.Template{
Name: name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,12 @@ spec:
- '{{inputs.parameters.kubernetes-config}}'
- "--mlPipelineServiceTLSEnabled"
- "false"
- "--mlmd_server_address"
- "metadata-grpc-service"
- "--mlmd_server_port"
- "8080"
- "--metadataTLSEnabled"
- "false"
command:
- driver
image: gcr.io/ml-pipeline/kfp-driver
Expand Down Expand Up @@ -312,6 +318,12 @@ spec:
- '{{outputs.parameters.condition.path}}'
- "--mlPipelineServiceTLSEnabled"
- "false"
- "--mlmd_server_address"
- "metadata-grpc-service"
- "--mlmd_server_port"
- "8080"
- "--metadataTLSEnabled"
- "false"
command:
- driver
image: gcr.io/ml-pipeline/kfp-driver
Expand Down
12 changes: 12 additions & 0 deletions backend/src/v2/compiler/argocompiler/testdata/hello_world.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ spec:
- '{{inputs.parameters.kubernetes-config}}'
- "--mlPipelineServiceTLSEnabled"
- "false"
- "--mlmd_server_address"
- "metadata-grpc-service"
- "--mlmd_server_port"
- "8080"
- "--metadataTLSEnabled"
- "false"
command:
- driver
image: gcr.io/ml-pipeline/kfp-driver
Expand Down Expand Up @@ -242,6 +248,12 @@ spec:
- '{{outputs.parameters.condition.path}}'
- "--mlPipelineServiceTLSEnabled"
- "false"
- "--mlmd_server_address"
- "metadata-grpc-service"
- "--mlmd_server_port"
- "8080"
- "--metadataTLSEnabled"
- "false"
env:
- name: ML_PIPELINE_SERVICE_HOST
value: ml-pipeline.kubeflow.svc.cluster.local
Expand Down
10 changes: 8 additions & 2 deletions backend/src/v2/compiler/argocompiler/testdata/importer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ spec:
- --pod_uid
- $(KFP_POD_UID)
- --mlmd_server_address
- $(METADATA_GRPC_SERVICE_HOST)
- "metadata-grpc-service"
- --mlmd_server_port
- $(METADATA_GRPC_SERVICE_PORT)
- "8080"
command:
- launcher-v2
env:
Expand Down Expand Up @@ -120,6 +120,12 @@ spec:
- '{{outputs.parameters.condition.path}}'
- "--mlPipelineServiceTLSEnabled"
- "false"
- "--mlmd_server_address"
- "metadata-grpc-service"
- "--mlmd_server_port"
- "8080"
- "--metadataTLSEnabled"
- "false"
command:
- driver
image: gcr.io/ml-pipeline/kfp-driver
Expand Down
2 changes: 1 addition & 1 deletion backend/src/v2/component/importer_launcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ func NewImporterLauncher(ctx context.Context, componentSpecJSON, importerSpecJSO
if err != nil {
return nil, fmt.Errorf("failed to initialize kubernetes client set: %w", err)
}
metadataClient, err := metadata.NewClient(launcherV2Opts.MLMDServerAddress, launcherV2Opts.MLMDServerPort)
metadataClient, err := metadata.NewClient(launcherV2Opts.MLMDServerAddress, launcherV2Opts.MLMDServerPort, launcherV2Opts.MetadataTLSEnabled)
if err != nil {
return nil, err
}
Expand Down
4 changes: 3 additions & 1 deletion backend/src/v2/component/launcher_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ type LauncherV2Options struct {
RunID string
// set to true if ml pipeline server is serving over tls
MLPipelineTLSEnabled bool
// set to true if metadata server is serving over tls
MetadataTLSEnabled bool
}

type LauncherV2 struct {
Expand Down Expand Up @@ -110,7 +112,7 @@ func NewLauncherV2(ctx context.Context, executionID int64, executorInputJSON, co
if err != nil {
return nil, fmt.Errorf("failed to initialize kubernetes client set: %w", err)
}
metadataClient, err := metadata.NewClient(opts.MLMDServerAddress, opts.MLMDServerPort)
metadataClient, err := metadata.NewClient(opts.MLMDServerAddress, opts.MLMDServerPort, opts.MetadataTLSEnabled)
if err != nil {
return nil, err
}
Expand Down
19 changes: 14 additions & 5 deletions backend/src/v2/driver/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,13 @@ type Options struct {

// set to true if ml pipeline server is serving over tls
MLPipelineTLSEnabled bool

MLMDServerAddress string

MLMDServerPort string

// set to true if MLMD server is serving over tls
MLMDTLSEnabled bool
}

// Identifying information used for error messages
Expand Down Expand Up @@ -339,7 +346,7 @@ func Container(ctx context.Context, opts Options, mlmd *metadata.Client, cacheCl
return execution, nil
}

podSpec, err := initPodSpecPatch(opts.Container, opts.Component, executorInput, execution.ID, opts.PipelineName, opts.RunID, opts.MLPipelineTLSEnabled)
podSpec, err := initPodSpecPatch(opts.Container, opts.Component, executorInput, execution.ID, opts.PipelineName, opts.RunID, opts.MLPipelineTLSEnabled, opts.MLMDServerAddress, opts.MLMDServerPort, opts.MLMDTLSEnabled)
if err != nil {
return execution, err
}
Expand Down Expand Up @@ -373,6 +380,9 @@ func initPodSpecPatch(
pipelineName string,
runID string,
mlPipelineTLSEnabled bool,
mlmdServerAddress string,
mlmdServerPort string,
mlmdTLSEnabled bool,
) (*k8score.PodSpec, error) {
executorInputJSON, err := protojson.Marshal(executorInput)
if err != nil {
Expand Down Expand Up @@ -407,10 +417,9 @@ func initPodSpecPatch(
fmt.Sprintf("$(%s)", component.EnvPodName),
"--pod_uid",
fmt.Sprintf("$(%s)", component.EnvPodUID),
"--mlmd_server_address",
fmt.Sprintf("$(%s)", component.EnvMetadataHost),
"--mlmd_server_port",
fmt.Sprintf("$(%s)", component.EnvMetadataPort),
"--mlmd_server_address", mlmdServerAddress,
"--mlmd_server_port", mlmdServerPort,
"--metadataTLSEnabled", fmt.Sprintf("%v", mlmdTLSEnabled),
"--mlPipelineServiceTLSEnabled",
fmt.Sprintf("%v", mlPipelineTLSEnabled),
"--", // separater before user command and args
Expand Down
4 changes: 2 additions & 2 deletions backend/src/v2/driver/driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ func Test_initPodSpecPatch_acceleratorConfig(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
podSpec, err := initPodSpecPatch(tt.args.container, tt.args.componentSpec, tt.args.executorInput, tt.args.executionID, tt.args.pipelineName, tt.args.runID, false)
podSpec, err := initPodSpecPatch(tt.args.container, tt.args.componentSpec, tt.args.executorInput, tt.args.executionID, tt.args.pipelineName, tt.args.runID, false, "unused-mlmd-server-address", "unused-mlmd-server-port", false)
if tt.wantErr {
assert.Nil(t, podSpec)
assert.NotNil(t, err)
Expand Down Expand Up @@ -403,7 +403,7 @@ func Test_initPodSpecPatch_resourceRequests(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
podSpec, err := initPodSpecPatch(tt.args.container, tt.args.componentSpec, tt.args.executorInput, tt.args.executionID, tt.args.pipelineName, tt.args.runID, false)
podSpec, err := initPodSpecPatch(tt.args.container, tt.args.componentSpec, tt.args.executorInput, tt.args.executionID, tt.args.pipelineName, tt.args.runID, false, "unused-mlmd-server-address", "unused-mlmd-server-port", false)
assert.Nil(t, err)
assert.NotEmpty(t, podSpec)
podSpecString, err := json.Marshal(podSpec)
Expand Down
14 changes: 12 additions & 2 deletions backend/src/v2/metadata/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,11 @@ package metadata

import (
"context"
"crypto/tls"
"errors"
"fmt"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"path"
"strconv"
"strings"
Expand Down Expand Up @@ -105,14 +108,21 @@ type Client struct {
}

// NewClient creates a Client given the MLMD server address and port.
func NewClient(serverAddress, serverPort string) (*Client, error) {
func NewClient(serverAddress, serverPort string, tlsEnabled bool) (*Client, error) {
opts := []grpc_retry.CallOption{
grpc_retry.WithMax(mlmdClientSideMaxRetries),
grpc_retry.WithBackoff(grpc_retry.BackoffExponentialWithJitter(300*time.Millisecond, 0.20)),
grpc_retry.WithCodes(codes.Aborted),
}

creds := insecure.NewCredentials()
if tlsEnabled {
config := &tls.Config{}
creds = credentials.NewTLS(config)
}

conn, err := grpc.Dial(fmt.Sprintf("%s:%s", serverAddress, serverPort),
grpc.WithInsecure(),
grpc.WithTransportCredentials(creds),
HumairAK marked this conversation as resolved.
Show resolved Hide resolved
grpc.WithStreamInterceptor(grpc_retry.StreamClientInterceptor(opts...)),
grpc.WithUnaryInterceptor(grpc_retry.UnaryClientInterceptor(opts...)),
)
Expand Down
Loading
Loading