Skip to content

Commit

Permalink
Merge pull request #410 from fujaba/feat/update-openai
Browse files Browse the repository at this point in the history
OpenAI Model Selection
  • Loading branch information
Clashsoft authored Feb 17, 2024
2 parents 6a97622 + c6a17ca commit 5cfe28f
Show file tree
Hide file tree
Showing 12 changed files with 133 additions and 46 deletions.
1 change: 1 addition & 0 deletions frontend/src/app/assignment/model/assignment.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ export interface ClassroomInfo {
mossLanguage?: string;
mossResult?: string;
openaiApiKey?: string;
openaiModel?: string;
openaiConsent?: boolean;
openaiIgnore?: string;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,29 @@
<a href="https://platform.openai.com/account/api-keys" target="_blank">OpenAI API Keys</a>.
</div>
</div>

<div class="mb-3">
<label class="form-label bi-cpu" for="openaiApiKeyInput">
OpenAI Embedding Model
</label>
@for (model of embeddingModels; track model.id) {
<div class="form-check">
<input class="form-check-input" type="radio" name="openaiModel" id="openaiModel-{{model.id}}" [value]="model.id" [ngModel]="classroom.openaiModel">
<label class="form-check-label" for="openaiModel-{{model.id}}">
{{ model.id }}
<span class="badge bg-{{ model.labelBg }}">
{{ model.label }}
</span>
</label>
</div>
}
<div class="form-text">
The embedding model for use with the OpenAI API.
Learn more in the
<a href="https://platform.openai.com/docs/guides/embeddings/embedding-models" target="_blank">OpenAI Embedding Models Documentation</a>
and check out the
<a href="https://openai.com/pricing#embedding-models" target="_blank">OpenAI Pricing</a>.
</div>
</div>
<div class="mb-3">
<label class="form-label bi-ui-checks" for="openaiConsentCheck">
OpenAI Consent
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,18 @@ import {ClassroomInfo} from "../../../model/assignment";
export class CodeSearchComponent {
classroom: ClassroomInfo;

// TODO use a shared constant when frontend and backend are merged
embeddingModels = [
{id: 'text-embedding-3-small', label: 'Cheapest', labelBg: 'success'},
{id: 'text-embedding-3-large', label: 'Most accurate', labelBg: 'primary'},
{id: 'text-embedding-ada-002', label: 'Legacy', labelBg: 'secondary'},
] as const;

constructor(
readonly context: AssignmentContext,
) {
this.classroom = this.context.assignment.classroom ||= {};
this.classroom.openaiConsent ??= true;
this.classroom.openaiModel ??= 'text-embedding-ada-002';
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@
</p>
}
@if (costs) {
<div class="row">
<div class="row mb-3">
<app-statistic-value class="col" label="Solutions" [value]="costs.solutions" [standalone]="true"></app-statistic-value>
<app-statistic-value class="col" label="Files" [value]="costs.files" [standalone]="true"></app-statistic-value>
<app-statistic-value class="col" label="Functions" [value]="costs.functions.length" [standalone]="true"></app-statistic-value>
<app-statistic-value class="col" label="Tokens" [value]="costs.tokens" [standalone]="true"></app-statistic-value>
<app-statistic-value class="col" [label]="costsAreFinal ? 'Total Cost' : 'Estimated Cost'" [value]="costs.estimatedCost | currency:'USD':true:'0.7'" [standalone]="true"></app-statistic-value>
<app-statistic-value class="col" [label]="costsAreFinal ? 'Total Cost' : 'Estimated Cost'" [value]="costs.estimatedCost | currency:'USD'" [standalone]="true"></app-statistic-value>
</div>
<div class="mb-3">
<label for="functions">Imported Functions</label>
Expand All @@ -43,3 +43,9 @@
</div>
}
}
@if (!costsAreFinal && costs && costs.functions.length > rateLimit) {
<div class="alert alert-warning">
The number of functions exceeds the OpenAI rate limit.
The import may take up to {{ ceil(costs.functions.length / rateLimit) }} minutes.
</div>
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ export class ImportEmbeddingsComponent implements OnInit {
costs?: EstimatedCosts;
costsAreFinal = false;

// TODO use a shared constant when frontend and backend are merged
readonly rateLimit = 3000;
readonly ceil = Math.ceil;

constructor(
private embeddingService: EmbeddingService,
private route: ActivatedRoute,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ export class ImportModalComponent {
this.importing = true;
component.import().subscribe({
next: results => {
this.importing = false;
if (typeof results === 'string') {
this.toastService.success('Import', 'Successfully ran MOSS');
} else if (results && typeof results === 'object' && 'length' in results) {
Expand All @@ -30,8 +31,10 @@ export class ImportModalComponent {
this.toastService.success('Import', 'Successfully imported embeddings');
}
},
error: error => this.toastService.error('Import', 'Failed to import solutions', error),
complete: () => this.importing = false,
error: error => {
this.importing = false;
this.toastService.error('Import', 'Failed to import solutions', error);
},
});
}
}
7 changes: 7 additions & 0 deletions services/apps/assignments/src/assignment/assignment.schema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import {
import {Types} from 'mongoose';
import {MOSS_LANGUAGES} from "../search/search.constants";
import {Doc} from "@mean-stream/nestx";
import {EmbeddingModel, EMBEDDING_MODELS} from "../embedding/openai.service";

@Schema({id: false, _id: false})
export class Task {
Expand Down Expand Up @@ -117,6 +118,12 @@ export class ClassroomInfo {
@Transform(({value}) => value === '***' ? undefined : value)
openaiApiKey?: string;

@Prop({type: String})
@ApiPropertyOptional()
@IsOptional()
@IsIn(Object.keys(EMBEDDING_MODELS))
openaiModel?: EmbeddingModel;

@Prop()
@ApiPropertyOptional()
@IsOptional()
Expand Down
23 changes: 11 additions & 12 deletions services/apps/assignments/src/embedding/embedding.handler.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import {Injectable} from "@nestjs/common";
import {OnEvent} from "@nestjs/event-emitter";
import {AssignmentDocument, Task} from "../assignment/assignment.schema";
import {Assignment, AssignmentDocument, Task} from "../assignment/assignment.schema";
import {EmbeddingService} from "./embedding.service";
import {SolutionDocument} from "../solution/solution.schema";
import {DEFAULT_MODEL} from "./openai.service";

@Injectable()
export class EmbeddingHandler {
Expand All @@ -14,39 +15,37 @@ export class EmbeddingHandler {
@OnEvent('assignments.*.created')
@OnEvent('assignments.*.updated')
async onAssignment(assignment: AssignmentDocument) {
const apiKey = assignment.classroom?.openaiApiKey;
if (!apiKey) {
if (!assignment.classroom?.openaiApiKey) {
return;
}

const taskIds = new Set<string>();
const assignmentId = assignment._id.toString();
this.upsertTasks(apiKey, assignmentId, assignment.tasks, '', taskIds);
await this.embeddingService.deleteTasksNotIn(assignmentId, [...taskIds]);
this.upsertTasks(assignment, assignment.tasks, '', taskIds);
await this.embeddingService.deleteTasksNotIn(assignment._id.toString(), [...taskIds]);
}

@OnEvent('assignments.*.deleted')
async onAssignmentDeleted(assignment: AssignmentDocument) {
await this.embeddingService.deleteAll(assignment._id.toString());
}

private upsertTasks(apiKey: string, assignment: string, tasks: Task[], prefix: string, taskIds: Set<string>) {
private upsertTasks(assignment: Assignment, tasks: Task[], prefix: string, taskIds: Set<string>) {
for (const task of tasks) {
taskIds.add(task._id);
this.upsertTask(apiKey, assignment, task, prefix);
this.upsertTasks(apiKey, assignment, task.children, `${prefix}${task.description} > `, taskIds);
this.upsertTask(assignment, task, prefix);
this.upsertTasks(assignment, task.children, `${prefix}${task.description} > `, taskIds);
}
}

private upsertTask(apiKey: string, assignment: string, task: Task, prefix: string) {
private upsertTask(assignment: Assignment, task: Task, prefix: string) {
return this.embeddingService.upsert({
id: task._id,
assignment,
assignment: assignment._id.toString(),
type: 'task',
task: task._id,
text: prefix + task.description,
embedding: [],
}, apiKey);
}, assignment.classroom!.openaiApiKey!, assignment.classroom!.openaiModel ?? DEFAULT_MODEL);
}

@OnEvent('assignments.*.solutions.*.deleted')
Expand Down
26 changes: 19 additions & 7 deletions services/apps/assignments/src/embedding/embedding.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import {ForbiddenException, Injectable, OnModuleInit} from '@nestjs/common';
import {ElasticsearchService} from "@nestjs/elasticsearch";
import {FileDocument, SearchService} from "../search/search.service";
import {Embeddable, EmbeddableSearch, EmbeddingEstimate, SnippetEmbeddable} from "./embedding.dto";
import {OpenAIService} from "./openai.service";
import {DEFAULT_MODEL, EmbeddingModel, OpenAIService} from "./openai.service";
import {QueryDslQueryContainer} from "@elastic/elasticsearch/lib/api/types";
import {SolutionService} from "../solution/solution.service";
import {Assignment} from "../assignment/assignment.schema";
Expand Down Expand Up @@ -72,6 +72,7 @@ export class EmbeddingService implements OnModuleInit {
if (!apiKey) {
throw new ForbiddenException('No OpenAI API key configured for this assignment.');
}
const model = assignment.classroom?.openaiModel || DEFAULT_MODEL;

const {solutions, documents, ignoreFn, ignoredFiles} = await this.getDocuments(assignment);
const ignoredFunctions = new Set<string>();
Expand All @@ -94,18 +95,29 @@ export class EmbeddingService implements OnModuleInit {
let tokens = 0;
if (estimate) {
for (const func of functions) {
tokens += this.openaiService.countTokens(func.text);
tokens += this.openaiService.countTokens(func.text, model);
}
} else {
tokens = (await Promise.all(functions.map(async func => this.upsert(func, apiKey).then(({tokens}) => tokens))))
.reduce((a, b) => a + b, 0);
for (let i = 0; i < functions.length; i += this.openaiService.rateLimitPerMinute) {
const start = Date.now();
const batch = await Promise.all(functions
.slice(i, i + this.openaiService.rateLimitPerMinute)
.map(async func => this.upsert(func, apiKey, model).then(({tokens}) => tokens))
);
tokens += batch.reduce((a, b) => a + b, 0);
const elapsed = Date.now() - start;
if (elapsed < 60100) {
// wait for the minute to pass to avoid rate limiting
await new Promise(resolve => setTimeout(resolve, 60100 - elapsed));
}
}
}

return {
solutions,
files: documents.length,
tokens,
estimatedCost: this.openaiService.estimateCost(tokens),
estimatedCost: this.openaiService.estimateCost(tokens, model),
functions: functions.map(f => `${f.file}#${f.name}`),
ignoredFiles: Array.from(ignoredFiles),
ignoredFunctions: Array.from(ignoredFunctions),
Expand Down Expand Up @@ -167,12 +179,12 @@ export class EmbeddingService implements OnModuleInit {
return results;
}

async upsert(embeddable: Embeddable, apiKey: string): Promise<{ embeddable: Embeddable, tokens: number }> {
async upsert(embeddable: Embeddable, apiKey: string, model: EmbeddingModel): Promise<{ embeddable: Embeddable, tokens: number }> {
const existing = await this.find(embeddable.id);
if (existing && existing.text === embeddable.text) {
return {embeddable: existing, tokens: 0};
}
const {embedding, tokens} = await this.openaiService.getEmbedding(embeddable.text, apiKey);
const {embedding, tokens} = await this.openaiService.getEmbedding(embeddable.text, apiKey, model);
embeddable.embedding = embedding;
await this.index(embeddable);
return {embeddable, tokens};
Expand Down
47 changes: 36 additions & 11 deletions services/apps/assignments/src/embedding/openai.service.ts
Original file line number Diff line number Diff line change
@@ -1,38 +1,63 @@
import {Injectable, OnModuleDestroy} from "@nestjs/common";
import * as tiktoken from "tiktoken";
import {File} from "@app/moss/moss-api";
import OpenAI from "openai";
import {TEXT_EXTENSIONS} from "../search/search.constants";

const model = 'text-embedding-ada-002';
// https://platform.openai.com/docs/guides/embeddings/embedding-models
// https://openai.com/pricing#embedding-models
export const EMBEDDING_MODELS = {
'text-embedding-ada-002': {
tokenCost: 0.00010 / 1000,
dimensions: undefined,
},
'text-embedding-3-small': {
tokenCost: 0.00002 / 1000,
dimensions: undefined,
},
'text-embedding-3-large': {
tokenCost: 0.00013 / 1000,
dimensions: 1536,
},
} as const;
export type EmbeddingModel = keyof typeof EMBEDDING_MODELS;
export const DEFAULT_MODEL: EmbeddingModel = 'text-embedding-ada-002';

@Injectable()
export class OpenAIService implements OnModuleDestroy {
enc = tiktoken.encoding_for_model(model);
readonly rateLimitPerMinute = 3000;

private encoders: Record<EmbeddingModel, tiktoken.Tiktoken> = {
// https://platform.openai.com/docs/guides/embeddings/how-can-i-tell-how-many-tokens-a-string-has-before-i-embed-it
'text-embedding-ada-002': tiktoken.encoding_for_model('text-embedding-ada-002'),
'text-embedding-3-small': tiktoken.get_encoding('cl100k_base'),
'text-embedding-3-large': tiktoken.get_encoding('cl100k_base'),
};

onModuleDestroy(): any {
this.enc.free();
for (const enc of Object.values(this.encoders)) {
enc.free();
}
}

countTokens(text: string): number {
return this.enc.encode(text).length;
countTokens(text: string, model: EmbeddingModel): number {
return this.encoders[model].encode(text).length;
}

isSupportedExtension(filename: string) {
const extension = filename.substring(filename.lastIndexOf('.') + 1);
return TEXT_EXTENSIONS.has(extension);
}

estimateCost(tokens: number): number {
// https://platform.openai.com/docs/guides/embeddings/embedding-models
return tokens * 0.0000004;
estimateCost(tokens: number, model: EmbeddingModel): number {
return tokens * EMBEDDING_MODELS[model].tokenCost;
}

async getEmbedding(text: string, apiKey: string): Promise<{ tokens: number, embedding: number[] }> {
async getEmbedding(text: string, apiKey: string, model: EmbeddingModel): Promise<{ tokens: number, embedding: number[] }> {
const api = new OpenAI({apiKey});
const result = await api.embeddings.create({
model: model,
model,
input: text,
dimensions: EMBEDDING_MODELS[model].dimensions,
});
return {tokens: result.usage.total_tokens, embedding: result.data[0].embedding};
}
Expand Down
4 changes: 2 additions & 2 deletions services/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
"mongoose": "^8.0.3",
"multer": "1.4.5-lts.1",
"nats": "^2.18.0",
"openai": "^4.21.0",
"openai": "^4.27.0",
"openapi-merge": "^1.3.2",
"passport": "^0.7.0",
"passport-jwt": "^4.0.1",
Expand All @@ -62,7 +62,7 @@
"rxjs": "^7.8.1",
"swagger-ui-express": "^5.0.0",
"textextensions": "^6.9.0",
"tiktoken": "^1.0.11",
"tiktoken": "^1.0.13",
"tslib": "^2.6.2",
"unzipper": "^0.10.14"
},
Expand Down
Loading

0 comments on commit 5cfe28f

Please sign in to comment.