Skip to content

Commit

Permalink
Refactor collection's cgo call
Browse files Browse the repository at this point in the history
Signed-off-by: Enwei Jiao <[email protected]>
  • Loading branch information
jiaoew1991 committed Nov 1, 2023
1 parent 873b29e commit 1a6f644
Show file tree
Hide file tree
Showing 8 changed files with 69 additions and 74 deletions.
50 changes: 23 additions & 27 deletions internal/core/src/segcore/Collection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,52 +17,48 @@

namespace milvus::segcore {

Collection::Collection(const std::string_view collection_proto)
: schema_proto_(collection_proto) {
parse();
Collection::Collection(const milvus::proto::schema::CollectionSchema* schema) {
Assert(schema != nullptr);
collection_name_ = schema->name();
schema_ = Schema::ParseFrom(*schema);
}

void
Collection::parse() {
// if (schema_proto_.empty()) {
// // TODO: remove hard code use unittests are ready
// std::cout << "WARN: Use default schema" << std::endl;
// auto schema = std::make_shared<Schema>();
// schema->AddDebugField("fakevec", DataType::VECTOR_FLOAT, 16, MetricType::METRIC_L2);
// schema->AddDebugField("age", DataType::INT32);
// collection_name_ = "default-collection";
// schema_ = schema;
// return;
// }

Assert(!schema_proto_.empty());
Collection::Collection(const std::string_view schema_proto) {
milvus::proto::schema::CollectionSchema collection_schema;
auto suc = google::protobuf::TextFormat::ParseFromString(
schema_proto_, &collection_schema);
std::string(schema_proto), &collection_schema);
if (!suc) {
LOG_SEGCORE_WARNING_ << "unmarshal schema string failed";
}
collection_name_ = collection_schema.name();
schema_ = Schema::ParseFrom(collection_schema);
}

Collection::Collection(const void* schema_proto, const int64_t length) {
Assert(schema_proto != nullptr);
milvus::proto::schema::CollectionSchema collection_schema;
auto suc = collection_schema.ParseFromArray(schema_proto, length);
if (!suc) {
std::cerr << "unmarshal schema string failed" << std::endl;
LOG_SEGCORE_WARNING_ << "unmarshal schema string failed";
}

collection_name_ = collection_schema.name();
schema_ = Schema::ParseFrom(collection_schema);
}

void
Collection::parseIndexMeta(const std::string_view index_meta_proto_) {
Assert(!index_meta_proto_.empty());
Collection::parseIndexMeta(const void* index_proto, const int64_t length) {
Assert(index_proto != nullptr);

milvus::proto::segcore::CollectionIndexMeta protobuf_indexMeta;
auto suc = google::protobuf::TextFormat::ParseFromString(
std::string(index_meta_proto_), &protobuf_indexMeta);
milvus::proto::segcore::CollectionIndexMeta indexMeta;
auto suc = indexMeta.ParseFromArray(index_proto, length);

if (!suc) {
LOG_SEGCORE_ERROR_ << "unmarshal index meta string failed" << std::endl;
LOG_SEGCORE_ERROR_ << "unmarshal index meta string failed";
return;
}

index_meta_ = std::shared_ptr<CollectionIndexMeta>(
new CollectionIndexMeta(protobuf_indexMeta));
index_meta_ = std::make_shared<CollectionIndexMeta>(indexMeta);
LOG_SEGCORE_INFO_ << "index meta info : " << index_meta_->ToString();
}

Expand Down
10 changes: 4 additions & 6 deletions internal/core/src/segcore/Collection.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,12 @@ namespace milvus::segcore {

class Collection {
public:
explicit Collection(const std::string_view collection_proto);
explicit Collection(const milvus::proto::schema::CollectionSchema* schema);
explicit Collection(const std::string_view schema_proto);
explicit Collection(const void* collection_proto, const int64_t length);

void
parse();

void
parseIndexMeta(const std::string_view index_meta_proto_blob);
parseIndexMeta(const void* index_meta_proto_blob, const int64_t length);

public:
SchemaPtr&
Expand All @@ -47,7 +46,6 @@ class Collection {

private:
std::string collection_name_;
std::string schema_proto_;
SchemaPtr schema_;
IndexMetaPtr index_meta_;
};
Expand Down
13 changes: 7 additions & 6 deletions internal/core/src/segcore/collection_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,18 @@
#include "segcore/Collection.h"

CCollection
NewCollection(const char* schema_proto_blob) {
auto proto = std::string(schema_proto_blob);
auto collection = std::make_unique<milvus::segcore::Collection>(proto);
NewCollection(const void* schema_proto_blob, const int64_t length) {
auto collection = std::make_unique<milvus::segcore::Collection>(
schema_proto_blob, length);
return (void*)collection.release();
}

void
SetIndexMeta(CCollection collection, const char* index_meta_proto_blob) {
SetIndexMeta(CCollection collection,
const void* proto_blob,
const int64_t length) {
auto col = (milvus::segcore::Collection*)collection;
auto proto = std::string_view(index_meta_proto_blob);
col->parseIndexMeta(proto);
col->parseIndexMeta(proto_blob, length);
}

void
Expand Down
8 changes: 6 additions & 2 deletions internal/core/src/segcore/collection_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,21 @@

#pragma once

#include <stdint.h>

#ifdef __cplusplus
extern "C" {
#endif

typedef void* CCollection;

CCollection
NewCollection(const char* schema_proto_blob);
NewCollection(const void* schema_proto_blob, const int64_t length);

void
SetIndexMeta(CCollection collection, const char* index_meta_proto_blob);
SetIndexMeta(CCollection collection,
const void* proto_blob,
const int64_t length);

void
DeleteCollection(CCollection collection);
Expand Down
7 changes: 6 additions & 1 deletion internal/core/unittest/test_c_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,12 @@ TEST(CApiTest, CollectionTest) {

TEST(CApiTest, SetIndexMetaTest) {
auto collection = NewCollection(get_default_schema_config());
SetIndexMeta(collection, get_default_index_meta());

milvus::proto::segcore::CollectionIndexMeta indexMeta;
indexMeta.ParseFromString(get_default_index_meta());
char buffer[indexMeta.ByteSizeLong()];
indexMeta.SerializeToArray(buffer, indexMeta.ByteSizeLong());
SetIndexMeta(collection, buffer, indexMeta.ByteSizeLong());
DeleteCollection(collection);
}

Expand Down
27 changes: 2 additions & 25 deletions internal/core/unittest/test_float16.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "query/ExprImpl.h"
#include "segcore/Reduce.h"
#include "segcore/reduce_c.h"
#include "test_utils/DataGen.h"
#include "test_utils/PbHelper.h"
#include "test_utils/indexbuilder_test_utils.h"

Expand Down Expand Up @@ -253,35 +254,11 @@ generate_collection_schema(std::string metric_type, int dim, bool is_fp16) {
return schema_string;
}

CCollection
NewCollection(const char* schema_proto_blob) {
auto proto = std::string(schema_proto_blob);
auto collection = std::make_unique<milvus::segcore::Collection>(proto);
return (void*)collection.release();
}

TEST(Float16, CApiCPlan) {
std::string schema_string =
generate_collection_schema(knowhere::metric::L2, 16, true);
auto collection = NewCollection(schema_string.c_str());

// const char* dsl_string = R"(
// {
// "bool": {
// "vector": {
// "fakevec": {
// "metric_type": "L2",
// "params": {
// "nprobe": 10
// },
// "query": "$0",
// "topk": 10,
// "round_decimal": 3
// }
// }
// }
// })";

milvus::proto::plan::PlanNode plan_node;
auto vector_anns = plan_node.mutable_vector_anns();
vector_anns->set_vector_type(
Expand Down Expand Up @@ -416,4 +393,4 @@ TEST(Float16, ExecWithPredicate) {

query::Json json = SearchResultToJson(*sr);
std::cout << json.dump(2);
}
}
9 changes: 9 additions & 0 deletions internal/core/unittest/test_utils/DataGen.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,14 @@
#include "index/StringIndexSort.h"
#include "index/VectorMemIndex.h"
#include "query/SearchOnIndex.h"
#include "segcore/Collection.h"
#include "segcore/SegmentGrowingImpl.h"
#include "segcore/SegmentSealedImpl.h"
#include "segcore/Utils.h"
#include "knowhere/comp/index_param.h"

#include "PbHelper.h"
#include "segcore/collection_c.h"

using boost::algorithm::starts_with;

Expand Down Expand Up @@ -1012,4 +1014,11 @@ GenRandomIds(int rows, int64_t seed = 42) {
return ids_ds;
}

inline CCollection
NewCollection(const char* schema_proto_blob) {
auto proto = std::string(schema_proto_blob);
auto collection = std::make_unique<milvus::segcore::Collection>(proto);
return (void*)collection.release();
}

} // namespace milvus::segcore
19 changes: 12 additions & 7 deletions internal/querynodev2/segments/collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,16 +195,21 @@ func NewCollection(collectionID int64, schema *schemapb.CollectionSchema, indexM
CCollection
NewCollection(const char* schema_proto_blob);
*/
schemaBlob := proto.MarshalTextString(schema)
cSchemaBlob := C.CString(schemaBlob)
defer C.free(unsafe.Pointer(cSchemaBlob))
schemaBlob, err := proto.Marshal(schema)
if err != nil {
log.Warn("marshal schema failed", zap.Error(err))
return nil
}

collection := C.NewCollection(cSchemaBlob)
collection := C.NewCollection(unsafe.Pointer(&schemaBlob[0]), (C.int64_t)(len(schemaBlob)))

if indexMeta != nil && len(indexMeta.GetIndexMetas()) > 0 && indexMeta.GetMaxIndexRowCount() > 0 {
indexMetaBlob := proto.MarshalTextString(indexMeta)
cIndexMetaBlob := C.CString(indexMetaBlob)
C.SetIndexMeta(collection, cIndexMetaBlob)
indexMetaBlob, err := proto.Marshal(indexMeta)
if err != nil {
log.Warn("marshal index meta failed", zap.Error(err))
return nil
}
C.SetIndexMeta(collection, unsafe.Pointer(&indexMetaBlob[0]), (C.int64_t)(len(indexMetaBlob)))
}

return &Collection{
Expand Down

0 comments on commit 1a6f644

Please sign in to comment.