-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
286 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
63 changes: 63 additions & 0 deletions
63
...store-serving/src/test/java/dev/caraml/serving/store/bigtable/GenericHbase2Container.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
package dev.caraml.serving.store.bigtable; | ||
|
||
import java.net.InetAddress; | ||
import java.net.UnknownHostException; | ||
import java.time.Duration; | ||
import java.util.Arrays; | ||
import org.apache.hadoop.conf.Configuration; | ||
import org.apache.hadoop.hbase.HBaseConfiguration; | ||
import org.testcontainers.containers.GenericContainer; | ||
import org.testcontainers.containers.wait.strategy.Wait; | ||
import org.testcontainers.utility.DockerImageName; | ||
|
||
public class GenericHbase2Container extends GenericContainer<GenericHbase2Container> { | ||
|
||
private final String hostName; | ||
private final Configuration hbase2Configuration = HBaseConfiguration.create(); | ||
|
||
public GenericHbase2Container() { | ||
super(DockerImageName.parse("jcjabouille/hbase-standalone:2.4.9")); | ||
{ | ||
try { | ||
hostName = InetAddress.getLocalHost().getHostName(); | ||
} catch (UnknownHostException e) { | ||
throw new RuntimeException(e); | ||
} | ||
} | ||
int masterPort = 16010; | ||
addExposedPort(masterPort); | ||
int regionPort = 16011; | ||
addExposedPort(regionPort); | ||
addExposedPort(2181); | ||
|
||
withCreateContainerCmdModifier( | ||
cmd -> { | ||
cmd.withHostName(hostName); | ||
}); | ||
|
||
waitingFor(Wait.forLogMessage(".*running regionserver.*", 1)); | ||
withStartupTimeout(Duration.ofMinutes(10)); | ||
|
||
withEnv("HBASE_MASTER_PORT", Integer.toString(masterPort)); | ||
withEnv("HBASE_REGION_PORT", Integer.toString(regionPort)); | ||
setPortBindings( | ||
Arrays.asList( | ||
String.format("%d:%d", masterPort, masterPort), | ||
String.format("%d:%d", regionPort, regionPort))); | ||
} | ||
|
||
@Override | ||
protected void doStart() { | ||
super.doStart(); | ||
|
||
hbase2Configuration.set("hbase.client.pause", "200"); | ||
hbase2Configuration.set("hbase.client.retries.number", "10"); | ||
hbase2Configuration.set("hbase.rpc.timeout", "3000"); | ||
hbase2Configuration.set("hbase.client.operation.timeout", "3000"); | ||
hbase2Configuration.set("hbase.client.scanner.timeout.period", "10000"); | ||
hbase2Configuration.set("zookeeper.session.timeout", "10000"); | ||
hbase2Configuration.set("hbase.zookeeper.quorum", "localhost"); | ||
hbase2Configuration.set( | ||
"hbase.zookeeper.property.clientPort", Integer.toString(getMappedPort(2181))); | ||
} | ||
} |
209 changes: 209 additions & 0 deletions
209
...ore-serving/src/test/java/dev/caraml/serving/store/bigtable/HbaseOnlineRetrieverTest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,209 @@ | ||
package dev.caraml.serving.store.bigtable; | ||
|
||
import static org.junit.jupiter.api.Assertions.assertEquals; | ||
|
||
import com.google.common.hash.Hashing; | ||
import dev.caraml.serving.store.Feature; | ||
import dev.caraml.store.protobuf.serving.ServingServiceProto; | ||
import dev.caraml.store.protobuf.types.ValueProto; | ||
import dev.caraml.store.testutils.it.DataGenerator; | ||
import java.io.ByteArrayOutputStream; | ||
import java.io.IOException; | ||
import java.util.List; | ||
import java.util.stream.Stream; | ||
import org.apache.avro.Schema; | ||
import org.apache.avro.SchemaBuilder; | ||
import org.apache.avro.generic.GenericDatumWriter; | ||
import org.apache.avro.generic.GenericRecord; | ||
import org.apache.avro.generic.GenericRecordBuilder; | ||
import org.apache.avro.io.Encoder; | ||
import org.apache.avro.io.EncoderFactory; | ||
import org.apache.hadoop.conf.Configuration; | ||
import org.apache.hadoop.hbase.HBaseConfiguration; | ||
import org.apache.hadoop.hbase.TableName; | ||
import org.apache.hadoop.hbase.client.*; | ||
import org.junit.jupiter.api.BeforeAll; | ||
import org.junit.jupiter.api.Test; | ||
import org.testcontainers.junit.jupiter.Container; | ||
import org.testcontainers.junit.jupiter.Testcontainers; | ||
|
||
@Testcontainers | ||
public class HbaseOnlineRetrieverTest { | ||
static Connection hbaseClient; | ||
static HBaseAdmin admin; | ||
static Configuration hbaseConfiguration = HBaseConfiguration.create(); | ||
static final String FEAST_PROJECT = "default"; | ||
|
||
@Container public static GenericHbase2Container hbase = new GenericHbase2Container(); | ||
|
||
@BeforeAll | ||
public static void setup() throws IOException { | ||
hbaseConfiguration.set("hbase.zookeeper.quorum", hbase.getHost()); | ||
hbaseConfiguration.set("hbase.zookeeper.property.clientPort", "2181"); | ||
hbaseClient = ConnectionFactory.createConnection(hbaseConfiguration); | ||
admin = (HBaseAdmin) hbaseClient.getAdmin(); | ||
ingestData(); | ||
} | ||
|
||
private static void ingestData() throws IOException { | ||
String featureTableName = "rides"; | ||
|
||
/** Single Entity Ingestion Workflow */ | ||
Schema schema = | ||
SchemaBuilder.record("DriverData") | ||
.namespace(featureTableName) | ||
.fields() | ||
.requiredLong("trip_cost") | ||
.requiredDouble("trip_distance") | ||
.nullableString("trip_empty", "null") | ||
.requiredString("trip_wrong_type") | ||
.endRecord(); | ||
createTable(FEAST_PROJECT, List.of("driver"), List.of(featureTableName)); | ||
insertSchema(FEAST_PROJECT, List.of("driver"), schema); | ||
|
||
GenericRecord record = | ||
new GenericRecordBuilder(schema) | ||
.set("trip_cost", 5L) | ||
.set("trip_distance", 3.5) | ||
.set("trip_empty", null) | ||
.set("trip_wrong_type", "test") | ||
.build(); | ||
String entityKey = String.valueOf(DataGenerator.createInt64Value(1).getInt64Val()); | ||
insertRow(FEAST_PROJECT, List.of("driver"), entityKey, featureTableName, schema, record); | ||
} | ||
|
||
private static String getTableName(String project, List<String> entityNames) { | ||
return String.format("%s__%s", project, String.join("__", entityNames)); | ||
} | ||
|
||
private static byte[] serializedSchemaReference(Schema schema) { | ||
return Hashing.murmur3_32().hashBytes(schema.toString().getBytes()).asBytes(); | ||
} | ||
|
||
private static void createTable( | ||
String project, List<String> entityNames, List<String> featureTables) { | ||
String tableName = getTableName(project, entityNames); | ||
|
||
List<String> columnFamilies = | ||
Stream.concat(featureTables.stream(), Stream.of("metadata")).toList(); | ||
TableDescriptorBuilder tb = TableDescriptorBuilder.newBuilder(TableName.valueOf(tableName)); | ||
columnFamilies.forEach(cf -> tb.setColumnFamily(ColumnFamilyDescriptorBuilder.of(cf))); | ||
try { | ||
if (admin.tableExists(TableName.valueOf(tableName))) { | ||
return; | ||
} | ||
admin.createTable(tb.build()); | ||
} catch (IOException e) { | ||
throw new RuntimeException(e); | ||
} | ||
} | ||
|
||
private static void insertSchema(String project, List<String> entityNames, Schema schema) | ||
throws IOException { | ||
String tableName = getTableName(project, entityNames); | ||
byte[] schemaReference = serializedSchemaReference(schema); | ||
byte[] schemaKey = createSchemaKey(schemaReference); | ||
Table table = hbaseClient.getTable(TableName.valueOf(tableName)); | ||
Put put = new Put(schemaKey); | ||
put.addColumn("metadata".getBytes(), "avro".getBytes(), schema.toString().getBytes()); | ||
table.put(put); | ||
table.close(); | ||
} | ||
|
||
private static byte[] createSchemaKey(byte[] schemaReference) throws IOException { | ||
String schemaKeyPrefix = "schema#"; | ||
ByteArrayOutputStream concatOutputStream = new ByteArrayOutputStream(); | ||
concatOutputStream.write(schemaKeyPrefix.getBytes()); | ||
concatOutputStream.write(schemaReference); | ||
return concatOutputStream.toByteArray(); | ||
} | ||
|
||
private static byte[] createEntityValue(Schema schema, GenericRecord record) throws IOException { | ||
byte[] schemaReference = serializedSchemaReference(schema); | ||
// Entity-Feature Row | ||
byte[] avroSerializedFeatures = recordToAvro(record, schema); | ||
|
||
ByteArrayOutputStream concatOutputStream = new ByteArrayOutputStream(); | ||
concatOutputStream.write(schemaReference); | ||
concatOutputStream.write("".getBytes()); | ||
concatOutputStream.write(avroSerializedFeatures); | ||
byte[] entityFeatureValue = concatOutputStream.toByteArray(); | ||
|
||
return entityFeatureValue; | ||
} | ||
|
||
private static byte[] recordToAvro(GenericRecord datum, Schema schema) throws IOException { | ||
GenericDatumWriter<Object> writer = new GenericDatumWriter<>(schema); | ||
ByteArrayOutputStream output = new ByteArrayOutputStream(); | ||
Encoder encoder = EncoderFactory.get().binaryEncoder(output, null); | ||
writer.write(datum, encoder); | ||
encoder.flush(); | ||
|
||
return output.toByteArray(); | ||
} | ||
|
||
private static void insertRow( | ||
String project, | ||
List<String> entityNames, | ||
String entityKey, | ||
String featureTableName, | ||
Schema schema, | ||
GenericRecord record) | ||
throws IOException { | ||
byte[] entityFeatureValue = createEntityValue(schema, record); | ||
String tableName = getTableName(project, entityNames); | ||
|
||
// Update Compound Entity-Feature Row | ||
Table table = hbaseClient.getTable(TableName.valueOf(tableName)); | ||
Put put = new Put(entityKey.getBytes()); | ||
put.addColumn(featureTableName.getBytes(), "".getBytes(), entityFeatureValue); | ||
table.put(put); | ||
table.close(); | ||
} | ||
|
||
@Test | ||
public void shouldRetrieveFeaturesSuccessfully() { | ||
HBaseOnlineRetriever retriever = new HBaseOnlineRetriever(hbaseClient); | ||
List<ServingServiceProto.FeatureReference> featureReferences = | ||
Stream.of("trip_cost", "trip_distance") | ||
.map( | ||
f -> | ||
ServingServiceProto.FeatureReference.newBuilder() | ||
.setFeatureTable("rides") | ||
.setName(f) | ||
.build()) | ||
.toList(); | ||
List<String> entityNames = List.of("driver"); | ||
List<ServingServiceProto.GetOnlineFeaturesRequest.EntityRow> entityRows = | ||
List.of(DataGenerator.createEntityRow("driver", DataGenerator.createInt64Value(1), 100)); | ||
List<List<Feature>> featuresForRows = | ||
retriever.getOnlineFeatures(FEAST_PROJECT, entityRows, featureReferences, entityNames); | ||
assertEquals(1, featuresForRows.size()); | ||
List<Feature> features = featuresForRows.get(0); | ||
assertEquals(2, features.size()); | ||
assertEquals( | ||
5L, features.get(0).getFeatureValue(ValueProto.ValueType.Enum.INT64).getInt64Val()); | ||
assertEquals(featureReferences.get(0), features.get(0).getFeatureReference()); | ||
assertEquals( | ||
3.5, features.get(1).getFeatureValue(ValueProto.ValueType.Enum.DOUBLE).getDoubleVal()); | ||
assertEquals(featureReferences.get(1), features.get(1).getFeatureReference()); | ||
} | ||
|
||
@Test | ||
public void shouldFilterOutMissingFeatureRefUsingHbase() { | ||
HBaseOnlineRetriever retriever = new HBaseOnlineRetriever(hbaseClient); | ||
List<ServingServiceProto.FeatureReference> featureReferences = | ||
List.of( | ||
ServingServiceProto.FeatureReference.newBuilder() | ||
.setFeatureTable("rides") | ||
.setName("not_exists") | ||
.build()); | ||
List<String> entityNames = List.of("driver"); | ||
List<ServingServiceProto.GetOnlineFeaturesRequest.EntityRow> entityRows = | ||
List.of(DataGenerator.createEntityRow("driver", DataGenerator.createInt64Value(1), 100)); | ||
List<List<Feature>> features = | ||
retriever.getOnlineFeatures(FEAST_PROJECT, entityRows, featureReferences, entityNames); | ||
assertEquals(1, features.size()); | ||
assertEquals(0, features.get(0).size()); | ||
} | ||
} |