diff --git a/src/core/rag/embedding.ts b/src/core/rag/embedding.ts index 5583843..f8f0543 100644 --- a/src/core/rag/embedding.ts +++ b/src/core/rag/embedding.ts @@ -64,7 +64,7 @@ export const getEmbeddingModel = ( // 确保模型已加载 if (!embeddingManager.modelLoaded || embeddingManager.currentModel !== settings.embeddingModelId) { console.log(`Loading model: ${settings.embeddingModelId}`) - await embeddingManager.loadModel(settings.embeddingModelId, false) + await embeddingManager.loadModel(settings.embeddingModelId, true) } const results = await embeddingManager.embedBatch(texts) diff --git a/src/database/modules/vector/vector-manager.ts b/src/database/modules/vector/vector-manager.ts index 151c220..172db6a 100644 --- a/src/database/modules/vector/vector-manager.ts +++ b/src/database/modules/vector/vector-manager.ts @@ -27,7 +27,7 @@ export class VectorManager { constructor(app: App, dbManager: DBManager) { this.app = app this.dbManager = dbManager - this.repository = new VectorRepository(app, dbManager.getPgClient()) + this.repository = new VectorRepository(app, dbManager.getPgClient() as any) } async performSimilaritySearch( @@ -88,6 +88,7 @@ export class VectorManager { ): Promise { let filesToIndex: TFile[] if (options.reindexAll) { + console.log("updateVaultIndex reindexAll") filesToIndex = await this.getFilesToIndex({ embeddingModel: embeddingModel, excludePatterns: options.excludePatterns, @@ -96,17 +97,22 @@ export class VectorManager { }) await this.repository.clearAllVectors(embeddingModel) } else { + console.log("updateVaultIndex for update files") await this.cleanVectorsForDeletedFiles(embeddingModel) + console.log("updateVaultIndex cleanVectorsForDeletedFiles") filesToIndex = await this.getFilesToIndex({ embeddingModel: embeddingModel, excludePatterns: options.excludePatterns, includePatterns: options.includePatterns, }) + console.log("get files to index: ", filesToIndex.length) await this.repository.deleteVectorsForMultipleFiles( filesToIndex.map((file) => file.path), embeddingModel, ) + console.log("delete vectors for multiple files: ", filesToIndex.length) } + console.log("get files to index: ", filesToIndex.length) if (filesToIndex.length === 0) { return @@ -131,6 +137,7 @@ export class VectorManager { "", ], }); + console.log("textSplitter chunkSize: ", options.chunkSize, "overlap: ", overlap) const skippedFiles: string[] = [] const contentChunks: InsertVector[] = ( @@ -145,15 +152,16 @@ export class VectorManager { ]) return fileDocuments .map((chunk): InsertVector | null => { - const content = removeMarkdown(chunk.pageContent).replace(/\0/g, '') - if (!content || content.trim().length === 0) { + // 保存原始内容,不在此处调用 removeMarkdown + const rawContent = chunk.pageContent.replace(/\0/g, '') + if (!rawContent || rawContent.trim().length === 0) { console.log("skipped chunk", chunk.pageContent) return null } return { path: file.path, mtime: file.stat.mtime, - content, + content: rawContent, // 保存原始内容 embedding: [], metadata: { startLine: Number(chunk.metadata.loc.lines.from), @@ -171,6 +179,8 @@ export class VectorManager { ) ).flat() + console.log("contentChunks: ", contentChunks.length) + if (skippedFiles.length > 0) { console.warn(`跳过了 ${skippedFiles.length} 个有问题的文件:`, skippedFiles) new Notice(`跳过了 ${skippedFiles.length} 个有问题的文件`) @@ -186,31 +196,42 @@ export class VectorManager { // 减少批量大小以降低内存压力 const insertBatchSize = 32 let batchCount = 0 - + try { if (embeddingModel.supportsBatch) { // 支持批量处理的提供商:使用流式处理逻辑 - const embeddingBatchSize = 32 - + const embeddingBatchSize = 32 + for (let i = 0; i < contentChunks.length; i += embeddingBatchSize) { batchCount++ const batchChunks = contentChunks.slice(i, Math.min(i + embeddingBatchSize, contentChunks.length)) - const batchTexts = batchChunks.map(chunk => chunk.content) - + const embeddedBatch: InsertVector[] = [] - + await backOff( async () => { + // 在嵌入之前处理 markdown,只处理一次 + const cleanedBatchData = batchChunks.map(chunk => { + const cleanContent = removeMarkdown(chunk.content).replace(/\0/g, '') + return { chunk, cleanContent } + }).filter(({ cleanContent }) => cleanContent && cleanContent.trim().length > 0) + + if (cleanedBatchData.length === 0) { + return + } + + const batchTexts = cleanedBatchData.map(({ cleanContent }) => cleanContent) const batchEmbeddings = await embeddingModel.getBatchEmbeddings(batchTexts) - + // 合并embedding结果到chunk数据 - for (let j = 0; j < batchChunks.length; j++) { + for (let j = 0; j < cleanedBatchData.length; j++) { + const { chunk, cleanContent } = cleanedBatchData[j] const embeddedChunk: InsertVector = { - path: batchChunks[j].path, - mtime: batchChunks[j].mtime, - content: batchChunks[j].content, + path: chunk.path, + mtime: chunk.mtime, + content: cleanContent, // 使用已经清理过的内容 embedding: batchEmbeddings[j], - metadata: batchChunks[j].metadata, + metadata: chunk.metadata, } embeddedBatch.push(embeddedChunk) } @@ -229,7 +250,7 @@ export class VectorManager { // 清理批次数据 embeddedBatch.length = 0 } - + embeddingProgress.completed += batchChunks.length updateProgress?.({ completedChunks: embeddingProgress.completed, @@ -244,17 +265,17 @@ export class VectorManager { // 不支持批量处理的提供商:使用流式处理逻辑 const limit = pLimit(32) // 从50降低到10,减少并发压力 const abortController = new AbortController() - + // 流式处理:分批处理并立即插入 for (let i = 0; i < contentChunks.length; i += insertBatchSize) { if (abortController.signal.aborted) { throw new Error('Operation was aborted') } - + batchCount++ const batchChunks = contentChunks.slice(i, Math.min(i + insertBatchSize, contentChunks.length)) const embeddedBatch: InsertVector[] = [] - + const tasks = batchChunks.map((chunk) => limit(async () => { if (abortController.signal.aborted) { @@ -263,11 +284,18 @@ export class VectorManager { try { await backOff( async () => { - const embedding = await embeddingModel.getEmbedding(chunk.content) + // 在嵌入之前处理 markdown + const cleanContent = removeMarkdown(chunk.content).replace(/\0/g, '') + // 跳过清理后为空的内容 + if (!cleanContent || cleanContent.trim().length === 0) { + return + } + + const embedding = await embeddingModel.getEmbedding(cleanContent) const embeddedChunk = { path: chunk.path, mtime: chunk.mtime, - content: chunk.content, + content: cleanContent, // 使用清理后的内容 embedding, metadata: chunk.metadata, } @@ -286,16 +314,16 @@ export class VectorManager { } }), ) - + await Promise.all(tasks) - + // 立即插入当前批次 if (embeddedBatch.length > 0) { await this.repository.insertVectors(embeddedBatch, embeddingModel) // 清理批次数据 embeddedBatch.length = 0 } - + embeddingProgress.completed += batchChunks.length updateProgress?.({ completedChunks: embeddingProgress.completed, @@ -339,9 +367,23 @@ export class VectorManager { ) // Embed the files - const textSplitter = new MarkdownTextSplitter({ + const overlap = Math.floor(chunkSize * 0.15) + const textSplitter = new RecursiveCharacterTextSplitter({ chunkSize: chunkSize, - chunkOverlap: Math.floor(chunkSize * 0.15) + chunkOverlap: overlap, + separators: [ + "\n\n", + "\n", + ".", + ",", + " ", + "\u200b", // Zero-width space + "\uff0c", // Fullwidth comma + "\u3001", // Ideographic comma + "\uff0e", // Fullwidth full stop + "\u3002", // Ideographic full stop + "", + ], }); let fileContent = await this.app.vault.cachedRead(file) // 清理null字节,防止PostgreSQL UTF8编码错误 @@ -352,14 +394,15 @@ export class VectorManager { const contentChunks: InsertVector[] = fileDocuments .map((chunk): InsertVector | null => { - const content = removeMarkdown(chunk.pageContent).replace(/\0/g, '') - if (!content || content.trim().length === 0) { + // 保存原始内容,不在此处调用 removeMarkdown + const rawContent = String(chunk.pageContent || '').replace(/\0/g, '') + if (!rawContent || rawContent.trim().length === 0) { return null } return { path: file.path, mtime: file.stat.mtime, - content, + content: rawContent, // 保存原始内容 embedding: [], metadata: { startLine: Number(chunk.metadata.loc.lines.from), @@ -372,32 +415,43 @@ export class VectorManager { // 减少批量大小以降低内存压力 const insertBatchSize = 16 // 从64降低到16 let batchCount = 0 - + try { if (embeddingModel.supportsBatch) { // 支持批量处理的提供商:使用流式处理逻辑 const embeddingBatchSize = 16 // 从64降低到16 - + for (let i = 0; i < contentChunks.length; i += embeddingBatchSize) { batchCount++ console.log(`Embedding batch ${batchCount} of ${Math.ceil(contentChunks.length / embeddingBatchSize)}`) const batchChunks = contentChunks.slice(i, Math.min(i + embeddingBatchSize, contentChunks.length)) - const batchTexts = batchChunks.map(chunk => chunk.content) - + const embeddedBatch: InsertVector[] = [] - + await backOff( async () => { + // 在嵌入之前处理 markdown,只处理一次 + const cleanedBatchData = batchChunks.map(chunk => { + const cleanContent = removeMarkdown(chunk.content).replace(/\0/g, '') + return { chunk, cleanContent } + }).filter(({ cleanContent }) => cleanContent && cleanContent.trim().length > 0) + + if (cleanedBatchData.length === 0) { + return + } + + const batchTexts = cleanedBatchData.map(({ cleanContent }) => cleanContent) const batchEmbeddings = await embeddingModel.getBatchEmbeddings(batchTexts) - + // 合并embedding结果到chunk数据 - for (let j = 0; j < batchChunks.length; j++) { + for (let j = 0; j < cleanedBatchData.length; j++) { + const { chunk, cleanContent } = cleanedBatchData[j] const embeddedChunk: InsertVector = { - path: batchChunks[j].path, - mtime: batchChunks[j].mtime, - content: batchChunks[j].content, + path: chunk.path, + mtime: chunk.mtime, + content: cleanContent, // 使用已经清理过的内容 embedding: batchEmbeddings[j], - metadata: batchChunks[j].metadata, + metadata: chunk.metadata, } embeddedBatch.push(embeddedChunk) } @@ -424,17 +478,17 @@ export class VectorManager { // 不支持批量处理的提供商:使用流式处理逻辑 const limit = pLimit(10) // 从50降低到10 const abortController = new AbortController() - + // 流式处理:分批处理并立即插入 for (let i = 0; i < contentChunks.length; i += insertBatchSize) { if (abortController.signal.aborted) { throw new Error('Operation was aborted') } - + batchCount++ const batchChunks = contentChunks.slice(i, Math.min(i + insertBatchSize, contentChunks.length)) const embeddedBatch: InsertVector[] = [] - + const tasks = batchChunks.map((chunk) => limit(async () => { if (abortController.signal.aborted) { @@ -443,11 +497,18 @@ export class VectorManager { try { await backOff( async () => { - const embedding = await embeddingModel.getEmbedding(chunk.content) + // 在嵌入之前处理 markdown + const cleanContent = removeMarkdown(chunk.content).replace(/\0/g, '') + // 跳过清理后为空的内容 + if (!cleanContent || cleanContent.trim().length === 0) { + return + } + + const embedding = await embeddingModel.getEmbedding(cleanContent) const embeddedChunk = { path: chunk.path, mtime: chunk.mtime, - content: chunk.content, + content: cleanContent, // 使用清理后的内容 embedding, metadata: chunk.metadata, } @@ -466,9 +527,9 @@ export class VectorManager { } }), ) - + await Promise.all(tasks) - + // 立即插入当前批次 if (embeddedBatch.length > 0) { await this.repository.insertVectors(embeddedBatch, embeddingModel) @@ -522,8 +583,9 @@ export class VectorManager { excludePatterns: string[] includePatterns: string[] reindexAll?: boolean - }): Promise { + }): Promise { let filesToIndex = this.app.vault.getMarkdownFiles() + console.log("get all vault files: ", filesToIndex.length) filesToIndex = filesToIndex.filter((file) => { return !excludePatterns.some((pattern) => minimatch(file.path, pattern)) @@ -538,39 +600,24 @@ export class VectorManager { if (reindexAll) { return filesToIndex } - // Check for updated or new files - filesToIndex = await Promise.all( - filesToIndex.map(async (file) => { - try { - const fileChunks = await this.repository.getVectorsByFilePath( - file.path, - embeddingModel, - ) - if (fileChunks.length === 0) { - // File is not indexed, so we need to index it - let fileContent = await this.app.vault.cachedRead(file) - // 清理null字节,防止PostgreSQL UTF8编码错误 - fileContent = fileContent.replace(/\0/g, '') - if (fileContent.length === 0) { - // Ignore empty files - return null - } - return file - } - const outOfDate = file.stat.mtime > fileChunks[0].mtime - if (outOfDate) { - // File has changed, so we need to re-index it - console.log("File has changed, so we need to re-index it", file.path) - return file - } - return null - } catch (error) { - console.warn(`跳过文件 ${file.path}:`, error.message) - return null - } - }), - ).then((files) => files.filter(Boolean)) - return filesToIndex + // 优化流程:使用数据库最大mtime来过滤需要更新的文件 + try { + const maxMtime = await this.repository.getMaxMtime(embeddingModel) + console.log("Database max mtime:", maxMtime) + + if (maxMtime === null) { + // 数据库中没有任何向量,需要索引所有文件 + return filesToIndex + } + + // 筛选出在数据库最后更新时间之后修改的文件 + return filesToIndex.filter((file) => { + return file.stat.mtime > maxMtime + }) + } catch (error) { + console.error("Error getting max mtime from database:", error) + return [] + } } } diff --git a/src/database/modules/vector/vector-repository.ts b/src/database/modules/vector/vector-repository.ts index 435c93e..e7d07d8 100644 --- a/src/database/modules/vector/vector-repository.ts +++ b/src/database/modules/vector/vector-repository.ts @@ -33,6 +33,17 @@ export class VectorRepository { return result.rows.map((row: { path: string }) => row.path) } + async getMaxMtime(embeddingModel: EmbeddingModel): Promise { + if (!this.db) { + throw new DatabaseNotInitializedException() + } + const tableName = this.getTableName(embeddingModel) + const result = await this.db.query<{ max_mtime: number | null }>( + `SELECT MAX(mtime) as max_mtime FROM "${tableName}"` + ) + return result.rows[0]?.max_mtime || null + } + async getVectorsByFilePath( filePath: string, embeddingModel: EmbeddingModel, diff --git a/src/embedworker/EmbeddingManager.ts b/src/embedworker/EmbeddingManager.ts index c3c6145..caf600f 100644 --- a/src/embedworker/EmbeddingManager.ts +++ b/src/embedworker/EmbeddingManager.ts @@ -4,284 +4,285 @@ import EmbedWorker from './embed.worker'; // 类型定义 export interface EmbedResult { - vec: number[]; - tokens: number; - embed_input?: string; + vec: number[]; + tokens: number; + embed_input?: string; } export interface ModelLoadResult { - model_loaded: boolean; + model_loaded: boolean; } export interface ModelUnloadResult { - model_unloaded: boolean; + model_unloaded: boolean; } export interface TokenCountResult { - tokens: number; + tokens: number; } export class EmbeddingManager { - private worker: Worker; - private requests = new Map void; reject: (reason?: any) => void }>(); - private nextRequestId = 0; - private isModelLoaded = false; - private currentModelId: string | null = null; + private worker: Worker; + private requests = new Map void; reject: (reason?: any) => void }>(); + private nextRequestId = 0; + private isModelLoaded = false; + private currentModelId: string | null = null; - constructor() { - // 创建 Worker,使用与 pgworker 相同的模式 - this.worker = new EmbedWorker(); + constructor() { + // 创建 Worker,使用与 pgworker 相同的模式 + this.worker = new EmbedWorker(); - // 统一监听来自 Worker 的所有消息 - this.worker.onmessage = (event) => { - try { - const { id, result, error } = event.data; + // 统一监听来自 Worker 的所有消息 + this.worker.onmessage = (event) => { + try { + const { id, result, error } = event.data; - // 根据返回的 id 找到对应的 Promise 回调 - const request = this.requests.get(id); + // 根据返回的 id 找到对应的 Promise 回调 + const request = this.requests.get(id); - if (request) { - if (error) { - request.reject(new Error(error)); - } else { - request.resolve(result); - } - // 完成后从 Map 中删除 - this.requests.delete(id); - } - } catch (err) { - console.error("Error processing worker message:", err); - // 拒绝所有待处理的请求 - this.requests.forEach(request => { - request.reject(new Error(`Worker message processing error: ${err.message}`)); - }); - this.requests.clear(); - } - }; + if (request) { + if (error) { + request.reject(new Error(error)); + } else { + request.resolve(result); + } + // 完成后从 Map 中删除 + this.requests.delete(id); + } + } catch (err) { + console.error("Error processing worker message:", err); + // 拒绝所有待处理的请求 + this.requests.forEach(request => { + request.reject(new Error(`Worker message processing error: ${err.message}`)); + }); + this.requests.clear(); + } + }; - this.worker.onerror = (error) => { - console.error("EmbeddingWorker error:", error); - // 拒绝所有待处理的请求 - this.requests.forEach(request => { - request.reject(new Error(`Worker error: ${error.message || 'Unknown worker error'}`)); - }); - this.requests.clear(); - - // 重置状态 - this.isModelLoaded = false; - this.currentModelId = null; - }; - } + this.worker.onerror = (error) => { + console.error("EmbeddingWorker error:", error); + // 拒绝所有待处理的请求 + this.requests.forEach(request => { + request.reject(new Error(`Worker error: ${error.message || 'Unknown worker error'}`)); + }); + this.requests.clear(); + + // 重置状态 + this.isModelLoaded = false; + this.currentModelId = null; + }; + } - /** - * 向 Worker 发送一个请求,并返回一个 Promise,该 Promise 将在收到响应时解析。 - * @param method 要调用的方法 (e.g., 'load', 'embed_batch') - * @param params 方法所需的参数 - */ - private postRequest(method: string, params: any): Promise { - return new Promise((resolve, reject) => { - const id = this.nextRequestId++; - this.requests.set(id, { resolve, reject }); - this.worker.postMessage({ method, params, id }); - }); - } + /** + * 向 Worker 发送一个请求,并返回一个 Promise,该 Promise 将在收到响应时解析。 + * @param method 要调用的方法 (e.g., 'load', 'embed_batch') + * @param params 方法所需的参数 + */ + private postRequest(method: string, params: any): Promise { + return new Promise((resolve, reject) => { + const id = this.nextRequestId++; + this.requests.set(id, { resolve, reject }); + this.worker.postMessage({ method, params, id }); + }); + } - /** - * 加载指定的嵌入模型到 Worker 中。 - * @param modelId 模型ID, 例如 'TaylorAI/bge-micro-v2' - * @param useGpu 是否使用GPU加速,默认为false - */ - public async loadModel(modelId: string, useGpu: boolean = false): Promise { - console.log(`Loading embedding model: ${modelId}, GPU: ${useGpu}`); - - try { - // 如果已经加载了相同的模型,直接返回 - if (this.isModelLoaded && this.currentModelId === modelId) { - console.log(`Model ${modelId} already loaded`); - return { model_loaded: true }; - } - - // 如果加载了不同的模型,先卸载 - if (this.isModelLoaded && this.currentModelId !== modelId) { - console.log(`Unloading previous model: ${this.currentModelId}`); - await this.unloadModel(); - } - - const result = await this.postRequest('load', { - model_key: modelId, - use_gpu: useGpu - }); - - this.isModelLoaded = result.model_loaded; - this.currentModelId = result.model_loaded ? modelId : null; - - if (result.model_loaded) { - console.log(`Model ${modelId} loaded successfully`); - } - - return result; - } catch (error) { - console.error(`Failed to load model ${modelId}:`, error); - this.isModelLoaded = false; - this.currentModelId = null; - throw error; - } - } + /** + * 加载指定的嵌入模型到 Worker 中。 + * @param modelId 模型ID, 例如 'TaylorAI/bge-micro-v2' + * @param useGpu 是否使用GPU加速,默认为false + */ + public async loadModel(modelId: string, useGpu: boolean = false): Promise { + console.log(`Loading embedding model: ${modelId}, GPU: ${useGpu}`); - /** - * 为一批文本生成嵌入向量。 - * @param texts 要处理的文本数组 - * @returns 返回一个包含向量和 token 信息的对象数组 - */ - public async embedBatch(texts: string[]): Promise { - if (!this.isModelLoaded) { - throw new Error('Model not loaded. Please call loadModel() first.'); - } - - if (!texts || texts.length === 0) { - return []; - } - - console.log(`Generating embeddings for ${texts.length} texts`); - - try { - const inputs = texts.map(text => ({ embed_input: text })); - const results = await this.postRequest('embed_batch', { inputs }); - - console.log(`Generated ${results.length} embeddings`); - return results; - } catch (error) { - console.error('Failed to generate embeddings:', error); - throw error; - } - } + try { + // 如果已经加载了相同的模型,直接返回 + if (this.isModelLoaded && this.currentModelId === modelId) { + console.log(`Model ${modelId} already loaded`); + return { model_loaded: true }; + } - /** - * 为单个文本生成嵌入向量。 - * @param text 要处理的文本 - * @returns 返回包含向量和 token 信息的对象 - */ - public async embed(text: string): Promise { - if (!text || text.trim().length === 0) { - throw new Error('Text cannot be empty'); - } - - const results = await this.embedBatch([text]); - if (results.length === 0) { - throw new Error('Failed to generate embedding'); - } - - return results[0]; - } + // 如果加载了不同的模型,先卸载 + if (this.isModelLoaded && this.currentModelId !== modelId) { + console.log(`Unloading previous model: ${this.currentModelId}`); + await this.unloadModel(); + } - /** - * 计算文本的 token 数量。 - * @param text 要计算的文本 - */ - public async countTokens(text: string): Promise { - if (!this.isModelLoaded) { - throw new Error('Model not loaded. Please call loadModel() first.'); - } - - if (!text) { - return { tokens: 0 }; - } - - try { - return await this.postRequest('count_tokens', text); - } catch (error) { - console.error('Failed to count tokens:', error); - throw error; - } - } + const result = await this.postRequest('load', { + model_key: modelId, + use_gpu: useGpu + }); - /** - * 卸载模型,释放内存。 - */ - public async unloadModel(): Promise { - if (!this.isModelLoaded) { - console.log('No model to unload'); - return { model_unloaded: true }; - } - - try { - console.log(`Unloading model: ${this.currentModelId}`); - const result = await this.postRequest('unload', {}); - - this.isModelLoaded = false; - this.currentModelId = null; - - console.log('Model unloaded successfully'); - return result; - } catch (error) { - console.error('Failed to unload model:', error); - // 即使卸载失败,也重置状态 - this.isModelLoaded = false; - this.currentModelId = null; - throw error; - } - } + this.isModelLoaded = result.model_loaded; + this.currentModelId = result.model_loaded ? modelId : null; - /** - * 检查模型是否已加载。 - */ - public get modelLoaded(): boolean { - return this.isModelLoaded; - } + if (result.model_loaded) { + console.log(`Model ${modelId} loaded successfully`); + } - /** - * 获取当前加载的模型ID。 - */ - public get currentModel(): string | null { - return this.currentModelId; - } + return result; + } catch (error) { + console.error(`Failed to load model ${modelId}:`, error); + this.isModelLoaded = false; + this.currentModelId = null; + throw error; + } + } - /** - * 获取支持的模型列表。 - */ - public getSupportedModels(): string[] { - return [ - 'Xenova/all-MiniLM-L6-v2', - 'Xenova/bge-small-en-v1.5', - 'Xenova/bge-base-en-v1.5', - 'Xenova/jina-embeddings-v2-base-zh', - 'Xenova/jina-embeddings-v2-small-en', - 'Xenova/multilingual-e5-small', - 'Xenova/multilingual-e5-base', - 'Xenova/gte-small', - 'Xenova/e5-small-v2', - 'Xenova/e5-base-v2' - ]; - } + /** + * 为一批文本生成嵌入向量。 + * @param texts 要处理的文本数组 + * @returns 返回一个包含向量和 token 信息的对象数组 + */ + public async embedBatch(texts: string[]): Promise { + if (!this.isModelLoaded) { + throw new Error('Model not loaded. Please call loadModel() first.'); + } - /** - * 获取模型信息。 - */ - public getModelInfo(modelId: string): { dims: number; maxTokens: number; description: string } | null { - const modelInfoMap: Record = { - 'Xenova/all-MiniLM-L6-v2': { dims: 384, maxTokens: 512, description: 'All-MiniLM-L6-v2 (推荐,轻量级)' }, - 'Xenova/bge-small-en-v1.5': { dims: 384, maxTokens: 512, description: 'BGE-small-en-v1.5' }, - 'Xenova/bge-base-en-v1.5': { dims: 768, maxTokens: 512, description: 'BGE-base-en-v1.5 (更高质量)' }, - 'Xenova/jina-embeddings-v2-base-zh': { dims: 768, maxTokens: 8192, description: 'Jina-v2-base-zh (中英双语)' }, - 'Xenova/jina-embeddings-v2-small-en': { dims: 512, maxTokens: 8192, description: 'Jina-v2-small-en' }, - 'Xenova/multilingual-e5-small': { dims: 384, maxTokens: 512, description: 'E5-small (多语言)' }, - 'Xenova/multilingual-e5-base': { dims: 768, maxTokens: 512, description: 'E5-base (多语言,更高质量)' }, - 'Xenova/gte-small': { dims: 384, maxTokens: 512, description: 'GTE-small' }, - 'Xenova/e5-small-v2': { dims: 384, maxTokens: 512, description: 'E5-small-v2' }, - 'Xenova/e5-base-v2': { dims: 768, maxTokens: 512, description: 'E5-base-v2 (更高质量)' } - }; + if (!texts || texts.length === 0) { + return []; + } - return modelInfoMap[modelId] || null; - } + console.log(`Generating embeddings for ${texts.length} texts`); - /** - * 终止 Worker,释放资源。 - */ - public terminate() { - this.worker.terminate(); - this.requests.clear(); - this.isModelLoaded = false; - } + try { + const inputs = texts.map(text => ({ embed_input: text })); + const results = await this.postRequest('embed_batch', { inputs }); + + console.log(`Generated ${results.length} embeddings`); + return results; + } catch (error) { + console.error('Failed to generate embeddings:', error); + throw error; + } + } + + /** + * 为单个文本生成嵌入向量。 + * @param text 要处理的文本 + * @returns 返回包含向量和 token 信息的对象 + */ + public async embed(text: string): Promise { + if (!text || text.trim().length === 0) { + throw new Error('Text cannot be empty'); + } + + const results = await this.embedBatch([text]); + if (results.length === 0) { + throw new Error('Failed to generate embedding'); + } + + return results[0]; + } + + /** + * 计算文本的 token 数量。 + * @param text 要计算的文本 + */ + public async countTokens(text: string): Promise { + if (!this.isModelLoaded) { + throw new Error('Model not loaded. Please call loadModel() first.'); + } + + if (!text) { + return { tokens: 0 }; + } + + try { + return await this.postRequest('count_tokens', text); + } catch (error) { + console.error('Failed to count tokens:', error); + throw error; + } + } + + /** + * 卸载模型,释放内存。 + */ + public async unloadModel(): Promise { + if (!this.isModelLoaded) { + console.log('No model to unload'); + return { model_unloaded: true }; + } + + try { + console.log(`Unloading model: ${this.currentModelId}`); + const result = await this.postRequest('unload', {}); + + this.isModelLoaded = false; + this.currentModelId = null; + + console.log('Model unloaded successfully'); + return result; + } catch (error) { + console.error('Failed to unload model:', error); + // 即使卸载失败,也重置状态 + this.isModelLoaded = false; + this.currentModelId = null; + throw error; + } + } + + /** + * 检查模型是否已加载。 + */ + public get modelLoaded(): boolean { + return this.isModelLoaded; + } + + /** + * 获取当前加载的模型ID。 + */ + public get currentModel(): string | null { + return this.currentModelId; + } + + /** + * 获取支持的模型列表。 + */ + public getSupportedModels(): string[] { + return [ + 'TaylorAI/bge-micro-v2', + 'Xenova/all-MiniLM-L6-v2', + 'Xenova/bge-small-en-v1.5', + 'Xenova/bge-base-en-v1.5', + 'Xenova/jina-embeddings-v2-base-zh', + 'Xenova/jina-embeddings-v2-small-en', + 'Xenova/multilingual-e5-small', + 'Xenova/multilingual-e5-base', + 'Xenova/gte-small', + 'Xenova/e5-small-v2', + 'Xenova/e5-base-v2' + ]; + } + + /** + * 获取模型信息。 + */ + public getModelInfo(modelId: string): { dims: number; maxTokens: number; description: string } | null { + const modelInfoMap: Record = { + 'Xenova/all-MiniLM-L6-v2': { dims: 384, maxTokens: 512, description: 'All-MiniLM-L6-v2 (推荐,轻量级)' }, + 'Xenova/bge-small-en-v1.5': { dims: 384, maxTokens: 512, description: 'BGE-small-en-v1.5' }, + 'Xenova/bge-base-en-v1.5': { dims: 768, maxTokens: 512, description: 'BGE-base-en-v1.5 (更高质量)' }, + 'Xenova/jina-embeddings-v2-base-zh': { dims: 768, maxTokens: 8192, description: 'Jina-v2-base-zh (中英双语)' }, + 'Xenova/jina-embeddings-v2-small-en': { dims: 512, maxTokens: 8192, description: 'Jina-v2-small-en' }, + 'Xenova/multilingual-e5-small': { dims: 384, maxTokens: 512, description: 'E5-small (多语言)' }, + 'Xenova/multilingual-e5-base': { dims: 768, maxTokens: 512, description: 'E5-base (多语言,更高质量)' }, + 'Xenova/gte-small': { dims: 384, maxTokens: 512, description: 'GTE-small' }, + 'Xenova/e5-small-v2': { dims: 384, maxTokens: 512, description: 'E5-small-v2' }, + 'Xenova/e5-base-v2': { dims: 768, maxTokens: 512, description: 'E5-base-v2 (更高质量)' } + }; + + return modelInfoMap[modelId] || null; + } + + /** + * 终止 Worker,释放资源。 + */ + public terminate() { + this.worker.terminate(); + this.requests.clear(); + this.isModelLoaded = false; + } } diff --git a/src/embedworker/embed.worker.ts b/src/embedworker/embed.worker.ts index 9306739..fca571c 100644 --- a/src/embedworker/embed.worker.ts +++ b/src/embedworker/embed.worker.ts @@ -48,7 +48,7 @@ async function loadTransformers() { env.allowRemoteModels = true; // 配置 WASM 后端 - 修复线程配置 - env.backends.onnx.wasm.numThreads = 4; // 在 Worker 中使用单线程,避免竞态条件 + env.backends.onnx.wasm.numThreads = 1; // 在 Worker 中使用单线程,避免竞态条件 env.backends.onnx.wasm.simd = true; // 禁用 Node.js 特定功能 @@ -201,7 +201,7 @@ async function embedBatch(inputs: EmbedInput[]): Promise { } // 批处理大小(可以根据需要调整) - const batchSize = 1; + const batchSize = 8; if (filteredInputs.length > batchSize) { console.log(`Processing ${filteredInputs.length} inputs in batches of ${batchSize}`); diff --git a/src/embedworker/index.ts b/src/embedworker/index.ts index 645a422..e0010f5 100644 --- a/src/embedworker/index.ts +++ b/src/embedworker/index.ts @@ -8,8 +8,8 @@ export { EmbeddingManager }; // 导出类型定义 export type { - EmbedResult, - ModelLoadResult, - ModelUnloadResult, - TokenCountResult + EmbedResult, + ModelLoadResult, + ModelUnloadResult, + TokenCountResult } from './EmbeddingManager'; diff --git a/src/utils/api.ts b/src/utils/api.ts index d25c0de..cc41324 100644 --- a/src/utils/api.ts +++ b/src/utils/api.ts @@ -1641,6 +1641,7 @@ export const localProviderDefaultAutoCompleteModelId = null // this is not suppo export const localProviderDefaultEmbeddingModelId: keyof typeof localProviderEmbeddingModels = "TaylorAI/bge-micro-v2" export const localProviderEmbeddingModels = { + 'TaylorAI/bge-micro-v2': { dimensions: 384, description: 'BGE-micro-v2 (本地,512令牌,384维)' }, 'Xenova/all-MiniLM-L6-v2': { dimensions: 384, description: 'All-MiniLM-L6-v2 (推荐,轻量级)' }, 'Xenova/bge-small-en-v1.5': { dimensions: 384, description: 'BGE-small-en-v1.5' }, 'Xenova/bge-base-en-v1.5': { dimensions: 768, description: 'BGE-base-en-v1.5 (更高质量)' }, @@ -1651,8 +1652,6 @@ export const localProviderEmbeddingModels = { 'Xenova/gte-small': { dimensions: 384, description: 'GTE-small' }, 'Xenova/e5-small-v2': { dimensions: 384, description: 'E5-small-v2' }, 'Xenova/e5-base-v2': { dimensions: 768, description: 'E5-base-v2 (更高质量)' }, - // 新增的模型 - 'TaylorAI/bge-micro-v2': { dimensions: 384, description: 'BGE-micro-v2 (本地,512令牌,384维)' }, 'Snowflake/snowflake-arctic-embed-xs': { dimensions: 384, description: 'Snowflake Arctic Embed XS (本地,512令牌,384维)' }, 'Snowflake/snowflake-arctic-embed-s': { dimensions: 384, description: 'Snowflake Arctic Embed Small (本地,512令牌,384维)' }, 'Snowflake/snowflake-arctic-embed-m': { dimensions: 768, description: 'Snowflake Arctic Embed Medium (本地,512令牌,768维)' },