Skip to content

Commit

Permalink
Fix propeller crash when inferring literal type for an offloaded lite…
Browse files Browse the repository at this point in the history
…ral (#5771)

* Fix propeller crash when inferring literal type for an offloaded literal

Signed-off-by: pmahindrakar-oss <[email protected]>

* Added reading of large offloaded literal for map task

Signed-off-by: pmahindrakar-oss <[email protected]>

* added changes for promise resolver

Signed-off-by: pmahindrakar-oss <[email protected]>

* flytectl support for offloaded literal

Signed-off-by: pmahindrakar-oss <[email protected]>

* refactor

Signed-off-by: pmahindrakar-oss <[email protected]>

* Add hash of the offloaded value for caching

Signed-off-by: pmahindrakar-oss <[email protected]>

* add hash unit test

Signed-off-by: pmahindrakar-oss <[email protected]>

* unit tests

Signed-off-by: pmahindrakar-oss <[email protected]>

* remove type

Signed-off-by: pmahindrakar-oss <[email protected]>

* lint

Signed-off-by: pmahindrakar-oss <[email protected]>

* remove overwriting the inputs and fix the issue in flytekit

Signed-off-by: pmahindrakar-oss <[email protected]>

* review comments

Signed-off-by: pmahindrakar-oss <[email protected]>

---------

Signed-off-by: pmahindrakar-oss <[email protected]>
  • Loading branch information
pmahindrakar-oss authored Oct 2, 2024
1 parent 1942173 commit e7ce437
Show file tree
Hide file tree
Showing 15 changed files with 205 additions and 26 deletions.
2 changes: 2 additions & 0 deletions flyte-single-binary-local.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ storage:
access_key_id: minio
secret_key: miniostorage
container: my-s3-bucket
limits:
maxDownloadMBs: 1000

task_resources:
defaults:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,7 @@ func CheckAndFetchInputsForExecution(
}
executionInputMap[name] = expectedInput.GetDefault()
} else {
var inputType *core.LiteralType
switch executionInputMap[name].GetValue().(type) {
case *core.Literal_OffloadedMetadata:
inputType = executionInputMap[name].GetOffloadedMetadata().GetInferredType()
default:
inputType = validators.LiteralTypeForLiteral(executionInputMap[name])
}
inputType := validators.LiteralTypeForLiteral(executionInputMap[name])
err := validators.ValidateLiteralType(inputType)
if err != nil {
return nil, errors.NewInvalidLiteralTypeError(name, err)
Expand Down
4 changes: 4 additions & 0 deletions flyteidl/clients/go/coreutils/extract_literal.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,10 @@ func ExtractFromLiteral(literal *core.Literal) (interface{}, error) {
}
}
return mapResult, nil
case *core.Literal_OffloadedMetadata:
// Return the URI of the offloaded metadata to be used when displaying in flytectl
return literalValue.OffloadedMetadata.Uri, nil

}
return nil, fmt.Errorf("unsupported literal type %T", literal)
}
24 changes: 24 additions & 0 deletions flyteidl/clients/go/coreutils/extract_literal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,30 @@ func TestFetchLiteral(t *testing.T) {
assert.Equal(t, literalVal, extractedLiteralVal)
})

t.Run("Offloaded metadata", func(t *testing.T) {
literalVal := "s3://blah/blah/blah"
var storedLiteralType = &core.LiteralType{
Type: &core.LiteralType_CollectionType{
CollectionType: &core.LiteralType{
Type: &core.LiteralType_Simple{
Simple: core.SimpleType_INTEGER,
},
},
},
}
offloadedLiteral := &core.Literal{
Value: &core.Literal_OffloadedMetadata{
OffloadedMetadata: &core.LiteralOffloadedMetadata{
Uri: literalVal,
InferredType: storedLiteralType,
},
},
}
extractedLiteralVal, err := ExtractFromLiteral(offloadedLiteral)
assert.NoError(t, err)
assert.Equal(t, literalVal, extractedLiteralVal)
})

t.Run("Union", func(t *testing.T) {
literalVal := int64(1)
var literalType = &core.LiteralType{
Expand Down
1 change: 0 additions & 1 deletion flyteidl/clients/go/coreutils/literals.go
Original file line number Diff line number Diff line change
Expand Up @@ -636,7 +636,6 @@ func MakeLiteralForType(t *core.LiteralType, v interface{}) (*core.Literal, erro
if !found {
return nil, fmt.Errorf("incorrect union value [%s], supported values %+v", v, newT.UnionType.Variants)
}

default:
return nil, fmt.Errorf("unsupported type %s", t.String())
}
Expand Down
10 changes: 1 addition & 9 deletions flytepropeller/pkg/compiler/transformers/k8s/inputs.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,20 +35,12 @@ func validateInputs(nodeID common.NodeID, iface *core.TypedInterface, inputs cor
continue
}

var inputType *core.LiteralType
switch inputVal.GetValue().(type) {
case *core.Literal_OffloadedMetadata:
inputType = inputVal.GetOffloadedMetadata().GetInferredType()
default:
inputType = validators.LiteralTypeForLiteral(inputVal)
}

inputType := validators.LiteralTypeForLiteral(inputVal)
err := validators.ValidateLiteralType(inputType)
if err != nil {
errs.Collect(errors.NewInvalidLiteralTypeErr(nodeID, inputVar, err))
continue
}

if !validators.AreTypesCastable(inputType, v.Type) {
errs.Collect(errors.NewMismatchingTypesErr(nodeID, inputVar, v.Type.String(), inputType.String()))
continue
Expand Down
3 changes: 2 additions & 1 deletion flytepropeller/pkg/compiler/validators/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -284,8 +284,9 @@ func LiteralTypeForLiteral(l *core.Literal) *core.LiteralType {
MapValueType: literalTypeForLiterals(maps.Values(l.GetMap().Literals)),
},
}
case *core.Literal_OffloadedMetadata:
return l.GetOffloadedMetadata().GetInferredType()
}

return nil
}

Expand Down
41 changes: 41 additions & 0 deletions flytepropeller/pkg/compiler/validators/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,47 @@ func TestLiteralTypeForLiterals(t *testing.T) {
assert.True(t, proto.Equal(expectedLt, lt))
})

t.Run("nested Lists with different types", func(t *testing.T) {
inferredType := &core.LiteralType{
Type: &core.LiteralType_CollectionType{
CollectionType: &core.LiteralType{
Type: &core.LiteralType_CollectionType{
CollectionType: &core.LiteralType{
Type: &core.LiteralType_UnionType{
UnionType: &core.UnionType{
Variants: []*core.LiteralType{
{
Type: &core.LiteralType_Simple{
Simple: core.SimpleType_INTEGER,
},
},
{
Type: &core.LiteralType_Simple{
Simple: core.SimpleType_STRING,
},
},
},
},
},
},
},
},
},
}
literals := &core.Literal{
Value: &core.Literal_OffloadedMetadata{
OffloadedMetadata: &core.LiteralOffloadedMetadata{
Uri: "dummy/uri",
SizeBytes: 1000,
InferredType: inferredType,
},
},
}
expectedLt := inferredType
lt := LiteralTypeForLiteral(literals)
assert.True(t, proto.Equal(expectedLt, lt))
})

}

func TestJoinVariableMapsUniqueKeys(t *testing.T) {
Expand Down
12 changes: 11 additions & 1 deletion flytepropeller/pkg/controller/nodes/array/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu
}

size := -1

for key, variable := range literalMap.Literals {
literalType := validators.LiteralTypeForLiteral(variable)
err := validators.ValidateLiteralType(literalType)
Expand All @@ -200,10 +201,19 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu
handler.PhaseInfoFailure(idlcore.ExecutionError_USER, errors.IDLNotFoundErr, errMsg, nil),
), nil
}
if variable.GetOffloadedMetadata() != nil {
// variable will be overwritten with the contents of the offloaded data which contains the actual large literal.
// We need this for the map task to be able to create the subNodeSpec
err := common.ReadLargeLiteral(ctx, nCtx.DataStore(), variable)
if err != nil {
return handler.DoTransition(handler.TransitionTypeEphemeral,
handler.PhaseInfoFailure(idlcore.ExecutionError_SYSTEM, errors.RuntimeExecutionError, "couldn't read the offloaded literal", nil),
), nil
}
}
switch literalType.Type.(type) {
case *idlcore.LiteralType_CollectionType:
collectionLength := len(variable.GetCollection().Literals)

if size == -1 {
size = collectionLength
} else if size != collectionLength {
Expand Down
15 changes: 14 additions & 1 deletion flytepropeller/pkg/controller/nodes/attr_path_resolver.go
Original file line number Diff line number Diff line change
@@ -1,23 +1,36 @@
package nodes

import (
"context"

"google.golang.org/protobuf/types/known/structpb"

"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core"
"github.com/flyteorg/flyte/flytepropeller/pkg/controller/nodes/common"
"github.com/flyteorg/flyte/flytepropeller/pkg/controller/nodes/errors"
"github.com/flyteorg/flyte/flytestdlib/storage"
)

// resolveAttrPathInPromise resolves the literal with attribute path
// If the promise is chained with attributes (e.g. promise.a["b"][0]), then we need to resolve the promise
func resolveAttrPathInPromise(nodeID string, literal *core.Literal, bindAttrPath []*core.PromiseAttribute) (*core.Literal, error) {
func resolveAttrPathInPromise(ctx context.Context, datastore *storage.DataStore, nodeID string, literal *core.Literal, bindAttrPath []*core.PromiseAttribute) (*core.Literal, error) {
var currVal *core.Literal = literal
var tmpVal *core.Literal
var err error
var exist bool
count := 0

for _, attr := range bindAttrPath {
if currVal.GetOffloadedMetadata() != nil {
// currVal will be overwritten with the contents of the offloaded data which contains the actual large literal.
err := common.ReadLargeLiteral(ctx, datastore, currVal)
if err != nil {
return nil, errors.Errorf(errors.PromiseAttributeResolveError, nodeID, "failed to read offloaded metadata for promise")
}
}
switch currVal.GetValue().(type) {
case *core.Literal_OffloadedMetadata:
return nil, errors.Errorf(errors.PromiseAttributeResolveError, nodeID, "unexpected offloaded metadata type")
case *core.Literal_Map:
tmpVal, exist = currVal.GetMap().GetLiterals()[attr.GetStringValue()]
if !exist {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package nodes

import (
"context"
"testing"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -319,7 +320,7 @@ func TestResolveAttrPathIn(t *testing.T) {
}

for i, arg := range args {
resolved, err := resolveAttrPathInPromise("", arg.literal, arg.path)
resolved, err := resolveAttrPathInPromise(context.Background(), nil, "", arg.literal, arg.path)
if arg.hasError {
assert.Error(t, err, i)
assert.ErrorContains(t, err, errors.PromiseAttributeResolveError, i)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ func (m *CatalogClient) Get(ctx context.Context, key catalog.Key) (catalog.Entry
logger.Debugf(ctx, "DataCatalog failed to get artifact by tag %+v, err: %+v", tag, err)
return catalog.Entry{}, err
}
logger.Debugf(ctx, "Artifact found %v from tag %v", artifact, tag)
logger.Debugf(ctx, "Artifact found %v from tag %v", artifact.GetId(), tag)

var relevantTag *datacatalog.Tag
if len(artifact.GetTags()) > 0 {
Expand Down Expand Up @@ -230,7 +230,7 @@ func (m *CatalogClient) createArtifact(ctx context.Context, key catalog.Key, dat
createArtifactRequest := &datacatalog.CreateArtifactRequest{Artifact: cachedArtifact}
_, err := m.client.CreateArtifact(ctx, createArtifactRequest)
if err != nil {
logger.Errorf(ctx, "Failed to create Artifact %+v, err: %v", cachedArtifact, err)
logger.Errorf(ctx, "Failed to create Artifact %+v, err: %v", cachedArtifact.Id, err)
return catalog.Status{}, err
}
logger.Debugf(ctx, "Created artifact: %v, with %v outputs from execution %+v", cachedArtifact.Id, len(artifactDataList), metadata)
Expand Down Expand Up @@ -259,7 +259,7 @@ func (m *CatalogClient) createArtifact(ctx context.Context, key catalog.Key, dat
}
}

logger.Debugf(ctx, "Successfully created artifact %+v for key %+v, dataset %+v and execution %+v", cachedArtifact, key, datasetID, metadata)
logger.Debugf(ctx, "Successfully created artifact %+v for key %+v, dataset %+v and execution %+v", cachedArtifact.Id, key, datasetID, metadata)
return catalog.NewStatus(core.CatalogCacheStatus_CACHE_POPULATED, EventCatalogMetadata(datasetID, tag, nil)), nil
}

Expand Down
33 changes: 33 additions & 0 deletions flytepropeller/pkg/controller/nodes/common/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package common

import (
"context"
"encoding/base64"
"fmt"
"strconv"

Expand All @@ -17,6 +18,7 @@ import (
"github.com/flyteorg/flyte/flytepropeller/pkg/controller/nodes/handler"
"github.com/flyteorg/flyte/flytepropeller/pkg/controller/nodes/interfaces"
"github.com/flyteorg/flyte/flytestdlib/logger"
"github.com/flyteorg/flyte/flytestdlib/pbhash"
"github.com/flyteorg/flyte/flytestdlib/storage"
)

Expand Down Expand Up @@ -79,6 +81,27 @@ func GetTargetEntity(ctx context.Context, nCtx interfaces.NodeExecutionContext)
return targetEntity
}

// ReadLargeLiteral reads the offloaded large literal needed by array node task
func ReadLargeLiteral(ctx context.Context, datastore *storage.DataStore,
tobeRead *idlcore.Literal) error {
if tobeRead.GetOffloadedMetadata() == nil {
return fmt.Errorf("unsupported type for reading offloaded literal")
}
dataReference := tobeRead.GetOffloadedMetadata().GetUri()
if dataReference == "" {
return fmt.Errorf("uri is empty for offloaded literal")
}
// read the offloaded literal
size := tobeRead.GetOffloadedMetadata().GetSizeBytes()
if err := datastore.ReadProtobuf(ctx, storage.DataReference(dataReference), tobeRead); err != nil {
logger.Errorf(ctx, "Failed to read the offloaded literal at location [%s] with error [%s]", dataReference, err)
return err
}

logger.Infof(ctx, "read offloaded literal at location [%s] with size [%s]", dataReference, size)
return nil
}

// OffloadLargeLiteral offloads the large literal if meets the threshold conditions
func OffloadLargeLiteral(ctx context.Context, datastore *storage.DataStore, dataReference storage.DataReference,
toBeOffloaded *idlcore.Literal, literalOffloadingConfig config.LiteralOffloadingConfig) error {
Expand Down Expand Up @@ -108,6 +131,16 @@ func OffloadLargeLiteral(ctx context.Context, datastore *storage.DataStore, data
return err
}

if toBeOffloaded.GetHash() == "" {
// compute the hash of the literal
literalDigest, err := pbhash.ComputeHash(ctx, toBeOffloaded)
if err != nil {
logger.Errorf(ctx, "Failed to compute hash for offloaded literal with error [%s]", err)
return err
}
// Set the hash or else respect what the user set in the literal
toBeOffloaded.Hash = base64.RawURLEncoding.EncodeToString(literalDigest)
}
// update the literal with the offloaded URI, size and inferred type
toBeOffloaded.Value = &idlcore.Literal_OffloadedMetadata{
OffloadedMetadata: &idlcore.LiteralOffloadedMetadata{
Expand Down
Loading

0 comments on commit e7ce437

Please sign in to comment.