Skip to content

Commit

Permalink
Fix linting
Browse files Browse the repository at this point in the history
  • Loading branch information
shydefoo committed Sep 25, 2024
1 parent 2c64047 commit fbc5c3f
Show file tree
Hide file tree
Showing 3 changed files with 286 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,9 @@ public static void setup() throws IOException {
.setInstanceId(INSTANCE_ID)
.build());
Configuration config = BigtableConfiguration.configure(PROJECT_ID, INSTANCE_ID);
config.set(BigtableOptionsFactory.BIGTABLE_EMULATOR_HOST_KEY, "localhost:" + bigtableEmulator.getMappedPort(BIGTABLE_EMULATOR_PORT));
config.set(
BigtableOptionsFactory.BIGTABLE_EMULATOR_HOST_KEY,
"localhost:" + bigtableEmulator.getMappedPort(BIGTABLE_EMULATOR_PORT));
hbaseClient = BigtableConfiguration.connect(config);
ingestData();
}
Expand Down Expand Up @@ -237,38 +239,37 @@ public void shouldFilterOutMissingFeatureRef() {
}

@Test
public void shouldRetrieveFeaturesSuccessfullyWhenUsingHbase(){
public void shouldRetrieveFeaturesSuccessfullyWhenUsingHbase() {
HBaseOnlineRetriever retriever = new HBaseOnlineRetriever(hbaseClient);
List<FeatureReference> featureReferences =
Stream.of("trip_cost", "trip_distance")
.map(f -> FeatureReference.newBuilder().setFeatureTable("rides").setName(f).build())
.toList();
Stream.of("trip_cost", "trip_distance")
.map(f -> FeatureReference.newBuilder().setFeatureTable("rides").setName(f).build())
.toList();
List<String> entityNames = List.of("driver");
List<EntityRow> entityRows =
List.of(DataGenerator.createEntityRow("driver", DataGenerator.createInt64Value(1), 100));
List.of(DataGenerator.createEntityRow("driver", DataGenerator.createInt64Value(1), 100));
List<List<Feature>> featuresForRows =
retriever.getOnlineFeatures(FEAST_PROJECT, entityRows, featureReferences, entityNames);
retriever.getOnlineFeatures(FEAST_PROJECT, entityRows, featureReferences, entityNames);
assertEquals(1, featuresForRows.size());
List<Feature> features = featuresForRows.get(0);
assertEquals(2, features.size());
assertEquals(5L, features.get(0).getFeatureValue(ValueType.Enum.INT64).getInt64Val());
assertEquals(featureReferences.get(0), features.get(0).getFeatureReference());
assertEquals(3.5, features.get(1).getFeatureValue(ValueType.Enum.DOUBLE).getDoubleVal());
assertEquals(featureReferences.get(1), features.get(1).getFeatureReference());

}

@Test
public void shouldFilterOutMissingFeatureRefUsingHbase() {
BigTableOnlineRetriever retriever = new BigTableOnlineRetriever(client);
HBaseOnlineRetriever retriever = new HBaseOnlineRetriever(hbaseClient);
List<FeatureReference> featureReferences =
List.of(
FeatureReference.newBuilder().setFeatureTable("rides").setName("not_exists").build());
List.of(
FeatureReference.newBuilder().setFeatureTable("rides").setName("not_exists").build());
List<String> entityNames = List.of("driver");
List<EntityRow> entityRows =
List.of(DataGenerator.createEntityRow("driver", DataGenerator.createInt64Value(1), 100));
List.of(DataGenerator.createEntityRow("driver", DataGenerator.createInt64Value(1), 100));
List<List<Feature>> features =
retriever.getOnlineFeatures(FEAST_PROJECT, entityRows, featureReferences, entityNames);
retriever.getOnlineFeatures(FEAST_PROJECT, entityRows, featureReferences, entityNames);
assertEquals(1, features.size());
assertEquals(0, features.get(0).size());
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package dev.caraml.serving.store.bigtable;

import java.net.InetAddress;
import java.net.UnknownHostException;
import java.time.Duration;
import java.util.Arrays;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.hbase.HBaseConfiguration;
import org.testcontainers.containers.GenericContainer;
import org.testcontainers.containers.wait.strategy.Wait;
import org.testcontainers.utility.DockerImageName;

public class GenericHbase2Container extends GenericContainer<GenericHbase2Container> {

private final String hostName;
private final Configuration hbase2Configuration = HBaseConfiguration.create();

public GenericHbase2Container() {
super(DockerImageName.parse("jcjabouille/hbase-standalone:2.4.9"));
{
try {
hostName = InetAddress.getLocalHost().getHostName();
} catch (UnknownHostException e) {
throw new RuntimeException(e);
}
}
int masterPort = 16010;
addExposedPort(masterPort);
int regionPort = 16011;
addExposedPort(regionPort);
addExposedPort(2181);

withCreateContainerCmdModifier(
cmd -> {
cmd.withHostName(hostName);
});

waitingFor(Wait.forLogMessage(".*running regionserver.*", 1));
withStartupTimeout(Duration.ofMinutes(10));

withEnv("HBASE_MASTER_PORT", Integer.toString(masterPort));
withEnv("HBASE_REGION_PORT", Integer.toString(regionPort));
setPortBindings(
Arrays.asList(
String.format("%d:%d", masterPort, masterPort),
String.format("%d:%d", regionPort, regionPort)));
}

@Override
protected void doStart() {
super.doStart();

hbase2Configuration.set("hbase.client.pause", "200");
hbase2Configuration.set("hbase.client.retries.number", "10");
hbase2Configuration.set("hbase.rpc.timeout", "3000");
hbase2Configuration.set("hbase.client.operation.timeout", "3000");
hbase2Configuration.set("hbase.client.scanner.timeout.period", "10000");
hbase2Configuration.set("zookeeper.session.timeout", "10000");
hbase2Configuration.set("hbase.zookeeper.quorum", "localhost");
hbase2Configuration.set(
"hbase.zookeeper.property.clientPort", Integer.toString(getMappedPort(2181)));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
package dev.caraml.serving.store.bigtable;

import static org.junit.jupiter.api.Assertions.assertEquals;

import com.google.common.hash.Hashing;
import dev.caraml.serving.store.Feature;
import dev.caraml.store.protobuf.serving.ServingServiceProto;
import dev.caraml.store.protobuf.types.ValueProto;
import dev.caraml.store.testutils.it.DataGenerator;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.util.List;
import java.util.stream.Stream;
import org.apache.avro.Schema;
import org.apache.avro.SchemaBuilder;
import org.apache.avro.generic.GenericDatumWriter;
import org.apache.avro.generic.GenericRecord;
import org.apache.avro.generic.GenericRecordBuilder;
import org.apache.avro.io.Encoder;
import org.apache.avro.io.EncoderFactory;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.hbase.HBaseConfiguration;
import org.apache.hadoop.hbase.TableName;
import org.apache.hadoop.hbase.client.*;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers;

@Testcontainers
public class HbaseOnlineRetrieverTest {
static Connection hbaseClient;
static HBaseAdmin admin;
static Configuration hbaseConfiguration = HBaseConfiguration.create();
static final String FEAST_PROJECT = "default";

@Container public static GenericHbase2Container hbase = new GenericHbase2Container();

@BeforeAll
public static void setup() throws IOException {
hbaseConfiguration.set("hbase.zookeeper.quorum", hbase.getHost());
hbaseConfiguration.set("hbase.zookeeper.property.clientPort", "2181");
hbaseClient = ConnectionFactory.createConnection(hbaseConfiguration);
admin = (HBaseAdmin) hbaseClient.getAdmin();
ingestData();
}

private static void ingestData() throws IOException {
String featureTableName = "rides";

/** Single Entity Ingestion Workflow */
Schema schema =
SchemaBuilder.record("DriverData")
.namespace(featureTableName)
.fields()
.requiredLong("trip_cost")
.requiredDouble("trip_distance")
.nullableString("trip_empty", "null")
.requiredString("trip_wrong_type")
.endRecord();
createTable(FEAST_PROJECT, List.of("driver"), List.of(featureTableName));
insertSchema(FEAST_PROJECT, List.of("driver"), schema);

GenericRecord record =
new GenericRecordBuilder(schema)
.set("trip_cost", 5L)
.set("trip_distance", 3.5)
.set("trip_empty", null)
.set("trip_wrong_type", "test")
.build();
String entityKey = String.valueOf(DataGenerator.createInt64Value(1).getInt64Val());
insertRow(FEAST_PROJECT, List.of("driver"), entityKey, featureTableName, schema, record);
}

private static String getTableName(String project, List<String> entityNames) {
return String.format("%s__%s", project, String.join("__", entityNames));
}

private static byte[] serializedSchemaReference(Schema schema) {
return Hashing.murmur3_32().hashBytes(schema.toString().getBytes()).asBytes();
}

private static void createTable(
String project, List<String> entityNames, List<String> featureTables) {
String tableName = getTableName(project, entityNames);

List<String> columnFamilies =
Stream.concat(featureTables.stream(), Stream.of("metadata")).toList();
TableDescriptorBuilder tb = TableDescriptorBuilder.newBuilder(TableName.valueOf(tableName));
columnFamilies.forEach(cf -> tb.setColumnFamily(ColumnFamilyDescriptorBuilder.of(cf)));
try {
if (admin.tableExists(TableName.valueOf(tableName))) {
return;
}
admin.createTable(tb.build());
} catch (IOException e) {
throw new RuntimeException(e);
}
}

private static void insertSchema(String project, List<String> entityNames, Schema schema)
throws IOException {
String tableName = getTableName(project, entityNames);
byte[] schemaReference = serializedSchemaReference(schema);
byte[] schemaKey = createSchemaKey(schemaReference);
Table table = hbaseClient.getTable(TableName.valueOf(tableName));
Put put = new Put(schemaKey);
put.addColumn("metadata".getBytes(), "avro".getBytes(), schema.toString().getBytes());
table.put(put);
table.close();
}

private static byte[] createSchemaKey(byte[] schemaReference) throws IOException {
String schemaKeyPrefix = "schema#";
ByteArrayOutputStream concatOutputStream = new ByteArrayOutputStream();
concatOutputStream.write(schemaKeyPrefix.getBytes());
concatOutputStream.write(schemaReference);
return concatOutputStream.toByteArray();
}

private static byte[] createEntityValue(Schema schema, GenericRecord record) throws IOException {
byte[] schemaReference = serializedSchemaReference(schema);
// Entity-Feature Row
byte[] avroSerializedFeatures = recordToAvro(record, schema);

ByteArrayOutputStream concatOutputStream = new ByteArrayOutputStream();
concatOutputStream.write(schemaReference);
concatOutputStream.write("".getBytes());
concatOutputStream.write(avroSerializedFeatures);
byte[] entityFeatureValue = concatOutputStream.toByteArray();

return entityFeatureValue;
}

private static byte[] recordToAvro(GenericRecord datum, Schema schema) throws IOException {
GenericDatumWriter<Object> writer = new GenericDatumWriter<>(schema);
ByteArrayOutputStream output = new ByteArrayOutputStream();
Encoder encoder = EncoderFactory.get().binaryEncoder(output, null);
writer.write(datum, encoder);
encoder.flush();

return output.toByteArray();
}

private static void insertRow(
String project,
List<String> entityNames,
String entityKey,
String featureTableName,
Schema schema,
GenericRecord record)
throws IOException {
byte[] entityFeatureValue = createEntityValue(schema, record);
String tableName = getTableName(project, entityNames);

// Update Compound Entity-Feature Row
Table table = hbaseClient.getTable(TableName.valueOf(tableName));
Put put = new Put(entityKey.getBytes());
put.addColumn(featureTableName.getBytes(), "".getBytes(), entityFeatureValue);
table.put(put);
table.close();
}

@Test
public void shouldRetrieveFeaturesSuccessfully() {
HBaseOnlineRetriever retriever = new HBaseOnlineRetriever(hbaseClient);
List<ServingServiceProto.FeatureReference> featureReferences =
Stream.of("trip_cost", "trip_distance")
.map(
f ->
ServingServiceProto.FeatureReference.newBuilder()
.setFeatureTable("rides")
.setName(f)
.build())
.toList();
List<String> entityNames = List.of("driver");
List<ServingServiceProto.GetOnlineFeaturesRequest.EntityRow> entityRows =
List.of(DataGenerator.createEntityRow("driver", DataGenerator.createInt64Value(1), 100));
List<List<Feature>> featuresForRows =
retriever.getOnlineFeatures(FEAST_PROJECT, entityRows, featureReferences, entityNames);
assertEquals(1, featuresForRows.size());
List<Feature> features = featuresForRows.get(0);
assertEquals(2, features.size());
assertEquals(
5L, features.get(0).getFeatureValue(ValueProto.ValueType.Enum.INT64).getInt64Val());
assertEquals(featureReferences.get(0), features.get(0).getFeatureReference());
assertEquals(
3.5, features.get(1).getFeatureValue(ValueProto.ValueType.Enum.DOUBLE).getDoubleVal());
assertEquals(featureReferences.get(1), features.get(1).getFeatureReference());
}

@Test
public void shouldFilterOutMissingFeatureRefUsingHbase() {
HBaseOnlineRetriever retriever = new HBaseOnlineRetriever(hbaseClient);
List<ServingServiceProto.FeatureReference> featureReferences =
List.of(
ServingServiceProto.FeatureReference.newBuilder()
.setFeatureTable("rides")
.setName("not_exists")
.build());
List<String> entityNames = List.of("driver");
List<ServingServiceProto.GetOnlineFeaturesRequest.EntityRow> entityRows =
List.of(DataGenerator.createEntityRow("driver", DataGenerator.createInt64Value(1), 100));
List<List<Feature>> features =
retriever.getOnlineFeatures(FEAST_PROJECT, entityRows, featureReferences, entityNames);
assertEquals(1, features.size());
assertEquals(0, features.get(0).size());
}
}

0 comments on commit fbc5c3f

Please sign in to comment.