From a666b605bf6195cf481beb08a921caadcc6b4950 Mon Sep 17 00:00:00 2001 From: nael Date: Mon, 16 Sep 2024 07:40:36 +0200 Subject: [PATCH] :sparkles: Added multi tenancy to vector dbs --- .env.example | 5 +- .../RAGSettings/RAGItemDisplay.tsx | 11 +- .../Configuration/RAGSettings/RAGItemList.tsx | 2 +- .../Configuration/RAGSettings/RAGLayout.tsx | 2 +- .../RAGSettings/RAGSettingsPage.tsx | 2 +- docker-compose.dev.yml | 4 + docker-compose.source.yml | 4 + docker-compose.yml | 4 + .../environment/environment.service.ts | 12 +- .../api/src/@core/rag/document.processor.ts | 11 +- packages/api/src/@core/rag/rag.controller.ts | 35 +++-- packages/api/src/@core/rag/rag.service.ts | 3 +- .../rag/vecdb/chromadb/chromadb.service.ts | 47 +++--- .../@core/rag/vecdb/milvus/milvus.service.ts | 140 ++++++++++-------- .../rag/vecdb/pinecone/pinecone.service.ts | 11 +- .../@core/rag/vecdb/qdrant/qdrant.service.ts | 113 +++++++++++--- .../rag/vecdb/vecdb.credentials.service.ts | 23 ++- .../api/src/@core/rag/vecdb/vecdb.service.ts | 10 +- .../rag/vecdb/weaviate/weaviate.service.ts | 108 +++++++++++--- packages/api/swagger/swagger-spec.yaml | 8 +- 20 files changed, 386 insertions(+), 169 deletions(-) diff --git a/.env.example b/.env.example index 0c4fe65b2..85cc0f535 100644 --- a/.env.example +++ b/.env.example @@ -178,19 +178,22 @@ NEXT_PUBLIC_DISTRIBUTION=selfhost # selfhost or managed ## pinecone PINECONE_API_KEY= PINECONE_INDEX_NAME= - ## qdrant QDRANT_BASE_URL= QDRANT_API_KEY= +QDRANT_COLLECTION_NAME= ## chroma CHROMADB_URL= +CHROMADB_COLLECTION_NAME= ## weaviate WEAVIATE_URL= WEAVIATE_API_KEY= +WEAVIATE_CLASS_NAME= # turbopuffer TURBOPUFFER_API_KEY= # milvus MILVUS_ADDRESS= +MILVUS_COLLECTION_NAME= # EMBEDDINGS JINA_API_KEY= diff --git a/apps/webapp/src/components/Configuration/RAGSettings/RAGItemDisplay.tsx b/apps/webapp/src/components/Configuration/RAGSettings/RAGItemDisplay.tsx index 99e56567c..8e7e7711c 100644 --- a/apps/webapp/src/components/Configuration/RAGSettings/RAGItemDisplay.tsx +++ b/apps/webapp/src/components/Configuration/RAGSettings/RAGItemDisplay.tsx @@ -24,6 +24,8 @@ const formSchema = z.object({ url: z.string().optional(), indexName: z.string().optional(), embeddingApiKey: z.string().optional(), + collectionName: z.string().optional(), + className: z.string().optional(), }); interface ItemDisplayProps { @@ -183,11 +185,11 @@ export function RAGItemDisplay({ item, type }: ItemDisplayProps) { case 'pinecone': return ['apiKey', 'indexName']; case 'qdrant': - return ['apiKey', 'baseUrl']; + return ['apiKey', 'baseUrl', 'collectionName']; case 'chromadb': - return ['url']; + return ['url', 'collectionName']; case 'weaviate': - return ['apiKey', 'url']; + return ['apiKey', 'url', 'className']; case 'openai_ada_small_1536': case 'openai_ada_large_3072': case 'openai_ada_002': @@ -210,13 +212,16 @@ export function RAGItemDisplay({ item, type }: ItemDisplayProps) { case 'qdrant': form.setValue("apiKey", data[0]); form.setValue("baseUrl", data[1]); + form.setValue("collectionName", data[1]); break; case 'chromadb': form.setValue("url", data[0]); + form.setValue("collectionName", data[1]); break; case 'weaviate': form.setValue("apiKey", data[0]); form.setValue("url", data[1]); + form.setValue("className", data[1]); break; case 'openai_ada_small_1536': case 'openai_ada_large_3072': diff --git a/apps/webapp/src/components/Configuration/RAGSettings/RAGItemList.tsx b/apps/webapp/src/components/Configuration/RAGSettings/RAGItemList.tsx index d5a8d79c6..cf0e1b17d 100644 --- a/apps/webapp/src/components/Configuration/RAGSettings/RAGItemList.tsx +++ b/apps/webapp/src/components/Configuration/RAGSettings/RAGItemList.tsx @@ -3,7 +3,7 @@ import { cn } from "@/lib/utils" import { Button } from "@/components/ui/button" import { ScrollArea } from "@/components/ui/scroll-area" import { vectorDatabases, embeddingModels } from "./utils" -import { useRagItem } from "./useRAGItem" +import { useRagItem } from "./useRagItem" interface RAGItemListProps { items: (typeof vectorDatabases[number] | typeof embeddingModels[number])[]; diff --git a/apps/webapp/src/components/Configuration/RAGSettings/RAGLayout.tsx b/apps/webapp/src/components/Configuration/RAGSettings/RAGLayout.tsx index a64a8d279..5dc141338 100644 --- a/apps/webapp/src/components/Configuration/RAGSettings/RAGLayout.tsx +++ b/apps/webapp/src/components/Configuration/RAGSettings/RAGLayout.tsx @@ -8,7 +8,7 @@ import * as React from "react" import { RAGItemDisplay } from "./RAGItemDisplay" import { RAGItemList } from "./RAGItemList" import { embeddingModels, vectorDatabases } from "./utils" -import { useRagItem } from "./useRAGItem" +import { useRagItem } from "./useRagItem" interface Props { items: (typeof vectorDatabases[number] | typeof embeddingModels[number])[]; diff --git a/apps/webapp/src/components/Configuration/RAGSettings/RAGSettingsPage.tsx b/apps/webapp/src/components/Configuration/RAGSettings/RAGSettingsPage.tsx index 22b96e7ef..d71e53aae 100644 --- a/apps/webapp/src/components/Configuration/RAGSettings/RAGSettingsPage.tsx +++ b/apps/webapp/src/components/Configuration/RAGSettings/RAGSettingsPage.tsx @@ -1,5 +1,5 @@ import * as React from "react" -import { RAGLayout } from "./RagLayout"; +import { RAGLayout } from "./RAGLayout"; import { embeddingModels, TabType, vectorDatabases } from "./utils"; export default function RAGSettingsPage() { diff --git a/docker-compose.dev.yml b/docker-compose.dev.yml index 325e8e0f7..6bfe5872b 100644 --- a/docker-compose.dev.yml +++ b/docker-compose.dev.yml @@ -205,11 +205,15 @@ services: PINECONE_INDEX_NAME: ${PINECONE_INDEX_NAME} QDRANT_BASE_URL: ${QDRANT_BASE_URL} QDRANT_API_KEY: ${QDRANT_API_KEY} + QDRANT_COLLECTION_NAME: ${QDRANT_COLLECTION_NAME} CHROMADB_URL: ${CHROMADB_URL} + CHROMADB_COLLECTION_NAME: ${CHROMADB_COLLECTION_NAME} WEAVIATE_URL: ${WEAVIATE_URL} WEAVIATE_API_KEY: ${WEAVIATE_API_KEY} + WEAVIATE_CLASS_NAME: ${WEAVIATE_CLASS_NAME} TURBOPUFFER_API_KEY: ${TURBOPUFFER_API_KEY} MILVUS_ADDRESS: ${MILVUS_ADDRESS} + MILVUS_COLLECTION_NAME: ${MILVUS_COLLECTION_NAME} restart: unless-stopped ports: diff --git a/docker-compose.source.yml b/docker-compose.source.yml index 6ebe20104..4de1e850b 100644 --- a/docker-compose.source.yml +++ b/docker-compose.source.yml @@ -205,11 +205,15 @@ services: PINECONE_INDEX_NAME: ${PINECONE_INDEX_NAME} QDRANT_BASE_URL: ${QDRANT_BASE_URL} QDRANT_API_KEY: ${QDRANT_API_KEY} + QDRANT_COLLECTION_NAME: ${QDRANT_COLLECTION_NAME} CHROMADB_URL: ${CHROMADB_URL} + CHROMADB_COLLECTION_NAME: ${CHROMADB_COLLECTION_NAME} WEAVIATE_URL: ${WEAVIATE_URL} WEAVIATE_API_KEY: ${WEAVIATE_API_KEY} + WEAVIATE_CLASS_NAME: ${WEAVIATE_CLASS_NAME} TURBOPUFFER_API_KEY: ${TURBOPUFFER_API_KEY} MILVUS_ADDRESS: ${MILVUS_ADDRESS} + MILVUS_COLLECTION_NAME: ${MILVUS_COLLECTION_NAME} restart: unless-stopped ports: diff --git a/docker-compose.yml b/docker-compose.yml index 65c57ff5f..e91cf788e 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -199,11 +199,15 @@ services: PINECONE_INDEX_NAME: ${PINECONE_INDEX_NAME} QDRANT_BASE_URL: ${QDRANT_BASE_URL} QDRANT_API_KEY: ${QDRANT_API_KEY} + QDRANT_COLLECTION_NAME: ${QDRANT_COLLECTION_NAME} CHROMADB_URL: ${CHROMADB_URL} + CHROMADB_COLLECTION_NAME: ${CHROMADB_COLLECTION_NAME} WEAVIATE_URL: ${WEAVIATE_URL} WEAVIATE_API_KEY: ${WEAVIATE_API_KEY} + WEAVIATE_CLASS_NAME: ${WEAVIATE_CLASS_NAME} TURBOPUFFER_API_KEY: ${TURBOPUFFER_API_KEY} MILVUS_ADDRESS: ${MILVUS_ADDRESS} + MILVUS_COLLECTION_NAME: ${MILVUS_COLLECTION_NAME} restart: unless-stopped ports: diff --git a/packages/api/src/@core/@core-services/environment/environment.service.ts b/packages/api/src/@core/@core-services/environment/environment.service.ts index fb0b19662..9e52c75eb 100644 --- a/packages/api/src/@core/@core-services/environment/environment.service.ts +++ b/packages/api/src/@core/@core-services/environment/environment.service.ts @@ -65,13 +65,19 @@ export class EnvironmentService { }; } - getChromaCreds(): string { - return this.configService.get('CHROMADB_URL'); + getChromaCreds() { + return { + url: this.configService.get('CHROMADB_URL'), + collectionName: this.configService.get( + 'CHROMADB_COLLECTION_NAME', + ), + }; } getMilvusCreds() { return { address: this.configService.get('MILVUS_ADDRESS'), + collectionName: this.configService.get('MILVUS_COLLECTION_NAME'), }; } getPineconeCreds() { @@ -85,6 +91,7 @@ export class EnvironmentService { return { url: this.configService.get('WEAVIATE_URL'), apiKey: this.configService.get('WEAVIATE_API_KEY'), + className: this.configService.get('WEAVIATE_CLASS_NAME'), }; } @@ -96,6 +103,7 @@ export class EnvironmentService { return { baseUrl: this.configService.get('QDRANT_BASE_URL'), apiKey: this.configService.get('QDRANT_API_KEY'), + collectionName: this.configService.get('QDRANT_COLLECTION_NAME'), }; } diff --git a/packages/api/src/@core/rag/document.processor.ts b/packages/api/src/@core/rag/document.processor.ts index e7e465b9b..ef5191c2c 100644 --- a/packages/api/src/@core/rag/document.processor.ts +++ b/packages/api/src/@core/rag/document.processor.ts @@ -19,9 +19,13 @@ export class ProcessDocumentProcessor { @Process('batchDocs') async processDocuments( - job: Job<{ filesInfo: FileInfo[]; projectId: string }>, + job: Job<{ + filesInfo: FileInfo[]; + projectId: string; + linkedUserId: string; + }>, ) { - const { filesInfo, projectId } = job.data; + const { filesInfo, projectId, linkedUserId } = job.data; const results = []; for (const fileInfo of filesInfo) { @@ -40,7 +44,7 @@ export class ProcessDocumentProcessor { // console.log(`chunks for ${fileInfo.id} are ` + JSON.stringify(chunks)); const embeddings = await this.embeddingService.generateEmbeddings( chunks, - projectId + projectId, ); // Split embeddings into smaller batches const batchSize = 100; // Adjust this value as needed @@ -52,6 +56,7 @@ export class ProcessDocumentProcessor { batchChunks, batchEmbeddings, projectId, + linkedUserId, ); } results.push(`Successfully processed document ${fileInfo.id}`); diff --git a/packages/api/src/@core/rag/rag.controller.ts b/packages/api/src/@core/rag/rag.controller.ts index 74ee8c29a..89abcb634 100644 --- a/packages/api/src/@core/rag/rag.controller.ts +++ b/packages/api/src/@core/rag/rag.controller.ts @@ -1,16 +1,8 @@ -import { Controller, Post, Body, UseGuards } from '@nestjs/common'; -import { RagService } from './rag.service'; import { ApiKeyAuthGuard } from '@@core/auth/guards/api-key.guard'; -import { - ApiBody, - ApiOperation, - ApiParam, - ApiQuery, - ApiTags, - ApiHeader, - //ApiKeyAuth, -} from '@nestjs/swagger'; import { ConnectionUtils } from '@@core/connections/@utils'; +import { Body, Controller, Headers, Post, UseGuards } from '@nestjs/common'; +import { ApiHeader } from '@nestjs/swagger'; +import { RagService } from './rag.service'; @Controller('rag') export class RagController { @@ -21,9 +13,26 @@ export class RagController { ) {} @Post('query') + @ApiHeader({ + name: 'x-connection-token', + required: true, + description: 'The connection token', + example: 'b008e199-eda9-4629-bd41-a01b6195864a', + }) @UseGuards(ApiKeyAuthGuard) - async queryEmbeddings(@Body() body: { query: string; topK?: number }) { - return this.documentEmbeddingService.queryEmbeddings(body.query, body.topK); + async queryEmbeddings( + @Body() body: { query: string; topK?: number }, + @Headers('x-connection-token') connection_token: string, + ) { + const { linkedUserId, remoteSource, connectionId, projectId } = + await this.connectionUtils.getConnectionMetadataFromConnectionToken( + connection_token, + ); + return this.documentEmbeddingService.queryEmbeddings( + body.query, + body.topK, + linkedUserId, + ); } /* diff --git a/packages/api/src/@core/rag/rag.service.ts b/packages/api/src/@core/rag/rag.service.ts index 24452ee49..1c6057ef1 100644 --- a/packages/api/src/@core/rag/rag.service.ts +++ b/packages/api/src/@core/rag/rag.service.ts @@ -14,11 +14,12 @@ export class RagService { private s3Service: S3Service, ) {} - async queryEmbeddings(query: string, topK = 5) { + async queryEmbeddings(query: string, topK = 5, linkedUserId: string) { const queryEmbedding = await this.embeddingService.embedQuery(query); const results = await this.vectorDatabaseService.queryEmbeddings( queryEmbedding, topK, + linkedUserId, ); return results.map((match: any) => ({ chunk: match.metadata.text, diff --git a/packages/api/src/@core/rag/vecdb/chromadb/chromadb.service.ts b/packages/api/src/@core/rag/vecdb/chromadb/chromadb.service.ts index 5c088694e..acdb29dcb 100644 --- a/packages/api/src/@core/rag/vecdb/chromadb/chromadb.service.ts +++ b/packages/api/src/@core/rag/vecdb/chromadb/chromadb.service.ts @@ -1,19 +1,25 @@ import { EnvironmentService } from '@@core/@core-services/environment/environment.service'; import { ProcessedChunk } from '@@core/rag/types'; import { Injectable } from '@nestjs/common'; -import { ChromaClient } from 'chromadb'; +import { ChromaClient, Collection } from 'chromadb'; @Injectable() export class ChromaDBService { private client: ChromaClient; + private collection: Collection; - constructor(private envService: EnvironmentService) { - //this.initialize(); + constructor(private envService: EnvironmentService) {} + + async onModuleInit() { + return; } - async initialize() { + async initialize(credentials: string[]) { this.client = new ChromaClient({ - path: this.envService.getChromaCreds(), + path: credentials[0], + }); + this.collection = await this.client.getOrCreateCollection({ + name: credentials[1], }); } @@ -21,32 +27,29 @@ export class ChromaDBService { fileId: string, chunks: ProcessedChunk[], embeddings: number[][], + linkedUserId: string, ) { - const collection = await this.client.createCollection({ name: fileId }); - await collection.add({ + await this.collection.add({ ids: chunks.map((_, i) => `${fileId}_${i}`), embeddings: embeddings, metadatas: chunks.map((chunk) => ({ text: chunk.text, ...chunk.metadata, + user_id: `ns_${linkedUserId}`, })), }); } - async queryEmbeddings(queryEmbedding: number[], topK: number) { - const collections = await this.client.listCollections(); - const results = await Promise.all( - collections.map(async (collection) => { - const collectionInstance = await this.client.getCollection({ - name: collection.name, - }); - const result = await collectionInstance.query({ - queryEmbeddings: [queryEmbedding], - nResults: topK, - }); - return result.metadatas[0]; - }), - ); - return results.flat().slice(0, topK); + async queryEmbeddings( + queryEmbedding: number[], + topK: number, + linkedUserId: string, + ) { + const result = await this.collection.query({ + queryEmbeddings: [queryEmbedding], + nResults: topK, + where: { user_id: `ns_${linkedUserId}` }, + }); + return result.metadatas[0]; } } diff --git a/packages/api/src/@core/rag/vecdb/milvus/milvus.service.ts b/packages/api/src/@core/rag/vecdb/milvus/milvus.service.ts index d35fafd92..fd740a94b 100644 --- a/packages/api/src/@core/rag/vecdb/milvus/milvus.service.ts +++ b/packages/api/src/@core/rag/vecdb/milvus/milvus.service.ts @@ -6,16 +6,19 @@ import { DataType, MilvusClient } from '@zilliz/milvus2-sdk-node'; @Injectable() export class MilvusService { private client: MilvusClient; + private collectionName: string; - constructor(private envService: EnvironmentService) { - //this.initialize(); + constructor(private envService: EnvironmentService) {} + + async onModuleInit() { + return; } - async initialize() { - const milvus_creds = this.envService.getMilvusCreds(); + async initialize(credentials: string[]) { this.client = new MilvusClient({ - address: milvus_creds.address, + address: credentials[0], }); + this.collectionName = credentials[1]; await this.client.connectPromise; } @@ -23,78 +26,91 @@ export class MilvusService { fileId: string, chunks: ProcessedChunk[], embeddings: number[][], + linkedUserId: string, ) { - const collection_name = fileId; - await this.client.createCollection({ - collection_name, - fields: [ - { - name: 'id', - description: 'ID field', - data_type: DataType.VarChar, - is_primary_key: true, - max_length: 100, - }, - { - name: 'text', - description: 'Text field', - data_type: DataType.VarChar, - max_length: 65535, - }, - { - name: 'embedding', - description: 'Vector field', - data_type: DataType.FloatVector, - dim: embeddings[0].length, - }, - ], + const tenant = `ns_${linkedUserId}`; + const hasCollection = await this.client.hasCollection({ + collection_name: this.collectionName, }); + if (!hasCollection) { + await this.client.createCollection({ + collection_name: this.collectionName, + fields: [ + { + name: 'id', + description: 'ID field', + data_type: DataType.VarChar, + is_primary_key: true, + max_length: 100, + }, + { + name: 'tenant', + description: 'Tenant field', + data_type: DataType.VarChar, + max_length: 100, + }, + { + name: 'text', + description: 'Text field', + data_type: DataType.VarChar, + max_length: 65535, + }, + { + name: 'embedding', + description: 'Vector field', + data_type: DataType.FloatVector, + dim: embeddings[0].length, + }, + ], + enable_dynamic_field: true, + }); + + // Create index + await this.client.createIndex({ + collection_name: this.collectionName, + field_name: 'embedding', + index_type: 'HNSW', + params: { efConstruction: 10, M: 4 }, + metric_type: 'L2', + }); + } + const data = chunks.map((chunk, i) => ({ id: `${fileId}_${i}`, + tenant, text: chunk.text, embedding: embeddings[i], })); await this.client.insert({ - collection_name, + collection_name: this.collectionName, data, }); - await this.client.createIndex({ - collection_name, - field_name: 'embedding', - index_type: 'HNSW', - params: { efConstruction: 10, M: 4 }, - metric_type: 'L2', - }); - await this.client.loadCollectionSync({ - collection_name, + collection_name: this.collectionName, }); } - async queryEmbeddings(queryEmbedding: number[], topK: number) { - const collections = await this.client.listCollections(); - const results = await Promise.all( - collections.data.map(async (collection) => { - const res = await this.client.search({ - collection_name: collection.name, - vector: queryEmbedding, - filter: '', - params: { nprobe: 10 }, - limit: topK, - output_fields: ['text'], - }); - return res.results.map((hit) => ({ - id: hit.id, - text: hit.text, - score: hit.score, - })); - }), - ); - return results - .flat() - .sort((a, b) => b.score - a.score) - .slice(0, topK); + async queryEmbeddings( + queryEmbedding: number[], + topK: number, + linkedUserId: string, + ) { + const tenant = `ns_${linkedUserId}`; + const res = await this.client.search({ + collection_name: this.collectionName, + vector: queryEmbedding, + filter: `tenant == "${tenant}"`, + params: { nprobe: 10 }, + limit: topK, + output_fields: ['text'], + }); + + return res.results.map((hit) => ({ + id: hit.id, + text: hit.text, + score: hit.score, + })); } } diff --git a/packages/api/src/@core/rag/vecdb/pinecone/pinecone.service.ts b/packages/api/src/@core/rag/vecdb/pinecone/pinecone.service.ts index 76c85a306..c4e9f96bc 100644 --- a/packages/api/src/@core/rag/vecdb/pinecone/pinecone.service.ts +++ b/packages/api/src/@core/rag/vecdb/pinecone/pinecone.service.ts @@ -18,6 +18,7 @@ export class PineconeService { fileId: string, chunks: ProcessedChunk[], embeddings: number[][], + linkedUserId: string, ) { const index = this.client.Index(this.indexName); const vectors = chunks.map((chunk, i) => ({ @@ -28,7 +29,7 @@ export class PineconeService { ...chunk.metadata, }), })); - await index.upsert(vectors); + await index.namespace(`ns_${linkedUserId}`).upsert(vectors); console.log(`Inserted embeddings on Pinecone for fileId ${fileId}`); } private sanitizeMetadata(metadata: Record): Record { @@ -53,9 +54,13 @@ export class PineconeService { return sanitized; } - async queryEmbeddings(queryEmbedding: number[], topK: number) { + async queryEmbeddings( + queryEmbedding: number[], + topK: number, + linkedUserId: string, + ) { const index = this.client.Index(this.indexName); - const queryResponse = await index.query({ + const queryResponse = await index.namespace(`ns_${linkedUserId}`).query({ vector: queryEmbedding, topK, includeMetadata: true, diff --git a/packages/api/src/@core/rag/vecdb/qdrant/qdrant.service.ts b/packages/api/src/@core/rag/vecdb/qdrant/qdrant.service.ts index a26714db7..34653b4b0 100644 --- a/packages/api/src/@core/rag/vecdb/qdrant/qdrant.service.ts +++ b/packages/api/src/@core/rag/vecdb/qdrant/qdrant.service.ts @@ -1,33 +1,54 @@ import { EnvironmentService } from '@@core/@core-services/environment/environment.service'; import { ProcessedChunk } from '@@core/rag/types'; -import { Injectable } from '@nestjs/common'; +import { Injectable, OnModuleInit } from '@nestjs/common'; import { QdrantClient } from '@qdrant/js-client-rest'; @Injectable() -export class QdrantDBService { +export class QdrantDBService implements OnModuleInit { private client: QdrantClient; + private collectionName: string; - constructor(private envService: EnvironmentService) { - //this.initialize(); + constructor(private envService: EnvironmentService) {} + + async onModuleInit() { + return; } - async initialize() { - const creds = this.envService.getQdrantCreds(); + async initialize(credentials: string[]) { this.client = new QdrantClient({ - url: `https://${creds.baseUrl}.us-east-0-1.aws.cloud.qdrant.io`, - apiKey: creds.apiKey, + url: `https://${credentials[1]}.us-east-0-1.aws.cloud.qdrant.io`, + apiKey: credentials[0], }); + this.collectionName = credentials[2]; + await this.ensureCollectionExists(); + } + + private async ensureCollectionExists() { + try { + await this.client.getCollection(this.collectionName); + } catch (error) { + if (error.status === 404) { + await this.client.createCollection(this.collectionName, { + vectors: { size: 1536, distance: 'Cosine' }, // Adjust size as needed + optimizers_config: { + indexing_threshold: 20000, + }, + replication_factor: 2, + }); + } else { + throw error; + } + } } async storeEmbeddings( fileId: string, chunks: ProcessedChunk[], embeddings: number[][], + linkedUserId: string, ) { - await this.client.createCollection(fileId, { - vectors: { size: embeddings[0].length, distance: 'Cosine' }, - }); - await this.client.upsert(fileId, { + const tenantId = `ns_${linkedUserId}`; + await this.client.upsert(this.collectionName, { wait: true, points: chunks.map((chunk, i) => ({ id: `${fileId}_${i}`, @@ -35,22 +56,66 @@ export class QdrantDBService { payload: { text: chunk.text, ...chunk.metadata, + tenant_id: tenantId, + file_id: fileId, }, })), }); } - async queryEmbeddings(queryEmbedding: number[], topK: number) { - const { collections } = await this.client.getCollections(); - const results = await Promise.all( - collections.map(async (collection) => { - const result = await this.client.search(collection.name, { - vector: queryEmbedding, - limit: topK, - }); - return result.map((item) => item.payload); - }), - ); - return results.flat().slice(0, topK); + async queryEmbeddings( + queryEmbedding: number[], + topK: number, + linkedUserId: string, + ) { + const tenantId = `ns_${linkedUserId}`; + const result = await this.client.search(this.collectionName, { + vector: queryEmbedding, + limit: topK, + filter: { + must: [ + { + key: 'tenant_id', + match: { value: tenantId }, + }, + ], + }, + }); + return result.map((item) => item.payload); + } + + /*async deleteEmbeddings(tenantId: string, fileId: string) { + await this.client.delete(this.collectionName, { + filter: { + must: [ + { + key: 'tenant_id', + match: { value: tenantId }, + }, + { + key: 'file_id', + match: { value: fileId }, + }, + ], + }, + }); } + + async listTenants() { + const result = await this.client.scroll(this.collectionName, { + filter: { + must: [ + { + key: 'tenant_id', + match: { value: '' }, + }, + ], + }, + limit: 100, + }); + const tenants = new Set( + result.points.map((point) => point.payload.tenant_id), + ); + return Array.from(tenants); + }*/ } diff --git a/packages/api/src/@core/rag/vecdb/vecdb.credentials.service.ts b/packages/api/src/@core/rag/vecdb/vecdb.credentials.service.ts index b574f4fb4..023243d4d 100644 --- a/packages/api/src/@core/rag/vecdb/vecdb.credentials.service.ts +++ b/packages/api/src/@core/rag/vecdb/vecdb.credentials.service.ts @@ -48,15 +48,26 @@ export class VectorDbCredentialsService { this.envService.getPineconeCreds().indexName, ]; case 'chromadb': - return [this.envService.getChromaCreds()]; + return [ + this.envService.getChromaCreds().url, + this.envService.getChromaCreds().collectionName, + ]; case 'weaviate': const weaviateCreds = this.envService.getWeaviateCreds(); - return [weaviateCreds.apiKey, weaviateCreds.url]; + return [ + weaviateCreds.apiKey, + weaviateCreds.url, + weaviateCreds.className, + ]; case 'turbopuffer': return [this.envService.getTurboPufferApiKey()]; case 'qdrant': const qdrantCreds = this.envService.getQdrantCreds(); - return [qdrantCreds.apiKey, qdrantCreds.baseUrl]; + return [ + qdrantCreds.apiKey, + qdrantCreds.baseUrl, + qdrantCreds.collectionName, + ]; default: throw new Error(`Unsupported vector database: ${vectorDb}`); } @@ -69,11 +80,11 @@ export class VectorDbCredentialsService { case 'turbopuffer': return ['apiKey']; case 'qdrant': - return ['apiKey', 'baseUrl']; + return ['apiKey', 'baseUrl', 'collectionName']; case 'chromadb': - return ['url']; + return ['url', 'collectionName']; case 'weaviate': - return ['apiKey', 'url']; + return ['apiKey', 'url', 'className']; default: throw new Error(`Unsupported vector database: ${vectorDb}`); } diff --git a/packages/api/src/@core/rag/vecdb/vecdb.service.ts b/packages/api/src/@core/rag/vecdb/vecdb.service.ts index 7ddbb398e..1b03f7760 100644 --- a/packages/api/src/@core/rag/vecdb/vecdb.service.ts +++ b/packages/api/src/@core/rag/vecdb/vecdb.service.ts @@ -94,6 +94,7 @@ export class VectorDatabaseService implements OnModuleInit { chunks: Document>[], embeddings: number[][], projectId: string, + linkedUserId: string, ) { await this.init(projectId); const processedChunks: ProcessedChunk[] = chunks.map((chunk) => ({ @@ -105,10 +106,15 @@ export class VectorDatabaseService implements OnModuleInit { processedChunks, embeddings, projectId, + linkedUserId, ); } - async queryEmbeddings(queryEmbedding: number[], topK: number) { - return this.vectorDb.queryEmbeddings(queryEmbedding, topK); + async queryEmbeddings( + queryEmbedding: number[], + topK: number, + linkedUserId: string, + ) { + return this.vectorDb.queryEmbeddings(queryEmbedding, topK, linkedUserId); } } diff --git a/packages/api/src/@core/rag/vecdb/weaviate/weaviate.service.ts b/packages/api/src/@core/rag/vecdb/weaviate/weaviate.service.ts index ae92bf0e5..e3a8cd15d 100644 --- a/packages/api/src/@core/rag/vecdb/weaviate/weaviate.service.ts +++ b/packages/api/src/@core/rag/vecdb/weaviate/weaviate.service.ts @@ -1,49 +1,111 @@ import { EnvironmentService } from '@@core/@core-services/environment/environment.service'; import { ProcessedChunk } from '@@core/rag/types'; -import { Injectable } from '@nestjs/common'; -import weaviate from 'weaviate-client'; +import { Injectable, OnModuleInit } from '@nestjs/common'; +import weaviate, { WeaviateClient, ApiKey } from 'weaviate-ts-client'; @Injectable() -export class WeaviateService { - private client: any; +export class WeaviateService implements OnModuleInit { + private client: WeaviateClient; + private className: string; - constructor(private envService: EnvironmentService) { - //this.initialize(); + constructor(private envService: EnvironmentService) {} + + async onModuleInit() { + return; } - async initialize() { - const weaviate_creds = this.envService.getWeaviateCreds(); - this.client = weaviate.connectToWeaviateCloud(weaviate_creds.url, { - authCredentials: new weaviate.ApiKey(weaviate_creds.apiKey), + async initialize(credentials: string[]) { + this.client = weaviate.client({ + scheme: 'https', + host: credentials[1], + apiKey: new ApiKey(credentials[0]), }); + this.className = credentials[2]; + await this.ensureClassExists(); + } + + private async ensureClassExists() { + const classObj = { + class: this.className, + vectorizer: 'none', // assuming you're providing your own vectors + multiTenancyConfig: { + enabled: true, + }, + properties: [ + { name: 'text', dataType: ['text'] }, + { name: 'metadata', dataType: ['object'] }, + ], + }; + + try { + await this.client.schema.classCreator().withClass(classObj).do(); + } catch (error) { + // Class might already exist, which is fine + console.log( + `Class ${this.className} might already exist:`, + error.message, + ); + } } async storeEmbeddings( fileId: string, chunks: ProcessedChunk[], embeddings: number[][], + linkedUserId: string, ) { - const className = 'Document'; - for (let i = 0; i < chunks.length; i++) { - await this.client.data - .creator() - .withClassName(className) - .withId(`${fileId}_${i}`) - .withProperties({ text: chunks[i].text, ...chunks[i].metadata }) - .withVector(embeddings[i]) - .do(); + const batchSize = 100; + const tenant = `ns_${linkedUserId}`; + for (let i = 0; i < chunks.length; i += batchSize) { + const batcher = this.client.batch.objectsBatcher(); + const batch = chunks.slice(i, i + batchSize); + + batch.forEach((chunk, index) => { + batcher.withObject({ + class: this.className, + id: `${fileId}_${i + index}`, + properties: { + text: chunk.text, + metadata: chunk.metadata, + }, + vector: embeddings[i + index], + tenant: tenant, + }); + }); + + await batcher.do(); } } - async queryEmbeddings(queryEmbedding: number[], topK: number) { - const className = 'Document'; + async queryEmbeddings( + queryEmbedding: number[], + topK: number, + linkedUserId: string, + ) { + const tenant = `ns_${linkedUserId}`; const result = await this.client.graphql .get() - .withClassName(className) + .withClassName(this.className) + .withTenant(tenant) .withFields('text metadata') .withNearVector({ vector: queryEmbedding }) .withLimit(topK) .do(); - return result.data.Get[className] || []; + + return result.data.Get[this.className] || []; } + + /*async deleteEmbeddings(fileId: string, tenant: string) { + await this.client.batch + .objectsBatcher() + .withTenant(tenant) + .withClassName(this.className) + .withWhere({ + operator: 'Like', + path: ['id'], + valueString: `${fileId}*`, + }) + .withDelete() + .do(); + }*/ } diff --git a/packages/api/swagger/swagger-spec.yaml b/packages/api/swagger/swagger-spec.yaml index ba9083962..c9a2e5847 100644 --- a/packages/api/swagger/swagger-spec.yaml +++ b/packages/api/swagger/swagger-spec.yaml @@ -28,7 +28,13 @@ paths: /rag/query: post: operationId: RagController_queryEmbeddings - parameters: [] + parameters: + - name: x-connection-token + required: true + in: header + description: The connection token + schema: + type: string responses: '201': description: ''