console.log('Embedding worker loaded'); interface EmbedInput { embed_input: string; } interface EmbedResult { vec: number[]; tokens: number; embed_input?: string; } interface WorkerMessage { method: string; params: any; id: number; worker_id?: string; } interface WorkerResponse { id: number; result?: any; error?: string; worker_id?: string; } // 全局变量 let model: any = null; let pipeline: any = null; let tokenizer: any = null; let processing_message = false; let transformersLoaded = false; // 动态导入 Transformers.js async function loadTransformers() { if (transformersLoaded) return; try { console.log('Loading Transformers.js...'); // 尝试使用旧版本的 Transformers.js,它在 Worker 中更稳定 const { pipeline: pipelineFactory, env, AutoTokenizer } = await import('@xenova/transformers'); // 配置环境以适应浏览器 Worker env.allowLocalModels = false; env.allowRemoteModels = true; // 配置 WASM 后端 - 修复线程配置 env.backends.onnx.wasm.numThreads = 1; // 在 Worker 中使用单线程,避免竞态条件 env.backends.onnx.wasm.simd = true; // 禁用 Node.js 特定功能 env.useFS = false; env.useBrowserCache = true; (globalThis as any).pipelineFactory = pipelineFactory; (globalThis as any).AutoTokenizer = AutoTokenizer; (globalThis as any).env = env; transformersLoaded = true; console.log('Transformers.js loaded successfully'); } catch (error) { console.error('Failed to load Transformers.js:', error); throw new Error(`Failed to load Transformers.js: ${error}`); } } async function loadModel(modelKey: string, useGpu: boolean = false) { try { console.log(`Loading model: ${modelKey}, GPU: ${useGpu}`); // 确保 Transformers.js 已加载 await loadTransformers(); const pipelineFactory = (globalThis as any).pipelineFactory; const AutoTokenizer = (globalThis as any).AutoTokenizer; const env = (globalThis as any).env; // 配置管道选项 const pipelineOpts: any = { quantized: true, // 修复进度回调,添加错误处理 progress_callback: (progress: any) => { try { if (progress && typeof progress === 'object') { console.log('Model loading progress:', progress); } } catch (error) { // 忽略进度回调错误,避免中断模型加载 console.warn('Progress callback error (ignored):', error); } } }; // GPU 配置更加谨慎 if (useGpu) { try { // 检查 WebGPU 支持 console.log("useGpu", useGpu) if (typeof navigator !== 'undefined' && 'gpu' in navigator) { const gpu = (navigator as any).gpu; if (gpu && typeof gpu.requestAdapter === 'function') { console.log('[Transformers] Attempting to use GPU'); pipelineOpts.device = 'webgpu'; pipelineOpts.dtype = 'fp32'; } else { console.log('[Transformers] WebGPU not fully supported, using CPU'); } } else { console.log('[Transformers] WebGPU not available, using CPU'); } } catch (error) { console.warn('[Transformers] Error checking GPU support, falling back to CPU:', error); } } else { console.log('[Transformers] Using CPU'); } // 创建嵌入管道 pipeline = await pipelineFactory('feature-extraction', modelKey, pipelineOpts); // 创建分词器 tokenizer = await AutoTokenizer.from_pretrained(modelKey); model = { loaded: true, model_key: modelKey, use_gpu: useGpu }; console.log(`Model ${modelKey} loaded successfully`); return { model_loaded: true }; } catch (error) { console.error('Error loading model:', error); throw new Error(`Failed to load model: ${error}`); } } async function unloadModel() { try { console.log('Unloading model...'); if (pipeline) { if (pipeline.destroy) { pipeline.destroy(); } pipeline = null; } if (tokenizer) { tokenizer = null; } model = null; console.log('Model unloaded successfully'); return { model_unloaded: true }; } catch (error) { console.error('Error unloading model:', error); throw new Error(`Failed to unload model: ${error}`); } } async function countTokens(input: string) { try { if (!tokenizer) { throw new Error('Tokenizer not loaded'); } const { input_ids } = await tokenizer(input); return { tokens: input_ids.data.length }; } catch (error) { console.error('Error counting tokens:', error); throw new Error(`Failed to count tokens: ${error}`); } } async function embedBatch(inputs: EmbedInput[]): Promise { try { if (!pipeline || !tokenizer) { throw new Error('Model not loaded'); } console.log(`Processing ${inputs.length} inputs`); // 过滤空输入 const filteredInputs = inputs.filter(item => item.embed_input && item.embed_input.length > 0); if (filteredInputs.length === 0) { return []; } // 批处理大小(可以根据需要调整) const batchSize = 1; if (filteredInputs.length > batchSize) { console.log(`Processing ${filteredInputs.length} inputs in batches of ${batchSize}`); const results: EmbedResult[] = []; for (let i = 0; i < filteredInputs.length; i += batchSize) { const batch = filteredInputs.slice(i, i + batchSize); const batchResults = await processBatch(batch); results.push(...batchResults); } return results; } return await processBatch(filteredInputs); } catch (error) { console.error('Error in embed batch:', error); throw new Error(`Failed to generate embeddings: ${error}`); } } async function processBatch(batchInputs: EmbedInput[]): Promise { try { // 计算每个输入的 token 数量 const tokens = await Promise.all( batchInputs.map(item => countTokens(item.embed_input)) ); // 准备嵌入输入(处理超长文本) const maxTokens = 512; // 大多数模型的最大 token 限制 const embedInputs = await Promise.all( batchInputs.map(async (item, i) => { if (tokens[i].tokens < maxTokens) { return item.embed_input; } // 截断超长文本 let tokenCt = tokens[i].tokens; let truncatedInput = item.embed_input; while (tokenCt > maxTokens) { const pct = maxTokens / tokenCt; const maxChars = Math.floor(truncatedInput.length * pct * 0.9); truncatedInput = truncatedInput.substring(0, maxChars) + '...'; tokenCt = (await countTokens(truncatedInput)).tokens; } tokens[i].tokens = tokenCt; return truncatedInput; }) ); // 生成嵌入向量 const resp = await pipeline(embedInputs, { pooling: 'mean', normalize: true }); // 处理结果 return batchInputs.map((item, i) => ({ vec: Array.from(resp[i].data).map((val: number) => Math.round(val * 1e8) / 1e8), tokens: tokens[i].tokens, embed_input: item.embed_input })); } catch (error) { console.error('Error processing batch:', error); // 如果批处理失败,尝试逐个处理 return Promise.all( batchInputs.map(async (item) => { try { const result = await pipeline(item.embed_input, { pooling: 'mean', normalize: true }); const tokenCount = await countTokens(item.embed_input); return { vec: Array.from(result[0].data).map((val: number) => Math.round(val * 1e8) / 1e8), tokens: tokenCount.tokens, embed_input: item.embed_input }; } catch (singleError) { console.error('Error processing single item:', singleError); return { vec: [], tokens: 0, embed_input: item.embed_input, error: (singleError as Error).message } as any; } }) ); } } async function processMessage(data: WorkerMessage): Promise { const { method, params, id, worker_id } = data; try { let result: any; switch (method) { case 'load': console.log('Load method called with params:', params); result = await loadModel(params.model_key, params.use_gpu || false); break; case 'unload': console.log('Unload method called'); result = await unloadModel(); break; case 'embed_batch': console.log('Embed batch method called'); if (!model) { throw new Error('Model not loaded'); } // 等待之前的处理完成 if (processing_message) { while (processing_message) { await new Promise(resolve => setTimeout(resolve, 100)); } } processing_message = true; result = await embedBatch(params.inputs); processing_message = false; break; case 'count_tokens': console.log('Count tokens method called'); if (!model) { throw new Error('Model not loaded'); } // 等待之前的处理完成 if (processing_message) { while (processing_message) { await new Promise(resolve => setTimeout(resolve, 100)); } } processing_message = true; result = await countTokens(params); processing_message = false; break; default: throw new Error(`Unknown method: ${method}`); } return { id, result, worker_id }; } catch (error) { console.error('Error processing message:', error); processing_message = false; return { id, error: (error as Error).message, worker_id }; } } self.addEventListener('message', async (event) => { try { console.log('Worker received message:', event.data); // 验证消息格式 if (!event.data || typeof event.data !== 'object') { console.error('Invalid message format received'); self.postMessage({ id: -1, error: 'Invalid message format' }); return; } const response = await processMessage(event.data); console.log('Worker sending response:', response); self.postMessage(response); } catch (error) { console.error('Unhandled error in worker message handler:', error); self.postMessage({ id: event.data?.id || -1, error: `Worker error: ${error.message || 'Unknown error'}` }); } }); self.addEventListener('error', (event) => { console.error('Worker global error:', event); self.postMessage({ id: -1, error: `Worker global error: ${event.message || 'Unknown error'}` }); }); self.addEventListener('unhandledrejection', (event) => { console.error('Worker unhandled promise rejection:', event); self.postMessage({ id: -1, error: `Worker unhandled rejection: ${event.reason || 'Unknown error'}` }); event.preventDefault(); // 防止默认的控制台错误 }); console.log('Embedding worker ready');