Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support to Binary input/output type #291

Merged
merged 1 commit into from
Apr 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions flytekit-api/src/main/java/org/flyte/api/v1/Binary.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Copyright 2020-2021 Flyte Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.flyte.api.v1;

import com.google.auto.value.AutoValue;

/**
* A simple byte array with a tag to help different parts of the system communicate about what is in
* the byte array. It's strongly advisable that consumers of this type define a unique tag and
* validate the tag before parsing the data.
*/
@AutoValue
public abstract class Binary {
public static final String TAG_FIELD = "tag";
public static final String VALUE_FIELD = "value";

public abstract byte[] value();

public abstract String tag();

public static Builder builder() {
return new AutoValue_Binary.Builder();
}

@AutoValue.Builder
public abstract static class Builder {
public abstract Builder value(byte[] value);

public abstract Builder tag(String tag);

public abstract Binary build();
}
}
9 changes: 8 additions & 1 deletion flytekit-api/src/main/java/org/flyte/api/v1/Scalar.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ public abstract class Scalar {
public enum Kind {
PRIMITIVE,
GENERIC,
BLOB
BLOB,
BINARY
}

public abstract Kind kind();
Expand All @@ -36,6 +37,8 @@ public enum Kind {

public abstract Blob blob();

public abstract Binary binary();

// TODO: add the rest of the cases from src/main/proto/flyteidl/core/literals.proto

public static Scalar ofPrimitive(Primitive primitive) {
Expand All @@ -49,4 +52,8 @@ public static Scalar ofGeneric(Struct generic) {
public static Scalar ofBlob(Blob blob) {
return AutoOneOf_Scalar.blob(blob);
}

public static Scalar ofBinary(Binary binary) {
return AutoOneOf_Scalar.binary(binary);
}
}
3 changes: 2 additions & 1 deletion flytekit-api/src/main/java/org/flyte/api/v1/SimpleType.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,6 @@ public enum SimpleType {
BOOLEAN,
DATETIME,
DURATION,
STRUCT
STRUCT,
BINARY
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.flyte.api.v1.Binary;
import org.flyte.api.v1.Blob;
import org.flyte.api.v1.BlobType;
import org.flyte.api.v1.Variable;
Expand Down Expand Up @@ -172,6 +173,8 @@ private SdkLiteralType<?> toLiteralType(
// feature
// https://docs.flyte.org/projects/flytekit/en/latest/generated/flytekit.BlobType.html#flytekit-blobtype
return SdkLiteralTypes.blobs(BlobType.DEFAULT);
} else if (Binary.class.isAssignableFrom(type)) {
return SdkLiteralTypes.binary();
}
try {
return JacksonSdkLiteralType.of(type);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import com.fasterxml.jackson.databind.DeserializationContext;
import com.fasterxml.jackson.databind.JavaType;
import com.fasterxml.jackson.databind.deser.std.StdDeserializer;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.NotSerializableException;
import java.io.ObjectInputStream;
Expand All @@ -34,6 +35,7 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.flyte.api.v1.Binary;
import org.flyte.api.v1.Blob;
import org.flyte.api.v1.Literal;
import org.flyte.api.v1.LiteralType;
Expand Down Expand Up @@ -167,11 +169,56 @@ private static Literal deserialize(JsonParser p, SimpleType simpleType) throws I
Struct generic = readValueAsStruct(p);

return Literal.ofScalar(Scalar.ofGeneric(generic));

case BINARY:
Binary binary = readValueAsBinary(p);

return Literal.ofScalar(Scalar.ofBinary(binary));
}

throw new AssertionError(String.format("Unexpected SimpleType: [%s]", simpleType));
}

private static Binary readValueAsBinary(JsonParser p) throws IOException {
verifyToken(p, JsonToken.START_OBJECT);
p.nextToken();

Binary.Builder binaryBuilder = Binary.builder();

while (p.currentToken() != JsonToken.END_OBJECT) {
verifyToken(p, JsonToken.FIELD_NAME);
String fieldName = p.currentName();
p.nextToken();

switch (fieldName) {
case Binary.TAG_FIELD:
binaryBuilder.tag(p.readValueAs(String.class));
break;
case Binary.VALUE_FIELD:
ByteArrayOutputStream value = new ByteArrayOutputStream();
p.readBinaryValue(value);
binaryBuilder.value(value.toByteArray());
break;
default:
throw new IllegalStateException("Unexpected field [" + fieldName + "]");
}

p.nextToken();
}

Binary binary = binaryBuilder.build();

if (binary.tag() == null) {
throw new IllegalStateException("Missing field [" + Binary.TAG_FIELD + "]");
}

if (binary.value().length == 0) {
throw new IllegalStateException("Missing field [" + Binary.VALUE_FIELD + "]");
}

return binary;
}

private static Struct readValueAsStruct(JsonParser p) throws IOException {
verifyToken(p, JsonToken.START_OBJECT);
p.nextToken();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import java.io.IOException;
import java.time.Duration;
import java.time.Instant;
import java.util.Base64;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
Expand All @@ -42,6 +43,7 @@
import java.util.stream.Collectors;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;
import org.flyte.api.v1.Binary;
import org.flyte.api.v1.Blob;
import org.flyte.api.v1.BlobMetadata;
import org.flyte.api.v1.BlobType;
Expand Down Expand Up @@ -128,12 +130,24 @@ private SdkBindingData<?> transformScalar(
case GENERIC:
return transformGeneric(tree, deserializationContext, scalarKind, type);

case BINARY:
return transformBinary(tree);

default:
throw new UnsupportedOperationException(
"Type contains an unsupported scalar: " + scalarKind);
}
}

private static SdkBindingData<Binary> transformBinary(JsonNode tree) {
JsonNode value = tree.get(VALUE);
String tag = value.get(Binary.TAG_FIELD).asText();
String base64Value = value.get(Binary.VALUE_FIELD).asText();

return SdkBindingDataFactory.of(
Binary.builder().tag(tag).value(Base64.getDecoder().decode(base64Value)).build());
}

private static SdkBindingData<Blob> transformBlob(JsonNode tree) {
JsonNode value = tree.get(VALUE);
String uri = value.get("uri").asText();
Expand Down Expand Up @@ -256,6 +270,8 @@ private SdkLiteralType<?> readLiteralType(JsonNode typeNode) {
return SdkLiteralTypes.durations();
case STRUCT:
return JacksonSdkLiteralType.of(type.getContentType().getRawClass());
case BINARY:
return SdkLiteralTypes.binary();
}
throw new UnsupportedOperationException(
"Type contains a collection/map of an supported literal type: " + kind);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Copyright 2020-2023 Flyte Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.flyte.flytekit.jackson.serializers;

import static org.flyte.flytekit.jackson.serializers.SdkBindingDataSerializationProtocol.VALUE;

import com.fasterxml.jackson.core.JsonGenerator;
import com.fasterxml.jackson.databind.SerializerProvider;
import java.io.IOException;
import java.util.Base64;
import org.flyte.api.v1.Binary;
import org.flyte.api.v1.Literal;
import org.flyte.api.v1.LiteralType;
import org.flyte.api.v1.Scalar.Kind;

public class BinarySerializer extends ScalarSerializer {
public BinarySerializer(
JsonGenerator gen,
String key,
Literal value,
SerializerProvider serializerProvider,
LiteralType literalType) {
super(gen, key, value, serializerProvider, literalType);
}

@Override
void serializeScalar() throws IOException {
gen.writeObject(Kind.BINARY);
gen.writeFieldName(VALUE);
gen.writeStartObject();
gen.writeFieldName(Binary.TAG_FIELD);
gen.writeString(value.scalar().binary().tag());
gen.writeFieldName(Binary.VALUE_FIELD);
gen.writeString(Base64.getEncoder().encodeToString(value.scalar().binary().value()));
gen.writeEndObject();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ private static ScalarSerializer createScalarSerializer(
return new GenericSerializer(gen, key, value, serializerProvider, literalType);
case BLOB:
return new BlobSerializer(gen, key, value, serializerProvider, literalType);
case BINARY:
return new BinarySerializer(gen, key, value, serializerProvider, literalType);
}
throw new AssertionError("Unexpected Literal.Kind: [" + value.scalar().kind() + "]");
}
Expand Down
Loading
Loading