Skip to content

Commit

Permalink
Bring back Blob support
Browse files Browse the repository at this point in the history
Signed-off-by: Hongxin Liang <[email protected]>
  • Loading branch information
honnix committed Oct 9, 2023
1 parent c892ceb commit 6ea206b
Show file tree
Hide file tree
Showing 17 changed files with 281 additions and 58 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,11 @@
import java.time.Instant;
import java.util.List;
import java.util.Map;
import org.flyte.api.v1.Blob;
import org.flyte.api.v1.BlobType.BlobDimensionality;
import org.flyte.flytekit.SdkBindingData;
import org.flyte.flytekit.SdkRunnableTask;
import org.flyte.flytekit.jackson.BlobTypeDescription;
import org.flyte.flytekit.jackson.JacksonSdkType;

@AutoService(SdkRunnableTask.class)
Expand All @@ -48,8 +51,8 @@ public abstract static class AutoAllInputsInput {

public abstract SdkBindingData<Duration> d();

// TODO add blobs to sdkbinding data
// public abstract SdkBindingData<Blob> blob();
@BlobTypeDescription(format = "csv", dimensionality = BlobDimensionality.MULTIPART)
public abstract SdkBindingData<Blob> blob();

public abstract SdkBindingData<List<String>> l();

Expand All @@ -66,13 +69,13 @@ public static AutoAllInputsInput create(
SdkBindingData<Boolean> b,
SdkBindingData<Instant> t,
SdkBindingData<Duration> d,
// Blob blob,
SdkBindingData<Blob> blob,
SdkBindingData<List<String>> l,
SdkBindingData<Map<String, String>> m,
SdkBindingData<List<String>> emptyList,
SdkBindingData<Map<String, Long>> emptyMap) {
return new AutoValue_AllInputsTask_AutoAllInputsInput(
i, f, s, b, t, d, l, m, emptyList, emptyMap);
i, f, s, b, t, d, blob, l, m, emptyList, emptyMap);
}
}

Expand All @@ -91,8 +94,8 @@ public abstract static class AutoAllInputsOutput {

public abstract SdkBindingData<Duration> d();

// TODO add blobs to sdkbinding data
// public abstract SdkBindingData<Blob> blob();
@BlobTypeDescription(format = "csv", dimensionality = BlobDimensionality.MULTIPART)
public abstract SdkBindingData<Blob> blob();

public abstract SdkBindingData<List<String>> l();

Expand All @@ -109,12 +112,13 @@ public static AutoAllInputsOutput create(
SdkBindingData<Boolean> b,
SdkBindingData<Instant> t,
SdkBindingData<Duration> d,
SdkBindingData<Blob> blob,
SdkBindingData<List<String>> l,
SdkBindingData<Map<String, String>> m,
SdkBindingData<List<String>> emptyList,
SdkBindingData<Map<String, Long>> emptyMap) {
return new AutoValue_AllInputsTask_AutoAllInputsOutput(
i, f, s, b, t, d, l, m, emptyList, emptyMap);
i, f, s, b, t, d, blob, l, m, emptyList, emptyMap);
}
}

Expand All @@ -127,6 +131,7 @@ public AutoAllInputsOutput run(AutoAllInputsInput input) {
input.b(),
input.t(),
input.d(),
input.blob(),
input.l(),
input.m(),
input.emptyList(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,18 @@
import java.util.Collections;
import java.util.List;
import java.util.Map;
import org.flyte.api.v1.Blob;
import org.flyte.api.v1.BlobMetadata;
import org.flyte.api.v1.BlobType;
import org.flyte.api.v1.BlobType.BlobDimensionality;
import org.flyte.examples.AllInputsTask.AutoAllInputsOutput;
import org.flyte.flytekit.SdkBindingData;
import org.flyte.flytekit.SdkBindingDataFactory;
import org.flyte.flytekit.SdkNode;
import org.flyte.flytekit.SdkTypes;
import org.flyte.flytekit.SdkWorkflow;
import org.flyte.flytekit.SdkWorkflowBuilder;
import org.flyte.flytekit.jackson.BlobTypeDescription;
import org.flyte.flytekit.jackson.JacksonSdkType;

@AutoService(SdkWorkflow.class)
Expand All @@ -57,6 +62,18 @@ public AllInputsWorkflowOutput expand(SdkWorkflowBuilder builder, Void noInput)
SdkBindingDataFactory.of(true),
SdkBindingDataFactory.of(someInstant),
SdkBindingDataFactory.of(Duration.ofDays(1L)),
SdkBindingDataFactory.of(
Blob.builder()
.uri("file://test/test.csv")
.metadata(
BlobMetadata.builder()
.type(
BlobType.builder()
.format("csv")
.dimensionality(BlobDimensionality.MULTIPART)
.build())
.build())
.build()),
SdkBindingDataFactory.ofStringCollection(Arrays.asList("foo", "bar")),
SdkBindingDataFactory.ofStringMap(Map.of("test", "test")),
SdkBindingDataFactory.ofStringCollection(Collections.emptyList()),
Expand All @@ -71,6 +88,7 @@ public AllInputsWorkflowOutput expand(SdkWorkflowBuilder builder, Void noInput)
outputs.b(),
outputs.t(),
outputs.d(),
outputs.blob(),
outputs.l(),
outputs.m(),
outputs.emptyList(),
Expand All @@ -92,8 +110,8 @@ public abstract static class AllInputsWorkflowOutput {

public abstract SdkBindingData<Duration> d();

// TODO add blobs to sdkbinding data
// public abstract SdkBindingData<Blob> blob();
@BlobTypeDescription(format = "csv", dimensionality = BlobDimensionality.MULTIPART)
public abstract SdkBindingData<Blob> blob();

public abstract SdkBindingData<List<String>> l();

Expand All @@ -110,12 +128,13 @@ public static AllInputsWorkflow.AllInputsWorkflowOutput create(
SdkBindingData<Boolean> b,
SdkBindingData<Instant> t,
SdkBindingData<Duration> d,
SdkBindingData<Blob> blob,
SdkBindingData<List<String>> l,
SdkBindingData<Map<String, String>> m,
SdkBindingData<List<String>> emptyList,
SdkBindingData<Map<String, Long>> emptyMap) {
return new AutoValue_AllInputsWorkflow_AllInputsWorkflowOutput(
i, f, s, b, t, d, l, m, emptyList, emptyMap);
i, f, s, b, t, d, blob, l, m, emptyList, emptyMap);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Copyright 2023 Flyte Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.flyte.flytekit.jackson;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
import org.flyte.api.v1.BlobType.BlobDimensionality;

/** Applied to a blob property to annotate its type. */
@Target({ElementType.FIELD, ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
public @interface BlobTypeDescription {
/**
* Describes the blob's format.
*
* @return format, not {@code null}
*/
String format();

/**
* Describes the blob's dimensionality.
*
* @return dimensionality, not {@code null}
*/
BlobDimensionality dimensionality();
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.flyte.api.v1.Blob;
import org.flyte.api.v1.BlobType;
import org.flyte.api.v1.Variable;
import org.flyte.flytekit.SdkBindingData;
import org.flyte.flytekit.SdkLiteralType;
Expand Down Expand Up @@ -63,11 +65,7 @@ public void property(BeanProperty prop) {
String propName = prop.getName();
AnnotatedMember member = prop.getMember();
SdkLiteralType<?> literalType =
toLiteralType(
handledType,
/*rootLevel=*/ true,
propName,
member.getMember().getDeclaringClass().getName());
toLiteralType(handledType, /* rootLevel= */ true, propName, member);

String description = getDescription(member);

Expand Down Expand Up @@ -132,18 +130,17 @@ private String getDescription(AnnotatedMember member) {

@SuppressWarnings("AlreadyChecked")
private SdkLiteralType<?> toLiteralType(
JavaType javaType, boolean rootLevel, String propName, String declaringClassName) {
JavaType javaType, boolean rootLevel, String propName, AnnotatedMember member) {
Class<?> type = javaType.getRawClass();

if (SdkBindingData.class.isAssignableFrom(type)) {
return toLiteralType(
javaType.getBindings().getBoundType(0), false, propName, declaringClassName);
return toLiteralType(javaType.getBindings().getBoundType(0), false, propName, member);
} else if (rootLevel) {
throw new UnsupportedOperationException(
String.format(
"Field '%s' from class '%s' is declared as '%s' and it is not matching any of the supported types. "
+ "Please make sure your variable declared type is wrapped in 'SdkBindingData<>'.",
propName, declaringClassName, type));
propName, member.getMember().getDeclaringClass().getName(), type));
} else if (isPrimitiveAssignableFrom(Long.class, type)) {
return SdkLiteralTypes.integers();
} else if (isPrimitiveAssignableFrom(Double.class, type)) {
Expand All @@ -159,8 +156,7 @@ private SdkLiteralType<?> toLiteralType(
} else if (List.class.isAssignableFrom(type)) {
JavaType elementType = javaType.getBindings().getBoundType(0);

return SdkLiteralTypes.collections(
toLiteralType(elementType, false, propName, declaringClassName));
return SdkLiteralTypes.collections(toLiteralType(elementType, false, propName, member));
} else if (Map.class.isAssignableFrom(type)) {
JavaType keyType = javaType.getBindings().getBoundType(0);
JavaType valueType = javaType.getBindings().getBoundType(1);
Expand All @@ -170,9 +166,22 @@ private SdkLiteralType<?> toLiteralType(
"Only Map<String, ?> is supported, got [" + javaType.getGenericSignature() + "]");
}

return SdkLiteralTypes.maps(toLiteralType(valueType, false, propName, declaringClassName));
return SdkLiteralTypes.maps(toLiteralType(valueType, false, propName, member));
} else if (Blob.class.isAssignableFrom(type)) {
BlobTypeDescription annotation = member.getAnnotation(BlobTypeDescription.class);
if (annotation == null) {
throw new UnsupportedOperationException(
String.format(
"Field '%s' from class '%s' is declared as '%s' and it must be annotated",
propName, member.getMember().getDeclaringClass().getName(), type));
}
return SdkLiteralTypes.blobs(
BlobType.builder()
.format(annotation.format())
.dimensionality(annotation.dimensionality())
.build());
}
// TODO: Support blobs and structs
// TODO: Support structs
throw new UnsupportedOperationException(
String.format("Unsupported type: [%s]", type.getName()));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.deser.std.StdDeserializer;
import java.io.IOException;
import java.io.Serializable;
import java.time.Duration;
import java.time.Instant;
import java.util.Iterator;
Expand All @@ -39,6 +38,10 @@
import java.util.stream.Collectors;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;
import org.flyte.api.v1.Blob;
import org.flyte.api.v1.BlobMetadata;
import org.flyte.api.v1.BlobType;
import org.flyte.api.v1.BlobType.BlobDimensionality;
import org.flyte.api.v1.Literal;
import org.flyte.api.v1.LiteralType;
import org.flyte.api.v1.Primitive;
Expand Down Expand Up @@ -80,7 +83,7 @@ private SdkBindingData<?> transform(JsonNode tree) {
}
}

private static SdkBindingData<? extends Serializable> transformScalar(JsonNode tree) {
private static SdkBindingData<?> transformScalar(JsonNode tree) {
Scalar.Kind scalarKind = Scalar.Kind.valueOf(tree.get(SCALAR).asText());
switch (scalarKind) {
case PRIMITIVE:
Expand All @@ -102,14 +105,33 @@ private static SdkBindingData<? extends Serializable> transformScalar(JsonNode t
throw new UnsupportedOperationException(
"Type contains an unsupported primitive: " + primitiveKind);

case GENERIC:
case BLOB:
return transformBlob(tree);

case GENERIC:
default:
throw new UnsupportedOperationException(
"Type contains an unsupported scalar: " + scalarKind);
}
}

private static SdkBindingData<Blob> transformBlob(JsonNode tree) {
JsonNode value = tree.get(VALUE);
String uri = value.get("uri").asText();
JsonNode type = value.get("metadata").get("type");
String format = type.get("format").asText();
BlobDimensionality dimensionality =
BlobDimensionality.valueOf(type.get("dimensionality").asText());
return SdkBindingDataFactory.of(
Blob.builder()
.uri(uri)
.metadata(
BlobMetadata.builder()
.type(BlobType.builder().format(format).dimensionality(dimensionality).build())
.build())
.build());
}

@SuppressWarnings("unchecked")
private <T> SdkBindingData<List<T>> transformCollection(JsonNode tree) {
SdkLiteralType<T> literalType = (SdkLiteralType<T>) readLiteralType(tree.get(TYPE));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@
*/
package org.flyte.flytekit.jackson.serializers;

import static org.flyte.flytekit.jackson.serializers.SdkBindingDataSerializationProtocol.SCALAR;
import static org.flyte.flytekit.jackson.serializers.SdkBindingDataSerializationProtocol.VALUE;

import com.fasterxml.jackson.core.JsonGenerator;
import com.fasterxml.jackson.databind.SerializerProvider;
import java.io.IOException;
import org.flyte.api.v1.Blob;
import org.flyte.api.v1.Literal;
import org.flyte.api.v1.LiteralType;
import org.flyte.api.v1.Scalar;
import org.flyte.api.v1.Scalar.Kind;

public class BlobSerializer extends ScalarSerializer {
public BlobSerializer(
Expand All @@ -38,8 +38,8 @@ public BlobSerializer(

@Override
void serializeScalar() throws IOException {
gen.writeFieldName(SCALAR);
gen.writeObject(Scalar.Kind.BLOB);
gen.writeObject(Kind.BLOB);
gen.writeFieldName(VALUE);
serializerProvider
.findValueSerializer(Blob.class)
.serialize(value.scalar().blob(), gen, serializerProvider);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,5 @@ public class SdkBindingDataSerializationProtocol {
public static final String TYPE = "type";
public static final String KIND = "kind";
public static final String PRIMITIVE = "primitive";
public static final String BLOB = "blob";
}
Loading

0 comments on commit 6ea206b

Please sign in to comment.