From 4e139ecc4fff594b416d26e5e434cd0b8f863bd7 Mon Sep 17 00:00:00 2001 From: duanfuxiang Date: Fri, 4 Jul 2025 15:52:00 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=96=B0=20RAGEngine=20=E5=92=8C?= =?UTF-8?q?=E5=B5=8C=E5=85=A5=E7=AE=A1=E7=90=86=E5=99=A8=E4=BB=A5=E6=94=AF?= =?UTF-8?q?=E6=8C=81=E5=B5=8C=E5=85=A5=E7=AE=A1=E7=90=86=E5=99=A8=E7=9A=84?= =?UTF-8?q?=E4=BC=A0=E9=80=92=EF=BC=8C=E6=B7=BB=E5=8A=A0=E6=9C=AC=E5=9C=B0?= =?UTF-8?q?=E6=8F=90=E4=BE=9B=E8=80=85=E7=9A=84=E5=B5=8C=E5=85=A5=E6=A8=A1?= =?UTF-8?q?=E5=9E=8B=E5=8A=A0=E8=BD=BD=E9=80=BB=E8=BE=91=EF=BC=8C=E4=BC=98?= =?UTF-8?q?=E5=8C=96=E9=94=99=E8=AF=AF=E5=A4=84=E7=90=86=E5=92=8C=E6=B6=88?= =?UTF-8?q?=E6=81=AF=E5=A4=84=E7=90=86=E6=9C=BA=E5=88=B6=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/core/rag/embedding.ts | 57 ++ src/core/rag/rag-engine.ts | 16 +- src/database/modules/vector/vector-manager.ts | 2 +- src/embedworker/EmbeddingManager.ts | 35 +- src/embedworker/embed.worker.ts | 652 ++++++++++-------- src/main.ts | 2 +- src/utils/provider-urls.ts | 1 + 7 files changed, 453 insertions(+), 312 deletions(-) diff --git a/src/core/rag/embedding.ts b/src/core/rag/embedding.ts index 897cc6a..5583843 100644 --- a/src/core/rag/embedding.ts +++ b/src/core/rag/embedding.ts @@ -16,10 +16,67 @@ import { } from '../llm/exception' import { NoStainlessOpenAI } from '../llm/ollama' +// EmbeddingManager 类型定义 +type EmbeddingManager = { + modelLoaded: boolean + currentModel: string | null + loadModel(modelId: string, useGpu: boolean): Promise + embed(text: string): Promise<{ vec: number[] }> + embedBatch(texts: string[]): Promise<{ vec: number[] }[]> +} + export const getEmbeddingModel = ( settings: InfioSettings, + embeddingManager?: EmbeddingManager, ): EmbeddingModel => { switch (settings.embeddingModelProvider) { + case ApiProvider.LocalProvider: { + if (!embeddingManager) { + throw new Error('EmbeddingManager is required for LocalProvider') + } + + const modelInfo = GetEmbeddingModelInfo(settings.embeddingModelProvider, settings.embeddingModelId) + if (!modelInfo) { + throw new Error(`Embedding model ${settings.embeddingModelId} not found for provider ${settings.embeddingModelProvider}`) + } + + return { + id: settings.embeddingModelId, + dimension: modelInfo.dimensions, + supportsBatch: true, + getEmbedding: async (text: string) => { + try { + // 确保模型已加载 + if (!embeddingManager.modelLoaded || embeddingManager.currentModel !== settings.embeddingModelId) { + console.log(`Loading model: ${settings.embeddingModelId}`) + await embeddingManager.loadModel(settings.embeddingModelId, true) + } + + const result = await embeddingManager.embed(text) + return result.vec + } catch (error) { + console.error('LocalProvider embedding error:', error) + throw new Error(`LocalProvider embedding failed: ${error.message}`) + } + }, + getBatchEmbeddings: async (texts: string[]) => { + try { + // 确保模型已加载 + if (!embeddingManager.modelLoaded || embeddingManager.currentModel !== settings.embeddingModelId) { + console.log(`Loading model: ${settings.embeddingModelId}`) + await embeddingManager.loadModel(settings.embeddingModelId, false) + } + + const results = await embeddingManager.embedBatch(texts) + console.log('results', results) + return results.map(result => result.vec) + } catch (error) { + console.error('LocalProvider batch embedding error:', error) + throw new Error(`LocalProvider batch embedding failed: ${error.message}`) + } + }, + } + } case ApiProvider.Infio: { const openai = new OpenAI({ apiKey: settings.infioProvider.apiKey, diff --git a/src/core/rag/rag-engine.ts b/src/core/rag/rag-engine.ts index becd4f8..292f684 100644 --- a/src/core/rag/rag-engine.ts +++ b/src/core/rag/rag-engine.ts @@ -10,9 +10,19 @@ import { InfioSettings } from '../../types/settings' import { getEmbeddingModel } from './embedding' +// EmbeddingManager 类型定义 +type EmbeddingManager = { + modelLoaded: boolean + currentModel: string | null + loadModel(modelId: string, useGpu: boolean): Promise + embed(text: string): Promise<{ vec: number[] }> + embedBatch(texts: string[]): Promise<{ vec: number[] }[]> +} + export class RAGEngine { private app: App private settings: InfioSettings + private embeddingManager?: EmbeddingManager private vectorManager: VectorManager | null = null private embeddingModel: EmbeddingModel | null = null private initialized = false @@ -21,13 +31,15 @@ export class RAGEngine { app: App, settings: InfioSettings, dbManager: DBManager, + embeddingManager?: EmbeddingManager, ) { this.app = app this.settings = settings + this.embeddingManager = embeddingManager this.vectorManager = dbManager.getVectorManager() if (settings.embeddingModelId && settings.embeddingModelId.trim() !== '') { try { - this.embeddingModel = getEmbeddingModel(settings) + this.embeddingModel = getEmbeddingModel(settings, embeddingManager) } catch (error) { console.warn('Failed to initialize embedding model:', error) this.embeddingModel = null @@ -46,7 +58,7 @@ export class RAGEngine { this.settings = settings if (settings.embeddingModelId && settings.embeddingModelId.trim() !== '') { try { - this.embeddingModel = getEmbeddingModel(settings) + this.embeddingModel = getEmbeddingModel(settings, this.embeddingManager) } catch (error) { console.warn('Failed to initialize embedding model:', error) this.embeddingModel = null diff --git a/src/database/modules/vector/vector-manager.ts b/src/database/modules/vector/vector-manager.ts index 6a69a62..1421f3e 100644 --- a/src/database/modules/vector/vector-manager.ts +++ b/src/database/modules/vector/vector-manager.ts @@ -113,7 +113,7 @@ export class VectorManager { const textSplitter = new MarkdownTextSplitter({ chunkSize: options.chunkSize, - chunkOverlap: Math.floor(options.chunkSize * 0.15) + // chunkOverlap: Math.floor(options.chunkSize * 0.15) }) const skippedFiles: string[] = [] diff --git a/src/embedworker/EmbeddingManager.ts b/src/embedworker/EmbeddingManager.ts index 76b74b8..c3c6145 100644 --- a/src/embedworker/EmbeddingManager.ts +++ b/src/embedworker/EmbeddingManager.ts @@ -34,19 +34,28 @@ export class EmbeddingManager { // 统一监听来自 Worker 的所有消息 this.worker.onmessage = (event) => { - const { id, result, error } = event.data; + try { + const { id, result, error } = event.data; - // 根据返回的 id 找到对应的 Promise 回调 - const request = this.requests.get(id); + // 根据返回的 id 找到对应的 Promise 回调 + const request = this.requests.get(id); - if (request) { - if (error) { - request.reject(new Error(error)); - } else { - request.resolve(result); + if (request) { + if (error) { + request.reject(new Error(error)); + } else { + request.resolve(result); + } + // 完成后从 Map 中删除 + this.requests.delete(id); } - // 完成后从 Map 中删除 - this.requests.delete(id); + } catch (err) { + console.error("Error processing worker message:", err); + // 拒绝所有待处理的请求 + this.requests.forEach(request => { + request.reject(new Error(`Worker message processing error: ${err.message}`)); + }); + this.requests.clear(); } }; @@ -54,9 +63,13 @@ export class EmbeddingManager { console.error("EmbeddingWorker error:", error); // 拒绝所有待处理的请求 this.requests.forEach(request => { - request.reject(error); + request.reject(new Error(`Worker error: ${error.message || 'Unknown worker error'}`)); }); this.requests.clear(); + + // 重置状态 + this.isModelLoaded = false; + this.currentModelId = null; }; } diff --git a/src/embedworker/embed.worker.ts b/src/embedworker/embed.worker.ts index c7ecd0a..9306739 100644 --- a/src/embedworker/embed.worker.ts +++ b/src/embedworker/embed.worker.ts @@ -3,27 +3,27 @@ console.log('Embedding worker loaded'); // 类型定义 interface EmbedInput { - embed_input: string; + embed_input: string; } interface EmbedResult { - vec: number[]; - tokens: number; - embed_input?: string; + vec: number[]; + tokens: number; + embed_input?: string; } interface WorkerMessage { - method: string; - params: any; - id: number; - worker_id?: string; + method: string; + params: any; + id: number; + worker_id?: string; } interface WorkerResponse { - id: number; - result?: any; - error?: string; - worker_id?: string; + id: number; + result?: any; + error?: string; + worker_id?: string; } // 全局变量 @@ -35,319 +35,377 @@ let transformersLoaded = false; // 动态导入 Transformers.js async function loadTransformers() { - if (transformersLoaded) return; - - try { - console.log('Loading Transformers.js...'); - - // 尝试使用旧版本的 Transformers.js,它在 Worker 中更稳定 - const { pipeline: pipelineFactory, env, AutoTokenizer } = await import('@xenova/transformers'); - - // 配置环境以适应浏览器 Worker - env.allowLocalModels = false; - env.allowRemoteModels = true; - - // 配置 WASM 后端 - env.backends.onnx.wasm.numThreads = 2; // 在 Worker 中使用单线程 - env.backends.onnx.wasm.simd = true; - - // 禁用 Node.js 特定功能 - env.useFS = false; - env.useBrowserCache = true; - - // 存储导入的函数 - (globalThis as any).pipelineFactory = pipelineFactory; - (globalThis as any).AutoTokenizer = AutoTokenizer; - (globalThis as any).env = env; - - transformersLoaded = true; - console.log('Transformers.js loaded successfully'); - } catch (error) { - console.error('Failed to load Transformers.js:', error); - throw new Error(`Failed to load Transformers.js: ${error}`); - } + if (transformersLoaded) return; + + try { + console.log('Loading Transformers.js...'); + + // 尝试使用旧版本的 Transformers.js,它在 Worker 中更稳定 + const { pipeline: pipelineFactory, env, AutoTokenizer } = await import('@xenova/transformers'); + + // 配置环境以适应浏览器 Worker + env.allowLocalModels = false; + env.allowRemoteModels = true; + + // 配置 WASM 后端 - 修复线程配置 + env.backends.onnx.wasm.numThreads = 4; // 在 Worker 中使用单线程,避免竞态条件 + env.backends.onnx.wasm.simd = true; + + // 禁用 Node.js 特定功能 + env.useFS = false; + env.useBrowserCache = true; + + // 存储导入的函数 + (globalThis as any).pipelineFactory = pipelineFactory; + (globalThis as any).AutoTokenizer = AutoTokenizer; + (globalThis as any).env = env; + + transformersLoaded = true; + console.log('Transformers.js loaded successfully'); + } catch (error) { + console.error('Failed to load Transformers.js:', error); + throw new Error(`Failed to load Transformers.js: ${error}`); + } } // 加载模型 async function loadModel(modelKey: string, useGpu: boolean = false) { - try { - console.log(`Loading model: ${modelKey}, GPU: ${useGpu}`); - - // 确保 Transformers.js 已加载 - await loadTransformers(); - - const pipelineFactory = (globalThis as any).pipelineFactory; - const AutoTokenizer = (globalThis as any).AutoTokenizer; - const env = (globalThis as any).env; - - // 配置管道选项 - const pipelineOpts: any = { - quantized: true, - progress_callback: (progress: any) => { - console.log('Model loading progress:', progress); - } - }; - - if (useGpu && typeof navigator !== 'undefined' && 'gpu' in navigator) { - console.log('[Transformers] Attempting to use GPU'); - try { - pipelineOpts.device = 'webgpu'; - pipelineOpts.dtype = 'fp32'; - } catch (error) { - console.warn('[Transformers] GPU not available, falling back to CPU'); - } - } else { - console.log('[Transformers] Using CPU'); - } - - // 创建嵌入管道 - pipeline = await pipelineFactory('feature-extraction', modelKey, pipelineOpts); - - // 创建分词器 - tokenizer = await AutoTokenizer.from_pretrained(modelKey); - - model = { - loaded: true, - model_key: modelKey, - use_gpu: useGpu - }; - - console.log(`Model ${modelKey} loaded successfully`); - return { model_loaded: true }; - - } catch (error) { - console.error('Error loading model:', error); - throw new Error(`Failed to load model: ${error}`); - } + try { + console.log(`Loading model: ${modelKey}, GPU: ${useGpu}`); + + // 确保 Transformers.js 已加载 + await loadTransformers(); + + const pipelineFactory = (globalThis as any).pipelineFactory; + const AutoTokenizer = (globalThis as any).AutoTokenizer; + const env = (globalThis as any).env; + + // 配置管道选项 + const pipelineOpts: any = { + quantized: true, + // 修复进度回调,添加错误处理 + progress_callback: (progress: any) => { + try { + if (progress && typeof progress === 'object') { + console.log('Model loading progress:', progress); + } + } catch (error) { + // 忽略进度回调错误,避免中断模型加载 + console.warn('Progress callback error (ignored):', error); + } + } + }; + + // GPU 配置更加谨慎 + if (useGpu) { + try { + // 检查 WebGPU 支持 + console.log("useGpu", useGpu) + if (typeof navigator !== 'undefined' && 'gpu' in navigator) { + const gpu = (navigator as any).gpu; + if (gpu && typeof gpu.requestAdapter === 'function') { + console.log('[Transformers] Attempting to use GPU'); + pipelineOpts.device = 'webgpu'; + pipelineOpts.dtype = 'fp32'; + } else { + console.log('[Transformers] WebGPU not fully supported, using CPU'); + } + } else { + console.log('[Transformers] WebGPU not available, using CPU'); + } + } catch (error) { + console.warn('[Transformers] Error checking GPU support, falling back to CPU:', error); + } + } else { + console.log('[Transformers] Using CPU'); + } + + // 创建嵌入管道 + pipeline = await pipelineFactory('feature-extraction', modelKey, pipelineOpts); + + // 创建分词器 + tokenizer = await AutoTokenizer.from_pretrained(modelKey); + + model = { + loaded: true, + model_key: modelKey, + use_gpu: useGpu + }; + + console.log(`Model ${modelKey} loaded successfully`); + return { model_loaded: true }; + + } catch (error) { + console.error('Error loading model:', error); + throw new Error(`Failed to load model: ${error}`); + } } // 卸载模型 async function unloadModel() { - try { - console.log('Unloading model...'); - - if (pipeline) { - if (pipeline.destroy) { - pipeline.destroy(); - } - pipeline = null; - } - - if (tokenizer) { - tokenizer = null; - } - - model = null; - - console.log('Model unloaded successfully'); - return { model_unloaded: true }; - - } catch (error) { - console.error('Error unloading model:', error); - throw new Error(`Failed to unload model: ${error}`); - } + try { + console.log('Unloading model...'); + + if (pipeline) { + if (pipeline.destroy) { + pipeline.destroy(); + } + pipeline = null; + } + + if (tokenizer) { + tokenizer = null; + } + + model = null; + + console.log('Model unloaded successfully'); + return { model_unloaded: true }; + + } catch (error) { + console.error('Error unloading model:', error); + throw new Error(`Failed to unload model: ${error}`); + } } // 计算 token 数量 async function countTokens(input: string) { - try { - if (!tokenizer) { - throw new Error('Tokenizer not loaded'); - } - - const { input_ids } = await tokenizer(input); - return { tokens: input_ids.data.length }; - - } catch (error) { - console.error('Error counting tokens:', error); - throw new Error(`Failed to count tokens: ${error}`); - } + try { + if (!tokenizer) { + throw new Error('Tokenizer not loaded'); + } + + const { input_ids } = await tokenizer(input); + return { tokens: input_ids.data.length }; + + } catch (error) { + console.error('Error counting tokens:', error); + throw new Error(`Failed to count tokens: ${error}`); + } } // 生成嵌入向量 async function embedBatch(inputs: EmbedInput[]): Promise { - try { - if (!pipeline || !tokenizer) { - throw new Error('Model not loaded'); - } - - console.log(`Processing ${inputs.length} inputs`); - - // 过滤空输入 - const filteredInputs = inputs.filter(item => item.embed_input && item.embed_input.length > 0); - - if (filteredInputs.length === 0) { - return []; - } - - // 批处理大小(可以根据需要调整) - const batchSize = 1; - - if (filteredInputs.length > batchSize) { - console.log(`Processing ${filteredInputs.length} inputs in batches of ${batchSize}`); - const results: EmbedResult[] = []; - - for (let i = 0; i < filteredInputs.length; i += batchSize) { - const batch = filteredInputs.slice(i, i + batchSize); - const batchResults = await processBatch(batch); - results.push(...batchResults); - } - - return results; - } - - return await processBatch(filteredInputs); - - } catch (error) { - console.error('Error in embed batch:', error); - throw new Error(`Failed to generate embeddings: ${error}`); - } + try { + if (!pipeline || !tokenizer) { + throw new Error('Model not loaded'); + } + + console.log(`Processing ${inputs.length} inputs`); + + // 过滤空输入 + const filteredInputs = inputs.filter(item => item.embed_input && item.embed_input.length > 0); + + if (filteredInputs.length === 0) { + return []; + } + + // 批处理大小(可以根据需要调整) + const batchSize = 1; + + if (filteredInputs.length > batchSize) { + console.log(`Processing ${filteredInputs.length} inputs in batches of ${batchSize}`); + const results: EmbedResult[] = []; + + for (let i = 0; i < filteredInputs.length; i += batchSize) { + const batch = filteredInputs.slice(i, i + batchSize); + const batchResults = await processBatch(batch); + results.push(...batchResults); + } + + return results; + } + + return await processBatch(filteredInputs); + + } catch (error) { + console.error('Error in embed batch:', error); + throw new Error(`Failed to generate embeddings: ${error}`); + } } // 处理单个批次 async function processBatch(batchInputs: EmbedInput[]): Promise { - try { - // 计算每个输入的 token 数量 - const tokens = await Promise.all( - batchInputs.map(item => countTokens(item.embed_input)) - ); - - // 准备嵌入输入(处理超长文本) - const maxTokens = 512; // 大多数模型的最大 token 限制 - const embedInputs = await Promise.all( - batchInputs.map(async (item, i) => { - if (tokens[i].tokens < maxTokens) { - return item.embed_input; - } - - // 截断超长文本 - let tokenCt = tokens[i].tokens; - let truncatedInput = item.embed_input; - - while (tokenCt > maxTokens) { - const pct = maxTokens / tokenCt; - const maxChars = Math.floor(truncatedInput.length * pct * 0.9); - truncatedInput = truncatedInput.substring(0, maxChars) + '...'; - tokenCt = (await countTokens(truncatedInput)).tokens; - } - - tokens[i].tokens = tokenCt; - return truncatedInput; - }) - ); - - // 生成嵌入向量 - const resp = await pipeline(embedInputs, { pooling: 'mean', normalize: true }); - - // 处理结果 - return batchInputs.map((item, i) => ({ - vec: Array.from(resp[i].data).map((val: number) => Math.round(val * 1e8) / 1e8), - tokens: tokens[i].tokens, - embed_input: item.embed_input - })); - - } catch (error) { - console.error('Error processing batch:', error); - - // 如果批处理失败,尝试逐个处理 - return Promise.all( - batchInputs.map(async (item) => { - try { - const result = await pipeline(item.embed_input, { pooling: 'mean', normalize: true }); - const tokenCount = await countTokens(item.embed_input); - - return { - vec: Array.from(result[0].data).map((val: number) => Math.round(val * 1e8) / 1e8), - tokens: tokenCount.tokens, - embed_input: item.embed_input - }; - } catch (singleError) { - console.error('Error processing single item:', singleError); - return { - vec: [], - tokens: 0, - embed_input: item.embed_input, - error: (singleError as Error).message - } as any; - } - }) - ); - } + try { + // 计算每个输入的 token 数量 + const tokens = await Promise.all( + batchInputs.map(item => countTokens(item.embed_input)) + ); + + // 准备嵌入输入(处理超长文本) + const maxTokens = 512; // 大多数模型的最大 token 限制 + const embedInputs = await Promise.all( + batchInputs.map(async (item, i) => { + if (tokens[i].tokens < maxTokens) { + return item.embed_input; + } + + // 截断超长文本 + let tokenCt = tokens[i].tokens; + let truncatedInput = item.embed_input; + + while (tokenCt > maxTokens) { + const pct = maxTokens / tokenCt; + const maxChars = Math.floor(truncatedInput.length * pct * 0.9); + truncatedInput = truncatedInput.substring(0, maxChars) + '...'; + tokenCt = (await countTokens(truncatedInput)).tokens; + } + + tokens[i].tokens = tokenCt; + return truncatedInput; + }) + ); + + // 生成嵌入向量 + const resp = await pipeline(embedInputs, { pooling: 'mean', normalize: true }); + + // 处理结果 + return batchInputs.map((item, i) => ({ + vec: Array.from(resp[i].data).map((val: number) => Math.round(val * 1e8) / 1e8), + tokens: tokens[i].tokens, + embed_input: item.embed_input + })); + + } catch (error) { + console.error('Error processing batch:', error); + + // 如果批处理失败,尝试逐个处理 + return Promise.all( + batchInputs.map(async (item) => { + try { + const result = await pipeline(item.embed_input, { pooling: 'mean', normalize: true }); + const tokenCount = await countTokens(item.embed_input); + + return { + vec: Array.from(result[0].data).map((val: number) => Math.round(val * 1e8) / 1e8), + tokens: tokenCount.tokens, + embed_input: item.embed_input + }; + } catch (singleError) { + console.error('Error processing single item:', singleError); + return { + vec: [], + tokens: 0, + embed_input: item.embed_input, + error: (singleError as Error).message + } as any; + } + }) + ); + } } // 处理消息 async function processMessage(data: WorkerMessage): Promise { - const { method, params, id, worker_id } = data; - - try { - let result: any; - - switch (method) { - case 'load': - console.log('Load method called with params:', params); - result = await loadModel(params.model_key, params.use_gpu || false); - break; - - case 'unload': - console.log('Unload method called'); - result = await unloadModel(); - break; - - case 'embed_batch': - console.log('Embed batch method called'); - if (!model) { - throw new Error('Model not loaded'); - } - - // 等待之前的处理完成 - if (processing_message) { - while (processing_message) { - await new Promise(resolve => setTimeout(resolve, 100)); - } - } - - processing_message = true; - result = await embedBatch(params.inputs); - processing_message = false; - break; - - case 'count_tokens': - console.log('Count tokens method called'); - if (!model) { - throw new Error('Model not loaded'); - } - - // 等待之前的处理完成 - if (processing_message) { - while (processing_message) { - await new Promise(resolve => setTimeout(resolve, 100)); - } - } - - processing_message = true; - result = await countTokens(params); - processing_message = false; - break; - - default: - throw new Error(`Unknown method: ${method}`); - } - - return { id, result, worker_id }; - - } catch (error) { - console.error('Error processing message:', error); - processing_message = false; - return { id, error: (error as Error).message, worker_id }; - } + const { method, params, id, worker_id } = data; + + try { + let result: any; + + switch (method) { + case 'load': + console.log('Load method called with params:', params); + result = await loadModel(params.model_key, params.use_gpu || false); + break; + + case 'unload': + console.log('Unload method called'); + result = await unloadModel(); + break; + + case 'embed_batch': + console.log('Embed batch method called'); + if (!model) { + throw new Error('Model not loaded'); + } + + // 等待之前的处理完成 + if (processing_message) { + while (processing_message) { + await new Promise(resolve => setTimeout(resolve, 100)); + } + } + + processing_message = true; + result = await embedBatch(params.inputs); + processing_message = false; + break; + + case 'count_tokens': + console.log('Count tokens method called'); + if (!model) { + throw new Error('Model not loaded'); + } + + // 等待之前的处理完成 + if (processing_message) { + while (processing_message) { + await new Promise(resolve => setTimeout(resolve, 100)); + } + } + + processing_message = true; + result = await countTokens(params); + processing_message = false; + break; + + default: + throw new Error(`Unknown method: ${method}`); + } + + return { id, result, worker_id }; + + } catch (error) { + console.error('Error processing message:', error); + processing_message = false; + return { id, error: (error as Error).message, worker_id }; + } } // 监听消息 self.addEventListener('message', async (event) => { - console.log('Worker received message:', event.data); - const response = await processMessage(event.data); - console.log('Worker sending response:', response); - self.postMessage(response); + try { + console.log('Worker received message:', event.data); + + // 验证消息格式 + if (!event.data || typeof event.data !== 'object') { + console.error('Invalid message format received'); + self.postMessage({ + id: -1, + error: 'Invalid message format' + }); + return; + } + + const response = await processMessage(event.data); + console.log('Worker sending response:', response); + self.postMessage(response); + } catch (error) { + console.error('Unhandled error in worker message handler:', error); + self.postMessage({ + id: event.data?.id || -1, + error: `Worker error: ${error.message || 'Unknown error'}` + }); + } +}); + +// 添加全局错误处理 +self.addEventListener('error', (event) => { + console.error('Worker global error:', event); + self.postMessage({ + id: -1, + error: `Worker global error: ${event.message || 'Unknown error'}` + }); +}); + +// 添加未处理的 Promise 拒绝处理 +self.addEventListener('unhandledrejection', (event) => { + console.error('Worker unhandled promise rejection:', event); + self.postMessage({ + id: -1, + error: `Worker unhandled rejection: ${event.reason || 'Unknown error'}` + }); + event.preventDefault(); // 防止默认的控制台错误 }); console.log('Embedding worker ready'); diff --git a/src/main.ts b/src/main.ts index f7cc0fa..25b1457 100644 --- a/src/main.ts +++ b/src/main.ts @@ -623,7 +623,7 @@ export default class InfioPlugin extends Plugin { if (!this.ragEngineInitPromise) { this.ragEngineInitPromise = (async () => { const dbManager = await this.getDbManager() - this.ragEngine = new RAGEngine(this.app, this.settings, dbManager) + this.ragEngine = new RAGEngine(this.app, this.settings, dbManager, this.embeddingManager) return this.ragEngine })() } diff --git a/src/utils/provider-urls.ts b/src/utils/provider-urls.ts index fe9879f..0952534 100644 --- a/src/utils/provider-urls.ts +++ b/src/utils/provider-urls.ts @@ -14,6 +14,7 @@ export const providerApiUrls: Record = { [ApiProvider.Grok]: 'https://console.x.ai/', [ApiProvider.Ollama]: '', // Ollama 不需要API Key [ApiProvider.OpenAICompatible]: '', // 自定义兼容API,无固定URL + [ApiProvider.LocalProvider]: '', // 本地提供者,无固定URL }; // 获取指定provider的API Key获取URL