diff --git a/flytepropeller/pkg/compiler/validators/utils.go b/flytepropeller/pkg/compiler/validators/utils.go index cbb14b3124..d3e533e00d 100644 --- a/flytepropeller/pkg/compiler/validators/utils.go +++ b/flytepropeller/pkg/compiler/validators/utils.go @@ -202,13 +202,18 @@ func buildMultipleTypeUnion(innerType []*core.LiteralType) *core.LiteralType { return unionLiteralType } -func literalTypeForLiterals(literals []*core.Literal) *core.LiteralType { +func literalTypeForLiterals(literals []*core.Literal) (*core.LiteralType, bool) { innerType := make([]*core.LiteralType, 0, 1) innerTypeSet := sets.NewString() var noneType *core.LiteralType + isOffloadedType := false for _, x := range literals { otherType := LiteralTypeForLiteral(x) otherTypeKey := otherType.String() + if _, ok := x.GetValue().(*core.Literal_OffloadedMetadata); ok { + isOffloadedType = true + return otherType, isOffloadedType + } if _, ok := x.GetValue().(*core.Literal_Collection); ok { if x.GetCollection().GetLiterals() == nil { noneType = otherType @@ -230,9 +235,9 @@ func literalTypeForLiterals(literals []*core.Literal) *core.LiteralType { if len(innerType) == 0 { return &core.LiteralType{ Type: &core.LiteralType_Simple{Simple: core.SimpleType_NONE}, - } + }, isOffloadedType } else if len(innerType) == 1 { - return innerType[0] + return innerType[0], isOffloadedType } // sort inner types to ensure consistent union types are generated @@ -247,7 +252,7 @@ func literalTypeForLiterals(literals []*core.Literal) *core.LiteralType { return 0 }) - return buildMultipleTypeUnion(innerType) + return buildMultipleTypeUnion(innerType), isOffloadedType } // ValidateLiteralType check if the literal type is valid, return error if the literal is invalid. @@ -274,15 +279,23 @@ func LiteralTypeForLiteral(l *core.Literal) *core.LiteralType { case *core.Literal_Scalar: return literalTypeForScalar(l.GetScalar()) case *core.Literal_Collection: + collectionType, isOffloaded := literalTypeForLiterals(l.GetCollection().Literals) + if isOffloaded { + return collectionType + } return &core.LiteralType{ Type: &core.LiteralType_CollectionType{ - CollectionType: literalTypeForLiterals(l.GetCollection().Literals), + CollectionType: collectionType, }, } case *core.Literal_Map: + mapValueType, isOffloaded := literalTypeForLiterals(maps.Values(l.GetMap().Literals)) + if isOffloaded { + return mapValueType + } return &core.LiteralType{ Type: &core.LiteralType_MapValueType{ - MapValueType: literalTypeForLiterals(maps.Values(l.GetMap().Literals)), + MapValueType: mapValueType, }, } case *core.Literal_OffloadedMetadata: diff --git a/flytepropeller/pkg/compiler/validators/utils_test.go b/flytepropeller/pkg/compiler/validators/utils_test.go index 09790849f3..dd32a98a53 100644 --- a/flytepropeller/pkg/compiler/validators/utils_test.go +++ b/flytepropeller/pkg/compiler/validators/utils_test.go @@ -13,8 +13,9 @@ import ( func TestLiteralTypeForLiterals(t *testing.T) { t.Run("empty", func(t *testing.T) { - lt := literalTypeForLiterals(nil) + lt, isOffloaded := literalTypeForLiterals(nil) assert.Equal(t, core.SimpleType_NONE.String(), lt.GetSimple().String()) + assert.False(t, isOffloaded) }) t.Run("binary idl with raw binary data and no tag", func(t *testing.T) { @@ -94,17 +95,18 @@ func TestLiteralTypeForLiterals(t *testing.T) { }) t.Run("homogeneous", func(t *testing.T) { - lt := literalTypeForLiterals([]*core.Literal{ + lt, isOffloaded := literalTypeForLiterals([]*core.Literal{ coreutils.MustMakeLiteral(5), coreutils.MustMakeLiteral(0), coreutils.MustMakeLiteral(5), }) assert.Equal(t, core.SimpleType_INTEGER.String(), lt.GetSimple().String()) + assert.False(t, isOffloaded) }) t.Run("non-homogenous", func(t *testing.T) { - lt := literalTypeForLiterals([]*core.Literal{ + lt, isOffloaded := literalTypeForLiterals([]*core.Literal{ coreutils.MustMakeLiteral("hello"), coreutils.MustMakeLiteral(5), coreutils.MustMakeLiteral("world"), @@ -115,10 +117,11 @@ func TestLiteralTypeForLiterals(t *testing.T) { assert.Len(t, lt.GetUnionType().Variants, 2) assert.Equal(t, core.SimpleType_INTEGER.String(), lt.GetUnionType().Variants[0].GetSimple().String()) assert.Equal(t, core.SimpleType_STRING.String(), lt.GetUnionType().Variants[1].GetSimple().String()) + assert.False(t, isOffloaded) }) t.Run("non-homogenous ensure ordering", func(t *testing.T) { - lt := literalTypeForLiterals([]*core.Literal{ + lt, isOffloaded := literalTypeForLiterals([]*core.Literal{ coreutils.MustMakeLiteral(5), coreutils.MustMakeLiteral("world"), coreutils.MustMakeLiteral(0), @@ -128,6 +131,7 @@ func TestLiteralTypeForLiterals(t *testing.T) { assert.Len(t, lt.GetUnionType().Variants, 2) assert.Equal(t, core.SimpleType_INTEGER.String(), lt.GetUnionType().Variants[0].GetSimple().String()) assert.Equal(t, core.SimpleType_STRING.String(), lt.GetUnionType().Variants[1].GetSimple().String()) + assert.False(t, isOffloaded) }) t.Run("list with mixed types", func(t *testing.T) { @@ -454,6 +458,89 @@ func TestLiteralTypeForLiterals(t *testing.T) { assert.True(t, proto.Equal(expectedLt, lt)) }) + t.Run("nested Lists of offloaded List of string types", func(t *testing.T) { + inferredType := &core.LiteralType{ + Type: &core.LiteralType_CollectionType{ + CollectionType: &core.LiteralType{ + Type: &core.LiteralType_Simple{ + Simple: core.SimpleType_STRING, + }, + }, + }, + } + literals := &core.Literal{ + Value: &core.Literal_Collection{ + Collection: &core.LiteralCollection{ + Literals: []*core.Literal{ + { + Value: &core.Literal_OffloadedMetadata{ + OffloadedMetadata: &core.LiteralOffloadedMetadata{ + Uri: "dummy/uri-1", + SizeBytes: 1000, + InferredType: inferredType, + }, + }, + }, + { + Value: &core.Literal_OffloadedMetadata{ + OffloadedMetadata: &core.LiteralOffloadedMetadata{ + Uri: "dummy/uri-2", + SizeBytes: 1000, + InferredType: inferredType, + }, + }, + }, + }, + }, + }, + } + expectedLt := inferredType + lt := LiteralTypeForLiteral(literals) + assert.True(t, proto.Equal(expectedLt, lt)) + }) + t.Run("nested map of offloaded map of string types", func(t *testing.T) { + inferredType := &core.LiteralType{ + Type: &core.LiteralType_MapValueType{ + MapValueType: &core.LiteralType{ + Type: &core.LiteralType_Simple{ + Simple: core.SimpleType_STRING, + }, + }, + }, + } + literals := &core.Literal{ + Value: &core.Literal_Map{ + Map: &core.LiteralMap{ + Literals: map[string]*core.Literal{ + + "key1": { + Value: &core.Literal_OffloadedMetadata{ + OffloadedMetadata: &core.LiteralOffloadedMetadata{ + Uri: "dummy/uri-1", + SizeBytes: 1000, + InferredType: inferredType, + }, + }, + }, + "key2": { + Value: &core.Literal_OffloadedMetadata{ + OffloadedMetadata: &core.LiteralOffloadedMetadata{ + Uri: "dummy/uri-2", + SizeBytes: 1000, + InferredType: inferredType, + }, + }, + }, + }, + }, + }, + } + + expectedLt := inferredType + lt := LiteralTypeForLiteral(literals) + assert.True(t, proto.Equal(expectedLt, lt)) + }) + } func TestJoinVariableMapsUniqueKeys(t *testing.T) {