From 9ee91c1c72cfba5dc92feb9500f19d0c6a9ee400 Mon Sep 17 00:00:00 2001 From: Jason Parraga Date: Mon, 11 Nov 2024 16:51:51 -0800 Subject: [PATCH] Add pod spec to workers/head Signed-off-by: Jason Parraga --- flyteidl/gen/pb-es/flyteidl/plugins/ray_pb.ts | 18 ++--- flyteidl/gen/pb-go/flyteidl/plugins/ray.pb.go | 77 +++++++++---------- .../gen/pb_python/flyteidl/plugins/ray_pb2.py | 16 ++-- .../pb_python/flyteidl/plugins/ray_pb2.pyi | 16 ++-- flyteidl/gen/pb_rust/flyteidl.plugins.rs | 8 +- flyteidl/protos/flyteidl/plugins/ray.proto | 8 +- flyteplugins/go/tasks/plugins/k8s/ray/ray.go | 68 ++++++++++------ .../go/tasks/plugins/k8s/ray/ray_test.go | 60 +++++++++++---- 8 files changed, 162 insertions(+), 109 deletions(-) diff --git a/flyteidl/gen/pb-es/flyteidl/plugins/ray_pb.ts b/flyteidl/gen/pb-es/flyteidl/plugins/ray_pb.ts index 0809147e506..0c0a33e0a1c 100644 --- a/flyteidl/gen/pb-es/flyteidl/plugins/ray_pb.ts +++ b/flyteidl/gen/pb-es/flyteidl/plugins/ray_pb.ts @@ -5,7 +5,7 @@ import type { BinaryReadOptions, FieldList, JsonReadOptions, JsonValue, PartialMessage, PlainMessage } from "@bufbuild/protobuf"; import { Message, proto3 } from "@bufbuild/protobuf"; -import { Resources } from "../core/tasks_pb.js"; +import { K8sPod } from "../core/tasks_pb.js"; /** * RayJobSpec defines the desired state of RayJob @@ -155,11 +155,11 @@ export class HeadGroupSpec extends Message { rayStartParams: { [key: string]: string } = {}; /** - * Resource specification for ray head pod + * Pod Spec for the ray head pod * - * @generated from field: flyteidl.core.Resources resources = 2; + * @generated from field: flyteidl.core.K8sPod k8s_pod = 2; */ - resources?: Resources; + k8sPod?: K8sPod; constructor(data?: PartialMessage) { super(); @@ -170,7 +170,7 @@ export class HeadGroupSpec extends Message { static readonly typeName = "flyteidl.plugins.HeadGroupSpec"; static readonly fields: FieldList = proto3.util.newFieldList(() => [ { no: 1, name: "ray_start_params", kind: "map", K: 9 /* ScalarType.STRING */, V: {kind: "scalar", T: 9 /* ScalarType.STRING */} }, - { no: 2, name: "resources", kind: "message", T: Resources }, + { no: 2, name: "k8s_pod", kind: "message", T: K8sPod }, ]); static fromBinary(bytes: Uint8Array, options?: Partial): HeadGroupSpec { @@ -233,11 +233,11 @@ export class WorkerGroupSpec extends Message { rayStartParams: { [key: string]: string } = {}; /** - * Resource specification for ray worker pods + * Pod Spec for ray worker pods * - * @generated from field: flyteidl.core.Resources resources = 6; + * @generated from field: flyteidl.core.K8sPod k8s_pod = 6; */ - resources?: Resources; + k8sPod?: K8sPod; constructor(data?: PartialMessage) { super(); @@ -252,7 +252,7 @@ export class WorkerGroupSpec extends Message { { no: 3, name: "min_replicas", kind: "scalar", T: 5 /* ScalarType.INT32 */ }, { no: 4, name: "max_replicas", kind: "scalar", T: 5 /* ScalarType.INT32 */ }, { no: 5, name: "ray_start_params", kind: "map", K: 9 /* ScalarType.STRING */, V: {kind: "scalar", T: 9 /* ScalarType.STRING */} }, - { no: 6, name: "resources", kind: "message", T: Resources }, + { no: 6, name: "k8s_pod", kind: "message", T: K8sPod }, ]); static fromBinary(bytes: Uint8Array, options?: Partial): WorkerGroupSpec { diff --git a/flyteidl/gen/pb-go/flyteidl/plugins/ray.pb.go b/flyteidl/gen/pb-go/flyteidl/plugins/ray.pb.go index 153eefecbb8..18d0c4c1cba 100644 --- a/flyteidl/gen/pb-go/flyteidl/plugins/ray.pb.go +++ b/flyteidl/gen/pb-go/flyteidl/plugins/ray.pb.go @@ -187,8 +187,8 @@ type HeadGroupSpec struct { // Optional. RayStartParams are the params of the start command: address, object-store-memory. // Refer to https://docs.ray.io/en/latest/ray-core/package-ref.html#ray-start RayStartParams map[string]string `protobuf:"bytes,1,rep,name=ray_start_params,json=rayStartParams,proto3" json:"ray_start_params,omitempty" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3"` - // Resource specification for ray head pod - Resources *core.Resources `protobuf:"bytes,2,opt,name=resources,proto3" json:"resources,omitempty"` + // Pod Spec for the ray head pod + K8SPod *core.K8SPod `protobuf:"bytes,2,opt,name=k8s_pod,json=k8sPod,proto3" json:"k8s_pod,omitempty"` } func (x *HeadGroupSpec) Reset() { @@ -230,9 +230,9 @@ func (x *HeadGroupSpec) GetRayStartParams() map[string]string { return nil } -func (x *HeadGroupSpec) GetResources() *core.Resources { +func (x *HeadGroupSpec) GetK8SPod() *core.K8SPod { if x != nil { - return x.Resources + return x.K8SPod } return nil } @@ -254,8 +254,8 @@ type WorkerGroupSpec struct { // Optional. RayStartParams are the params of the start command: address, object-store-memory. // Refer to https://docs.ray.io/en/latest/ray-core/package-ref.html#ray-start RayStartParams map[string]string `protobuf:"bytes,5,rep,name=ray_start_params,json=rayStartParams,proto3" json:"ray_start_params,omitempty" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3"` - // Resource specification for ray worker pods - Resources *core.Resources `protobuf:"bytes,6,opt,name=resources,proto3" json:"resources,omitempty"` + // Pod Spec for ray worker pods + K8SPod *core.K8SPod `protobuf:"bytes,6,opt,name=k8s_pod,json=k8sPod,proto3" json:"k8s_pod,omitempty"` } func (x *WorkerGroupSpec) Reset() { @@ -325,9 +325,9 @@ func (x *WorkerGroupSpec) GetRayStartParams() map[string]string { return nil } -func (x *WorkerGroupSpec) GetResources() *core.Resources { +func (x *WorkerGroupSpec) GetK8SPod() *core.K8SPod { if x != nil { - return x.Resources + return x.K8SPod } return nil } @@ -370,40 +370,39 @@ var file_flyteidl_plugins_ray_proto_rawDesc = []byte{ 0x70, 0x53, 0x70, 0x65, 0x63, 0x12, 0x2d, 0x0a, 0x12, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x5f, 0x61, 0x75, 0x74, 0x6f, 0x73, 0x63, 0x61, 0x6c, 0x69, 0x6e, 0x67, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x11, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x41, 0x75, 0x74, 0x6f, 0x73, 0x63, 0x61, - 0x6c, 0x69, 0x6e, 0x67, 0x22, 0xe9, 0x01, 0x0a, 0x0d, 0x48, 0x65, 0x61, 0x64, 0x47, 0x72, 0x6f, + 0x6c, 0x69, 0x6e, 0x67, 0x22, 0xe1, 0x01, 0x0a, 0x0d, 0x48, 0x65, 0x61, 0x64, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x53, 0x70, 0x65, 0x63, 0x12, 0x5d, 0x0a, 0x10, 0x72, 0x61, 0x79, 0x5f, 0x73, 0x74, 0x61, 0x72, 0x74, 0x5f, 0x70, 0x61, 0x72, 0x61, 0x6d, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x33, 0x2e, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2e, 0x70, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x73, 0x2e, 0x48, 0x65, 0x61, 0x64, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x53, 0x70, 0x65, 0x63, 0x2e, 0x52, 0x61, 0x79, 0x53, 0x74, 0x61, 0x72, 0x74, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, 0x0e, 0x72, 0x61, 0x79, 0x53, 0x74, 0x61, 0x72, 0x74, 0x50, - 0x61, 0x72, 0x61, 0x6d, 0x73, 0x12, 0x36, 0x0a, 0x09, 0x72, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63, - 0x65, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x66, 0x6c, 0x79, 0x74, 0x65, - 0x69, 0x64, 0x6c, 0x2e, 0x63, 0x6f, 0x72, 0x65, 0x2e, 0x52, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63, - 0x65, 0x73, 0x52, 0x09, 0x72, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x73, 0x1a, 0x41, 0x0a, - 0x13, 0x52, 0x61, 0x79, 0x53, 0x74, 0x61, 0x72, 0x74, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x73, 0x45, - 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, - 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, - 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, - 0x22, 0xee, 0x02, 0x0a, 0x0f, 0x57, 0x6f, 0x72, 0x6b, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, - 0x53, 0x70, 0x65, 0x63, 0x12, 0x1d, 0x0a, 0x0a, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x5f, 0x6e, 0x61, - 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x4e, - 0x61, 0x6d, 0x65, 0x12, 0x1a, 0x0a, 0x08, 0x72, 0x65, 0x70, 0x6c, 0x69, 0x63, 0x61, 0x73, 0x18, - 0x02, 0x20, 0x01, 0x28, 0x05, 0x52, 0x08, 0x72, 0x65, 0x70, 0x6c, 0x69, 0x63, 0x61, 0x73, 0x12, - 0x21, 0x0a, 0x0c, 0x6d, 0x69, 0x6e, 0x5f, 0x72, 0x65, 0x70, 0x6c, 0x69, 0x63, 0x61, 0x73, 0x18, - 0x03, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0b, 0x6d, 0x69, 0x6e, 0x52, 0x65, 0x70, 0x6c, 0x69, 0x63, - 0x61, 0x73, 0x12, 0x21, 0x0a, 0x0c, 0x6d, 0x61, 0x78, 0x5f, 0x72, 0x65, 0x70, 0x6c, 0x69, 0x63, - 0x61, 0x73, 0x18, 0x04, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0b, 0x6d, 0x61, 0x78, 0x52, 0x65, 0x70, - 0x6c, 0x69, 0x63, 0x61, 0x73, 0x12, 0x5f, 0x0a, 0x10, 0x72, 0x61, 0x79, 0x5f, 0x73, 0x74, 0x61, - 0x72, 0x74, 0x5f, 0x70, 0x61, 0x72, 0x61, 0x6d, 0x73, 0x18, 0x05, 0x20, 0x03, 0x28, 0x0b, 0x32, - 0x35, 0x2e, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2e, 0x70, 0x6c, 0x75, 0x67, 0x69, - 0x6e, 0x73, 0x2e, 0x57, 0x6f, 0x72, 0x6b, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x53, 0x70, - 0x65, 0x63, 0x2e, 0x52, 0x61, 0x79, 0x53, 0x74, 0x61, 0x72, 0x74, 0x50, 0x61, 0x72, 0x61, 0x6d, - 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, 0x0e, 0x72, 0x61, 0x79, 0x53, 0x74, 0x61, 0x72, 0x74, - 0x50, 0x61, 0x72, 0x61, 0x6d, 0x73, 0x12, 0x36, 0x0a, 0x09, 0x72, 0x65, 0x73, 0x6f, 0x75, 0x72, - 0x63, 0x65, 0x73, 0x18, 0x06, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x66, 0x6c, 0x79, 0x74, - 0x65, 0x69, 0x64, 0x6c, 0x2e, 0x63, 0x6f, 0x72, 0x65, 0x2e, 0x52, 0x65, 0x73, 0x6f, 0x75, 0x72, - 0x63, 0x65, 0x73, 0x52, 0x09, 0x72, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x73, 0x1a, 0x41, + 0x61, 0x72, 0x61, 0x6d, 0x73, 0x12, 0x2e, 0x0a, 0x07, 0x6b, 0x38, 0x73, 0x5f, 0x70, 0x6f, 0x64, + 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, + 0x6c, 0x2e, 0x63, 0x6f, 0x72, 0x65, 0x2e, 0x4b, 0x38, 0x73, 0x50, 0x6f, 0x64, 0x52, 0x06, 0x6b, + 0x38, 0x73, 0x50, 0x6f, 0x64, 0x1a, 0x41, 0x0a, 0x13, 0x52, 0x61, 0x79, 0x53, 0x74, 0x61, 0x72, + 0x74, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, + 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x14, + 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x76, + 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x22, 0xe6, 0x02, 0x0a, 0x0f, 0x57, 0x6f, 0x72, + 0x6b, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x53, 0x70, 0x65, 0x63, 0x12, 0x1d, 0x0a, 0x0a, + 0x67, 0x72, 0x6f, 0x75, 0x70, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x09, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x1a, 0x0a, 0x08, 0x72, + 0x65, 0x70, 0x6c, 0x69, 0x63, 0x61, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x05, 0x52, 0x08, 0x72, + 0x65, 0x70, 0x6c, 0x69, 0x63, 0x61, 0x73, 0x12, 0x21, 0x0a, 0x0c, 0x6d, 0x69, 0x6e, 0x5f, 0x72, + 0x65, 0x70, 0x6c, 0x69, 0x63, 0x61, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0b, 0x6d, + 0x69, 0x6e, 0x52, 0x65, 0x70, 0x6c, 0x69, 0x63, 0x61, 0x73, 0x12, 0x21, 0x0a, 0x0c, 0x6d, 0x61, + 0x78, 0x5f, 0x72, 0x65, 0x70, 0x6c, 0x69, 0x63, 0x61, 0x73, 0x18, 0x04, 0x20, 0x01, 0x28, 0x05, + 0x52, 0x0b, 0x6d, 0x61, 0x78, 0x52, 0x65, 0x70, 0x6c, 0x69, 0x63, 0x61, 0x73, 0x12, 0x5f, 0x0a, + 0x10, 0x72, 0x61, 0x79, 0x5f, 0x73, 0x74, 0x61, 0x72, 0x74, 0x5f, 0x70, 0x61, 0x72, 0x61, 0x6d, + 0x73, 0x18, 0x05, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x35, 0x2e, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, + 0x64, 0x6c, 0x2e, 0x70, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x73, 0x2e, 0x57, 0x6f, 0x72, 0x6b, 0x65, + 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x53, 0x70, 0x65, 0x63, 0x2e, 0x52, 0x61, 0x79, 0x53, 0x74, + 0x61, 0x72, 0x74, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, 0x0e, + 0x72, 0x61, 0x79, 0x53, 0x74, 0x61, 0x72, 0x74, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x73, 0x12, 0x2e, + 0x0a, 0x07, 0x6b, 0x38, 0x73, 0x5f, 0x70, 0x6f, 0x64, 0x18, 0x06, 0x20, 0x01, 0x28, 0x0b, 0x32, + 0x15, 0x2e, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2e, 0x63, 0x6f, 0x72, 0x65, 0x2e, + 0x4b, 0x38, 0x73, 0x50, 0x6f, 0x64, 0x52, 0x06, 0x6b, 0x38, 0x73, 0x50, 0x6f, 0x64, 0x1a, 0x41, 0x0a, 0x13, 0x52, 0x61, 0x79, 0x53, 0x74, 0x61, 0x72, 0x74, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, @@ -443,16 +442,16 @@ var file_flyteidl_plugins_ray_proto_goTypes = []interface{}{ (*WorkerGroupSpec)(nil), // 3: flyteidl.plugins.WorkerGroupSpec nil, // 4: flyteidl.plugins.HeadGroupSpec.RayStartParamsEntry nil, // 5: flyteidl.plugins.WorkerGroupSpec.RayStartParamsEntry - (*core.Resources)(nil), // 6: flyteidl.core.Resources + (*core.K8SPod)(nil), // 6: flyteidl.core.K8sPod } var file_flyteidl_plugins_ray_proto_depIdxs = []int32{ 1, // 0: flyteidl.plugins.RayJob.ray_cluster:type_name -> flyteidl.plugins.RayCluster 2, // 1: flyteidl.plugins.RayCluster.head_group_spec:type_name -> flyteidl.plugins.HeadGroupSpec 3, // 2: flyteidl.plugins.RayCluster.worker_group_spec:type_name -> flyteidl.plugins.WorkerGroupSpec 4, // 3: flyteidl.plugins.HeadGroupSpec.ray_start_params:type_name -> flyteidl.plugins.HeadGroupSpec.RayStartParamsEntry - 6, // 4: flyteidl.plugins.HeadGroupSpec.resources:type_name -> flyteidl.core.Resources + 6, // 4: flyteidl.plugins.HeadGroupSpec.k8s_pod:type_name -> flyteidl.core.K8sPod 5, // 5: flyteidl.plugins.WorkerGroupSpec.ray_start_params:type_name -> flyteidl.plugins.WorkerGroupSpec.RayStartParamsEntry - 6, // 6: flyteidl.plugins.WorkerGroupSpec.resources:type_name -> flyteidl.core.Resources + 6, // 6: flyteidl.plugins.WorkerGroupSpec.k8s_pod:type_name -> flyteidl.core.K8sPod 7, // [7:7] is the sub-list for method output_type 7, // [7:7] is the sub-list for method input_type 7, // [7:7] is the sub-list for extension type_name diff --git a/flyteidl/gen/pb_python/flyteidl/plugins/ray_pb2.py b/flyteidl/gen/pb_python/flyteidl/plugins/ray_pb2.py index 35cdb3dfdea..c625fd957b6 100644 --- a/flyteidl/gen/pb_python/flyteidl/plugins/ray_pb2.py +++ b/flyteidl/gen/pb_python/flyteidl/plugins/ray_pb2.py @@ -14,7 +14,7 @@ from flyteidl.core import tasks_pb2 as flyteidl_dot_core_dot_tasks__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1a\x66lyteidl/plugins/ray.proto\x12\x10\x66lyteidl.plugins\x1a\x19\x66lyteidl/core/tasks.proto\"\x92\x02\n\x06RayJob\x12=\n\x0bray_cluster\x18\x01 \x01(\x0b\x32\x1c.flyteidl.plugins.RayClusterR\nrayCluster\x12#\n\x0bruntime_env\x18\x02 \x01(\tB\x02\x18\x01R\nruntimeEnv\x12=\n\x1bshutdown_after_job_finishes\x18\x03 \x01(\x08R\x18shutdownAfterJobFinishes\x12;\n\x1attl_seconds_after_finished\x18\x04 \x01(\x05R\x17ttlSecondsAfterFinished\x12(\n\x10runtime_env_yaml\x18\x05 \x01(\tR\x0eruntimeEnvYaml\"\xd3\x01\n\nRayCluster\x12G\n\x0fhead_group_spec\x18\x01 \x01(\x0b\x32\x1f.flyteidl.plugins.HeadGroupSpecR\rheadGroupSpec\x12M\n\x11worker_group_spec\x18\x02 \x03(\x0b\x32!.flyteidl.plugins.WorkerGroupSpecR\x0fworkerGroupSpec\x12-\n\x12\x65nable_autoscaling\x18\x03 \x01(\x08R\x11\x65nableAutoscaling\"\xe9\x01\n\rHeadGroupSpec\x12]\n\x10ray_start_params\x18\x01 \x03(\x0b\x32\x33.flyteidl.plugins.HeadGroupSpec.RayStartParamsEntryR\x0erayStartParams\x12\x36\n\tresources\x18\x02 \x01(\x0b\x32\x18.flyteidl.core.ResourcesR\tresources\x1a\x41\n\x13RayStartParamsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\"\xee\x02\n\x0fWorkerGroupSpec\x12\x1d\n\ngroup_name\x18\x01 \x01(\tR\tgroupName\x12\x1a\n\x08replicas\x18\x02 \x01(\x05R\x08replicas\x12!\n\x0cmin_replicas\x18\x03 \x01(\x05R\x0bminReplicas\x12!\n\x0cmax_replicas\x18\x04 \x01(\x05R\x0bmaxReplicas\x12_\n\x10ray_start_params\x18\x05 \x03(\x0b\x32\x35.flyteidl.plugins.WorkerGroupSpec.RayStartParamsEntryR\x0erayStartParams\x12\x36\n\tresources\x18\x06 \x01(\x0b\x32\x18.flyteidl.core.ResourcesR\tresources\x1a\x41\n\x13RayStartParamsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\xc0\x01\n\x14\x63om.flyteidl.pluginsB\x08RayProtoP\x01Z=github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins\xa2\x02\x03\x46PX\xaa\x02\x10\x46lyteidl.Plugins\xca\x02\x10\x46lyteidl\\Plugins\xe2\x02\x1c\x46lyteidl\\Plugins\\GPBMetadata\xea\x02\x11\x46lyteidl::Pluginsb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1a\x66lyteidl/plugins/ray.proto\x12\x10\x66lyteidl.plugins\x1a\x19\x66lyteidl/core/tasks.proto\"\x92\x02\n\x06RayJob\x12=\n\x0bray_cluster\x18\x01 \x01(\x0b\x32\x1c.flyteidl.plugins.RayClusterR\nrayCluster\x12#\n\x0bruntime_env\x18\x02 \x01(\tB\x02\x18\x01R\nruntimeEnv\x12=\n\x1bshutdown_after_job_finishes\x18\x03 \x01(\x08R\x18shutdownAfterJobFinishes\x12;\n\x1attl_seconds_after_finished\x18\x04 \x01(\x05R\x17ttlSecondsAfterFinished\x12(\n\x10runtime_env_yaml\x18\x05 \x01(\tR\x0eruntimeEnvYaml\"\xd3\x01\n\nRayCluster\x12G\n\x0fhead_group_spec\x18\x01 \x01(\x0b\x32\x1f.flyteidl.plugins.HeadGroupSpecR\rheadGroupSpec\x12M\n\x11worker_group_spec\x18\x02 \x03(\x0b\x32!.flyteidl.plugins.WorkerGroupSpecR\x0fworkerGroupSpec\x12-\n\x12\x65nable_autoscaling\x18\x03 \x01(\x08R\x11\x65nableAutoscaling\"\xe1\x01\n\rHeadGroupSpec\x12]\n\x10ray_start_params\x18\x01 \x03(\x0b\x32\x33.flyteidl.plugins.HeadGroupSpec.RayStartParamsEntryR\x0erayStartParams\x12.\n\x07k8s_pod\x18\x02 \x01(\x0b\x32\x15.flyteidl.core.K8sPodR\x06k8sPod\x1a\x41\n\x13RayStartParamsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\"\xe6\x02\n\x0fWorkerGroupSpec\x12\x1d\n\ngroup_name\x18\x01 \x01(\tR\tgroupName\x12\x1a\n\x08replicas\x18\x02 \x01(\x05R\x08replicas\x12!\n\x0cmin_replicas\x18\x03 \x01(\x05R\x0bminReplicas\x12!\n\x0cmax_replicas\x18\x04 \x01(\x05R\x0bmaxReplicas\x12_\n\x10ray_start_params\x18\x05 \x03(\x0b\x32\x35.flyteidl.plugins.WorkerGroupSpec.RayStartParamsEntryR\x0erayStartParams\x12.\n\x07k8s_pod\x18\x06 \x01(\x0b\x32\x15.flyteidl.core.K8sPodR\x06k8sPod\x1a\x41\n\x13RayStartParamsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\xc0\x01\n\x14\x63om.flyteidl.pluginsB\x08RayProtoP\x01Z=github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins\xa2\x02\x03\x46PX\xaa\x02\x10\x46lyteidl.Plugins\xca\x02\x10\x46lyteidl\\Plugins\xe2\x02\x1c\x46lyteidl\\Plugins\\GPBMetadata\xea\x02\x11\x46lyteidl::Pluginsb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -34,11 +34,11 @@ _globals['_RAYCLUSTER']._serialized_start=353 _globals['_RAYCLUSTER']._serialized_end=564 _globals['_HEADGROUPSPEC']._serialized_start=567 - _globals['_HEADGROUPSPEC']._serialized_end=800 - _globals['_HEADGROUPSPEC_RAYSTARTPARAMSENTRY']._serialized_start=735 - _globals['_HEADGROUPSPEC_RAYSTARTPARAMSENTRY']._serialized_end=800 - _globals['_WORKERGROUPSPEC']._serialized_start=803 - _globals['_WORKERGROUPSPEC']._serialized_end=1169 - _globals['_WORKERGROUPSPEC_RAYSTARTPARAMSENTRY']._serialized_start=735 - _globals['_WORKERGROUPSPEC_RAYSTARTPARAMSENTRY']._serialized_end=800 + _globals['_HEADGROUPSPEC']._serialized_end=792 + _globals['_HEADGROUPSPEC_RAYSTARTPARAMSENTRY']._serialized_start=727 + _globals['_HEADGROUPSPEC_RAYSTARTPARAMSENTRY']._serialized_end=792 + _globals['_WORKERGROUPSPEC']._serialized_start=795 + _globals['_WORKERGROUPSPEC']._serialized_end=1153 + _globals['_WORKERGROUPSPEC_RAYSTARTPARAMSENTRY']._serialized_start=727 + _globals['_WORKERGROUPSPEC_RAYSTARTPARAMSENTRY']._serialized_end=792 # @@protoc_insertion_point(module_scope) diff --git a/flyteidl/gen/pb_python/flyteidl/plugins/ray_pb2.pyi b/flyteidl/gen/pb_python/flyteidl/plugins/ray_pb2.pyi index 3725d8e5d46..239e2fbc1ac 100644 --- a/flyteidl/gen/pb_python/flyteidl/plugins/ray_pb2.pyi +++ b/flyteidl/gen/pb_python/flyteidl/plugins/ray_pb2.pyi @@ -31,7 +31,7 @@ class RayCluster(_message.Message): def __init__(self, head_group_spec: _Optional[_Union[HeadGroupSpec, _Mapping]] = ..., worker_group_spec: _Optional[_Iterable[_Union[WorkerGroupSpec, _Mapping]]] = ..., enable_autoscaling: bool = ...) -> None: ... class HeadGroupSpec(_message.Message): - __slots__ = ["ray_start_params", "resources"] + __slots__ = ["ray_start_params", "k8s_pod"] class RayStartParamsEntry(_message.Message): __slots__ = ["key", "value"] KEY_FIELD_NUMBER: _ClassVar[int] @@ -40,13 +40,13 @@ class HeadGroupSpec(_message.Message): value: str def __init__(self, key: _Optional[str] = ..., value: _Optional[str] = ...) -> None: ... RAY_START_PARAMS_FIELD_NUMBER: _ClassVar[int] - RESOURCES_FIELD_NUMBER: _ClassVar[int] + K8S_POD_FIELD_NUMBER: _ClassVar[int] ray_start_params: _containers.ScalarMap[str, str] - resources: _tasks_pb2.Resources - def __init__(self, ray_start_params: _Optional[_Mapping[str, str]] = ..., resources: _Optional[_Union[_tasks_pb2.Resources, _Mapping]] = ...) -> None: ... + k8s_pod: _tasks_pb2.K8sPod + def __init__(self, ray_start_params: _Optional[_Mapping[str, str]] = ..., k8s_pod: _Optional[_Union[_tasks_pb2.K8sPod, _Mapping]] = ...) -> None: ... class WorkerGroupSpec(_message.Message): - __slots__ = ["group_name", "replicas", "min_replicas", "max_replicas", "ray_start_params", "resources"] + __slots__ = ["group_name", "replicas", "min_replicas", "max_replicas", "ray_start_params", "k8s_pod"] class RayStartParamsEntry(_message.Message): __slots__ = ["key", "value"] KEY_FIELD_NUMBER: _ClassVar[int] @@ -59,11 +59,11 @@ class WorkerGroupSpec(_message.Message): MIN_REPLICAS_FIELD_NUMBER: _ClassVar[int] MAX_REPLICAS_FIELD_NUMBER: _ClassVar[int] RAY_START_PARAMS_FIELD_NUMBER: _ClassVar[int] - RESOURCES_FIELD_NUMBER: _ClassVar[int] + K8S_POD_FIELD_NUMBER: _ClassVar[int] group_name: str replicas: int min_replicas: int max_replicas: int ray_start_params: _containers.ScalarMap[str, str] - resources: _tasks_pb2.Resources - def __init__(self, group_name: _Optional[str] = ..., replicas: _Optional[int] = ..., min_replicas: _Optional[int] = ..., max_replicas: _Optional[int] = ..., ray_start_params: _Optional[_Mapping[str, str]] = ..., resources: _Optional[_Union[_tasks_pb2.Resources, _Mapping]] = ...) -> None: ... + k8s_pod: _tasks_pb2.K8sPod + def __init__(self, group_name: _Optional[str] = ..., replicas: _Optional[int] = ..., min_replicas: _Optional[int] = ..., max_replicas: _Optional[int] = ..., ray_start_params: _Optional[_Mapping[str, str]] = ..., k8s_pod: _Optional[_Union[_tasks_pb2.K8sPod, _Mapping]] = ...) -> None: ... diff --git a/flyteidl/gen/pb_rust/flyteidl.plugins.rs b/flyteidl/gen/pb_rust/flyteidl.plugins.rs index ddfcba6e8bf..65f187c3e00 100644 --- a/flyteidl/gen/pb_rust/flyteidl.plugins.rs +++ b/flyteidl/gen/pb_rust/flyteidl.plugins.rs @@ -255,9 +255,9 @@ pub struct HeadGroupSpec { /// Refer to #[prost(map="string, string", tag="1")] pub ray_start_params: ::std::collections::HashMap<::prost::alloc::string::String, ::prost::alloc::string::String>, - /// Resource specification for ray head pod + /// Pod Spec for the ray head pod #[prost(message, optional, tag="2")] - pub resources: ::core::option::Option, + pub k8s_pod: ::core::option::Option, } /// WorkerGroupSpec are the specs for the worker pods #[allow(clippy::derive_partial_eq_without_eq)] @@ -279,9 +279,9 @@ pub struct WorkerGroupSpec { /// Refer to #[prost(map="string, string", tag="5")] pub ray_start_params: ::std::collections::HashMap<::prost::alloc::string::String, ::prost::alloc::string::String>, - /// Resource specification for ray worker pods + /// Pod Spec for ray worker pods #[prost(message, optional, tag="6")] - pub resources: ::core::option::Option, + pub k8s_pod: ::core::option::Option, } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, Copy, PartialEq, ::prost::Message)] diff --git a/flyteidl/protos/flyteidl/plugins/ray.proto b/flyteidl/protos/flyteidl/plugins/ray.proto index 749f16932d9..749444ee049 100644 --- a/flyteidl/protos/flyteidl/plugins/ray.proto +++ b/flyteidl/protos/flyteidl/plugins/ray.proto @@ -37,8 +37,8 @@ message HeadGroupSpec { // Optional. RayStartParams are the params of the start command: address, object-store-memory. // Refer to https://docs.ray.io/en/latest/ray-core/package-ref.html#ray-start map ray_start_params = 1; - // Resource specification for ray head pod - core.Resources resources = 2; + // Pod Spec for the ray head pod + core.K8sPod k8s_pod = 2; } // WorkerGroupSpec are the specs for the worker pods @@ -54,6 +54,6 @@ message WorkerGroupSpec { // Optional. RayStartParams are the params of the start command: address, object-store-memory. // Refer to https://docs.ray.io/en/latest/ray-core/package-ref.html#ray-start map ray_start_params = 5; - // Resource specification for ray worker pods - core.Resources resources = 6; + // Pod Spec for ray worker pods + core.K8sPod k8s_pod = 6; } diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/ray.go b/flyteplugins/go/tasks/plugins/k8s/ray/ray.go index cc925c63053..7f26791632c 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/ray.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/ray.go @@ -329,7 +329,7 @@ func injectLogsSidecar(primaryContainer *v1.Container, podSpec *v1.PodSpec) { podSpec.Containers = append(podSpec.Containers, *sidecar) } -func buildHeadPodTemplate(primaryContainer *v1.Container, podSpec *v1.PodSpec, objectMeta *metav1.ObjectMeta, taskCtx pluginsCore.TaskExecutionContext, spec *plugins.HeadGroupSpec) (v1.PodTemplateSpec, error) { +func buildHeadPodTemplate(primaryContainer *v1.Container, basePodSpec *v1.PodSpec, objectMeta *metav1.ObjectMeta, taskCtx pluginsCore.TaskExecutionContext, spec *plugins.HeadGroupSpec) (v1.PodTemplateSpec, error) { // Some configs are copy from https://github.com/ray-project/kuberay/blob/b72e6bdcd9b8c77a9dc6b5da8560910f3a0c3ffd/apiserver/pkg/util/cluster.go#L97 // They should always be the same, so we could hard code here. primaryContainer.Name = "ray-head" @@ -367,20 +367,15 @@ func buildHeadPodTemplate(primaryContainer *v1.Container, podSpec *v1.PodSpec, o primaryContainer.Ports = append(primaryContainer.Ports, ports...) // Inject a sidecar for capturing and exposing Ray job logs - injectLogsSidecar(primaryContainer, podSpec) + injectLogsSidecar(primaryContainer, basePodSpec) - // Overwrite head pod taskResources if specified - if spec.Resources != nil { - res, err := flytek8s.ToK8sResourceRequirements(spec.Resources) - if err != nil { - return v1.PodTemplateSpec{}, flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecification HeadGroupSpec Resources[%v], Err: [%v]", spec.Resources, err.Error()) - } - - primaryContainer.Resources = *res + basePodSpec, err := mergeCustomPodSpec(primaryContainer, basePodSpec, spec.K8SPod) + if err != nil { + return v1.PodTemplateSpec{}, err } podTemplateSpec := v1.PodTemplateSpec{ - Spec: *podSpec, + Spec: *basePodSpec, ObjectMeta: *objectMeta, } cfg := config.GetK8sPluginConfig() @@ -403,7 +398,7 @@ func buildSubmitterPodTemplate(podSpec *v1.PodSpec, objectMeta *metav1.ObjectMet return podTemplateSpec } -func buildWorkerPodTemplate(primaryContainer *v1.Container, podSpec *v1.PodSpec, objectMetadata *metav1.ObjectMeta, taskCtx pluginsCore.TaskExecutionContext, spec *plugins.WorkerGroupSpec) (v1.PodTemplateSpec, error) { +func buildWorkerPodTemplate(primaryContainer *v1.Container, basePodSpec *v1.PodSpec, objectMetadata *metav1.ObjectMeta, taskCtx pluginsCore.TaskExecutionContext, spec *plugins.WorkerGroupSpmergeCustomPodSpeceSpec, error) { // Some configs are copy from https://github.com/ray-project/kuberay/blob/b72e6bdcd9b8c77a9dc6b5da8560910f3a0c3ffd/apiserver/pkg/util/cluster.go#L185 // They should always be the same, so we could hard code here. @@ -502,18 +497,13 @@ func buildWorkerPodTemplate(primaryContainer *v1.Container, podSpec *v1.PodSpec, } primaryContainer.Ports = append(primaryContainer.Ports, ports...) - // Overwrite worker pod taskResources if specified - if spec.Resources != nil { - res, err := flytek8s.ToK8sResourceRequirements(spec.Resources) - if err != nil { - return v1.PodTemplateSpec{}, flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecification on WorkerGroupSpec Resources[%v], Err: [%v]", spec.Resources, err.Error()) - } - - primaryContainer.Resources = *res + basePodSpec, err := mergeCustomPodSpec(primaryContainer, basePodSpec, spec.K8SPod) + if err != nil { + return v1.PodTemplateSpec{}, err } podTemplateSpec := v1.PodTemplateSpec{ - Spec: *podSpec, + Spec: *basePodSpec, ObjectMeta: *objectMetadata, } podTemplateSpec.SetLabels(utils.UnionMaps(podTemplateSpec.GetLabels(), utils.CopyMap(taskCtx.TaskExecutionMetadata().GetLabels()))) @@ -521,6 +511,38 @@ func buildWorkerPodTemplate(primaryContainer *v1.Container, podSpec *v1.PodSpec, return podTemplateSpec, nil } +// Merges a ray head/worker node custom pod specs onto task's generated pod spec +func mergeCustomPodSpec(primaryContainer *v1.Container, podSpec *v1.PodSpec, k8sPod *core.K8SPod) (*v1.PodSpec, error) { + if k8sPod == nil { + return podSpec, nil + } + + if k8sPod.PodSpec == nil { + return podSpec, nil + } + + var customPodSpec *v1.PodSpec + + err := utils.UnmarshalStructToObj(k8sPod.PodSpec, &customPodSpec) + if err != nil { + return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, + "Unable to unmarshal pod spec [%v], Err: [%v]", k8sPod.PodSpec, err.Error()) + } + + for _, container := range customPodSpec.Containers { + if container.Name != primaryContainer.Name { // Only support the primary container for now + continue + } + + // Just handle resources for now + if len(container.Resources.Requests) > 0 || len(container.Resources.Limits) > 0 { + primaryContainer.Resources = container.Resources + } + } + + return podSpec, nil +} + func (rayJobResourceHandler) BuildIdentityResource(ctx context.Context, taskCtx pluginsCore.TaskExecutionMetadata) (client.Object, error) { return &rayv1.RayJob{ TypeMeta: metav1.TypeMeta{ @@ -545,7 +567,7 @@ func getEventInfoForRayJob(logConfig logs.LogConfig, pluginContext k8s.PluginCon ExtraTemplateVars: []tasklog.TemplateVar{}, } if rayJob.Status.JobId != "" { - input.ExtraTemplateVars = append( + input.EmergeCustomPodSpec append( input.ExtraTemplateVars, tasklog.TemplateVar{ Regex: logTemplateRegexes.RayJobID, @@ -567,7 +589,7 @@ func getEventInfoForRayJob(logConfig logs.LogConfig, pluginContext k8s.PluginCon // RayJob CRD does not include the name of the worker or head pod for now logOutput, err := logPlugin.GetTaskLogs(input) if err != nil { - return nil, fmt.Errorf("failed to generate task logs. Error: %w", err) + return nil, fmt.Errorf("failed to genermergeCustomPodSpecor: %w", err) } taskLogs = append(taskLogs, logOutput.TaskLogs...) diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go b/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go index 1896000ee00..2cd3eb88934 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go @@ -2,19 +2,19 @@ package ray import ( "context" + "encoding/json" "reflect" "testing" "time" - structpb "github.com/golang/protobuf/ptypes/struct" rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" + "google.golang.org/protobuf/types/known/structpb" corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/resource" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - v1 "k8s.io/apimachinery/pkg/apis/meta/v1" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins" @@ -149,7 +149,7 @@ func dummyRayTaskContext(taskTemplate *core.TaskTemplate, resources *corev1.Reso taskExecutionMetadata.OnGetNamespace().Return("test-namespace") taskExecutionMetadata.OnGetAnnotations().Return(map[string]string{"annotation-1": "val1"}) taskExecutionMetadata.OnGetLabels().Return(map[string]string{"label-1": "val1"}) - taskExecutionMetadata.OnGetOwnerReference().Return(v1.OwnerReference{ + taskExecutionMetadata.OnGetOwnerReference().Return(metav1.OwnerReference{ Kind: "node", Name: "blah", }) @@ -420,7 +420,7 @@ func TestBuildResourceRayExtendedResources(t *testing.T) { } } -func TestBuildResourceRayCustomResources(t *testing.T) { +func TestBuildResourceRayCustomK8SPod(t *testing.T) { assert.NoError(t, config.SetK8sPluginConfig(&config.K8sPluginConfig{})) headResourceEntries := []*core.Resources_ResourceEntry{ @@ -443,11 +443,28 @@ func TestBuildResourceRayCustomResources(t *testing.T) { expectedWorkerResources, err := flytek8s.ToK8sResourceRequirements(workerResources) require.NoError(t, err) + headPodSpec := &corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "ray-head", + Resources: *expectedHeadResources, + }, + }, + } + workerPodSpec := &corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "ray-worker", + Resources: *expectedWorkerResources, + }, + }, + } + params := []struct { name string taskResources *corev1.ResourceRequirements - headResources *core.Resources - workerResources *core.Resources + headK8SPod *core.K8SPod + workerK8SPod *core.K8SPod expectedSubmitterResources *corev1.ResourceRequirements expectedHeadResources *corev1.ResourceRequirements expectedWorkerResources *corev1.ResourceRequirements @@ -460,10 +477,14 @@ func TestBuildResourceRayCustomResources(t *testing.T) { expectedWorkerResources: resourceRequirements, }, { - name: "custom worker and head resources", - taskResources: resourceRequirements, - headResources: headResources, - workerResources: workerResources, + name: "custom worker and head resources", + taskResources: resourceRequirements, + headK8SPod: &core.K8SPod{ + PodSpec: transformStructToStructPB(t, headPodSpec), + }, + workerK8SPod: &core.K8SPod{ + PodSpec: transformStructToStructPB(t, workerPodSpec), + }, expectedSubmitterResources: resourceRequirements, expectedHeadResources: expectedHeadResources, expectedWorkerResources: expectedWorkerResources, @@ -474,13 +495,13 @@ func TestBuildResourceRayCustomResources(t *testing.T) { t.Run(p.name, func(t *testing.T) { rayJobInput := dummyRayCustomObj() - if p.headResources != nil { - rayJobInput.RayCluster.HeadGroupSpec.Resources = p.headResources + if p.headK8SPod != nil { + rayJobInput.RayCluster.HeadGroupSpec.K8SPod = p.headK8SPod } - if p.workerResources != nil { + if p.workerK8SPod != nil { for _, spec := range rayJobInput.RayCluster.WorkerGroupSpec { - spec.Resources = p.workerResources + spec.K8SPod = p.workerK8SPod } } @@ -1200,3 +1221,14 @@ func TestGetPropertiesRay(t *testing.T) { expected := k8s.PluginProperties{} assert.Equal(t, expected, rayJobResourceHandler.GetProperties()) } + +func transformStructToStructPB(t *testing.T, obj interface{}) *structpb.Struct { + data, err := json.Marshal(obj) + assert.Nil(t, err) + podSpecMap := make(map[string]interface{}) + err = json.Unmarshal(data, &podSpecMap) + assert.Nil(t, err) + s, err := structpb.NewStruct(podSpecMap) + assert.Nil(t, err) + return s +}