This commit is contained in:
duanfuxiang 2025-06-12 13:35:00 +08:00
parent 3ce55899df
commit b20b4f9e19
3 changed files with 296 additions and 53 deletions

View File

@ -30,6 +30,7 @@ export const getEmbeddingModel = (
return {
id: settings.embeddingModelId,
dimension: modelInfo.dimensions,
supportsBatch: true,
getEmbedding: async (text: string) => {
try {
if (!openai.apiKey) {
@ -54,6 +55,31 @@ export const getEmbeddingModel = (
throw error
}
},
getBatchEmbeddings: async (texts: string[]) => {
console.log("use getBatchEmbeddings", texts.length)
try {
if (!openai.apiKey) {
throw new LLMAPIKeyNotSetException(
'OpenAI API key is missing. Please set it in settings menu.',
)
}
const embedding = await openai.embeddings.create({
model: settings.embeddingModelId,
input: texts,
})
return embedding.data.map(item => item.embedding)
} catch (error) {
if (
error.status === 429 &&
error.message.toLowerCase().includes('rate limit')
) {
throw new LLMRateLimitExceededException(
'OpenAI API rate limit exceeded. Please try again later.',
)
}
throw error
}
},
}
}
case ApiProvider.OpenAI: {
@ -67,6 +93,7 @@ export const getEmbeddingModel = (
return {
id: settings.embeddingModelId,
dimension: modelInfo.dimensions,
supportsBatch: true,
getEmbedding: async (text: string) => {
try {
if (!openai.apiKey) {
@ -91,6 +118,30 @@ export const getEmbeddingModel = (
throw error
}
},
getBatchEmbeddings: async (texts: string[]) => {
try {
if (!openai.apiKey) {
throw new LLMAPIKeyNotSetException(
'OpenAI API key is missing. Please set it in settings menu.',
)
}
const embedding = await openai.embeddings.create({
model: settings.embeddingModelId,
input: texts,
})
return embedding.data.map(item => item.embedding)
} catch (error) {
if (
error.status === 429 &&
error.message.toLowerCase().includes('rate limit')
) {
throw new LLMRateLimitExceededException(
'OpenAI API rate limit exceeded. Please try again later.',
)
}
throw error
}
},
}
}
case ApiProvider.SiliconFlow: {
@ -104,6 +155,7 @@ export const getEmbeddingModel = (
return {
id: settings.embeddingModelId,
dimension: modelInfo.dimensions,
supportsBatch: true,
getEmbedding: async (text: string) => {
try {
if (!openai.apiKey) {
@ -128,6 +180,30 @@ export const getEmbeddingModel = (
throw error
}
},
getBatchEmbeddings: async (texts: string[]) => {
try {
if (!openai.apiKey) {
throw new LLMAPIKeyNotSetException(
'SiliconFlow API key is missing. Please set it in settings menu.',
)
}
const embedding = await openai.embeddings.create({
model: settings.embeddingModelId,
input: texts,
})
return embedding.data.map(item => item.embedding)
} catch (error) {
if (
error.status === 429 &&
error.message.toLowerCase().includes('rate limit')
) {
throw new LLMRateLimitExceededException(
'SiliconFlow API rate limit exceeded. Please try again later.',
)
}
throw error
}
},
}
}
case ApiProvider.AlibabaQwen: {
@ -141,6 +217,7 @@ export const getEmbeddingModel = (
return {
id: settings.embeddingModelId,
dimension: modelInfo.dimensions,
supportsBatch: false,
getEmbedding: async (text: string) => {
try {
if (!openai.apiKey) {
@ -165,6 +242,30 @@ export const getEmbeddingModel = (
throw error
}
},
getBatchEmbeddings: async (texts: string[]) => {
try {
if (!openai.apiKey) {
throw new LLMAPIKeyNotSetException(
'Alibaba Qwen API key is missing. Please set it in settings menu.',
)
}
const embedding = await openai.embeddings.create({
model: settings.embeddingModelId,
input: texts,
})
return embedding.data.map(item => item.embedding)
} catch (error) {
if (
error.status === 429 &&
error.message.toLowerCase().includes('rate limit')
) {
throw new LLMRateLimitExceededException(
'Alibaba Qwen API rate limit exceeded. Please try again later.',
)
}
throw error
}
},
}
}
case ApiProvider.Google: {
@ -174,6 +275,7 @@ export const getEmbeddingModel = (
return {
id: settings.embeddingModelId,
dimension: modelInfo.dimensions,
supportsBatch: false,
getEmbedding: async (text: string) => {
try {
const response = await model.embedContent(text)
@ -190,6 +292,27 @@ export const getEmbeddingModel = (
throw error
}
},
getBatchEmbeddings: async (texts: string[]) => {
try {
const embeddings = await Promise.all(
texts.map(async (text) => {
const response = await model.embedContent(text)
return response.embedding.values
})
)
return embeddings
} catch (error) {
if (
error.status === 429 &&
error.message.includes('RATE_LIMIT_EXCEEDED')
) {
throw new LLMRateLimitExceededException(
'Gemini API rate limit exceeded. Please try again later.',
)
}
throw error
}
},
}
}
case ApiProvider.Ollama: {
@ -201,6 +324,7 @@ export const getEmbeddingModel = (
return {
id: settings.embeddingModelId,
dimension: 0,
supportsBatch: false,
getEmbedding: async (text: string) => {
if (!settings.ollamaProvider.baseUrl) {
throw new LLMBaseUrlNotSetException(
@ -213,6 +337,18 @@ export const getEmbeddingModel = (
})
return embedding.data[0].embedding
},
getBatchEmbeddings: async (texts: string[]) => {
if (!settings.ollamaProvider.baseUrl) {
throw new LLMBaseUrlNotSetException(
'Ollama Address is missing. Please set it in settings menu.',
)
}
const embedding = await openai.embeddings.create({
model: settings.embeddingModelId,
input: texts,
})
return embedding.data.map(item => item.embedding)
},
}
}
case ApiProvider.OpenAICompatible: {
@ -224,6 +360,7 @@ export const getEmbeddingModel = (
return {
id: settings.embeddingModelId,
dimension: 0,
supportsBatch: false,
getEmbedding: async (text: string) => {
try {
if (!openai.apiKey) {
@ -249,6 +386,31 @@ export const getEmbeddingModel = (
throw error
}
},
getBatchEmbeddings: async (texts: string[]) => {
try {
if (!openai.apiKey) {
throw new LLMAPIKeyNotSetException(
'OpenAI Compatible API key is missing. Please set it in settings menu.',
)
}
const embedding = await openai.embeddings.create({
model: settings.embeddingModelId,
input: texts,
encoding_format: "float",
})
return embedding.data.map(item => item.embedding)
} catch (error) {
if (
error.status === 429 &&
error.message.toLowerCase().includes('rate limit')
) {
throw new LLMRateLimitExceededException(
'OpenAI Compatible API rate limit exceeded. Please try again later.',
)
}
throw error
}
},
}
}
default:

View File

@ -131,27 +131,34 @@ export class VectorManager {
const embeddingProgress = { completed: 0 }
const embeddingChunks: InsertVector[] = []
const batchSize = 100
const limit = pLimit(50)
const abortController = new AbortController()
const tasks = contentChunks.map((chunk) =>
limit(async () => {
if (abortController.signal.aborted) {
throw new Error('Operation was aborted')
}
try {
const insertBatchSize = 100 // 数据库插入批量大小
try {
if (embeddingModel.supportsBatch) {
// 支持批量处理的提供商:使用批量处理逻辑
const embeddingBatchSize = 100 // API批量处理大小
for (let i = 0; i < contentChunks.length; i += embeddingBatchSize) {
const batchChunks = contentChunks.slice(i, Math.min(i + embeddingBatchSize, contentChunks.length))
const batchTexts = batchChunks.map(chunk => chunk.content)
await backOff(
async () => {
const embedding = await embeddingModel.getEmbedding(chunk.content)
const embeddedChunk = {
path: chunk.path,
mtime: chunk.mtime,
content: chunk.content,
embedding,
metadata: chunk.metadata,
const batchEmbeddings = await embeddingModel.getBatchEmbeddings(batchTexts)
// 合并embedding结果到chunk数据
for (let j = 0; j < batchChunks.length; j++) {
const embeddedChunk: InsertVector = {
path: batchChunks[j].path,
mtime: batchChunks[j].mtime,
content: batchChunks[j].content,
embedding: batchEmbeddings[j],
metadata: batchChunks[j].metadata,
}
embeddingChunks.push(embeddedChunk)
}
embeddingChunks.push(embeddedChunk)
embeddingProgress.completed++
embeddingProgress.completed += batchChunks.length
updateProgress?.({
completedChunks: embeddingProgress.completed,
totalChunks: contentChunks.length,
@ -165,15 +172,51 @@ export class VectorManager {
jitter: 'full',
},
)
} catch (error) {
abortController.abort()
throw error
}
}),
)
} else {
// 不支持批量处理的提供商:使用原来的逐个处理逻辑
const limit = pLimit(50)
const abortController = new AbortController()
const tasks = contentChunks.map((chunk) =>
limit(async () => {
if (abortController.signal.aborted) {
throw new Error('Operation was aborted')
}
try {
await backOff(
async () => {
const embedding = await embeddingModel.getEmbedding(chunk.content)
const embeddedChunk = {
path: chunk.path,
mtime: chunk.mtime,
content: chunk.content,
embedding,
metadata: chunk.metadata,
}
embeddingChunks.push(embeddedChunk)
embeddingProgress.completed++
updateProgress?.({
completedChunks: embeddingProgress.completed,
totalChunks: contentChunks.length,
totalFiles: filesToIndex.length,
})
},
{
numOfAttempts: 5,
startingDelay: 1000,
timeMultiple: 1.5,
jitter: 'full',
},
)
} catch (error) {
abortController.abort()
throw error
}
}),
)
try {
await Promise.all(tasks)
await Promise.all(tasks)
}
// all embedding generated, batch insert
if (embeddingChunks.length > 0) {
@ -182,7 +225,7 @@ export class VectorManager {
while (inserted < embeddingChunks.length) {
const chunksToInsert = embeddingChunks.slice(
inserted,
Math.min(inserted + batchSize, embeddingChunks.length)
Math.min(inserted + insertBatchSize, embeddingChunks.length)
)
await this.repository.insertVectors(chunksToInsert, embeddingModel)
inserted += chunksToInsert.length
@ -242,25 +285,32 @@ export class VectorManager {
})
const embeddingChunks: InsertVector[] = []
const limit = pLimit(50)
const abortController = new AbortController()
const tasks = contentChunks.map((chunk) =>
limit(async () => {
if (abortController.signal.aborted) {
throw new Error('Operation was aborted')
}
try {
const insertBatchSize = 100 // 数据库插入批量大小
try {
if (embeddingModel.supportsBatch) {
// 支持批量处理的提供商:使用批量处理逻辑
const embeddingBatchSize = 100 // API批量处理大小
for (let i = 0; i < contentChunks.length; i += embeddingBatchSize) {
const batchChunks = contentChunks.slice(i, Math.min(i + embeddingBatchSize, contentChunks.length))
const batchTexts = batchChunks.map(chunk => chunk.content)
await backOff(
async () => {
const embedding = await embeddingModel.getEmbedding(chunk.content)
const embeddedChunk = {
path: chunk.path,
mtime: chunk.mtime,
content: chunk.content,
embedding,
metadata: chunk.metadata,
const batchEmbeddings = await embeddingModel.getBatchEmbeddings(batchTexts)
// 合并embedding结果到chunk数据
for (let j = 0; j < batchChunks.length; j++) {
const embeddedChunk: InsertVector = {
path: batchChunks[j].path,
mtime: batchChunks[j].mtime,
content: batchChunks[j].content,
embedding: batchEmbeddings[j],
metadata: batchChunks[j].metadata,
}
embeddingChunks.push(embeddedChunk)
}
embeddingChunks.push(embeddedChunk)
},
{
numOfAttempts: 5,
@ -269,22 +319,51 @@ export class VectorManager {
jitter: 'full',
},
)
} catch (error) {
abortController.abort()
throw error
}
}),
)
} else {
// 不支持批量处理的提供商:使用原来的逐个处理逻辑
const limit = pLimit(50)
const abortController = new AbortController()
const tasks = contentChunks.map((chunk) =>
limit(async () => {
if (abortController.signal.aborted) {
throw new Error('Operation was aborted')
}
try {
await backOff(
async () => {
const embedding = await embeddingModel.getEmbedding(chunk.content)
const embeddedChunk = {
path: chunk.path,
mtime: chunk.mtime,
content: chunk.content,
embedding,
metadata: chunk.metadata,
}
embeddingChunks.push(embeddedChunk)
},
{
numOfAttempts: 5,
startingDelay: 1000,
timeMultiple: 1.5,
jitter: 'full',
},
)
} catch (error) {
abortController.abort()
throw error
}
}),
)
try {
await Promise.all(tasks)
await Promise.all(tasks)
}
// all embedding generated, batch insert
if (embeddingChunks.length > 0) {
const batchSize = 100
let inserted = 0
while (inserted < embeddingChunks.length) {
const chunksToInsert = embeddingChunks.slice(inserted, Math.min(inserted + batchSize, embeddingChunks.length))
const chunksToInsert = embeddingChunks.slice(inserted, Math.min(inserted + insertBatchSize, embeddingChunks.length))
await this.repository.insertVectors(chunksToInsert, embeddingModel)
inserted += chunksToInsert.length
}

View File

@ -17,5 +17,7 @@ export type EmbeddingModelOption = {
export type EmbeddingModel = {
id: string
dimension: number
supportsBatch: boolean
getEmbedding: (text: string) => Promise<number[]>
getBatchEmbeddings: (texts: string[]) => Promise<number[][]>
}