diff --git a/package-lock.json b/package-lock.json index e68248c..57372ef 100644 --- a/package-lock.json +++ b/package-lock.json @@ -11,23 +11,28 @@ "dependencies": { "@grpc/grpc-js": "^1.10.1", "@grpc/proto-loader": "^0.7.10", + "async": "^3.2.5", "chalk": "^5.3.0", - "clarifai-nodejs-grpc": "^10.0.9", + "clarifai-nodejs-grpc": "^10.3.2", "csv-parse": "^5.5.5", "from-protobuf-object": "^1.0.2", "google-protobuf": "^3.21.2", "js-yaml": "^4.1.0", + "lodash": "^4.17.21", "safe-flat": "^2.1.0", - "uuidv4": "^6.2.13", + "uuid": "^9.0.1", "winston": "^3.11.0", "zod": "^3.22.4" }, "devDependencies": { "@parcel/packager-ts": "^2.11.0", "@parcel/transformer-typescript-types": "^2.11.0", + "@types/async": "^3.2.24", "@types/google-protobuf": "^3.15.12", "@types/js-yaml": "^4.0.9", + "@types/lodash": "^4.17.0", "@types/node": "^20.11.16", + "@types/uuid": "^9.0.8", "@typescript-eslint/eslint-plugin": "^6.19.1", "@typescript-eslint/parser": "^6.19.1", "@vitest/coverage-v8": "^1.3.1", @@ -3258,6 +3263,12 @@ "node": ">=10.13.0" } }, + "node_modules/@types/async": { + "version": "3.2.24", + "resolved": "https://registry.npmjs.org/@types/async/-/async-3.2.24.tgz", + "integrity": "sha512-8iHVLHsCCOBKjCF2KwFe0p9Z3rfM9mL+sSP8btyR5vTjJRAqpBYD28/ZLgXPf0pjG1VxOvtCV/BgXkQbpSe8Hw==", + "dev": true + }, "node_modules/@types/estree": { "version": "1.0.5", "resolved": "https://registry.npmjs.org/@types/estree/-/estree-1.0.5.tgz", @@ -3288,6 +3299,12 @@ "integrity": "sha512-5+fP8P8MFNC+AyZCDxrB2pkZFPGzqQWUzpSeuuVLvm8VMcorNYavBqoFcxK8bQz4Qsbn4oUEEem4wDLfcysGHA==", "dev": true }, + "node_modules/@types/lodash": { + "version": "4.17.0", + "resolved": "https://registry.npmjs.org/@types/lodash/-/lodash-4.17.0.tgz", + "integrity": "sha512-t7dhREVv6dbNj0q17X12j7yDG4bD/DHYX7o5/DbDxobP0HnGPgpRz2Ej77aL7TZT3DSw13fqUTj8J4mMnqa7WA==", + "dev": true + }, "node_modules/@types/long": { "version": "4.0.2", "resolved": "https://registry.npmjs.org/@types/long/-/long-4.0.2.tgz", @@ -3313,9 +3330,10 @@ "integrity": "sha512-6WaYesThRMCl19iryMYP7/x2OVgCtbIVflDGFpWnb9irXI3UjYE4AzmYuiUKY1AJstGijoY+MgUszMgRxIYTYw==" }, "node_modules/@types/uuid": { - "version": "8.3.4", - "resolved": "https://registry.npmjs.org/@types/uuid/-/uuid-8.3.4.tgz", - "integrity": "sha512-c/I8ZRb51j+pYGAu5CrFMRxqZ2ke4y2grEBO5AUjgSkSk+qT2Ea+OdWElz/OiMf5MNpn2b17kuVBwZLQJXzihw==" + "version": "9.0.8", + "resolved": "https://registry.npmjs.org/@types/uuid/-/uuid-9.0.8.tgz", + "integrity": "sha512-jg+97EGIcY9AGHJJRaaPVgetKDsrTgbRjQ5Msgjh/DQKEFl0DtyRr/VCOyD1T2R1MNeWPK/u7JoGhlDZnKBAfA==", + "dev": true }, "node_modules/@typescript-eslint/eslint-plugin": { "version": "6.19.1", @@ -3932,9 +3950,9 @@ } }, "node_modules/clarifai-nodejs-grpc": { - "version": "10.0.9", - "resolved": "https://registry.npmjs.org/clarifai-nodejs-grpc/-/clarifai-nodejs-grpc-10.0.9.tgz", - "integrity": "sha512-CufiPJBifKS3pdPcEXzmrNVzsga1hMCDvxcrVQ9uenLLK9NSo2FMTb3ef4QdiLnWw0gJii0AQWaFhAdQY/OmLQ==", + "version": "10.3.2", + "resolved": "https://registry.npmjs.org/clarifai-nodejs-grpc/-/clarifai-nodejs-grpc-10.3.2.tgz", + "integrity": "sha512-uyC/ORz08hcEiW3AZ2VGmOU1saj1eMJE/W2zCu8zxVvbXsATdOMQSAawtgtnL4vTG6CxKItccuZqTpezYuX+dA==", "dependencies": { "@grpc/grpc-js": "^1.4.2", "@grpc/proto-loader": "^0.5.5", @@ -5688,6 +5706,11 @@ "url": "https://github.com/sponsors/antfu" } }, + "node_modules/lodash": { + "version": "4.17.21", + "resolved": "https://registry.npmjs.org/lodash/-/lodash-4.17.21.tgz", + "integrity": "sha512-v2kDEe57lecTulaDIuNTPy3Ry4gLGJ6Z1O3vE1krgXZNrsQ+LFTGHVxVjcXPs17LhbZVGedAJv8XZ1tvj5FvSg==" + }, "node_modules/lodash.camelcase": { "version": "4.3.0", "resolved": "https://registry.npmjs.org/lodash.camelcase/-/lodash.camelcase-4.3.0.tgz", @@ -7254,22 +7277,17 @@ } }, "node_modules/uuid": { - "version": "8.3.2", - "resolved": "https://registry.npmjs.org/uuid/-/uuid-8.3.2.tgz", - "integrity": "sha512-+NYs2QeMWy+GWFOEm9xnn6HCDp0l7QBD7ml8zLUmJ+93Q5NF0NocErnwkTkXVFNiX3/fpC6afS8Dhb/gz7R7eg==", + "version": "9.0.1", + "resolved": "https://registry.npmjs.org/uuid/-/uuid-9.0.1.tgz", + "integrity": "sha512-b+1eJOlsR9K8HJpow9Ok3fiWOWSIcIzXodvv0rQjVoOVNpWMpxf1wZNpt4y9h10odCNrqnYp1OBzRktckBe3sA==", + "funding": [ + "https://github.com/sponsors/broofa", + "https://github.com/sponsors/ctavan" + ], "bin": { "uuid": "dist/bin/uuid" } }, - "node_modules/uuidv4": { - "version": "6.2.13", - "resolved": "https://registry.npmjs.org/uuidv4/-/uuidv4-6.2.13.tgz", - "integrity": "sha512-AXyzMjazYB3ovL3q051VLH06Ixj//Knx7QnUSi1T//Ie3io6CpsPu9nVMOx5MoLWh6xV0B9J0hIaxungxXUbPQ==", - "dependencies": { - "@types/uuid": "8.3.4", - "uuid": "8.3.2" - } - }, "node_modules/v8-to-istanbul": { "version": "9.2.0", "resolved": "https://registry.npmjs.org/v8-to-istanbul/-/v8-to-istanbul-9.2.0.tgz", diff --git a/package.json b/package.json index 3f8f2fb..fc9ab53 100644 --- a/package.json +++ b/package.json @@ -38,9 +38,12 @@ "devDependencies": { "@parcel/packager-ts": "^2.11.0", "@parcel/transformer-typescript-types": "^2.11.0", + "@types/async": "^3.2.24", "@types/google-protobuf": "^3.15.12", "@types/js-yaml": "^4.0.9", + "@types/lodash": "^4.17.0", "@types/node": "^20.11.16", + "@types/uuid": "^9.0.8", "@typescript-eslint/eslint-plugin": "^6.19.1", "@typescript-eslint/parser": "^6.19.1", "@vitest/coverage-v8": "^1.3.1", @@ -60,14 +63,16 @@ "dependencies": { "@grpc/grpc-js": "^1.10.1", "@grpc/proto-loader": "^0.7.10", + "async": "^3.2.5", "chalk": "^5.3.0", - "clarifai-nodejs-grpc": "^10.0.9", + "clarifai-nodejs-grpc": "^10.3.2", "csv-parse": "^5.5.5", "from-protobuf-object": "^1.0.2", "google-protobuf": "^3.21.2", "js-yaml": "^4.1.0", + "lodash": "^4.17.21", "safe-flat": "^2.1.0", - "uuidv4": "^6.2.13", + "uuid": "^9.0.1", "winston": "^3.11.0", "zod": "^3.22.4" } diff --git a/src/client/app.ts b/src/client/app.ts index 57a384b..f9f3b79 100644 --- a/src/client/app.ts +++ b/src/client/app.ts @@ -44,7 +44,7 @@ import * as yaml from "js-yaml"; import { validateWorkflow } from "../workflows/validate"; import { getYamlOutputInfoProto } from "../workflows/utils"; import { Model as ModelConstructor } from "./model"; -import { uuid } from "uuidv4"; +import { v4 as uuid } from "uuid"; import { fromProtobufObject } from "from-protobuf-object"; import { fromPartialProtobufObject } from "../utils/fromPartialProtobufObject"; import { flatten } from "safe-flat"; diff --git a/src/client/dataset.ts b/src/client/dataset.ts new file mode 100644 index 0000000..80adf01 --- /dev/null +++ b/src/client/dataset.ts @@ -0,0 +1,220 @@ +import { + DatasetVersion, + Dataset as GrpcDataset, + Input as GrpcInput, +} from "clarifai-nodejs-grpc/proto/clarifai/api/resources_pb"; +import { UserError } from "../errors"; +import { ClarifaiUrl, ClarifaiUrlHelper } from "../urls/helper"; +import { AuthConfig } from "../utils/types"; +import { Lister } from "./lister"; +import { Input, InputBulkUpload } from "./input"; +import { + DeleteDatasetVersionsRequest, + ListDatasetVersionsRequest, + PostDatasetVersionsRequest, +} from "clarifai-nodejs-grpc/proto/clarifai/api/service_pb"; +import { + JavaScriptValue, + Struct, +} from "google-protobuf/google/protobuf/struct_pb"; +import { promisifyGrpcCall } from "../utils/misc"; +import { StatusCode } from "clarifai-nodejs-grpc/proto/clarifai/api/status/status_code_pb"; + +type DatasetConfig = + | { + authConfig?: AuthConfig; + datasetId: string; + datasetVersionId?: string; + url?: undefined; + } + | { + authConfig?: AuthConfig; + datasetId?: undefined; + datasetVersionId?: undefined; + url: ClarifaiUrl; + }; + +export class Dataset extends Lister { + private info: GrpcDataset = new GrpcDataset(); + private batchSize: number = 128; + private input: Input; + + constructor({ authConfig, datasetId, url, datasetVersionId }: DatasetConfig) { + if (url && datasetId) { + throw new UserError("You can only specify one of url or dataset_id."); + } + if (url) { + const [userId, appId, , _datasetId, _datasetVersionId] = + ClarifaiUrlHelper.splitClarifaiUrl(url); + if (authConfig) authConfig.userId = userId; + if (authConfig) authConfig.appId = appId; + datasetId = _datasetId; + datasetVersionId = _datasetVersionId; + } + + super({ authConfig }); + this.info.setId(datasetId!); + this.info.setVersion(new DatasetVersion().setId(datasetVersionId!)); + this.input = new Input({ authConfig }); + } + + async createVersion({ + id, + description, + metadata = {}, + }: { + id: string; + description: string; + metadata?: Record; + }): Promise { + const request = new PostDatasetVersionsRequest(); + request.setUserAppId(this.userAppId); + request.setDatasetId(this.info.getId()); + const datasetVersion = new DatasetVersion(); + datasetVersion.setId(id); + datasetVersion.setDescription(description); + datasetVersion.setMetadata(Struct.fromJavaScript(metadata)); + request.setDatasetVersionsList([datasetVersion]); + + const postDatasetVersions = promisifyGrpcCall( + this.STUB.client.postDatasetVersions, + this.STUB.client, + ); + + const response = await this.grpcRequest(postDatasetVersions, request); + const responseObject = response.toObject(); + if (responseObject.status?.code !== StatusCode.SUCCESS) { + throw new Error(responseObject.status?.description); + } + console.info("\nDataset Version created\n%s", response.getStatus()); + + return responseObject.datasetVersionsList[0]; + } + + async deleteVersion(versionId: string): Promise { + const request = new DeleteDatasetVersionsRequest(); + request.setUserAppId(this.userAppId); + request.setDatasetId(this.info.getId()); + request.setDatasetVersionIdsList([versionId]); + + const deleteDatasetVersions = promisifyGrpcCall( + this.STUB.client.deleteDatasetVersions, + this.STUB.client, + ); + const response = await this.grpcRequest(deleteDatasetVersions, request); + const responseObject = response.toObject(); + if (responseObject.status?.code !== StatusCode.SUCCESS) { + throw new Error(responseObject.status?.description); + } + console.info("\nDataset Version Deleted\n%s", response.getStatus()); + } + + async *listVersions( + pageNo?: number, + perPage?: number, + ): AsyncGenerator { + const request = new ListDatasetVersionsRequest(); + request.setUserAppId(this.userAppId); + request.setDatasetId(this.info.getId()); + + const listDatasetVersions = promisifyGrpcCall( + this.STUB.client.listDatasetVersions, + this.STUB.client, + ); + + const listDatasetVersionsGenerator = this.listPagesGenerator( + listDatasetVersions, + request, + pageNo, + perPage, + ); + + for await (const versions of listDatasetVersionsGenerator) { + yield versions.toObject().datasetVersionsList; + } + } + + async uploadFromFolder({ + folderPath, + inputType, + labels = false, + batchSize = this.batchSize, + uploadProgressEmitter, + }: { + folderPath: string; + inputType: "image" | "text"; + labels: boolean; + batchSize?: number; + uploadProgressEmitter?: InputBulkUpload; + }): Promise { + if (["image", "text"].indexOf(inputType) === -1) { + throw new UserError("Invalid input type"); + } + let inputProtos: GrpcInput[] = []; + if (inputType === "image") { + inputProtos = Input.getImageInputsFromFolder({ + folderPath: folderPath, + datasetId: this.info.getId(), + labels: labels, + }); + } + if (inputType === "text") { + inputProtos = Input.getTextInputsFromFolder({ + folderPath: folderPath, + datasetId: this.info.getId(), + labels: labels, + }); + } + await this.input.bulkUpload({ + inputs: inputProtos, + batchSize: batchSize, + uploadProgressEmitter, + }); + } + + async uploadFromCSV({ + csvPath, + inputType = "text", + csvType, + labels = true, + batchSize = 128, + uploadProgressEmitter, + }: { + csvPath: string; + inputType?: "image" | "text" | "video" | "audio"; + csvType: "raw" | "url" | "file"; + labels?: boolean; + batchSize?: number; + uploadProgressEmitter?: InputBulkUpload; + }): Promise { + if (!["image", "text", "video", "audio"].includes(inputType)) { + throw new UserError( + "Invalid input type, it should be image, text, audio, or video", + ); + } + if (!["raw", "url", "file"].includes(csvType)) { + throw new UserError( + "Invalid csv type, it should be raw, url, or file_path", + ); + } + if (!csvPath.endsWith(".csv")) { + throw new UserError("csvPath should be a csv file"); + } + if (csvType === "raw" && inputType !== "text") { + throw new UserError("Only text input type is supported for raw csv type"); + } + batchSize = Math.min(128, batchSize); + const inputProtos = await Input.getInputsFromCsv({ + csvPath: csvPath, + inputType: inputType, + csvType: csvType, + datasetId: this.info.getId(), + labels: labels, + }); + await this.input.bulkUpload({ + inputs: inputProtos, + batchSize: batchSize, + uploadProgressEmitter, + }); + } +} diff --git a/src/client/input.ts b/src/client/input.ts index bcc8f73..d16e01d 100644 --- a/src/client/input.ts +++ b/src/client/input.ts @@ -15,7 +15,7 @@ import { Text, Video, } from "clarifai-nodejs-grpc/proto/clarifai/api/resources_pb"; -import { AuthConfig } from "../utils/types"; +import { AuthConfig, Polygon as PolygonType } from "../utils/types"; import { Lister } from "./lister"; import { Buffer } from "buffer"; import fs from "fs"; @@ -27,14 +27,24 @@ import { } from "google-protobuf/google/protobuf/struct_pb"; import { parse } from "csv-parse"; import { finished } from "stream/promises"; -import { uuid } from "uuidv4"; +import { v4 as uuid } from "uuid"; import { + CancelInputsAddJobRequest, + DeleteInputsRequest, + GetInputsAddJobRequest, + ListInputsRequest, PatchInputsRequest, PostAnnotationsRequest, PostInputsRequest, } from "clarifai-nodejs-grpc/proto/clarifai/api/service_pb"; -import { promisifyGrpcCall } from "../utils/misc"; +import { BackoffIterator, promisifyGrpcCall } from "../utils/misc"; import { StatusCode } from "clarifai-nodejs-grpc/proto/clarifai/api/status/status_code_pb"; +import os from "os"; +import chunk from "lodash/chunk"; +import { Status } from "clarifai-nodejs-grpc/proto/clarifai/api/status/status_pb"; +import async from "async"; +import { MAX_RETRIES } from "../constants/dataset"; +import { EventEmitter } from "events"; interface CSVRecord { inputid: string; @@ -44,11 +54,36 @@ interface CSVRecord { geopoints: string; } +interface UploadEvents { + start: ProgressEvent; + progress: ProgressEvent; + error: ErrorEvent; + end: ProgressEvent; +} + +interface ProgressEvent { + current: number; + total: number; +} + +interface ErrorEvent { + error: Error; +} + +type BulkUploadEventEmitter = EventEmitter & { + emit(event: K, payload: T[K]): boolean; + on(event: K, listener: (payload: T[K]) => void): void; +}; + +export type InputBulkUpload = BulkUploadEventEmitter; + /** * Inputs is a class that provides access to Clarifai API endpoints related to Input information. * @noInheritDoc */ export class Input extends Lister { + private numOfWorkers: number = Math.min(os.cpus().length, 10); + /** * Initializes an input object. * @@ -480,6 +515,7 @@ export class Input extends Lister { imageUrl = null, imageBytes = null, datasetId = null, + labels = null, }: { inputId: string; rawText?: string | null; @@ -487,6 +523,7 @@ export class Input extends Lister { imageUrl?: string | null; imageBytes?: Uint8Array | null; datasetId?: string | null; + labels?: string[] | null; }): GrpcInput { if ((imageBytes && imageUrl) || (!imageBytes && !imageUrl)) { throw new Error( @@ -515,6 +552,7 @@ export class Input extends Lister { datasetId, imagePb, textPb, + labels, }); } @@ -717,12 +755,16 @@ export class Input extends Lister { return inputAnnotProto; } - static getMaskProto( - inputId: string, - label: string, - polygons: number[][][], - ): Annotation { - const polygonsSchema = z.array(z.array(z.array(z.number()))); + static getMaskProto({ + inputId, + label, + polygons, + }: { + inputId: string; + label: string; + polygons: PolygonType[]; + }): Annotation { + const polygonsSchema = z.array(z.array(z.tuple([z.number(), z.number()]))); try { polygonsSchema.parse(polygons); } catch { @@ -987,4 +1029,193 @@ export class Input extends Lister { } return retryUpload; } + + bulkUpload({ + inputs, + batchSize: providedBatchSize = 128, + uploadProgressEmitter, + }: { + inputs: GrpcInput[]; + batchSize?: number; + uploadProgressEmitter?: InputBulkUpload; + }): Promise { + const batchSize = Math.min(128, providedBatchSize); + const chunkedInputs = chunk(inputs, batchSize); + + let currentProgress = 0; + const total = chunkedInputs.length; + uploadProgressEmitter?.emit("start", { current: currentProgress, total }); + + return new Promise((resolve, reject) => { + async.mapLimit( + chunkedInputs, + this.numOfWorkers, + (batchInputs, callback) => { + this.uploadBatch({ inputs: batchInputs }) + .then((failedInputs) => { + this.retryUploads({ + failedInputs, + }).finally(() => { + currentProgress++; + uploadProgressEmitter?.emit("progress", { + current: currentProgress, + total, + }); + callback(null, failedInputs); + }); + }) + .catch((err) => { + callback(err); + }); + }, + (err) => { + if (err) { + console.error("Error processing batches", err); + uploadProgressEmitter?.emit("error"); + reject(err); + } + uploadProgressEmitter?.emit("end", { current: total, total }); + console.log("All inputs processed"); + resolve(); + }, + ); + }); + } + + private async uploadBatch({ + inputs, + }: { + inputs: GrpcInput[]; + }): Promise { + const inputJobId = await this.uploadInputs({ inputs, showLog: false }); + await this.waitForInputs({ inputJobId }); + const failedInputs = await this.deleteFailedInputs({ inputs }); + return failedInputs; + } + + private async waitForInputs({ + inputJobId, + }: { + inputJobId: string; + }): Promise { + const backoffIterator = new BackoffIterator({ + count: 10, + }); + let maxRetries = 10; + const startTime = Date.now(); + const thirtyMinutes = 60 * 30 * 1000; + // eslint-disable-next-line no-constant-condition + while (true) { + const getInputsAddJobRequest = new GetInputsAddJobRequest() + .setUserAppId(this.userAppId) + .setId(inputJobId); + + const getInputsAddJob = promisifyGrpcCall( + this.STUB.client.getInputsAddJob, + this.STUB.client, + ); + + const response = await this.grpcRequest( + getInputsAddJob, + getInputsAddJobRequest, + ); + + if (Date.now() - startTime > thirtyMinutes || maxRetries === 0) { + const cancelInputsAddJobRequest = new CancelInputsAddJobRequest() + .setUserAppId(this.userAppId) + .setId(inputJobId); + + const cancelInputsAddJob = promisifyGrpcCall( + this.STUB.client.cancelInputsAddJob, + this.STUB.client, + ); + + // 30 minutes timeout + await this.grpcRequest(cancelInputsAddJob, cancelInputsAddJobRequest); // Cancel Job + return false; + } + + const responseObject = response.toObject(); + + if (responseObject.status?.code !== StatusCode.SUCCESS) { + maxRetries -= 1; + console.warn( + `Get input job failed, status: ${responseObject.status?.description}\n`, + ); + continue; + } + if ( + responseObject.inputsAddJob?.progress?.inProgressCount === 0 && + responseObject.inputsAddJob.progress.pendingCount === 0 + ) { + return true; + } else { + await new Promise((resolve) => { + setTimeout(resolve, backoffIterator.next().value * 300); + }); + } + } + } + + private async deleteFailedInputs({ + inputs, + }: { + inputs: GrpcInput[]; + }): Promise { + const inputIds = inputs.map((input) => input.getId()); + const successStatus = new Status().setCode( + StatusCode.INPUT_DOWNLOAD_SUCCESS, // Status code for successful download + ); + const request = new ListInputsRequest(); + request.setIdsList(inputIds); + request.setPerPage(inputIds.length); + request.setUserAppId(this.userAppId); + request.setStatus(successStatus); + + const listInputs = promisifyGrpcCall( + this.STUB.client.listInputs, + this.STUB.client, + ); + + const response = await this.grpcRequest(listInputs, request); + const responseObject = response.toObject(); + const successInputs = responseObject.inputsList || []; + + const successInputIds = successInputs.map((input) => input.id); + const failedInputs = inputs.filter( + (input) => !successInputIds.includes(input.getId()), + ); + + const deleteInputs = promisifyGrpcCall( + this.STUB.client.deleteInputs, + this.STUB.client, + ); + + const deleteInputsRequest = new DeleteInputsRequest() + .setUserAppId(this.userAppId) + .setIdsList(failedInputs.map((input) => input.getId())); + + // Delete failed inputs + await this.grpcRequest(deleteInputs, deleteInputsRequest); + + return failedInputs; + } + + private async retryUploads({ + failedInputs, + }: { + failedInputs: GrpcInput[]; + }): Promise { + for (let retry = 0; retry < MAX_RETRIES; retry++) { + if (failedInputs.length > 0) { + console.log( + `Retrying upload for ${failedInputs.length} Failed inputs..\n`, + ); + failedInputs = await this.uploadBatch({ inputs: failedInputs }); + } + } + if (failedInputs.length > 0) { + console.log(`Failed to upload ${failedInputs.length} inputs..\n`); + } + } } diff --git a/src/client/search.ts b/src/client/search.ts index c7efa76..d998de0 100644 --- a/src/client/search.ts +++ b/src/client/search.ts @@ -1,4 +1,8 @@ -import { DEFAULT_SEARCH_METRIC, DEFAULT_TOP_K } from "../constants/search"; +import { + DEFAULT_SEARCH_ALGORITHM, + DEFAULT_SEARCH_METRIC, + DEFAULT_TOP_K, +} from "../constants/search"; import { Lister } from "./lister"; import { AuthConfig } from "../utils/types"; import { @@ -38,6 +42,8 @@ import { import { StatusCode } from "clarifai-nodejs-grpc/proto/clarifai/api/status/status_code_pb"; type FilterType = z.infer>; +type SupportedAlgorithm = "nearest_neighbor" | "brute_force"; +type SupportedMetric = "cosine" | "euclidean"; /** * @noInheritDoc @@ -47,15 +53,18 @@ export class Search extends Lister { private metricDistance: "COSINE_DISTANCE" | "EUCLIDEAN_DISTANCE"; private dataProto: Data; private inputProto: GrpcInput; + private algorithm: SupportedAlgorithm; constructor({ topK = DEFAULT_TOP_K, metric = DEFAULT_SEARCH_METRIC, authConfig, + algorithm = DEFAULT_SEARCH_ALGORITHM, }: { topK?: number; - metric?: string; + metric?: SupportedMetric; authConfig?: AuthConfig; + algorithm?: SupportedAlgorithm; }) { super({ pageSize: 1000, authConfig }); @@ -63,7 +72,14 @@ export class Search extends Lister { throw new UserError("Metric should be either cosine or euclidean"); } + if (algorithm !== "nearest_neighbor" && algorithm !== "brute_force") { + throw new UserError( + "Algorithm should be either nearest_neighbor or brute_force", + ); + } + this.topK = topK; + this.algorithm = algorithm; this.metricDistance = ( { cosine: "COSINE_DISTANCE", @@ -138,7 +154,9 @@ export class Search extends Lister { this.dataProto.setGeo(geoPointProto); } } else { - throw new UserError(`kwargs contain key that is not supported: ${key}`); + throw new UserError( + `arguments contain key that is not supported: ${key}`, + ); } } const annotation = new Annotation(); @@ -201,27 +219,30 @@ export class Search extends Lister { private async *listAllPagesGenerator< T extends PostInputsSearchesRequest | PostAnnotationsSearchesRequest, - >( + >({ + endpoint, + requestData, + page = 1, + perPage, + }: { endpoint: ( request: T, metadata: grpc.Metadata, options: Partial, - ) => Promise, - requestData: T, - ): AsyncGenerator< - MultiSearchResponse.AsObject & Record<"hits", unknown>, - void, - void - > { + ) => Promise; + requestData: T; + page?: number; + perPage?: number; + }): AsyncGenerator { const maxPages = Math.ceil(this.topK / this.defaultPageSize); let totalHits = 0; - let page = 1; - while (page <= maxPages) { - let perPage; - if (page === maxPages) { - perPage = this.topK - totalHits; - } else { - perPage = this.defaultPageSize; + while (page) { + if (!perPage) { + if (page === maxPages) { + perPage = this.topK - totalHits; + } else { + perPage = this.defaultPageSize; + } } const pagination = new Pagination(); @@ -247,7 +268,11 @@ export class Search extends Lister { } } - if (!("hits" in responseObject)) { + if ( + !("hitsList" in responseObject) || + responseObject.hitsList.length === 0 + ) { + yield responseObject; break; } page += 1; @@ -259,14 +284,14 @@ export class Search extends Lister { query({ ranks = [{}], filters = [{}], + page, + perPage, }: { ranks?: FilterType; filters?: FilterType; - }): AsyncGenerator< - MultiSearchResponse.AsObject & Record<"hits", unknown>, - void, - void - > { + page?: number; + perPage?: number; + }): AsyncGenerator { try { getSchema().parse(ranks); getSchema().parse(filters); @@ -286,7 +311,7 @@ export class Search extends Lister { if ( filters.length && - Object.prototype.hasOwnProperty.call(filters[0], "input") + Object.keys(filters[0]).some((k) => k.includes("input")) ) { const filtersInputProto: GrpcInput[] = []; for (const filterDict of filters) { @@ -304,6 +329,7 @@ export class Search extends Lister { const search = new GrpcSearch(); search.setQuery(query); + search.setAlgorithm(this.algorithm); search.setMetric(GrpcSearch["Metric"][this.metricDistance]); const postInputsSearches = promisifyGrpcCall( @@ -314,7 +340,12 @@ export class Search extends Lister { request.setUserAppId(this.userAppId); request.setSearchesList([search]); - return this.listAllPagesGenerator(postInputsSearches, request); + return this.listAllPagesGenerator({ + endpoint: postInputsSearches, + requestData: request, + page, + perPage, + }); } const filtersAnnotProto: Annotation[] = []; @@ -333,6 +364,7 @@ export class Search extends Lister { const search = new GrpcSearch(); search.setQuery(query); + search.setAlgorithm(this.algorithm); search.setMetric(GrpcSearch["Metric"][this.metricDistance]); const postAnnotationsSearches = promisifyGrpcCall( @@ -343,6 +375,11 @@ export class Search extends Lister { request.setUserAppId(this.userAppId); request.setSearchesList([search]); - return this.listAllPagesGenerator(postAnnotationsSearches, request); + return this.listAllPagesGenerator({ + endpoint: postAnnotationsSearches, + requestData: request, + page, + perPage, + }); } } diff --git a/src/constants/dataset.ts b/src/constants/dataset.ts new file mode 100644 index 0000000..ebd57fb --- /dev/null +++ b/src/constants/dataset.ts @@ -0,0 +1,27 @@ +export const DATASET_UPLOAD_TASKS = [ + "visual_classification", + "text_classification", + "visual_detection", + "visual_segmentation", + "visual_captioning", +]; + +export const TASK_TO_ANNOTATION_TYPE = { + visual_classification: { + concepts: "labels", + }, + text_classification: { + concepts: "labels", + }, + visual_captioning: { + concepts: "labels", + }, + visual_detection: { + bboxes: "bboxes", + }, + visual_segmentation: { + polygons: "polygons", + }, +}; + +export const MAX_RETRIES = 2; diff --git a/src/constants/search.ts b/src/constants/search.ts index f96471e..0f5d68d 100644 --- a/src/constants/search.ts +++ b/src/constants/search.ts @@ -1,2 +1,3 @@ export const DEFAULT_TOP_K = 10; export const DEFAULT_SEARCH_METRIC = "cosine"; +export const DEFAULT_SEARCH_ALGORITHM = "nearest_neighbor"; diff --git a/src/datasets/upload/base.ts b/src/datasets/upload/base.ts new file mode 100644 index 0000000..30bc8b4 --- /dev/null +++ b/src/datasets/upload/base.ts @@ -0,0 +1,54 @@ +import { + Annotation, + Input, +} from "clarifai-nodejs-grpc/proto/clarifai/api/resources_pb"; +import { + TextFeatures, + VisualClassificationFeatures, + VisualDetectionFeatures, + VisualSegmentationFeatures, +} from "./features"; + +type OutputFeaturesType = + | TextFeatures + | VisualClassificationFeatures + | VisualDetectionFeatures + | VisualSegmentationFeatures; + +export abstract class ClarifaiDataLoader { + abstract get task(): string; + abstract loadData(): void; + abstract get length(): number; + abstract getItem(index: number): OutputFeaturesType; +} + +export abstract class ClarifaiDataset { + protected dataGenerator: T; + protected datasetId: string; + protected allInputIds: Record = {}; + + constructor({ + dataGenerator, + datasetId, + }: { + dataGenerator: T; + datasetId: string; + }) { + this.dataGenerator = dataGenerator; + this.datasetId = datasetId; + } + + get length(): number { + return this.dataGenerator.length; + } + + // TODO: Plan for implementation + // protected abstract toList(inputProtos: Iterable): unknown[]; + + protected abstract extractProtos(_args: { + batchInputIds: string[]; + }): [Input[], Annotation[]]; + + // TODO: Plan for implementation + // abstract getProtos(_inputIds: number[]): [Input[], Annotation[]]; +} diff --git a/src/datasets/upload/features.ts b/src/datasets/upload/features.ts new file mode 100644 index 0000000..7c5099d --- /dev/null +++ b/src/datasets/upload/features.ts @@ -0,0 +1,42 @@ +import { JavaScriptValue } from "google-protobuf/google/protobuf/struct_pb"; +import { Polygon } from "../../utils/types"; + +export interface TextFeatures { + imagePath?: undefined; + geoInfo?: undefined; + imageBytes?: undefined; + text: string; + labels: Array; + id?: number; + metadata?: Record; + bboxes?: undefined; +} + +export interface VisualClassificationFeatures { + imagePath: string; + labels: Array; + geoInfo?: [number, number]; + id?: number; + metadata?: Record; + imageBytes?: Buffer; +} + +export interface VisualDetectionFeatures { + imagePath: string; + labels: Array; + bboxes: Array>; + geoInfo?: [number, number]; + id?: number; + metadata?: Record; + imageBytes?: Buffer; +} + +export interface VisualSegmentationFeatures { + imagePath: string; + labels: Array; + polygons: Polygon[]; + geoInfo?: [number, number]; + id?: number; + metadata?: Record; + imageBytes?: Buffer; +} diff --git a/src/datasets/upload/image.ts b/src/datasets/upload/image.ts new file mode 100644 index 0000000..a85006d --- /dev/null +++ b/src/datasets/upload/image.ts @@ -0,0 +1,241 @@ +import { + Input as GrpcInput, + Annotation, +} from "clarifai-nodejs-grpc/proto/clarifai/api/resources_pb"; +import { ClarifaiDataLoader, ClarifaiDataset } from "./base"; +import path from "path"; +import { v4 as uuid } from "uuid"; +import { Input } from "../../client/input"; +import { + VisualClassificationFeatures, + VisualDetectionFeatures, + VisualSegmentationFeatures, +} from "./features"; +import { JavaScriptValue } from "google-protobuf/google/protobuf/struct_pb"; + +export class VisualClassificationDataset extends ClarifaiDataset { + constructor(args: { dataGenerator: ClarifaiDataLoader; datasetId: string }) { + super(args); + } + + protected extractProtos({ + batchInputIds, + }: { + batchInputIds: string[]; + }): [GrpcInput[], Annotation[]] { + const inputProtos: GrpcInput[] = []; + const annotationProtos: Annotation[] = []; + + const processDataItem = (id: string) => { + const dataItem = this.dataGenerator.getItem( + Number(id), + ) as VisualClassificationFeatures; + let metadata: Record = {}; + const imagePath = dataItem.imagePath; + const labels = Array.isArray(dataItem.labels) + ? dataItem.labels.map((label) => label.toString()) + : [(dataItem.labels as string).toString()]; // clarifai concept expects labels to be an array + const inputId = `${this.datasetId}-${String(dataItem.id)}`; + const geoInfo = dataItem.geoInfo; + + if (dataItem.metadata) { + metadata = dataItem.metadata; + } else if (imagePath) { + metadata = { + filename: path.basename(imagePath), + }; + } + + this.allInputIds[id] = inputId; + + if (dataItem.imageBytes) { + inputProtos.push( + Input.getInputFromBytes({ + inputId, + imageBytes: dataItem.imageBytes, + datasetId: this.datasetId, + labels, + geoInfo: { + latitude: geoInfo?.[0] as number, + longitude: geoInfo?.[1] as number, + }, + metadata, + }), + ); + } else { + inputProtos.push( + Input.getInputFromFile({ + inputId, + imageFile: imagePath as string, + datasetId: this.datasetId, + labels, + geoInfo: { + latitude: geoInfo?.[0] as number, + longitude: geoInfo?.[1] as number, + }, + metadata, + }), + ); + } + }; + + batchInputIds.forEach((id) => processDataItem(id)); + + return [inputProtos, annotationProtos]; + } +} + +export class VisualDetectionDataset extends ClarifaiDataset { + constructor(args: { dataGenerator: ClarifaiDataLoader; datasetId: string }) { + super(args); + } + + protected extractProtos({ + batchInputIds, + }: { + batchInputIds: string[]; + }): [GrpcInput[], Annotation[]] { + const inputProtos: GrpcInput[] = []; + const annotationProtos: Annotation[] = []; + + const processDataItem = (id: string) => { + const dataItem = this.dataGenerator.getItem( + Number(id), + ) as VisualDetectionFeatures; + let metadata: Record = {}; + const image = dataItem.imagePath; + const labels = dataItem.labels.map((label) => label.toString()); + const bboxes = dataItem.bboxes; + const inputId = `${this.datasetId}-${dataItem.id ? dataItem.id.toString() : uuid().replace(/-/g, "").slice(0, 8)}`; + if (dataItem.metadata) { + metadata = dataItem.metadata; + } else if (image) { + metadata = { + filename: path.basename(image), + }; + } + const geoInfo = dataItem.geoInfo; + + this.allInputIds[id] = inputId; + + if (dataItem.imageBytes) { + inputProtos.push( + Input.getInputFromBytes({ + inputId, + imageBytes: dataItem.imageBytes, + datasetId: this.datasetId, + labels, + geoInfo: { + latitude: geoInfo?.[0] as number, + longitude: geoInfo?.[1] as number, + }, + metadata, + }), + ); + } else { + inputProtos.push( + Input.getInputFromFile({ + inputId, + imageFile: image, + datasetId: this.datasetId, + labels, + geoInfo: { + latitude: geoInfo?.[0] as number, + longitude: geoInfo?.[1] as number, + }, + metadata, + }), + ); + } + + for (let i = 0; i < bboxes.length; i++) { + annotationProtos.push( + Input.getBboxProto({ + inputId, + label: labels[i], + bbox: bboxes[i], + }), + ); + } + }; + + batchInputIds.forEach((id) => processDataItem(id)); + + return [inputProtos, annotationProtos]; + } +} + +export class VisualSegmentationDataset extends ClarifaiDataset { + constructor(args: { dataGenerator: ClarifaiDataLoader; datasetId: string }) { + super(args); + } + + protected extractProtos({ + batchInputIds, + }: { + batchInputIds: string[]; + }): [GrpcInput[], Annotation[]] { + const inputProtos: GrpcInput[] = []; + const annotationProtos: Annotation[] = []; + + const processDataItem = (id: string) => { + const dataItem = this.dataGenerator.getItem( + Number(id), + ) as VisualSegmentationFeatures; + let metadata: Record = {}; + const image = dataItem.imagePath; + const labels = dataItem.labels.map((label) => label.toString()); + const polygons = dataItem.polygons; + const inputId = `${this.datasetId}-${dataItem.id ? dataItem.id.toString() : uuid().replace(/-/g, "").slice(0, 8)}`; + if (dataItem.metadata) { + metadata = dataItem.metadata; + } else if (image) { + metadata = { + filename: path.basename(image), + }; + } + const geoInfo = dataItem.geoInfo; + this.allInputIds[id] = inputId; + if (dataItem.imageBytes) { + inputProtos.push( + Input.getInputFromBytes({ + inputId, + imageBytes: dataItem.imageBytes, + datasetId: this.datasetId, + geoInfo: { + latitude: geoInfo?.[0] as number, + longitude: geoInfo?.[1] as number, + }, + metadata, + }), + ); + } else { + inputProtos.push( + Input.getInputFromFile({ + inputId, + imageFile: image, + datasetId: this.datasetId, + geoInfo: { + latitude: geoInfo?.[0] as number, + longitude: geoInfo?.[1] as number, + }, + metadata, + }), + ); + } + for (let i = 0; i < polygons.length; i++) { + annotationProtos.push( + Input.getMaskProto({ + inputId, + label: labels[i], + polygons: [polygons[i]], + }), + ); + } + }; + + batchInputIds.forEach((id) => processDataItem(id)); + + return [inputProtos, annotationProtos]; + } +} diff --git a/src/index.ts b/src/index.ts index 537b368..d42ce7b 100644 --- a/src/index.ts +++ b/src/index.ts @@ -4,3 +4,4 @@ export * from "./client/input"; export * from "./client/model"; export * from "./client/search"; export * from "./client/workflow"; +export * from "./client/dataset"; diff --git a/src/schema/search.ts b/src/schema/search.ts index c17e530..2ac0684 100644 --- a/src/schema/search.ts +++ b/src/schema/search.ts @@ -46,37 +46,44 @@ export function getSchema(): z.ZodSchema< }> > { // Schema for a single concept - const conceptSchema = z.object({ - value: z.number().min(0).max(1).optional(), - id: z.string().min(1).optional(), - language: z.string().min(1).optional(), - name: z - .string() - .min(1) - .regex(/^[0-9A-Za-z]+([-_][0-9A-Za-z]+)*$/) // Non-empty string with dashes/underscores - .optional(), - }); + const conceptSchema = z + .object({ + value: z.number().min(0).max(1).optional(), + id: z.string().min(1).optional(), + language: z.string().min(1).optional(), + name: z + .string() + .min(1) + .regex(/^[0-9A-Za-z]+([-_][0-9A-Za-z]+)*$/) // Non-empty string with dashes/underscores + .optional(), + }) + .strict(); // Schema for a rank or filter item - const rankFilterItemSchema = z.object({ - imageUrl: z.string().url().optional(), - textRaw: z.string().min(1).optional(), - metadata: z.record(z.unknown()).optional(), - imageBytes: z.unknown().optional(), - geoPoint: z - .object({ - longitude: z.number(), - latitude: z.number(), - geoLimit: z.number().int(), - }) - .optional(), - concepts: z.array(conceptSchema).min(1).optional(), + const rankFilterItemSchema = z + .object({ + imageUrl: z.string().url().optional(), + textRaw: z.string().min(1).optional(), + metadata: z.record(z.unknown()).optional(), + imageBytes: z.unknown().optional(), + geoPoint: z + .object({ + longitude: z.number(), + latitude: z.number(), + geoLimit: z.number().int(), + }) + .strict() + .optional(), + concepts: z.array(conceptSchema).min(1).optional(), - // input filters - inputTypes: z.array(z.enum(["image", "video", "text", "audio"])).optional(), - inputDatasetIds: z.array(z.string()).optional(), - inputStatusCode: z.number().optional(), - }); + // input filters + inputTypes: z + .array(z.enum(["image", "video", "text", "audio"])) + .optional(), + inputDatasetIds: z.array(z.string()).optional(), + inputStatusCode: z.number().optional(), + }) + .strict(); // Schema for rank and filter args return z.array(rankFilterItemSchema); diff --git a/src/utils/misc.ts b/src/utils/misc.ts index 32c9431..ce83181 100644 --- a/src/utils/misc.ts +++ b/src/utils/misc.ts @@ -71,8 +71,8 @@ export function mergeObjects(obj1: AuthConfig, obj2: AuthConfig): AuthConfig { export class BackoffIterator { private count: number; - constructor() { - this.count = 0; + constructor({ count } = { count: 0 }) { + this.count = count; } [Symbol.iterator]() { diff --git a/src/utils/types.ts b/src/utils/types.ts index 310a7cd..6023b39 100644 --- a/src/utils/types.ts +++ b/src/utils/types.ts @@ -21,3 +21,6 @@ export type GrpcWithCallback = ( export type PaginationRequestParams> = | Omit, "userAppId" | "pageNo" | "perPage"> | Record; + +export type Point = [number, number]; +export type Polygon = Point[]; diff --git a/tests/client/app.integration.test.ts b/tests/client/app.integration.test.ts index 09362c5..cf7525c 100644 --- a/tests/client/app.integration.test.ts +++ b/tests/client/app.integration.test.ts @@ -1,7 +1,7 @@ import { describe, expect, it } from "vitest"; import { App, User } from "../../src/index"; -const NOW = Date.now().toString(); +const NOW = Date.now().toString() + "-app"; const MAIN_APP_ID = "main"; const MAIN_APP_USER_ID = "clarifai"; const GENERAL_MODEL_ID = "general-image-recognition"; diff --git a/tests/client/search.integration.test.ts b/tests/client/search.integration.test.ts new file mode 100644 index 0000000..5d65274 --- /dev/null +++ b/tests/client/search.integration.test.ts @@ -0,0 +1,407 @@ +import path from "path"; +import { getSchema } from "../../src/schema/search"; +import { z } from "zod"; +import { afterAll, beforeAll, describe, expect, it, vi } from "vitest"; +import { App, Dataset, Input, Search, User } from "../../src/index"; +import { Hit } from "clarifai-nodejs-grpc/proto/clarifai/api/resources_pb"; +import EventEmitter from "events"; + +const NOW = Date.now().toString() + "-search"; +const CREATE_APP_USER_ID = import.meta.env.VITE_CLARIFAI_USER_ID; +const CREATE_APP_ID = `ci_test_app_${NOW}`; +const CREATE_DATASET_ID = `ci_test_dataset_${NOW}`; +const DOG_IMG_URL = "https://samples.clarifai.com/dog.tiff"; +const DATASET_IMAGES_DIR = path.resolve(__dirname, "../assets/voc/images"); + +function getFiltersForTest(): [ + z.infer>, + number, +][] { + return [ + [ + [ + { + geoPoint: { + longitude: -29.0, + latitude: 40.0, + geoLimit: 100, + }, + }, + ], + 1, + ], + [ + [ + { + concepts: [ + { + name: "dog", + value: 1, + }, + ], + }, + ], + 1, + ], + [ + [ + { + concepts: [ + { + name: "deer", + value: 1, + }, + { + name: "dog", + value: 1, + }, + ], + }, + ], + 1, + ], + [ + [ + { + concepts: [ + { + name: "dog", + value: 1, + }, + ], + }, + { + concepts: [ + { + name: "deer", + value: 1, + }, + ], + }, + ], + 0, + ], + [ + [ + { + metadata: { + Breed: "Saint Bernard", + }, + }, + ], + 1, + ], + [ + [ + { + inputTypes: ["image"], + }, + { + inputStatusCode: 30000, + }, + ], + 1, + ], + [ + [ + { + inputTypes: ["text", "audio", "video"], + }, + ], + 0, + ], + [ + [ + { + inputTypes: ["text", "audio", "video"], + inputStatusCode: 30000, + }, + ], + 1, + ], + [ + [ + { + inputDatasetIds: ["random_dataset"], + }, + ], + 0, + ], + ]; +} + +describe("Search", () => { + const client = new User({ + userId: CREATE_APP_USER_ID, + appId: CREATE_APP_ID, + pat: import.meta.env.VITE_CLARIFAI_PAT, + }); + const search = new Search({ + authConfig: { + userId: CREATE_APP_USER_ID, + appId: CREATE_APP_ID, + pat: import.meta.env.VITE_CLARIFAI_PAT, + }, + topK: 1, + metric: "euclidean", + }); + // Initialize search without topK value for pagination with custom pages & page sizes + const searchWithPagination = new Search({ + authConfig: { + userId: CREATE_APP_USER_ID, + appId: CREATE_APP_ID, + pat: import.meta.env.VITE_CLARIFAI_PAT, + }, + metric: "euclidean", + }); + let app: App; + + beforeAll(async () => { + const appObj = await client.createApp({ + appId: CREATE_APP_ID, + baseWorkflow: "General", + }); + app = new App({ + authConfig: { + userId: CREATE_APP_USER_ID, + appId: appObj.id, + pat: import.meta.env.VITE_CLARIFAI_PAT, + }, + }); + const datasetObj = await app.createDataset({ + datasetId: CREATE_DATASET_ID, + }); + const metadata = { + Breed: "Saint Bernard", + }; + const inputProto = Input.getInputFromUrl({ + imageUrl: DOG_IMG_URL, + metadata, + datasetId: datasetObj.id, + inputId: "dog-tiff", + labels: ["dog"], + geoInfo: { longitude: -30.0, latitude: 40.0 }, + }); + const input = new Input({ + authConfig: { + userId: CREATE_APP_USER_ID, + appId: appObj.id, + pat: import.meta.env.VITE_CLARIFAI_PAT, + }, + }); + await input.uploadInputs({ inputs: [inputProto] }); + const dataset = new Dataset({ + authConfig: { + userId: CREATE_APP_USER_ID, + appId: appObj.id, + pat: import.meta.env.VITE_CLARIFAI_PAT, + }, + datasetId: datasetObj.id, + }); + const eventEmitter = new EventEmitter(); + const eventHandler = { + start: (...args: unknown[]) => console.log("start", args), + progress: (...args: unknown[]) => console.log("progress", args), + end: (...args: unknown[]) => console.log("end", args), + error: (...args: unknown[]) => console.log("error", args), + }; + const startSpy = vi.spyOn(eventHandler, "start"); + const progressSpy = vi.spyOn(eventHandler, "progress"); + const endSpy = vi.spyOn(eventHandler, "end"); + const errorSpy = vi.spyOn(eventHandler, "error"); + eventEmitter.on("start", (start) => { + eventHandler.start(start); + }); + eventEmitter.on("progress", (progress) => { + eventHandler.progress(progress); + }); + eventEmitter.on("end", (progress) => { + eventHandler.end(progress); + }); + eventEmitter.on("error", (error) => { + eventHandler.error(error); + }); + await dataset.uploadFromFolder({ + folderPath: DATASET_IMAGES_DIR, + inputType: "image", + labels: false, + uploadProgressEmitter: eventEmitter, + }); + expect(startSpy).toHaveBeenNthCalledWith( + 1, + expect.objectContaining({ current: 0, total: 1 }), + ); + expect(progressSpy).toHaveBeenNthCalledWith( + 1, + expect.objectContaining({ current: 1, total: 1 }), + ); + expect(endSpy).toHaveBeenNthCalledWith( + 1, + expect.objectContaining({ current: 1, total: 1 }), + ); + expect(errorSpy).not.toHaveBeenCalled(); + }, 50000); + + it("should get expected hits for filters", async () => { + const filtersWithHits = getFiltersForTest(); + for (const [filters, expectedHits] of filtersWithHits) { + const searchResponseGenerator = search.query({ + filters, + }); + const result = (await searchResponseGenerator.next())?.value ?? null; + expect(result).not.toBeNull(); + if (result) { + expect(result.hitsList.length).toBe(expectedHits); + } + } + }, 10000); + + it("should get expected hits for ranks", async () => { + const searchResponseGenerator = search.query({ + ranks: [ + { + imageUrl: DOG_IMG_URL, + }, + ], + }); + const result = (await searchResponseGenerator.next())?.value ?? null; + expect(result).not.toBeNull(); + if (result) { + expect(result.hitsList.length).toBe(1); + expect(result.hitsList[0].input?.id).toBe("dog-tiff"); + } + }); + + it("should get expected hits for filters and ranks", async () => { + const searchResponseGenerator = search.query({ + ranks: [ + { + imageUrl: DOG_IMG_URL, + }, + ], + filters: [ + { + inputTypes: ["image"], + }, + ], + }); + const result = (await searchResponseGenerator.next())?.value ?? null; + expect(result).not.toBeNull(); + if (result) { + expect(result.hitsList.length).toBe(1); + expect(result.hitsList[0].input?.id).toBe("dog-tiff"); + } + }); + + it("should get expected hits with corresponding pagination info", async () => { + const searchResponseGenerator = searchWithPagination.query({ + filters: [ + { + inputTypes: ["image"], + }, + ], + page: 1, + perPage: 3, + }); + const result = (await searchResponseGenerator.next())?.value ?? null; + expect(result).not.toBeNull(); + if (result) { + expect(result.hitsList.length).toBe(3); + } + }); + + it("should paginate through search results", async () => { + const searchResponseGenerator = searchWithPagination.query({ + filters: [{ inputTypes: ["image"] }], + }); + const hitsList: Hit.AsObject[] = []; + for await (const searchResponse of searchResponseGenerator) { + hitsList.push(...searchResponse.hitsList); + } + expect(hitsList.length).toBe(11); + }); + + it("should throw appropriate error for invalid arguments", async () => { + const invalidGeoPointFilters = () => { + return search.query({ + filters: [ + { + geoPoint: { + longitude: -29.0, + latitude: 40.0, + geoLimit: 10, + // @ts-expect-error - Invalid key + extra: 1, + }, + }, + ], + }); + }; + expect(invalidGeoPointFilters).toThrowError(); + const invalidConceptKeys = () => { + return search.query({ + filters: [ + { + concepts: [ + { + value: 1, + // @ts-expect-error - Missing required key + conceptId: "deer", + }, + { + name: "dog", + value: 1, + }, + ], + }, + ], + }); + }; + expect(invalidConceptKeys).toThrowError(); + const invalidConceptValues = () => { + return search.query({ + filters: [ + { + concepts: [ + { + name: "deer", + value: 2, + }, + { + name: "dog", + value: 1, + }, + ], + }, + ], + }); + }; + expect(invalidConceptValues).toThrowError(); + const incorrectInputTypes = () => { + return search.query({ + filters: [ + { + // @ts-expect-error - Invalid input type + inputTypes: ["imaage"], + }, + ], + }); + }; + expect(incorrectInputTypes).toThrowError(); + const invalidSearchFilterKey = () => { + return search.query({ + filters: [ + { + // @ts-expect-error - Invalid key + inputId: "test", + }, + ], + }); + }; + expect(invalidSearchFilterKey).toThrowError(); + }); + + afterAll(async () => { + await client.deleteApp({ appId: CREATE_APP_ID }); + }); +}); diff --git a/tests/client/workflow/workflowCrud.integration.test.ts b/tests/client/workflow/workflowCrud.integration.test.ts index a67dea3..930a7ac 100644 --- a/tests/client/workflow/workflowCrud.integration.test.ts +++ b/tests/client/workflow/workflowCrud.integration.test.ts @@ -4,10 +4,10 @@ import { Workflow as GrpcWorkflow } from "clarifai-nodejs-grpc/proto/clarifai/ap import path from "path"; import * as fs from "fs"; -const NOW = Date.now().toString(); +const NOW = Date.now().toString() + "-workflow"; const CREATE_APP_USER_ID = import.meta.env.VITE_CLARIFAI_USER_ID; const CLARIFAI_PAT = import.meta.env.VITE_CLARIFAI_PAT; -const CREATE_APP_ID = `test_workflow_create_delete_app_${NOW}`; +const CREATE_APP_ID = `test_create_delete_app_${NOW}`; const MAIN_APP_ID = "main"; const workflowFile = path.resolve(__dirname, "./fixtures/general.yml"); diff --git a/vitest.config.ts b/vitest.config.mjs similarity index 69% rename from vitest.config.ts rename to vitest.config.mjs index 95d9434..a642972 100644 --- a/vitest.config.ts +++ b/vitest.config.mjs @@ -1,9 +1,11 @@ import { defineConfig } from "vitest/config"; +/** @type {import('vitest/config').defineConfig} */ export default defineConfig({ test: { coverage: { reporter: ["text", "json", "html", "clover", "json-summary"], + include: ["src/**/*"], }, }, });