This commit is contained in:
duanfuxiang 2025-06-12 13:35:00 +08:00
parent 3ce55899df
commit b20b4f9e19
3 changed files with 296 additions and 53 deletions

View File

@ -30,6 +30,7 @@ export const getEmbeddingModel = (
return { return {
id: settings.embeddingModelId, id: settings.embeddingModelId,
dimension: modelInfo.dimensions, dimension: modelInfo.dimensions,
supportsBatch: true,
getEmbedding: async (text: string) => { getEmbedding: async (text: string) => {
try { try {
if (!openai.apiKey) { if (!openai.apiKey) {
@ -54,6 +55,31 @@ export const getEmbeddingModel = (
throw error throw error
} }
}, },
getBatchEmbeddings: async (texts: string[]) => {
console.log("use getBatchEmbeddings", texts.length)
try {
if (!openai.apiKey) {
throw new LLMAPIKeyNotSetException(
'OpenAI API key is missing. Please set it in settings menu.',
)
}
const embedding = await openai.embeddings.create({
model: settings.embeddingModelId,
input: texts,
})
return embedding.data.map(item => item.embedding)
} catch (error) {
if (
error.status === 429 &&
error.message.toLowerCase().includes('rate limit')
) {
throw new LLMRateLimitExceededException(
'OpenAI API rate limit exceeded. Please try again later.',
)
}
throw error
}
},
} }
} }
case ApiProvider.OpenAI: { case ApiProvider.OpenAI: {
@ -67,6 +93,7 @@ export const getEmbeddingModel = (
return { return {
id: settings.embeddingModelId, id: settings.embeddingModelId,
dimension: modelInfo.dimensions, dimension: modelInfo.dimensions,
supportsBatch: true,
getEmbedding: async (text: string) => { getEmbedding: async (text: string) => {
try { try {
if (!openai.apiKey) { if (!openai.apiKey) {
@ -91,6 +118,30 @@ export const getEmbeddingModel = (
throw error throw error
} }
}, },
getBatchEmbeddings: async (texts: string[]) => {
try {
if (!openai.apiKey) {
throw new LLMAPIKeyNotSetException(
'OpenAI API key is missing. Please set it in settings menu.',
)
}
const embedding = await openai.embeddings.create({
model: settings.embeddingModelId,
input: texts,
})
return embedding.data.map(item => item.embedding)
} catch (error) {
if (
error.status === 429 &&
error.message.toLowerCase().includes('rate limit')
) {
throw new LLMRateLimitExceededException(
'OpenAI API rate limit exceeded. Please try again later.',
)
}
throw error
}
},
} }
} }
case ApiProvider.SiliconFlow: { case ApiProvider.SiliconFlow: {
@ -104,6 +155,7 @@ export const getEmbeddingModel = (
return { return {
id: settings.embeddingModelId, id: settings.embeddingModelId,
dimension: modelInfo.dimensions, dimension: modelInfo.dimensions,
supportsBatch: true,
getEmbedding: async (text: string) => { getEmbedding: async (text: string) => {
try { try {
if (!openai.apiKey) { if (!openai.apiKey) {
@ -128,6 +180,30 @@ export const getEmbeddingModel = (
throw error throw error
} }
}, },
getBatchEmbeddings: async (texts: string[]) => {
try {
if (!openai.apiKey) {
throw new LLMAPIKeyNotSetException(
'SiliconFlow API key is missing. Please set it in settings menu.',
)
}
const embedding = await openai.embeddings.create({
model: settings.embeddingModelId,
input: texts,
})
return embedding.data.map(item => item.embedding)
} catch (error) {
if (
error.status === 429 &&
error.message.toLowerCase().includes('rate limit')
) {
throw new LLMRateLimitExceededException(
'SiliconFlow API rate limit exceeded. Please try again later.',
)
}
throw error
}
},
} }
} }
case ApiProvider.AlibabaQwen: { case ApiProvider.AlibabaQwen: {
@ -141,6 +217,7 @@ export const getEmbeddingModel = (
return { return {
id: settings.embeddingModelId, id: settings.embeddingModelId,
dimension: modelInfo.dimensions, dimension: modelInfo.dimensions,
supportsBatch: false,
getEmbedding: async (text: string) => { getEmbedding: async (text: string) => {
try { try {
if (!openai.apiKey) { if (!openai.apiKey) {
@ -165,6 +242,30 @@ export const getEmbeddingModel = (
throw error throw error
} }
}, },
getBatchEmbeddings: async (texts: string[]) => {
try {
if (!openai.apiKey) {
throw new LLMAPIKeyNotSetException(
'Alibaba Qwen API key is missing. Please set it in settings menu.',
)
}
const embedding = await openai.embeddings.create({
model: settings.embeddingModelId,
input: texts,
})
return embedding.data.map(item => item.embedding)
} catch (error) {
if (
error.status === 429 &&
error.message.toLowerCase().includes('rate limit')
) {
throw new LLMRateLimitExceededException(
'Alibaba Qwen API rate limit exceeded. Please try again later.',
)
}
throw error
}
},
} }
} }
case ApiProvider.Google: { case ApiProvider.Google: {
@ -174,6 +275,7 @@ export const getEmbeddingModel = (
return { return {
id: settings.embeddingModelId, id: settings.embeddingModelId,
dimension: modelInfo.dimensions, dimension: modelInfo.dimensions,
supportsBatch: false,
getEmbedding: async (text: string) => { getEmbedding: async (text: string) => {
try { try {
const response = await model.embedContent(text) const response = await model.embedContent(text)
@ -190,6 +292,27 @@ export const getEmbeddingModel = (
throw error throw error
} }
}, },
getBatchEmbeddings: async (texts: string[]) => {
try {
const embeddings = await Promise.all(
texts.map(async (text) => {
const response = await model.embedContent(text)
return response.embedding.values
})
)
return embeddings
} catch (error) {
if (
error.status === 429 &&
error.message.includes('RATE_LIMIT_EXCEEDED')
) {
throw new LLMRateLimitExceededException(
'Gemini API rate limit exceeded. Please try again later.',
)
}
throw error
}
},
} }
} }
case ApiProvider.Ollama: { case ApiProvider.Ollama: {
@ -201,6 +324,7 @@ export const getEmbeddingModel = (
return { return {
id: settings.embeddingModelId, id: settings.embeddingModelId,
dimension: 0, dimension: 0,
supportsBatch: false,
getEmbedding: async (text: string) => { getEmbedding: async (text: string) => {
if (!settings.ollamaProvider.baseUrl) { if (!settings.ollamaProvider.baseUrl) {
throw new LLMBaseUrlNotSetException( throw new LLMBaseUrlNotSetException(
@ -213,6 +337,18 @@ export const getEmbeddingModel = (
}) })
return embedding.data[0].embedding return embedding.data[0].embedding
}, },
getBatchEmbeddings: async (texts: string[]) => {
if (!settings.ollamaProvider.baseUrl) {
throw new LLMBaseUrlNotSetException(
'Ollama Address is missing. Please set it in settings menu.',
)
}
const embedding = await openai.embeddings.create({
model: settings.embeddingModelId,
input: texts,
})
return embedding.data.map(item => item.embedding)
},
} }
} }
case ApiProvider.OpenAICompatible: { case ApiProvider.OpenAICompatible: {
@ -224,6 +360,7 @@ export const getEmbeddingModel = (
return { return {
id: settings.embeddingModelId, id: settings.embeddingModelId,
dimension: 0, dimension: 0,
supportsBatch: false,
getEmbedding: async (text: string) => { getEmbedding: async (text: string) => {
try { try {
if (!openai.apiKey) { if (!openai.apiKey) {
@ -249,6 +386,31 @@ export const getEmbeddingModel = (
throw error throw error
} }
}, },
getBatchEmbeddings: async (texts: string[]) => {
try {
if (!openai.apiKey) {
throw new LLMAPIKeyNotSetException(
'OpenAI Compatible API key is missing. Please set it in settings menu.',
)
}
const embedding = await openai.embeddings.create({
model: settings.embeddingModelId,
input: texts,
encoding_format: "float",
})
return embedding.data.map(item => item.embedding)
} catch (error) {
if (
error.status === 429 &&
error.message.toLowerCase().includes('rate limit')
) {
throw new LLMRateLimitExceededException(
'OpenAI Compatible API rate limit exceeded. Please try again later.',
)
}
throw error
}
},
} }
} }
default: default:

View File

@ -131,27 +131,34 @@ export class VectorManager {
const embeddingProgress = { completed: 0 } const embeddingProgress = { completed: 0 }
const embeddingChunks: InsertVector[] = [] const embeddingChunks: InsertVector[] = []
const batchSize = 100 const insertBatchSize = 100 // 数据库插入批量大小
const limit = pLimit(50)
const abortController = new AbortController() try {
const tasks = contentChunks.map((chunk) => if (embeddingModel.supportsBatch) {
limit(async () => { // 支持批量处理的提供商:使用批量处理逻辑
if (abortController.signal.aborted) { const embeddingBatchSize = 100 // API批量处理大小
throw new Error('Operation was aborted')
} for (let i = 0; i < contentChunks.length; i += embeddingBatchSize) {
try { const batchChunks = contentChunks.slice(i, Math.min(i + embeddingBatchSize, contentChunks.length))
const batchTexts = batchChunks.map(chunk => chunk.content)
await backOff( await backOff(
async () => { async () => {
const embedding = await embeddingModel.getEmbedding(chunk.content) const batchEmbeddings = await embeddingModel.getBatchEmbeddings(batchTexts)
const embeddedChunk = {
path: chunk.path, // 合并embedding结果到chunk数据
mtime: chunk.mtime, for (let j = 0; j < batchChunks.length; j++) {
content: chunk.content, const embeddedChunk: InsertVector = {
embedding, path: batchChunks[j].path,
metadata: chunk.metadata, mtime: batchChunks[j].mtime,
content: batchChunks[j].content,
embedding: batchEmbeddings[j],
metadata: batchChunks[j].metadata,
}
embeddingChunks.push(embeddedChunk)
} }
embeddingChunks.push(embeddedChunk)
embeddingProgress.completed++ embeddingProgress.completed += batchChunks.length
updateProgress?.({ updateProgress?.({
completedChunks: embeddingProgress.completed, completedChunks: embeddingProgress.completed,
totalChunks: contentChunks.length, totalChunks: contentChunks.length,
@ -165,15 +172,51 @@ export class VectorManager {
jitter: 'full', jitter: 'full',
}, },
) )
} catch (error) {
abortController.abort()
throw error
} }
}), } else {
) // 不支持批量处理的提供商:使用原来的逐个处理逻辑
const limit = pLimit(50)
try { const abortController = new AbortController()
await Promise.all(tasks) const tasks = contentChunks.map((chunk) =>
limit(async () => {
if (abortController.signal.aborted) {
throw new Error('Operation was aborted')
}
try {
await backOff(
async () => {
const embedding = await embeddingModel.getEmbedding(chunk.content)
const embeddedChunk = {
path: chunk.path,
mtime: chunk.mtime,
content: chunk.content,
embedding,
metadata: chunk.metadata,
}
embeddingChunks.push(embeddedChunk)
embeddingProgress.completed++
updateProgress?.({
completedChunks: embeddingProgress.completed,
totalChunks: contentChunks.length,
totalFiles: filesToIndex.length,
})
},
{
numOfAttempts: 5,
startingDelay: 1000,
timeMultiple: 1.5,
jitter: 'full',
},
)
} catch (error) {
abortController.abort()
throw error
}
}),
)
await Promise.all(tasks)
}
// all embedding generated, batch insert // all embedding generated, batch insert
if (embeddingChunks.length > 0) { if (embeddingChunks.length > 0) {
@ -182,7 +225,7 @@ export class VectorManager {
while (inserted < embeddingChunks.length) { while (inserted < embeddingChunks.length) {
const chunksToInsert = embeddingChunks.slice( const chunksToInsert = embeddingChunks.slice(
inserted, inserted,
Math.min(inserted + batchSize, embeddingChunks.length) Math.min(inserted + insertBatchSize, embeddingChunks.length)
) )
await this.repository.insertVectors(chunksToInsert, embeddingModel) await this.repository.insertVectors(chunksToInsert, embeddingModel)
inserted += chunksToInsert.length inserted += chunksToInsert.length
@ -242,25 +285,32 @@ export class VectorManager {
}) })
const embeddingChunks: InsertVector[] = [] const embeddingChunks: InsertVector[] = []
const limit = pLimit(50) const insertBatchSize = 100 // 数据库插入批量大小
const abortController = new AbortController()
const tasks = contentChunks.map((chunk) => try {
limit(async () => { if (embeddingModel.supportsBatch) {
if (abortController.signal.aborted) { // 支持批量处理的提供商:使用批量处理逻辑
throw new Error('Operation was aborted') const embeddingBatchSize = 100 // API批量处理大小
}
try { for (let i = 0; i < contentChunks.length; i += embeddingBatchSize) {
const batchChunks = contentChunks.slice(i, Math.min(i + embeddingBatchSize, contentChunks.length))
const batchTexts = batchChunks.map(chunk => chunk.content)
await backOff( await backOff(
async () => { async () => {
const embedding = await embeddingModel.getEmbedding(chunk.content) const batchEmbeddings = await embeddingModel.getBatchEmbeddings(batchTexts)
const embeddedChunk = {
path: chunk.path, // 合并embedding结果到chunk数据
mtime: chunk.mtime, for (let j = 0; j < batchChunks.length; j++) {
content: chunk.content, const embeddedChunk: InsertVector = {
embedding, path: batchChunks[j].path,
metadata: chunk.metadata, mtime: batchChunks[j].mtime,
content: batchChunks[j].content,
embedding: batchEmbeddings[j],
metadata: batchChunks[j].metadata,
}
embeddingChunks.push(embeddedChunk)
} }
embeddingChunks.push(embeddedChunk)
}, },
{ {
numOfAttempts: 5, numOfAttempts: 5,
@ -269,22 +319,51 @@ export class VectorManager {
jitter: 'full', jitter: 'full',
}, },
) )
} catch (error) {
abortController.abort()
throw error
} }
}), } else {
) // 不支持批量处理的提供商:使用原来的逐个处理逻辑
const limit = pLimit(50)
try { const abortController = new AbortController()
await Promise.all(tasks) const tasks = contentChunks.map((chunk) =>
limit(async () => {
if (abortController.signal.aborted) {
throw new Error('Operation was aborted')
}
try {
await backOff(
async () => {
const embedding = await embeddingModel.getEmbedding(chunk.content)
const embeddedChunk = {
path: chunk.path,
mtime: chunk.mtime,
content: chunk.content,
embedding,
metadata: chunk.metadata,
}
embeddingChunks.push(embeddedChunk)
},
{
numOfAttempts: 5,
startingDelay: 1000,
timeMultiple: 1.5,
jitter: 'full',
},
)
} catch (error) {
abortController.abort()
throw error
}
}),
)
await Promise.all(tasks)
}
// all embedding generated, batch insert // all embedding generated, batch insert
if (embeddingChunks.length > 0) { if (embeddingChunks.length > 0) {
const batchSize = 100
let inserted = 0 let inserted = 0
while (inserted < embeddingChunks.length) { while (inserted < embeddingChunks.length) {
const chunksToInsert = embeddingChunks.slice(inserted, Math.min(inserted + batchSize, embeddingChunks.length)) const chunksToInsert = embeddingChunks.slice(inserted, Math.min(inserted + insertBatchSize, embeddingChunks.length))
await this.repository.insertVectors(chunksToInsert, embeddingModel) await this.repository.insertVectors(chunksToInsert, embeddingModel)
inserted += chunksToInsert.length inserted += chunksToInsert.length
} }

View File

@ -17,5 +17,7 @@ export type EmbeddingModelOption = {
export type EmbeddingModel = { export type EmbeddingModel = {
id: string id: string
dimension: number dimension: number
supportsBatch: boolean
getEmbedding: (text: string) => Promise<number[]> getEmbedding: (text: string) => Promise<number[]>
getBatchEmbeddings: (texts: string[]) => Promise<number[][]>
} }