From f20b8aa082820c56c0c721670b54148afa7d36a4 Mon Sep 17 00:00:00 2001 From: Prafulla Mahindrakar Date: Tue, 12 Nov 2024 12:51:42 -0800 Subject: [PATCH] [COR-2297/] Fix nested offloaded type validation (#552) (#5996) The following workflow works when we are not offloading literals in flytekit ``` import logging from typing import List from flytekit import map_task, task, workflow,LaunchPlan logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger("flytekit") logger.setLevel(logging.DEBUG) @task(cache=True, cache_version="1.1") def my_30mb_task(i: str) -> str: return f"Hello world {i}" * 30 * 100 * 1024 @task(cache=True, cache_version="1.1") def generate_strs(count: int) -> List[str]: return ["a"] * count @workflow def my_30mb_wf(mbs: int) -> List[str]: strs = generate_strs(count=mbs) return map_task(my_30mb_task)(i=strs) @workflow def big_inputs_wf(input: List[str]): noop() @task(cache=True, cache_version="1.1") def noop(): ... big_inputs_wf_lp = LaunchPlan.get_or_create(name="big_inputs_wf_lp", workflow=big_inputs_wf) @workflow def ref_wf(mbs: int): big_inputs_wf_lp(input=my_30mb_wf(mbs)) ``` Without flytekit offloading the return type is OffloadedLiteral{inferredType:{Collection{String}} and when checked against big_inputs_wf launchplan which needs Collection{String} , the LiteralTypeToLiteral returns the inferredType : Collection{String} If we enable offloading in flytekit, the returned data from map task is Collection{OffloadedLiteral<{inferredType:{Collection{String}}} When passing this Input to big_inputs_wf which takes Collection{String} then the type check fails due to LiteralTypeToLiteral returning Collection{OffloadedLiteral{inferredType:{Collection{String}}} as Collection{Collection{String}} Flytekit handles this case by special casing Collection{OffloadedLiteral} and similar special casing is needed in flyte code base Tested this by deploying this PR changes https://dogfood-gcp.cloud-staging.union.ai/console/projects/flytesnacks/domains/development/executions/akxs97cdmkmxhhqp228x/nodes Earlier it would fail like this https://dogfood-gcp.cloud-staging.union.ai/console/projects/flytesnacks/domains/development/executions/ap4thjp5528kjfspcsds/nodes ``` [UserError] failed to launch workflow, caused by: rpc error: code = InvalidArgument desc = invalid input input wrong type. Expected collection_type:{simple:STRING}, but got collection_type:{collection_type:{simple:STRING}} ``` Rollout to canary and then all prod byoc and serverless tenants Should this change be upstreamed to OSS (flyteorg/flyte)? If not, please uncheck this box, which is used for auditing. Note, it is the responsibility of each developer to actually upstream their changes. See [this guide](https://unionai.atlassian.net/wiki/spaces/ENG/pages/447610883/Flyte+-+Union+Cloud+Development+Runbook/#When-are-versions-updated%3F). - [x] To be upstreamed to OSS *TODO: Link Linear issue(s) using [magic words](https://linear.app/docs/github#magic-words). `fixes` will move to merged status, while `ref` will only link the PR.* * [X] Added tests * [ ] Ran a deploy dry run and shared the terraform plan * [ ] Added logging and metrics * [ ] Updated [dashboards](https://unionai.grafana.net/dashboards) and [alerts](https://unionai.grafana.net/alerting/list) * [ ] Updated documentation --- .../pkg/compiler/validators/utils.go | 25 +++-- .../pkg/compiler/validators/utils_test.go | 95 ++++++++++++++++++- 2 files changed, 110 insertions(+), 10 deletions(-) diff --git a/flytepropeller/pkg/compiler/validators/utils.go b/flytepropeller/pkg/compiler/validators/utils.go index cbb14b3124..d3e533e00d 100644 --- a/flytepropeller/pkg/compiler/validators/utils.go +++ b/flytepropeller/pkg/compiler/validators/utils.go @@ -202,13 +202,18 @@ func buildMultipleTypeUnion(innerType []*core.LiteralType) *core.LiteralType { return unionLiteralType } -func literalTypeForLiterals(literals []*core.Literal) *core.LiteralType { +func literalTypeForLiterals(literals []*core.Literal) (*core.LiteralType, bool) { innerType := make([]*core.LiteralType, 0, 1) innerTypeSet := sets.NewString() var noneType *core.LiteralType + isOffloadedType := false for _, x := range literals { otherType := LiteralTypeForLiteral(x) otherTypeKey := otherType.String() + if _, ok := x.GetValue().(*core.Literal_OffloadedMetadata); ok { + isOffloadedType = true + return otherType, isOffloadedType + } if _, ok := x.GetValue().(*core.Literal_Collection); ok { if x.GetCollection().GetLiterals() == nil { noneType = otherType @@ -230,9 +235,9 @@ func literalTypeForLiterals(literals []*core.Literal) *core.LiteralType { if len(innerType) == 0 { return &core.LiteralType{ Type: &core.LiteralType_Simple{Simple: core.SimpleType_NONE}, - } + }, isOffloadedType } else if len(innerType) == 1 { - return innerType[0] + return innerType[0], isOffloadedType } // sort inner types to ensure consistent union types are generated @@ -247,7 +252,7 @@ func literalTypeForLiterals(literals []*core.Literal) *core.LiteralType { return 0 }) - return buildMultipleTypeUnion(innerType) + return buildMultipleTypeUnion(innerType), isOffloadedType } // ValidateLiteralType check if the literal type is valid, return error if the literal is invalid. @@ -274,15 +279,23 @@ func LiteralTypeForLiteral(l *core.Literal) *core.LiteralType { case *core.Literal_Scalar: return literalTypeForScalar(l.GetScalar()) case *core.Literal_Collection: + collectionType, isOffloaded := literalTypeForLiterals(l.GetCollection().Literals) + if isOffloaded { + return collectionType + } return &core.LiteralType{ Type: &core.LiteralType_CollectionType{ - CollectionType: literalTypeForLiterals(l.GetCollection().Literals), + CollectionType: collectionType, }, } case *core.Literal_Map: + mapValueType, isOffloaded := literalTypeForLiterals(maps.Values(l.GetMap().Literals)) + if isOffloaded { + return mapValueType + } return &core.LiteralType{ Type: &core.LiteralType_MapValueType{ - MapValueType: literalTypeForLiterals(maps.Values(l.GetMap().Literals)), + MapValueType: mapValueType, }, } case *core.Literal_OffloadedMetadata: diff --git a/flytepropeller/pkg/compiler/validators/utils_test.go b/flytepropeller/pkg/compiler/validators/utils_test.go index 09790849f3..dd32a98a53 100644 --- a/flytepropeller/pkg/compiler/validators/utils_test.go +++ b/flytepropeller/pkg/compiler/validators/utils_test.go @@ -13,8 +13,9 @@ import ( func TestLiteralTypeForLiterals(t *testing.T) { t.Run("empty", func(t *testing.T) { - lt := literalTypeForLiterals(nil) + lt, isOffloaded := literalTypeForLiterals(nil) assert.Equal(t, core.SimpleType_NONE.String(), lt.GetSimple().String()) + assert.False(t, isOffloaded) }) t.Run("binary idl with raw binary data and no tag", func(t *testing.T) { @@ -94,17 +95,18 @@ func TestLiteralTypeForLiterals(t *testing.T) { }) t.Run("homogeneous", func(t *testing.T) { - lt := literalTypeForLiterals([]*core.Literal{ + lt, isOffloaded := literalTypeForLiterals([]*core.Literal{ coreutils.MustMakeLiteral(5), coreutils.MustMakeLiteral(0), coreutils.MustMakeLiteral(5), }) assert.Equal(t, core.SimpleType_INTEGER.String(), lt.GetSimple().String()) + assert.False(t, isOffloaded) }) t.Run("non-homogenous", func(t *testing.T) { - lt := literalTypeForLiterals([]*core.Literal{ + lt, isOffloaded := literalTypeForLiterals([]*core.Literal{ coreutils.MustMakeLiteral("hello"), coreutils.MustMakeLiteral(5), coreutils.MustMakeLiteral("world"), @@ -115,10 +117,11 @@ func TestLiteralTypeForLiterals(t *testing.T) { assert.Len(t, lt.GetUnionType().Variants, 2) assert.Equal(t, core.SimpleType_INTEGER.String(), lt.GetUnionType().Variants[0].GetSimple().String()) assert.Equal(t, core.SimpleType_STRING.String(), lt.GetUnionType().Variants[1].GetSimple().String()) + assert.False(t, isOffloaded) }) t.Run("non-homogenous ensure ordering", func(t *testing.T) { - lt := literalTypeForLiterals([]*core.Literal{ + lt, isOffloaded := literalTypeForLiterals([]*core.Literal{ coreutils.MustMakeLiteral(5), coreutils.MustMakeLiteral("world"), coreutils.MustMakeLiteral(0), @@ -128,6 +131,7 @@ func TestLiteralTypeForLiterals(t *testing.T) { assert.Len(t, lt.GetUnionType().Variants, 2) assert.Equal(t, core.SimpleType_INTEGER.String(), lt.GetUnionType().Variants[0].GetSimple().String()) assert.Equal(t, core.SimpleType_STRING.String(), lt.GetUnionType().Variants[1].GetSimple().String()) + assert.False(t, isOffloaded) }) t.Run("list with mixed types", func(t *testing.T) { @@ -454,6 +458,89 @@ func TestLiteralTypeForLiterals(t *testing.T) { assert.True(t, proto.Equal(expectedLt, lt)) }) + t.Run("nested Lists of offloaded List of string types", func(t *testing.T) { + inferredType := &core.LiteralType{ + Type: &core.LiteralType_CollectionType{ + CollectionType: &core.LiteralType{ + Type: &core.LiteralType_Simple{ + Simple: core.SimpleType_STRING, + }, + }, + }, + } + literals := &core.Literal{ + Value: &core.Literal_Collection{ + Collection: &core.LiteralCollection{ + Literals: []*core.Literal{ + { + Value: &core.Literal_OffloadedMetadata{ + OffloadedMetadata: &core.LiteralOffloadedMetadata{ + Uri: "dummy/uri-1", + SizeBytes: 1000, + InferredType: inferredType, + }, + }, + }, + { + Value: &core.Literal_OffloadedMetadata{ + OffloadedMetadata: &core.LiteralOffloadedMetadata{ + Uri: "dummy/uri-2", + SizeBytes: 1000, + InferredType: inferredType, + }, + }, + }, + }, + }, + }, + } + expectedLt := inferredType + lt := LiteralTypeForLiteral(literals) + assert.True(t, proto.Equal(expectedLt, lt)) + }) + t.Run("nested map of offloaded map of string types", func(t *testing.T) { + inferredType := &core.LiteralType{ + Type: &core.LiteralType_MapValueType{ + MapValueType: &core.LiteralType{ + Type: &core.LiteralType_Simple{ + Simple: core.SimpleType_STRING, + }, + }, + }, + } + literals := &core.Literal{ + Value: &core.Literal_Map{ + Map: &core.LiteralMap{ + Literals: map[string]*core.Literal{ + + "key1": { + Value: &core.Literal_OffloadedMetadata{ + OffloadedMetadata: &core.LiteralOffloadedMetadata{ + Uri: "dummy/uri-1", + SizeBytes: 1000, + InferredType: inferredType, + }, + }, + }, + "key2": { + Value: &core.Literal_OffloadedMetadata{ + OffloadedMetadata: &core.LiteralOffloadedMetadata{ + Uri: "dummy/uri-2", + SizeBytes: 1000, + InferredType: inferredType, + }, + }, + }, + }, + }, + }, + } + + expectedLt := inferredType + lt := LiteralTypeForLiteral(literals) + assert.True(t, proto.Equal(expectedLt, lt)) + }) + } func TestJoinVariableMapsUniqueKeys(t *testing.T) {