diff --git a/go/coordinator/internal/utils/pulsar_admin.go b/go/coordinator/internal/utils/pulsar_admin.go index f6a38f88267..c8258ecbf54 100644 --- a/go/coordinator/internal/utils/pulsar_admin.go +++ b/go/coordinator/internal/utils/pulsar_admin.go @@ -34,7 +34,7 @@ func CreateTopics(pulsarAdminURL string, tenant string, namespace string, topics log.Info("Topic already exists", zap.String("topic", topic), zap.Any("metadata", metadata)) continue } - err = admin.Topics().Create(*topicName, 1) + err = admin.Topics().Create(*topicName, 0) if err != nil { log.Error("Failed to create topic", zap.Error(err)) return err diff --git a/idl/chromadb/proto/chroma.proto b/idl/chromadb/proto/chroma.proto index 51579aae921..5676c0efb74 100644 --- a/idl/chromadb/proto/chroma.proto +++ b/idl/chromadb/proto/chroma.proto @@ -98,7 +98,7 @@ message VectorEmbeddingRecord { message VectorQueryResult { string id = 1; bytes seq_id = 2; - double distance = 3; + float distance = 3; optional Vector vector = 4; } diff --git a/k8s/deployment/segment-server.yaml b/k8s/deployment/segment-server.yaml index 1df7cec9ff4..0f2c6e02858 100644 --- a/k8s/deployment/segment-server.yaml +++ b/k8s/deployment/segment-server.yaml @@ -32,32 +32,18 @@ spec: spec: containers: - name: segment-server - image: server + image: worker imagePullPolicy: IfNotPresent - command: ["python", "-m", "chromadb.segment.impl.distributed.server"] + command: ["cargo", "run"] ports: - containerPort: 50051 volumeMounts: - name: chroma mountPath: /index_data env: - - name: IS_PERSISTENT - value: "TRUE" - - name: CHROMA_PRODUCER_IMPL - value: "chromadb.ingest.impl.pulsar.PulsarProducer" - - name: CHROMA_CONSUMER_IMPL - value: "chromadb.ingest.impl.pulsar.PulsarConsumer" - - name: PULSAR_BROKER_URL - value: "pulsar.chroma" - - name: PULSAR_BROKER_PORT - value: "6650" - - name: PULSAR_ADMIN_PORT - value: "8080" - - name: CHROMA_SERVER_GRPC_PORT - value: "50051" - - name: CHROMA_COLLECTION_ASSIGNMENT_POLICY_IMPL - value: "chromadb.ingest.impl.simple_policy.RendezvousHashingAssignmentPolicy" - - name: MY_POD_IP + - name: CHROMA_WORKER__PULSAR_URL + value: pulsar://pulsar.chroma:6650 + - name: CHROMA_WORKER__MY_IP valueFrom: fieldRef: fieldPath: status.podIP diff --git a/k8s/test/coordinator_service.yaml b/k8s/test/coordinator_service.yaml index 710a53fa1ad..37334b12187 100644 --- a/k8s/test/coordinator_service.yaml +++ b/k8s/test/coordinator_service.yaml @@ -1,7 +1,7 @@ apiVersion: v1 kind: Service metadata: - name: coordinator + name: coordinator-lb namespace: chroma spec: ports: diff --git a/k8s/test/minio.yaml b/k8s/test/minio.yaml new file mode 100644 index 00000000000..148c5170fd8 --- /dev/null +++ b/k8s/test/minio.yaml @@ -0,0 +1,52 @@ +apiVersion: apps/v1 +kind: Deployment +metadata: + name: minio-deployment + namespace: chroma +spec: + selector: + matchLabels: + app: minio + strategy: + type: Recreate + template: + metadata: + labels: + app: minio + spec: + volumes: + - name: minio + emptyDir: {} + containers: + - name: minio + image: minio/minio:latest + args: + - server + - /storage + env: + - name: MINIO_ACCESS_KEY + value: "minio" + - name: MINIO_SECRET_KEY + value: "minio123" + ports: + - containerPort: 9000 + hostPort: 9000 + volumeMounts: + - name: minio + mountPath: /storage + +--- + +apiVersion: v1 +kind: Service +metadata: + name: minio-lb + namespace: chroma +spec: + ports: + - name: http + port: 9000 + targetPort: 9000 + selector: + app: minio + type: LoadBalancer diff --git a/k8s/test/pulsar_service.yaml b/k8s/test/pulsar_service.yaml index 1053c709afa..56ff6440db2 100644 --- a/k8s/test/pulsar_service.yaml +++ b/k8s/test/pulsar_service.yaml @@ -5,7 +5,7 @@ apiVersion: v1 kind: Service metadata: - name: pulsar + name: pulsar-lb namespace: chroma spec: ports: diff --git a/k8s/test/segment_server_service.yml b/k8s/test/segment_server_service.yml new file mode 100644 index 00000000000..7463333deef --- /dev/null +++ b/k8s/test/segment_server_service.yml @@ -0,0 +1,13 @@ +apiVersion: v1 +kind: Service +metadata: + name: segment-server-lb + namespace: chroma +spec: + ports: + - name: segment-server-port + port: 50052 + targetPort: 50051 + selector: + app: segment-server + type: LoadBalancer diff --git a/rust/worker/Dockerfile b/rust/worker/Dockerfile index 9fec202fda1..96c4b08a4f0 100644 --- a/rust/worker/Dockerfile +++ b/rust/worker/Dockerfile @@ -1,5 +1,8 @@ FROM rust:1.74.1 as builder +WORKDIR / +RUN git clone https://github.com/chroma-core/hnswlib.git + WORKDIR /chroma/ COPY . . @@ -11,5 +14,6 @@ RUN curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v25.1 RUN cargo build -# For now this runs cargo test since we have no main binary -CMD ["cargo", "test"] +WORKDIR /chroma/rust/worker + +CMD ["cargo", "run"] diff --git a/rust/worker/chroma_config.yaml b/rust/worker/chroma_config.yaml index 90f36970d06..e760dcdb97c 100644 --- a/rust/worker/chroma_config.yaml +++ b/rust/worker/chroma_config.yaml @@ -4,7 +4,8 @@ # for now we nest it in the worker directory worker: - my_ip: "10.244.0.90" + my_ip: "10.244.0.9" + my_port: 50051 num_indexing_threads: 4 pulsar_url: "pulsar://127.0.0.1:6650" pulsar_tenant: "public" @@ -18,10 +19,10 @@ worker: memberlist_name: "worker-memberlist" queue_size: 100 ingest: - queue_size: 100 + queue_size: 10000 sysdb: Grpc: - host: "localhost" + host: "coordinator.chroma" port: 50051 segment_manager: storage_path: "./tmp/segment_manager/" diff --git a/rust/worker/src/config.rs b/rust/worker/src/config.rs index ab5212cf7bf..4d82cc472ad 100644 --- a/rust/worker/src/config.rs +++ b/rust/worker/src/config.rs @@ -4,7 +4,7 @@ use serde::Deserialize; use crate::errors::ChromaError; -const DEFAULT_CONFIG_PATH: &str = "chroma_config.yaml"; +const DEFAULT_CONFIG_PATH: &str = "./chroma_config.yaml"; const ENV_PREFIX: &str = "CHROMA_"; #[derive(Deserialize)] @@ -97,6 +97,7 @@ impl RootConfig { /// have its own field in this struct for its Config struct. pub(crate) struct WorkerConfig { pub(crate) my_ip: String, + pub(crate) my_port: u16, pub(crate) num_indexing_threads: u32, pub(crate) pulsar_tenant: String, pub(crate) pulsar_namespace: String, @@ -134,6 +135,7 @@ mod tests { r#" worker: my_ip: "192.0.0.1" + my_port: 50051 num_indexing_threads: 4 pulsar_tenant: "public" pulsar_namespace: "default" @@ -175,6 +177,7 @@ mod tests { r#" worker: my_ip: "192.0.0.1" + my_port: 50051 num_indexing_threads: 4 pulsar_tenant: "public" pulsar_namespace: "default" @@ -232,6 +235,7 @@ mod tests { r#" worker: my_ip: "192.0.0.1" + my_port: 50051 pulsar_tenant: "public" pulsar_namespace: "default" kube_namespace: "chroma" @@ -265,6 +269,7 @@ mod tests { fn test_config_with_env_override() { Jail::expect_with(|jail| { let _ = jail.set_env("CHROMA_WORKER__MY_IP", "192.0.0.1"); + let _ = jail.set_env("CHROMA_WORKER__MY_PORT", 50051); let _ = jail.set_env("CHROMA_WORKER__PULSAR_TENANT", "A"); let _ = jail.set_env("CHROMA_WORKER__PULSAR_NAMESPACE", "B"); let _ = jail.set_env("CHROMA_WORKER__KUBE_NAMESPACE", "C"); @@ -292,6 +297,7 @@ mod tests { ); let config = RootConfig::load(); assert_eq!(config.worker.my_ip, "192.0.0.1"); + assert_eq!(config.worker.my_port, 50051); assert_eq!(config.worker.num_indexing_threads, num_cpus::get() as u32); assert_eq!(config.worker.pulsar_tenant, "A"); assert_eq!(config.worker.pulsar_namespace, "B"); diff --git a/rust/worker/src/ingest/ingest.rs b/rust/worker/src/ingest/ingest.rs index 8bbc1d86ee3..e13689a9633 100644 --- a/rust/worker/src/ingest/ingest.rs +++ b/rust/worker/src/ingest/ingest.rs @@ -82,6 +82,7 @@ impl Configurable for Ingest { worker_config.pulsar_namespace.clone(), ); + println!("Pulsar connection url: {}", worker_config.pulsar_url); let pulsar = match Pulsar::builder(worker_config.pulsar_url.clone(), TokioExecutor) .build() .await diff --git a/rust/worker/src/lib.rs b/rust/worker/src/lib.rs index e24bf64c416..39a90984c81 100644 --- a/rust/worker/src/lib.rs +++ b/rust/worker/src/lib.rs @@ -5,6 +5,7 @@ mod index; mod ingest; mod memberlist; mod segment; +mod server; mod sysdb; mod system; mod types; @@ -12,6 +13,8 @@ mod types; use config::Configurable; use memberlist::MemberlistProvider; +use crate::sysdb::sysdb::SysDb; + mod chroma_proto { tonic::include_proto!("chroma"); } @@ -61,8 +64,18 @@ pub async fn worker_entrypoint() { segment_ingestor_receivers.push(recv); } + let mut worker_server = match server::WorkerServer::try_from_config(&config.worker).await { + Ok(worker_server) => worker_server, + Err(err) => { + println!("Failed to create worker server component: {:?}", err); + return; + } + }; + worker_server.set_segment_manager(segment_manager.clone()); + // Boot the system - // memberlist -> ingest -> scheduler -> NUM_THREADS x segment_ingestor + // memberlist -> ingest -> scheduler -> NUM_THREADS x segment_ingestor -> segment_manager + // server <- segment_manager for recv in segment_ingestor_receivers { scheduler.subscribe(recv); @@ -76,10 +89,14 @@ pub async fn worker_entrypoint() { memberlist.subscribe(recv); let mut memberlist_handle = system.start_component(memberlist); + let server_join_handle = tokio::spawn(async move { + crate::server::WorkerServer::run(worker_server).await; + }); + // Join on all handles let _ = tokio::join!( ingest_handle.join(), memberlist_handle.join(), - scheduler_handler.join() + scheduler_handler.join(), ); } diff --git a/rust/worker/src/segment/distributed_hnsw_segment.rs b/rust/worker/src/segment/distributed_hnsw_segment.rs index 1336436fa8b..d6f9ca26525 100644 --- a/rust/worker/src/segment/distributed_hnsw_segment.rs +++ b/rust/worker/src/segment/distributed_hnsw_segment.rs @@ -1,3 +1,4 @@ +use num_bigint::BigInt; use parking_lot::{Mutex, RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard}; use std::collections::HashMap; use std::sync::atomic::AtomicUsize; @@ -5,12 +6,13 @@ use std::sync::Arc; use crate::errors::ChromaError; use crate::index::{HnswIndex, HnswIndexConfig, Index, IndexConfig}; -use crate::types::{EmbeddingRecord, Operation, Segment}; +use crate::types::{EmbeddingRecord, Operation, Segment, VectorEmbeddingRecord}; pub(crate) struct DistributedHNSWSegment { index: Arc>, id: AtomicUsize, user_id_to_id: Arc>>, + id_to_user_id: Arc>>, index_config: IndexConfig, hnsw_config: HnswIndexConfig, } @@ -33,6 +35,7 @@ impl DistributedHNSWSegment { index: index, id: AtomicUsize::new(0), user_id_to_id: Arc::new(RwLock::new(HashMap::new())), + id_to_user_id: Arc::new(RwLock::new(HashMap::new())), index_config: index_config, hnsw_config, }); @@ -63,7 +66,10 @@ impl DistributedHNSWSegment { self.user_id_to_id .write() .insert(record.id.clone(), next_id); - println!("DIS SEGMENT Adding item: {}", next_id); + self.id_to_user_id + .write() + .insert(next_id, record.id.clone()); + println!("Segment adding item: {}", next_id); self.index.read().add(next_id, &vector); } None => { @@ -81,4 +87,50 @@ impl DistributedHNSWSegment { } } } + + pub(crate) fn get_records(&self, ids: Vec) -> Vec> { + let mut records = Vec::new(); + let user_id_to_id = self.user_id_to_id.read(); + let index = self.index.read(); + for id in ids { + let internal_id = match user_id_to_id.get(&id) { + Some(internal_id) => internal_id, + None => { + // TODO: Error + return records; + } + }; + let vector = index.get(*internal_id); + match vector { + Some(vector) => { + let record = VectorEmbeddingRecord { + id: id, + seq_id: BigInt::from(0), + vector, + }; + records.push(Box::new(record)); + } + None => { + // TODO: error + } + } + } + return records; + } + + pub(crate) fn query(&self, vector: &[f32], k: usize) -> (Vec, Vec) { + let index = self.index.read(); + let mut return_user_ids = Vec::new(); + let (ids, distances) = index.query(vector, k); + let user_ids = self.id_to_user_id.read(); + for id in ids { + match user_ids.get(&id) { + Some(user_id) => return_user_ids.push(user_id.clone()), + None => { + // TODO: error + } + }; + } + return (return_user_ids, distances); + } } diff --git a/rust/worker/src/segment/segment_manager.rs b/rust/worker/src/segment/segment_manager.rs index ee8443c14be..314a5b99cdf 100644 --- a/rust/worker/src/segment/segment_manager.rs +++ b/rust/worker/src/segment/segment_manager.rs @@ -2,9 +2,11 @@ use crate::{ config::{Configurable, WorkerConfig}, errors::ChromaError, sysdb::sysdb::{GrpcSysDb, SysDb}, + types::VectorQueryResult, }; use async_trait::async_trait; use k8s_openapi::api::node; +use num_bigint::BigInt; use parking_lot::{ MappedRwLockReadGuard, RwLock, RwLockReadGuard, RwLockUpgradableReadGuard, RwLockWriteGuard, }; @@ -13,7 +15,7 @@ use std::sync::Arc; use uuid::Uuid; use super::distributed_hnsw_segment::DistributedHNSWSegment; -use crate::types::{EmbeddingRecord, MetadataValue, Segment, SegmentScope}; +use crate::types::{EmbeddingRecord, MetadataValue, Segment, SegmentScope, VectorEmbeddingRecord}; #[derive(Clone)] pub(crate) struct SegmentManager { @@ -68,6 +70,8 @@ impl SegmentManager { } }; + println!("Writing to segment id {}", target_segment.id); + let segment_cache = self.inner.vector_segments.upgradable_read(); match segment_cache.get(&target_segment.id) { Some(segment) => { @@ -97,6 +101,84 @@ impl SegmentManager { } } + pub(crate) async fn get_records( + &self, + segment_id: &Uuid, + ids: Vec, + ) -> Result>, &'static str> { + // TODO: Load segment if not in cache + let segment_cache = self.inner.vector_segments.read(); + match segment_cache.get(segment_id) { + Some(segment) => { + return Ok(segment.get_records(ids)); + } + None => { + return Err("No segment found"); + } + } + } + + pub(crate) async fn query_vector( + &self, + segment_id: &Uuid, + vectors: &[f32], + k: usize, + include_vector: bool, + ) -> Result>, &'static str> { + let segment_cache = self.inner.vector_segments.read(); + match segment_cache.get(segment_id) { + Some(segment) => { + let mut results = Vec::new(); + let (ids, distances) = segment.query(vectors, k); + for (id, distance) in ids.iter().zip(distances.iter()) { + let fetched_vector = match include_vector { + true => Some(segment.get_records(vec![id.clone()])), + false => None, + }; + + let mut target_record = None; + if include_vector { + target_record = match fetched_vector { + Some(fetched_vectors) => { + if fetched_vectors.len() == 0 { + return Err("No vector found"); + } + let mut target_vec = None; + for vec in fetched_vectors.into_iter() { + if vec.id == *id { + target_vec = Some(vec); + break; + } + } + target_vec + } + None => { + return Err("No vector found"); + } + }; + } + + let ret_vec = match target_record { + Some(target_record) => Some(target_record.vector), + None => None, + }; + + let result = Box::new(VectorQueryResult { + id: id.to_string(), + seq_id: BigInt::from(0), + distance: *distance, + vector: ret_vec, + }); + results.push(result); + } + return Ok(results); + } + None => { + return Err("No segment found"); + } + } + } + async fn get_segments( &mut self, collection_uuid: &Uuid, diff --git a/rust/worker/src/server.rs b/rust/worker/src/server.rs new file mode 100644 index 00000000000..1ecc6ba2e70 --- /dev/null +++ b/rust/worker/src/server.rs @@ -0,0 +1,188 @@ +use std::f32::consts::E; + +use crate::chroma_proto; +use crate::chroma_proto::{ + GetVectorsRequest, GetVectorsResponse, QueryVectorsRequest, QueryVectorsResponse, +}; +use crate::config::{Configurable, WorkerConfig}; +use crate::errors::ChromaError; +use crate::segment::SegmentManager; +use crate::types::ScalarEncoding; +use async_trait::async_trait; +use kube::core::request; +use tonic::{transport::Server, Request, Response, Status}; +use uuid::Uuid; + +pub struct WorkerServer { + segment_manager: Option, + port: u16, +} + +#[async_trait] +impl Configurable for WorkerServer { + async fn try_from_config(config: &WorkerConfig) -> Result> { + Ok(WorkerServer { + segment_manager: None, + port: config.my_port, + }) + } +} + +impl WorkerServer { + pub(crate) async fn run(worker: WorkerServer) -> Result<(), Box> { + let addr = format!("[::]:{}", worker.port).parse().unwrap(); + println!("Worker listening on {}", addr); + let server = Server::builder() + .add_service(chroma_proto::vector_reader_server::VectorReaderServer::new( + worker, + )) + .serve(addr) + .await?; + println!("Worker shutting down"); + + Ok(()) + } + + pub(crate) fn set_segment_manager(&mut self, segment_manager: SegmentManager) { + self.segment_manager = Some(segment_manager); + } +} + +#[tonic::async_trait] +impl chroma_proto::vector_reader_server::VectorReader for WorkerServer { + async fn get_vectors( + &self, + request: Request, + ) -> Result, Status> { + let request = request.into_inner(); + let segment_uuid = match Uuid::parse_str(&request.segment_id) { + Ok(uuid) => uuid, + Err(_) => { + return Err(Status::invalid_argument("Invalid UUID")); + } + }; + + let segment_manager = match self.segment_manager { + Some(ref segment_manager) => segment_manager, + None => { + return Err(Status::internal("No segment manager found")); + } + }; + + let records = match segment_manager + .get_records(&segment_uuid, request.ids) + .await + { + Ok(records) => records, + Err(e) => { + return Err(Status::internal(format!("Error getting records: {}", e))); + } + }; + + let mut proto_records = Vec::new(); + for record in records { + let sed_id_bytes = record.seq_id.to_bytes_le(); + let dim = record.vector.len(); + let proto_vector = (record.vector, ScalarEncoding::FLOAT32, dim).try_into(); + match proto_vector { + Ok(proto_vector) => { + let proto_record = chroma_proto::VectorEmbeddingRecord { + id: record.id, + seq_id: sed_id_bytes.1, + vector: Some(proto_vector), + }; + proto_records.push(proto_record); + } + Err(e) => { + return Err(Status::internal(format!("Error converting vector: {}", e))); + } + } + } + + let resp = chroma_proto::GetVectorsResponse { + records: proto_records, + }; + + Ok(Response::new(resp)) + } + + async fn query_vectors( + &self, + request: Request, + ) -> Result, Status> { + let request = request.into_inner(); + let segment_uuid = match Uuid::parse_str(&request.segment_id) { + Ok(uuid) => uuid, + Err(_) => { + return Err(Status::invalid_argument("Invalid Segment UUID")); + } + }; + + let segment_manager = match self.segment_manager { + Some(ref segment_manager) => segment_manager, + None => { + return Err(Status::internal("No segment manager found")); + } + }; + + let mut proto_results_for_all = Vec::new(); + for proto_query_vector in request.vectors { + let (query_vector, encoding) = match proto_query_vector.try_into() { + Ok((vector, encoding)) => (vector, encoding), + Err(e) => { + return Err(Status::internal(format!("Error converting vector: {}", e))); + } + }; + + let results = match segment_manager + .query_vector( + &segment_uuid, + &query_vector, + request.k as usize, + request.include_embeddings, + ) + .await + { + Ok(results) => results, + Err(e) => { + return Err(Status::internal(format!("Error querying segment: {}", e))); + } + }; + + let mut proto_results = Vec::new(); + for query_result in results { + let proto_result = chroma_proto::VectorQueryResult { + id: query_result.id, + seq_id: query_result.seq_id.to_bytes_le().1, + distance: query_result.distance, + vector: match query_result.vector { + Some(vector) => { + match (vector, ScalarEncoding::FLOAT32, query_vector.len()).try_into() { + Ok(proto_vector) => Some(proto_vector), + Err(e) => { + return Err(Status::internal(format!( + "Error converting vector: {}", + e + ))); + } + } + } + None => None, + }, + }; + proto_results.push(proto_result); + } + + let vector_query_results = chroma_proto::VectorQueryResults { + results: proto_results, + }; + proto_results_for_all.push(vector_query_results); + } + + let resp = chroma_proto::QueryVectorsResponse { + results: proto_results_for_all, + }; + + return Ok(Response::new(resp)); + } +} diff --git a/rust/worker/src/sysdb/sysdb.rs b/rust/worker/src/sysdb/sysdb.rs index cf187c35638..ba8be18fdf5 100644 --- a/rust/worker/src/sysdb/sysdb.rs +++ b/rust/worker/src/sysdb/sysdb.rs @@ -87,6 +87,7 @@ impl Configurable for GrpcSysDb { SysDbConfig::Grpc(my_config) => { let host = &my_config.host; let port = &my_config.port; + println!("Connecting to sysdb at {}:{}", host, port); let connection_string = format!("http://{}:{}", host, port); let client = sys_db_client::SysDbClient::connect(connection_string).await; match client { diff --git a/rust/worker/src/types/embedding_record.rs b/rust/worker/src/types/embedding_record.rs index 2b4f2361e0a..14957a85349 100644 --- a/rust/worker/src/types/embedding_record.rs +++ b/rust/worker/src/types/embedding_record.rs @@ -166,6 +166,58 @@ fn vec_to_f32(bytes: &[u8]) -> Result<&[f32], VectorConversionError> { } } +fn f32_to_vec(vector: &[f32]) -> Vec { + unsafe { + std::slice::from_raw_parts( + vector.as_ptr() as *const u8, + vector.len() * std::mem::size_of::(), + ) + } + .to_vec() +} + +impl TryFrom<(Vec, ScalarEncoding, usize)> for chroma_proto::Vector { + type Error = VectorConversionError; + + fn try_from( + (vector, encoding, dimension): (Vec, ScalarEncoding, usize), + ) -> Result { + let proto_vector = chroma_proto::Vector { + vector: f32_to_vec(&vector), + encoding: encoding as i32, + dimension: dimension as i32, + }; + Ok(proto_vector) + } +} + +/* +=========================================== +Vector Embedding Record +=========================================== +*/ + +#[derive(Debug)] +pub(crate) struct VectorEmbeddingRecord { + pub(crate) id: String, + pub(crate) seq_id: SeqId, + pub(crate) vector: Vec, +} + +/* +=========================================== +Vector Query Result +=========================================== + */ + +#[derive(Debug)] +pub(crate) struct VectorQueryResult { + pub(crate) id: String, + pub(crate) seq_id: SeqId, + pub(crate) distance: f32, + pub(crate) vector: Option>, +} + #[cfg(test)] mod tests { use std::collections::HashMap;