Skip to content

Commit

Permalink
fix: updated methods to match docs (#14)
Browse files Browse the repository at this point in the history
* feat: setup basic dataset versions

* feat: setup concurrent annotation upload

* feat: setup datasets base class

* feat: setup image classification dataset classes

* feat: setup bulk upload utils

* feat: setup bulkupload

* feat: upload from folder method

* feat: upload from csv

* fix: search test cases

* fix: lint issues

* feat: added algorithm to search

* fix: test cases

* feat: setup tests for search

* chore: fix lint issues

* fix: failing tests

* chore: upgrade grpc module

* chore: replace deprecated uuid package

* chore: setup base rag class

* feat: setup llma powered utilities

* feat: setup chat method

* fix: lint issues

* feat: setup rag with tests

* fix: failing tests with url helper

* fix: lint issues

* fix: addressed review comments

* feat: use event emitter for bulk upload progress

* chore: fix lint issues

* refactor: use array notation for polygon type

* chore: remove unused server state management code

* fix: increase client state test timeout

* fix: lint issues

* fix: based on doc examples

* fix: add missing external dependency

* fix: updated model test cases

* test: added more model test case

* feat: added test case for multimodal input

* chore: set default search metric to euclidean

* fix: constructors
  • Loading branch information
DaniAkash authored Jul 2, 2024
1 parent 7354e93 commit 3d46b72
Show file tree
Hide file tree
Showing 7 changed files with 75 additions and 13 deletions.
1 change: 1 addition & 0 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
"@grpc/grpc-js": "^1.10.1",
"@grpc/proto-loader": "^0.7.10",
"async": "^3.2.5",
"axios": "^1.6.8",
"chalk": "^5.3.0",
"clarifai-nodejs-grpc": "^10.3.2",
"csv-parse": "^5.5.5",
Expand Down
2 changes: 1 addition & 1 deletion src/client/dataset.ts
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ export class Dataset extends Lister {
}: {
folderPath: string;
inputType: "image" | "text";
labels: boolean;
labels?: boolean;
batchSize?: number;
uploadProgressEmitter?: InputBulkUpload;
}): Promise<void> {
Expand Down
22 changes: 12 additions & 10 deletions src/client/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,22 +42,23 @@ import {

interface BaseModelConfig {
modelVersion?: { id: string };
modelUserAppId?: {
userId: string;
appId: string;
};
}

interface ModelConfigWithUrl extends BaseModelConfig {
url: ClarifaiUrl;
modelId?: undefined;
authConfig?: Omit<AuthConfig, "userId" | "appId">;
modelUserAppId?: undefined;
}

interface ModelConfigWithModelId extends BaseModelConfig {
url?: undefined;
modelId: string;
authConfig?: AuthConfig;
modelUserAppId?: {
userId: string;
appId: string;
};
}

type ModelConfig = ModelConfigWithUrl | ModelConfigWithModelId;
Expand Down Expand Up @@ -98,6 +99,9 @@ export class Model extends Lister {
if (config.url && config.modelId) {
throw new UserError("You can only specify one of url or model_id.");
}
if (config.url && modelUserAppId) {
throw new UserError("You can only specify one of url or modelUserAppId.");
}
if (!config.url && !config.modelId) {
throw new UserError("You must specify one of url or model_id.");
}
Expand All @@ -108,7 +112,7 @@ export class Model extends Lister {
if (isModelConfigWithUrl(config)) {
const { url } = config;
const [userId, appId] = ClarifaiUrlHelper.splitClarifaiUrl(url);
[, , _destructuredModelId, _destructuredModelVersionId] =
[, , , _destructuredModelId, _destructuredModelVersionId] =
ClarifaiUrlHelper.splitClarifaiUrl(url);
_authConfig = config.authConfig
? {
Expand Down Expand Up @@ -149,11 +153,9 @@ export class Model extends Lister {
this.modelInfo.setModelVersion(grpcModelVersion);
}
this.trainingParams = {};
if (modelUserAppId) {
this.modelUserAppId = new UserAppIDSet()
.setAppId(modelUserAppId.appId)
.setUserId(modelUserAppId.userId);
}
this.modelUserAppId = new UserAppIDSet()
.setAppId(_authConfig.appId)
.setUserId(_authConfig.userId);
}

/**
Expand Down
2 changes: 1 addition & 1 deletion src/constants/search.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
export const DEFAULT_TOP_K = 10;
export const DEFAULT_SEARCH_METRIC = "cosine";
export const DEFAULT_SEARCH_METRIC = "euclidean";
export const DEFAULT_SEARCH_ALGORITHM = "nearest_neighbor";
58 changes: 58 additions & 0 deletions tests/client/model.integration.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -255,5 +255,63 @@ describe(
clipDim,
);
});

it("should predict from a model outside the app", async () => {
const model = new Model({
url: "https://clarifai.com/clarifai/main/models/general-image-recognition",
authConfig: {
pat: CLARIFAI_PAT,
},
});
const imageUrl = "https://samples.clarifai.com/metro-north.jpg";
const modelPrediction = await model.predictByUrl({
url: imageUrl,
inputType: "image",
});
expect(modelPrediction.length).toBeGreaterThan(0);
});

it("should convert image to text", async () => {
const modelUrl =
"https://clarifai.com/salesforce/blip/models/general-english-image-caption-blip";
const imageUrl =
"https://s3.amazonaws.com/samples.clarifai.com/featured-models/image-captioning-statue-of-liberty.jpeg";

const model = new Model({
url: modelUrl,
authConfig: {
pat: CLARIFAI_PAT,
},
});
const modelPrediction = await model.predictByUrl({
url: imageUrl,
inputType: "image",
});
expect(modelPrediction?.[0]?.data?.text?.raw).toBeTruthy();
});

it("should predict multimodal with image and text", async () => {
const prompt = "What time of day is it?";
const imageUrl = "https://samples.clarifai.com/metro-north.jpg";
const modelUrl =
"https://clarifai.com/openai/chat-completion/models/openai-gpt-4-vision";
const inferenceParams = { temperature: 0.2, maxTokens: 100 };
const multiInputs = Input.getMultimodalInput({
inputId: "",
imageUrl,
rawText: prompt,
});
const model = new Model({
url: modelUrl,
authConfig: { pat: CLARIFAI_PAT },
});

const modelPrediction = await model.predict({
inputs: [multiInputs],
inferenceParams,
});

expect(modelPrediction?.[0]?.data?.text?.raw).toBeTruthy();
});
},
);
2 changes: 1 addition & 1 deletion tests/client/rag.integration.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ describe("Rag", async () => {
const messages = [{ role: "human", content: "What is 1 + 1?" }];
const newMessages = await rag.chat({ messages, clientManageState: true });
expect(newMessages.length).toBe(2);
}, 10000);
}, 15000);

// TODO: Server side state management is not supported yet - work in progress
it.skip("should predict & manage state on the server", async () => {
Expand Down

0 comments on commit 3d46b72

Please sign in to comment.