diff --git a/src/core/rag/embedding.ts b/src/core/rag/embedding.ts index 03d3380..831fc8d 100644 --- a/src/core/rag/embedding.ts +++ b/src/core/rag/embedding.ts @@ -30,6 +30,7 @@ export const getEmbeddingModel = ( return { id: settings.embeddingModelId, dimension: modelInfo.dimensions, + supportsBatch: true, getEmbedding: async (text: string) => { try { if (!openai.apiKey) { @@ -54,6 +55,31 @@ export const getEmbeddingModel = ( throw error } }, + getBatchEmbeddings: async (texts: string[]) => { + console.log("use getBatchEmbeddings", texts.length) + try { + if (!openai.apiKey) { + throw new LLMAPIKeyNotSetException( + 'OpenAI API key is missing. Please set it in settings menu.', + ) + } + const embedding = await openai.embeddings.create({ + model: settings.embeddingModelId, + input: texts, + }) + return embedding.data.map(item => item.embedding) + } catch (error) { + if ( + error.status === 429 && + error.message.toLowerCase().includes('rate limit') + ) { + throw new LLMRateLimitExceededException( + 'OpenAI API rate limit exceeded. Please try again later.', + ) + } + throw error + } + }, } } case ApiProvider.OpenAI: { @@ -67,6 +93,7 @@ export const getEmbeddingModel = ( return { id: settings.embeddingModelId, dimension: modelInfo.dimensions, + supportsBatch: true, getEmbedding: async (text: string) => { try { if (!openai.apiKey) { @@ -91,6 +118,30 @@ export const getEmbeddingModel = ( throw error } }, + getBatchEmbeddings: async (texts: string[]) => { + try { + if (!openai.apiKey) { + throw new LLMAPIKeyNotSetException( + 'OpenAI API key is missing. Please set it in settings menu.', + ) + } + const embedding = await openai.embeddings.create({ + model: settings.embeddingModelId, + input: texts, + }) + return embedding.data.map(item => item.embedding) + } catch (error) { + if ( + error.status === 429 && + error.message.toLowerCase().includes('rate limit') + ) { + throw new LLMRateLimitExceededException( + 'OpenAI API rate limit exceeded. Please try again later.', + ) + } + throw error + } + }, } } case ApiProvider.SiliconFlow: { @@ -104,6 +155,7 @@ export const getEmbeddingModel = ( return { id: settings.embeddingModelId, dimension: modelInfo.dimensions, + supportsBatch: true, getEmbedding: async (text: string) => { try { if (!openai.apiKey) { @@ -128,6 +180,30 @@ export const getEmbeddingModel = ( throw error } }, + getBatchEmbeddings: async (texts: string[]) => { + try { + if (!openai.apiKey) { + throw new LLMAPIKeyNotSetException( + 'SiliconFlow API key is missing. Please set it in settings menu.', + ) + } + const embedding = await openai.embeddings.create({ + model: settings.embeddingModelId, + input: texts, + }) + return embedding.data.map(item => item.embedding) + } catch (error) { + if ( + error.status === 429 && + error.message.toLowerCase().includes('rate limit') + ) { + throw new LLMRateLimitExceededException( + 'SiliconFlow API rate limit exceeded. Please try again later.', + ) + } + throw error + } + }, } } case ApiProvider.AlibabaQwen: { @@ -141,6 +217,7 @@ export const getEmbeddingModel = ( return { id: settings.embeddingModelId, dimension: modelInfo.dimensions, + supportsBatch: false, getEmbedding: async (text: string) => { try { if (!openai.apiKey) { @@ -165,6 +242,30 @@ export const getEmbeddingModel = ( throw error } }, + getBatchEmbeddings: async (texts: string[]) => { + try { + if (!openai.apiKey) { + throw new LLMAPIKeyNotSetException( + 'Alibaba Qwen API key is missing. Please set it in settings menu.', + ) + } + const embedding = await openai.embeddings.create({ + model: settings.embeddingModelId, + input: texts, + }) + return embedding.data.map(item => item.embedding) + } catch (error) { + if ( + error.status === 429 && + error.message.toLowerCase().includes('rate limit') + ) { + throw new LLMRateLimitExceededException( + 'Alibaba Qwen API rate limit exceeded. Please try again later.', + ) + } + throw error + } + }, } } case ApiProvider.Google: { @@ -174,6 +275,7 @@ export const getEmbeddingModel = ( return { id: settings.embeddingModelId, dimension: modelInfo.dimensions, + supportsBatch: false, getEmbedding: async (text: string) => { try { const response = await model.embedContent(text) @@ -190,6 +292,27 @@ export const getEmbeddingModel = ( throw error } }, + getBatchEmbeddings: async (texts: string[]) => { + try { + const embeddings = await Promise.all( + texts.map(async (text) => { + const response = await model.embedContent(text) + return response.embedding.values + }) + ) + return embeddings + } catch (error) { + if ( + error.status === 429 && + error.message.includes('RATE_LIMIT_EXCEEDED') + ) { + throw new LLMRateLimitExceededException( + 'Gemini API rate limit exceeded. Please try again later.', + ) + } + throw error + } + }, } } case ApiProvider.Ollama: { @@ -201,6 +324,7 @@ export const getEmbeddingModel = ( return { id: settings.embeddingModelId, dimension: 0, + supportsBatch: false, getEmbedding: async (text: string) => { if (!settings.ollamaProvider.baseUrl) { throw new LLMBaseUrlNotSetException( @@ -213,6 +337,18 @@ export const getEmbeddingModel = ( }) return embedding.data[0].embedding }, + getBatchEmbeddings: async (texts: string[]) => { + if (!settings.ollamaProvider.baseUrl) { + throw new LLMBaseUrlNotSetException( + 'Ollama Address is missing. Please set it in settings menu.', + ) + } + const embedding = await openai.embeddings.create({ + model: settings.embeddingModelId, + input: texts, + }) + return embedding.data.map(item => item.embedding) + }, } } case ApiProvider.OpenAICompatible: { @@ -224,6 +360,7 @@ export const getEmbeddingModel = ( return { id: settings.embeddingModelId, dimension: 0, + supportsBatch: false, getEmbedding: async (text: string) => { try { if (!openai.apiKey) { @@ -249,6 +386,31 @@ export const getEmbeddingModel = ( throw error } }, + getBatchEmbeddings: async (texts: string[]) => { + try { + if (!openai.apiKey) { + throw new LLMAPIKeyNotSetException( + 'OpenAI Compatible API key is missing. Please set it in settings menu.', + ) + } + const embedding = await openai.embeddings.create({ + model: settings.embeddingModelId, + input: texts, + encoding_format: "float", + }) + return embedding.data.map(item => item.embedding) + } catch (error) { + if ( + error.status === 429 && + error.message.toLowerCase().includes('rate limit') + ) { + throw new LLMRateLimitExceededException( + 'OpenAI Compatible API rate limit exceeded. Please try again later.', + ) + } + throw error + } + }, } } default: diff --git a/src/database/modules/vector/vector-manager.ts b/src/database/modules/vector/vector-manager.ts index 61ab56e..962ba50 100644 --- a/src/database/modules/vector/vector-manager.ts +++ b/src/database/modules/vector/vector-manager.ts @@ -131,27 +131,34 @@ export class VectorManager { const embeddingProgress = { completed: 0 } const embeddingChunks: InsertVector[] = [] - const batchSize = 100 - const limit = pLimit(50) - const abortController = new AbortController() - const tasks = contentChunks.map((chunk) => - limit(async () => { - if (abortController.signal.aborted) { - throw new Error('Operation was aborted') - } - try { + const insertBatchSize = 100 // 数据库插入批量大小 + + try { + if (embeddingModel.supportsBatch) { + // 支持批量处理的提供商:使用批量处理逻辑 + const embeddingBatchSize = 100 // API批量处理大小 + + for (let i = 0; i < contentChunks.length; i += embeddingBatchSize) { + const batchChunks = contentChunks.slice(i, Math.min(i + embeddingBatchSize, contentChunks.length)) + const batchTexts = batchChunks.map(chunk => chunk.content) + await backOff( async () => { - const embedding = await embeddingModel.getEmbedding(chunk.content) - const embeddedChunk = { - path: chunk.path, - mtime: chunk.mtime, - content: chunk.content, - embedding, - metadata: chunk.metadata, + const batchEmbeddings = await embeddingModel.getBatchEmbeddings(batchTexts) + + // 合并embedding结果到chunk数据 + for (let j = 0; j < batchChunks.length; j++) { + const embeddedChunk: InsertVector = { + path: batchChunks[j].path, + mtime: batchChunks[j].mtime, + content: batchChunks[j].content, + embedding: batchEmbeddings[j], + metadata: batchChunks[j].metadata, + } + embeddingChunks.push(embeddedChunk) } - embeddingChunks.push(embeddedChunk) - embeddingProgress.completed++ + + embeddingProgress.completed += batchChunks.length updateProgress?.({ completedChunks: embeddingProgress.completed, totalChunks: contentChunks.length, @@ -165,15 +172,51 @@ export class VectorManager { jitter: 'full', }, ) - } catch (error) { - abortController.abort() - throw error } - }), - ) - - try { - await Promise.all(tasks) + } else { + // 不支持批量处理的提供商:使用原来的逐个处理逻辑 + const limit = pLimit(50) + const abortController = new AbortController() + const tasks = contentChunks.map((chunk) => + limit(async () => { + if (abortController.signal.aborted) { + throw new Error('Operation was aborted') + } + try { + await backOff( + async () => { + const embedding = await embeddingModel.getEmbedding(chunk.content) + const embeddedChunk = { + path: chunk.path, + mtime: chunk.mtime, + content: chunk.content, + embedding, + metadata: chunk.metadata, + } + embeddingChunks.push(embeddedChunk) + embeddingProgress.completed++ + updateProgress?.({ + completedChunks: embeddingProgress.completed, + totalChunks: contentChunks.length, + totalFiles: filesToIndex.length, + }) + }, + { + numOfAttempts: 5, + startingDelay: 1000, + timeMultiple: 1.5, + jitter: 'full', + }, + ) + } catch (error) { + abortController.abort() + throw error + } + }), + ) + + await Promise.all(tasks) + } // all embedding generated, batch insert if (embeddingChunks.length > 0) { @@ -182,7 +225,7 @@ export class VectorManager { while (inserted < embeddingChunks.length) { const chunksToInsert = embeddingChunks.slice( inserted, - Math.min(inserted + batchSize, embeddingChunks.length) + Math.min(inserted + insertBatchSize, embeddingChunks.length) ) await this.repository.insertVectors(chunksToInsert, embeddingModel) inserted += chunksToInsert.length @@ -242,25 +285,32 @@ export class VectorManager { }) const embeddingChunks: InsertVector[] = [] - const limit = pLimit(50) - const abortController = new AbortController() - const tasks = contentChunks.map((chunk) => - limit(async () => { - if (abortController.signal.aborted) { - throw new Error('Operation was aborted') - } - try { + const insertBatchSize = 100 // 数据库插入批量大小 + + try { + if (embeddingModel.supportsBatch) { + // 支持批量处理的提供商:使用批量处理逻辑 + const embeddingBatchSize = 100 // API批量处理大小 + + for (let i = 0; i < contentChunks.length; i += embeddingBatchSize) { + const batchChunks = contentChunks.slice(i, Math.min(i + embeddingBatchSize, contentChunks.length)) + const batchTexts = batchChunks.map(chunk => chunk.content) + await backOff( async () => { - const embedding = await embeddingModel.getEmbedding(chunk.content) - const embeddedChunk = { - path: chunk.path, - mtime: chunk.mtime, - content: chunk.content, - embedding, - metadata: chunk.metadata, + const batchEmbeddings = await embeddingModel.getBatchEmbeddings(batchTexts) + + // 合并embedding结果到chunk数据 + for (let j = 0; j < batchChunks.length; j++) { + const embeddedChunk: InsertVector = { + path: batchChunks[j].path, + mtime: batchChunks[j].mtime, + content: batchChunks[j].content, + embedding: batchEmbeddings[j], + metadata: batchChunks[j].metadata, + } + embeddingChunks.push(embeddedChunk) } - embeddingChunks.push(embeddedChunk) }, { numOfAttempts: 5, @@ -269,22 +319,51 @@ export class VectorManager { jitter: 'full', }, ) - } catch (error) { - abortController.abort() - throw error } - }), - ) - - try { - await Promise.all(tasks) + } else { + // 不支持批量处理的提供商:使用原来的逐个处理逻辑 + const limit = pLimit(50) + const abortController = new AbortController() + const tasks = contentChunks.map((chunk) => + limit(async () => { + if (abortController.signal.aborted) { + throw new Error('Operation was aborted') + } + try { + await backOff( + async () => { + const embedding = await embeddingModel.getEmbedding(chunk.content) + const embeddedChunk = { + path: chunk.path, + mtime: chunk.mtime, + content: chunk.content, + embedding, + metadata: chunk.metadata, + } + embeddingChunks.push(embeddedChunk) + }, + { + numOfAttempts: 5, + startingDelay: 1000, + timeMultiple: 1.5, + jitter: 'full', + }, + ) + } catch (error) { + abortController.abort() + throw error + } + }), + ) + + await Promise.all(tasks) + } // all embedding generated, batch insert if (embeddingChunks.length > 0) { - const batchSize = 100 let inserted = 0 while (inserted < embeddingChunks.length) { - const chunksToInsert = embeddingChunks.slice(inserted, Math.min(inserted + batchSize, embeddingChunks.length)) + const chunksToInsert = embeddingChunks.slice(inserted, Math.min(inserted + insertBatchSize, embeddingChunks.length)) await this.repository.insertVectors(chunksToInsert, embeddingModel) inserted += chunksToInsert.length } diff --git a/src/types/embedding.ts b/src/types/embedding.ts index 9a20cf6..cacf66b 100644 --- a/src/types/embedding.ts +++ b/src/types/embedding.ts @@ -17,5 +17,7 @@ export type EmbeddingModelOption = { export type EmbeddingModel = { id: string dimension: number + supportsBatch: boolean getEmbedding: (text: string) => Promise + getBatchEmbeddings: (texts: string[]) => Promise }