diff --git a/flytepropeller/pkg/compiler/validators/utils.go b/flytepropeller/pkg/compiler/validators/utils.go index c053a8b399..bd1416b583 100644 --- a/flytepropeller/pkg/compiler/validators/utils.go +++ b/flytepropeller/pkg/compiler/validators/utils.go @@ -202,18 +202,13 @@ func buildMultipleTypeUnion(innerType []*core.LiteralType) *core.LiteralType { return unionLiteralType } -func literalTypeForLiterals(literals []*core.Literal) (*core.LiteralType, bool) { +func literalTypeForLiterals(literals []*core.Literal) *core.LiteralType { innerType := make([]*core.LiteralType, 0, 1) innerTypeSet := sets.NewString() var noneType *core.LiteralType - isOffloadedType := false for _, x := range literals { otherType := LiteralTypeForLiteral(x) otherTypeKey := otherType.String() - if _, ok := x.GetValue().(*core.Literal_OffloadedMetadata); ok { - isOffloadedType = true - return otherType, isOffloadedType - } if _, ok := x.GetValue().(*core.Literal_Collection); ok { if x.GetCollection().GetLiterals() == nil { noneType = otherType @@ -235,9 +230,9 @@ func literalTypeForLiterals(literals []*core.Literal) (*core.LiteralType, bool) if len(innerType) == 0 { return &core.LiteralType{ Type: &core.LiteralType_Simple{Simple: core.SimpleType_NONE}, - }, isOffloadedType + } } else if len(innerType) == 1 { - return innerType[0], isOffloadedType + return innerType[0] } // sort inner types to ensure consistent union types are generated @@ -252,7 +247,7 @@ func literalTypeForLiterals(literals []*core.Literal) (*core.LiteralType, bool) return 0 }) - return buildMultipleTypeUnion(innerType), isOffloadedType + return buildMultipleTypeUnion(innerType) } // ValidateLiteralType check if the literal type is valid, return error if the literal is invalid. @@ -279,23 +274,15 @@ func LiteralTypeForLiteral(l *core.Literal) *core.LiteralType { case *core.Literal_Scalar: return literalTypeForScalar(l.GetScalar()) case *core.Literal_Collection: - collectionType, isOffloaded := literalTypeForLiterals(l.GetCollection().GetLiterals()) - if isOffloaded { - return collectionType - } return &core.LiteralType{ Type: &core.LiteralType_CollectionType{ - CollectionType: collectionType, + CollectionType: literalTypeForLiterals(l.GetCollection().Literals), }, } case *core.Literal_Map: - mapValueType, isOffloaded := literalTypeForLiterals(maps.Values(l.GetMap().GetLiterals())) - if isOffloaded { - return mapValueType - } return &core.LiteralType{ Type: &core.LiteralType_MapValueType{ - MapValueType: mapValueType, + MapValueType: literalTypeForLiterals(maps.Values(l.GetMap().Literals)), }, } case *core.Literal_OffloadedMetadata: diff --git a/flytepropeller/pkg/compiler/validators/utils_test.go b/flytepropeller/pkg/compiler/validators/utils_test.go index 41a1333e62..09790849f3 100644 --- a/flytepropeller/pkg/compiler/validators/utils_test.go +++ b/flytepropeller/pkg/compiler/validators/utils_test.go @@ -13,9 +13,8 @@ import ( func TestLiteralTypeForLiterals(t *testing.T) { t.Run("empty", func(t *testing.T) { - lt, isOffloaded := literalTypeForLiterals(nil) + lt := literalTypeForLiterals(nil) assert.Equal(t, core.SimpleType_NONE.String(), lt.GetSimple().String()) - assert.False(t, isOffloaded) }) t.Run("binary idl with raw binary data and no tag", func(t *testing.T) { @@ -95,18 +94,17 @@ func TestLiteralTypeForLiterals(t *testing.T) { }) t.Run("homogeneous", func(t *testing.T) { - lt, isOffloaded := literalTypeForLiterals([]*core.Literal{ + lt := literalTypeForLiterals([]*core.Literal{ coreutils.MustMakeLiteral(5), coreutils.MustMakeLiteral(0), coreutils.MustMakeLiteral(5), }) assert.Equal(t, core.SimpleType_INTEGER.String(), lt.GetSimple().String()) - assert.False(t, isOffloaded) }) t.Run("non-homogenous", func(t *testing.T) { - lt, isOffloaded := literalTypeForLiterals([]*core.Literal{ + lt := literalTypeForLiterals([]*core.Literal{ coreutils.MustMakeLiteral("hello"), coreutils.MustMakeLiteral(5), coreutils.MustMakeLiteral("world"), @@ -114,24 +112,22 @@ func TestLiteralTypeForLiterals(t *testing.T) { coreutils.MustMakeLiteral(2), }) - assert.Len(t, lt.GetUnionType().GetVariants(), 2) - assert.Equal(t, core.SimpleType_INTEGER.String(), lt.GetUnionType().GetVariants()[0].GetSimple().String()) - assert.Equal(t, core.SimpleType_STRING.String(), lt.GetUnionType().GetVariants()[1].GetSimple().String()) - assert.False(t, isOffloaded) + assert.Len(t, lt.GetUnionType().Variants, 2) + assert.Equal(t, core.SimpleType_INTEGER.String(), lt.GetUnionType().Variants[0].GetSimple().String()) + assert.Equal(t, core.SimpleType_STRING.String(), lt.GetUnionType().Variants[1].GetSimple().String()) }) t.Run("non-homogenous ensure ordering", func(t *testing.T) { - lt, isOffloaded := literalTypeForLiterals([]*core.Literal{ + lt := literalTypeForLiterals([]*core.Literal{ coreutils.MustMakeLiteral(5), coreutils.MustMakeLiteral("world"), coreutils.MustMakeLiteral(0), coreutils.MustMakeLiteral(2), }) - assert.Len(t, lt.GetUnionType().GetVariants(), 2) - assert.Equal(t, core.SimpleType_INTEGER.String(), lt.GetUnionType().GetVariants()[0].GetSimple().String()) - assert.Equal(t, core.SimpleType_STRING.String(), lt.GetUnionType().GetVariants()[1].GetSimple().String()) - assert.False(t, isOffloaded) + assert.Len(t, lt.GetUnionType().Variants, 2) + assert.Equal(t, core.SimpleType_INTEGER.String(), lt.GetUnionType().Variants[0].GetSimple().String()) + assert.Equal(t, core.SimpleType_STRING.String(), lt.GetUnionType().Variants[1].GetSimple().String()) }) t.Run("list with mixed types", func(t *testing.T) { @@ -458,89 +454,6 @@ func TestLiteralTypeForLiterals(t *testing.T) { assert.True(t, proto.Equal(expectedLt, lt)) }) - t.Run("nested Lists of offloaded List of string types", func(t *testing.T) { - inferredType := &core.LiteralType{ - Type: &core.LiteralType_CollectionType{ - CollectionType: &core.LiteralType{ - Type: &core.LiteralType_Simple{ - Simple: core.SimpleType_STRING, - }, - }, - }, - } - literals := &core.Literal{ - Value: &core.Literal_Collection{ - Collection: &core.LiteralCollection{ - Literals: []*core.Literal{ - { - Value: &core.Literal_OffloadedMetadata{ - OffloadedMetadata: &core.LiteralOffloadedMetadata{ - Uri: "dummy/uri-1", - SizeBytes: 1000, - InferredType: inferredType, - }, - }, - }, - { - Value: &core.Literal_OffloadedMetadata{ - OffloadedMetadata: &core.LiteralOffloadedMetadata{ - Uri: "dummy/uri-2", - SizeBytes: 1000, - InferredType: inferredType, - }, - }, - }, - }, - }, - }, - } - expectedLt := inferredType - lt := LiteralTypeForLiteral(literals) - assert.True(t, proto.Equal(expectedLt, lt)) - }) - t.Run("nested map of offloaded map of string types", func(t *testing.T) { - inferredType := &core.LiteralType{ - Type: &core.LiteralType_MapValueType{ - MapValueType: &core.LiteralType{ - Type: &core.LiteralType_Simple{ - Simple: core.SimpleType_STRING, - }, - }, - }, - } - literals := &core.Literal{ - Value: &core.Literal_Map{ - Map: &core.LiteralMap{ - Literals: map[string]*core.Literal{ - - "key1": { - Value: &core.Literal_OffloadedMetadata{ - OffloadedMetadata: &core.LiteralOffloadedMetadata{ - Uri: "dummy/uri-1", - SizeBytes: 1000, - InferredType: inferredType, - }, - }, - }, - "key2": { - Value: &core.Literal_OffloadedMetadata{ - OffloadedMetadata: &core.LiteralOffloadedMetadata{ - Uri: "dummy/uri-2", - SizeBytes: 1000, - InferredType: inferredType, - }, - }, - }, - }, - }, - }, - } - - expectedLt := inferredType - lt := LiteralTypeForLiteral(literals) - assert.True(t, proto.Equal(expectedLt, lt)) - }) - } func TestJoinVariableMapsUniqueKeys(t *testing.T) {