Skip to content

Commit

Permalink
done
Browse files Browse the repository at this point in the history
  • Loading branch information
HammadB committed Dec 19, 2023
1 parent 2ece206 commit 0550e4f
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 12 deletions.
2 changes: 1 addition & 1 deletion k8s/deployment/segment-server.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ spec:
env:
- name: CHROMA_WORKER__PULSAR_URL
value: pulsar://pulsar.chroma:6650
- name: CHROMA_WORKER_MY_POD_IP
- name: CHROMA_WORKER__MY_IP
valueFrom:
fieldRef:
fieldPath: status.podIP
Expand Down
13 changes: 13 additions & 0 deletions k8s/test/segment_server_service.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
apiVersion: v1
kind: Service
metadata:
name: segment-server-lb
namespace: chroma
spec:
ports:
- name: segment-server-port
port: 50052
targetPort: 50051
selector:
app: segment-server
type: LoadBalancer
29 changes: 22 additions & 7 deletions rust/worker/src/segment/distributed_hnsw_segment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ pub(crate) struct DistributedHNSWSegment {
index: Arc<RwLock<HnswIndex>>,
id: AtomicUsize,
user_id_to_id: Arc<RwLock<HashMap<String, usize>>>,
id_to_user_id: Arc<RwLock<HashMap<usize, String>>>,
index_config: IndexConfig,
hnsw_config: HnswIndexConfig,
}
Expand All @@ -34,6 +35,7 @@ impl DistributedHNSWSegment {
index: index,
id: AtomicUsize::new(0),
user_id_to_id: Arc::new(RwLock::new(HashMap::new())),
id_to_user_id: Arc::new(RwLock::new(HashMap::new())),
index_config: index_config,
hnsw_config,
});
Expand Down Expand Up @@ -64,7 +66,10 @@ impl DistributedHNSWSegment {
self.user_id_to_id
.write()
.insert(record.id.clone(), next_id);
println!("DIS SEGMENT Adding item: {}", next_id);
self.id_to_user_id
.write()
.insert(next_id, record.id.clone());
println!("Segment adding item: {}", next_id);
self.index.read().add(next_id, &vector);
}
None => {
Expand All @@ -88,18 +93,18 @@ impl DistributedHNSWSegment {
let user_id_to_id = self.user_id_to_id.read();
let index = self.index.read();
for id in ids {
let id = match user_id_to_id.get(&id) {
Some(id) => id,
let internal_id = match user_id_to_id.get(&id) {
Some(internal_id) => internal_id,
None => {
// TODO: Error
return records;
}
};
let vector = index.get(*id);
let vector = index.get(*internal_id);
match vector {
Some(vector) => {
let record = VectorEmbeddingRecord {
id: id.to_string(),
id: id,
seq_id: BigInt::from(0),
vector,
};
Expand All @@ -113,9 +118,19 @@ impl DistributedHNSWSegment {
return records;
}

pub(crate) fn query(&self, vector: &[f32], k: usize) -> (Vec<usize>, Vec<f32>) {
pub(crate) fn query(&self, vector: &[f32], k: usize) -> (Vec<String>, Vec<f32>) {
let index = self.index.read();
let mut return_user_ids = Vec::new();
let (ids, distances) = index.query(vector, k);
return (ids, distances);
let user_ids = self.id_to_user_id.read();
for id in ids {
match user_ids.get(&id) {
Some(user_id) => return_user_ids.push(user_id.clone()),
None => {
// TODO: error
}
};
}
return (return_user_ids, distances);
}
}
5 changes: 2 additions & 3 deletions rust/worker/src/segment/segment_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,6 @@ impl SegmentManager {
let mut results = Vec::new();
let (ids, distances) = segment.query(vectors, k);
for (id, distance) in ids.iter().zip(distances.iter()) {
let id: String = id.to_string();
let fetched_vector = match include_vector {
true => Some(segment.get_records(vec![id.clone()])),
false => None,
Expand All @@ -146,7 +145,7 @@ impl SegmentManager {
}
let mut target_vec = None;
for vec in fetched_vectors.into_iter() {
if vec.id == id {
if vec.id == *id {
target_vec = Some(vec);
break;
}
Expand All @@ -165,7 +164,7 @@ impl SegmentManager {
};

let result = Box::new(VectorQueryResult {
id,
id: id.to_string(),
seq_id: BigInt::from(0),
distance: *distance,
vector: ret_vec,
Expand Down
2 changes: 1 addition & 1 deletion rust/worker/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ impl Configurable for WorkerServer {

impl WorkerServer {
pub(crate) async fn run(worker: WorkerServer) -> Result<(), Box<dyn std::error::Error>> {
let addr = format!("[::1]:{}", worker.port).parse().unwrap();
let addr = format!("[::]:{}", worker.port).parse().unwrap();
println!("Worker listening on {}", addr);
let server = Server::builder()
.add_service(chroma_proto::vector_reader_server::VectorReaderServer::new(
Expand Down

0 comments on commit 0550e4f

Please sign in to comment.