mirror of
https://github.com/EthanMarti/infio-copilot.git
synced 2026-01-18 00:47:51 +00:00
更新 RAGEngine 和嵌入管理器以支持嵌入管理器的传递,添加本地提供者的嵌入模型加载逻辑,优化错误处理和消息处理机制。
This commit is contained in:
parent
bed96a5233
commit
4e139ecc4f
@ -16,10 +16,67 @@ import {
|
|||||||
} from '../llm/exception'
|
} from '../llm/exception'
|
||||||
import { NoStainlessOpenAI } from '../llm/ollama'
|
import { NoStainlessOpenAI } from '../llm/ollama'
|
||||||
|
|
||||||
|
// EmbeddingManager 类型定义
|
||||||
|
type EmbeddingManager = {
|
||||||
|
modelLoaded: boolean
|
||||||
|
currentModel: string | null
|
||||||
|
loadModel(modelId: string, useGpu: boolean): Promise<any>
|
||||||
|
embed(text: string): Promise<{ vec: number[] }>
|
||||||
|
embedBatch(texts: string[]): Promise<{ vec: number[] }[]>
|
||||||
|
}
|
||||||
|
|
||||||
export const getEmbeddingModel = (
|
export const getEmbeddingModel = (
|
||||||
settings: InfioSettings,
|
settings: InfioSettings,
|
||||||
|
embeddingManager?: EmbeddingManager,
|
||||||
): EmbeddingModel => {
|
): EmbeddingModel => {
|
||||||
switch (settings.embeddingModelProvider) {
|
switch (settings.embeddingModelProvider) {
|
||||||
|
case ApiProvider.LocalProvider: {
|
||||||
|
if (!embeddingManager) {
|
||||||
|
throw new Error('EmbeddingManager is required for LocalProvider')
|
||||||
|
}
|
||||||
|
|
||||||
|
const modelInfo = GetEmbeddingModelInfo(settings.embeddingModelProvider, settings.embeddingModelId)
|
||||||
|
if (!modelInfo) {
|
||||||
|
throw new Error(`Embedding model ${settings.embeddingModelId} not found for provider ${settings.embeddingModelProvider}`)
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
id: settings.embeddingModelId,
|
||||||
|
dimension: modelInfo.dimensions,
|
||||||
|
supportsBatch: true,
|
||||||
|
getEmbedding: async (text: string) => {
|
||||||
|
try {
|
||||||
|
// 确保模型已加载
|
||||||
|
if (!embeddingManager.modelLoaded || embeddingManager.currentModel !== settings.embeddingModelId) {
|
||||||
|
console.log(`Loading model: ${settings.embeddingModelId}`)
|
||||||
|
await embeddingManager.loadModel(settings.embeddingModelId, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
const result = await embeddingManager.embed(text)
|
||||||
|
return result.vec
|
||||||
|
} catch (error) {
|
||||||
|
console.error('LocalProvider embedding error:', error)
|
||||||
|
throw new Error(`LocalProvider embedding failed: ${error.message}`)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
getBatchEmbeddings: async (texts: string[]) => {
|
||||||
|
try {
|
||||||
|
// 确保模型已加载
|
||||||
|
if (!embeddingManager.modelLoaded || embeddingManager.currentModel !== settings.embeddingModelId) {
|
||||||
|
console.log(`Loading model: ${settings.embeddingModelId}`)
|
||||||
|
await embeddingManager.loadModel(settings.embeddingModelId, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
const results = await embeddingManager.embedBatch(texts)
|
||||||
|
console.log('results', results)
|
||||||
|
return results.map(result => result.vec)
|
||||||
|
} catch (error) {
|
||||||
|
console.error('LocalProvider batch embedding error:', error)
|
||||||
|
throw new Error(`LocalProvider batch embedding failed: ${error.message}`)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
case ApiProvider.Infio: {
|
case ApiProvider.Infio: {
|
||||||
const openai = new OpenAI({
|
const openai = new OpenAI({
|
||||||
apiKey: settings.infioProvider.apiKey,
|
apiKey: settings.infioProvider.apiKey,
|
||||||
|
|||||||
@ -10,9 +10,19 @@ import { InfioSettings } from '../../types/settings'
|
|||||||
|
|
||||||
import { getEmbeddingModel } from './embedding'
|
import { getEmbeddingModel } from './embedding'
|
||||||
|
|
||||||
|
// EmbeddingManager 类型定义
|
||||||
|
type EmbeddingManager = {
|
||||||
|
modelLoaded: boolean
|
||||||
|
currentModel: string | null
|
||||||
|
loadModel(modelId: string, useGpu: boolean): Promise<any>
|
||||||
|
embed(text: string): Promise<{ vec: number[] }>
|
||||||
|
embedBatch(texts: string[]): Promise<{ vec: number[] }[]>
|
||||||
|
}
|
||||||
|
|
||||||
export class RAGEngine {
|
export class RAGEngine {
|
||||||
private app: App
|
private app: App
|
||||||
private settings: InfioSettings
|
private settings: InfioSettings
|
||||||
|
private embeddingManager?: EmbeddingManager
|
||||||
private vectorManager: VectorManager | null = null
|
private vectorManager: VectorManager | null = null
|
||||||
private embeddingModel: EmbeddingModel | null = null
|
private embeddingModel: EmbeddingModel | null = null
|
||||||
private initialized = false
|
private initialized = false
|
||||||
@ -21,13 +31,15 @@ export class RAGEngine {
|
|||||||
app: App,
|
app: App,
|
||||||
settings: InfioSettings,
|
settings: InfioSettings,
|
||||||
dbManager: DBManager,
|
dbManager: DBManager,
|
||||||
|
embeddingManager?: EmbeddingManager,
|
||||||
) {
|
) {
|
||||||
this.app = app
|
this.app = app
|
||||||
this.settings = settings
|
this.settings = settings
|
||||||
|
this.embeddingManager = embeddingManager
|
||||||
this.vectorManager = dbManager.getVectorManager()
|
this.vectorManager = dbManager.getVectorManager()
|
||||||
if (settings.embeddingModelId && settings.embeddingModelId.trim() !== '') {
|
if (settings.embeddingModelId && settings.embeddingModelId.trim() !== '') {
|
||||||
try {
|
try {
|
||||||
this.embeddingModel = getEmbeddingModel(settings)
|
this.embeddingModel = getEmbeddingModel(settings, embeddingManager)
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.warn('Failed to initialize embedding model:', error)
|
console.warn('Failed to initialize embedding model:', error)
|
||||||
this.embeddingModel = null
|
this.embeddingModel = null
|
||||||
@ -46,7 +58,7 @@ export class RAGEngine {
|
|||||||
this.settings = settings
|
this.settings = settings
|
||||||
if (settings.embeddingModelId && settings.embeddingModelId.trim() !== '') {
|
if (settings.embeddingModelId && settings.embeddingModelId.trim() !== '') {
|
||||||
try {
|
try {
|
||||||
this.embeddingModel = getEmbeddingModel(settings)
|
this.embeddingModel = getEmbeddingModel(settings, this.embeddingManager)
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.warn('Failed to initialize embedding model:', error)
|
console.warn('Failed to initialize embedding model:', error)
|
||||||
this.embeddingModel = null
|
this.embeddingModel = null
|
||||||
|
|||||||
@ -113,7 +113,7 @@ export class VectorManager {
|
|||||||
|
|
||||||
const textSplitter = new MarkdownTextSplitter({
|
const textSplitter = new MarkdownTextSplitter({
|
||||||
chunkSize: options.chunkSize,
|
chunkSize: options.chunkSize,
|
||||||
chunkOverlap: Math.floor(options.chunkSize * 0.15)
|
// chunkOverlap: Math.floor(options.chunkSize * 0.15)
|
||||||
})
|
})
|
||||||
|
|
||||||
const skippedFiles: string[] = []
|
const skippedFiles: string[] = []
|
||||||
|
|||||||
@ -34,19 +34,28 @@ export class EmbeddingManager {
|
|||||||
|
|
||||||
// 统一监听来自 Worker 的所有消息
|
// 统一监听来自 Worker 的所有消息
|
||||||
this.worker.onmessage = (event) => {
|
this.worker.onmessage = (event) => {
|
||||||
const { id, result, error } = event.data;
|
try {
|
||||||
|
const { id, result, error } = event.data;
|
||||||
|
|
||||||
// 根据返回的 id 找到对应的 Promise 回调
|
// 根据返回的 id 找到对应的 Promise 回调
|
||||||
const request = this.requests.get(id);
|
const request = this.requests.get(id);
|
||||||
|
|
||||||
if (request) {
|
if (request) {
|
||||||
if (error) {
|
if (error) {
|
||||||
request.reject(new Error(error));
|
request.reject(new Error(error));
|
||||||
} else {
|
} else {
|
||||||
request.resolve(result);
|
request.resolve(result);
|
||||||
|
}
|
||||||
|
// 完成后从 Map 中删除
|
||||||
|
this.requests.delete(id);
|
||||||
}
|
}
|
||||||
// 完成后从 Map 中删除
|
} catch (err) {
|
||||||
this.requests.delete(id);
|
console.error("Error processing worker message:", err);
|
||||||
|
// 拒绝所有待处理的请求
|
||||||
|
this.requests.forEach(request => {
|
||||||
|
request.reject(new Error(`Worker message processing error: ${err.message}`));
|
||||||
|
});
|
||||||
|
this.requests.clear();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -54,9 +63,13 @@ export class EmbeddingManager {
|
|||||||
console.error("EmbeddingWorker error:", error);
|
console.error("EmbeddingWorker error:", error);
|
||||||
// 拒绝所有待处理的请求
|
// 拒绝所有待处理的请求
|
||||||
this.requests.forEach(request => {
|
this.requests.forEach(request => {
|
||||||
request.reject(error);
|
request.reject(new Error(`Worker error: ${error.message || 'Unknown worker error'}`));
|
||||||
});
|
});
|
||||||
this.requests.clear();
|
this.requests.clear();
|
||||||
|
|
||||||
|
// 重置状态
|
||||||
|
this.isModelLoaded = false;
|
||||||
|
this.currentModelId = null;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -3,27 +3,27 @@ console.log('Embedding worker loaded');
|
|||||||
|
|
||||||
// 类型定义
|
// 类型定义
|
||||||
interface EmbedInput {
|
interface EmbedInput {
|
||||||
embed_input: string;
|
embed_input: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
interface EmbedResult {
|
interface EmbedResult {
|
||||||
vec: number[];
|
vec: number[];
|
||||||
tokens: number;
|
tokens: number;
|
||||||
embed_input?: string;
|
embed_input?: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
interface WorkerMessage {
|
interface WorkerMessage {
|
||||||
method: string;
|
method: string;
|
||||||
params: any;
|
params: any;
|
||||||
id: number;
|
id: number;
|
||||||
worker_id?: string;
|
worker_id?: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
interface WorkerResponse {
|
interface WorkerResponse {
|
||||||
id: number;
|
id: number;
|
||||||
result?: any;
|
result?: any;
|
||||||
error?: string;
|
error?: string;
|
||||||
worker_id?: string;
|
worker_id?: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
// 全局变量
|
// 全局变量
|
||||||
@ -35,319 +35,377 @@ let transformersLoaded = false;
|
|||||||
|
|
||||||
// 动态导入 Transformers.js
|
// 动态导入 Transformers.js
|
||||||
async function loadTransformers() {
|
async function loadTransformers() {
|
||||||
if (transformersLoaded) return;
|
if (transformersLoaded) return;
|
||||||
|
|
||||||
try {
|
try {
|
||||||
console.log('Loading Transformers.js...');
|
console.log('Loading Transformers.js...');
|
||||||
|
|
||||||
// 尝试使用旧版本的 Transformers.js,它在 Worker 中更稳定
|
// 尝试使用旧版本的 Transformers.js,它在 Worker 中更稳定
|
||||||
const { pipeline: pipelineFactory, env, AutoTokenizer } = await import('@xenova/transformers');
|
const { pipeline: pipelineFactory, env, AutoTokenizer } = await import('@xenova/transformers');
|
||||||
|
|
||||||
// 配置环境以适应浏览器 Worker
|
// 配置环境以适应浏览器 Worker
|
||||||
env.allowLocalModels = false;
|
env.allowLocalModels = false;
|
||||||
env.allowRemoteModels = true;
|
env.allowRemoteModels = true;
|
||||||
|
|
||||||
// 配置 WASM 后端
|
// 配置 WASM 后端 - 修复线程配置
|
||||||
env.backends.onnx.wasm.numThreads = 2; // 在 Worker 中使用单线程
|
env.backends.onnx.wasm.numThreads = 4; // 在 Worker 中使用单线程,避免竞态条件
|
||||||
env.backends.onnx.wasm.simd = true;
|
env.backends.onnx.wasm.simd = true;
|
||||||
|
|
||||||
// 禁用 Node.js 特定功能
|
// 禁用 Node.js 特定功能
|
||||||
env.useFS = false;
|
env.useFS = false;
|
||||||
env.useBrowserCache = true;
|
env.useBrowserCache = true;
|
||||||
|
|
||||||
// 存储导入的函数
|
// 存储导入的函数
|
||||||
(globalThis as any).pipelineFactory = pipelineFactory;
|
(globalThis as any).pipelineFactory = pipelineFactory;
|
||||||
(globalThis as any).AutoTokenizer = AutoTokenizer;
|
(globalThis as any).AutoTokenizer = AutoTokenizer;
|
||||||
(globalThis as any).env = env;
|
(globalThis as any).env = env;
|
||||||
|
|
||||||
transformersLoaded = true;
|
transformersLoaded = true;
|
||||||
console.log('Transformers.js loaded successfully');
|
console.log('Transformers.js loaded successfully');
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error('Failed to load Transformers.js:', error);
|
console.error('Failed to load Transformers.js:', error);
|
||||||
throw new Error(`Failed to load Transformers.js: ${error}`);
|
throw new Error(`Failed to load Transformers.js: ${error}`);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 加载模型
|
// 加载模型
|
||||||
async function loadModel(modelKey: string, useGpu: boolean = false) {
|
async function loadModel(modelKey: string, useGpu: boolean = false) {
|
||||||
try {
|
try {
|
||||||
console.log(`Loading model: ${modelKey}, GPU: ${useGpu}`);
|
console.log(`Loading model: ${modelKey}, GPU: ${useGpu}`);
|
||||||
|
|
||||||
// 确保 Transformers.js 已加载
|
// 确保 Transformers.js 已加载
|
||||||
await loadTransformers();
|
await loadTransformers();
|
||||||
|
|
||||||
const pipelineFactory = (globalThis as any).pipelineFactory;
|
const pipelineFactory = (globalThis as any).pipelineFactory;
|
||||||
const AutoTokenizer = (globalThis as any).AutoTokenizer;
|
const AutoTokenizer = (globalThis as any).AutoTokenizer;
|
||||||
const env = (globalThis as any).env;
|
const env = (globalThis as any).env;
|
||||||
|
|
||||||
// 配置管道选项
|
// 配置管道选项
|
||||||
const pipelineOpts: any = {
|
const pipelineOpts: any = {
|
||||||
quantized: true,
|
quantized: true,
|
||||||
progress_callback: (progress: any) => {
|
// 修复进度回调,添加错误处理
|
||||||
console.log('Model loading progress:', progress);
|
progress_callback: (progress: any) => {
|
||||||
}
|
try {
|
||||||
};
|
if (progress && typeof progress === 'object') {
|
||||||
|
console.log('Model loading progress:', progress);
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
// 忽略进度回调错误,避免中断模型加载
|
||||||
|
console.warn('Progress callback error (ignored):', error);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
if (useGpu && typeof navigator !== 'undefined' && 'gpu' in navigator) {
|
// GPU 配置更加谨慎
|
||||||
console.log('[Transformers] Attempting to use GPU');
|
if (useGpu) {
|
||||||
try {
|
try {
|
||||||
pipelineOpts.device = 'webgpu';
|
// 检查 WebGPU 支持
|
||||||
pipelineOpts.dtype = 'fp32';
|
console.log("useGpu", useGpu)
|
||||||
} catch (error) {
|
if (typeof navigator !== 'undefined' && 'gpu' in navigator) {
|
||||||
console.warn('[Transformers] GPU not available, falling back to CPU');
|
const gpu = (navigator as any).gpu;
|
||||||
}
|
if (gpu && typeof gpu.requestAdapter === 'function') {
|
||||||
} else {
|
console.log('[Transformers] Attempting to use GPU');
|
||||||
console.log('[Transformers] Using CPU');
|
pipelineOpts.device = 'webgpu';
|
||||||
}
|
pipelineOpts.dtype = 'fp32';
|
||||||
|
} else {
|
||||||
|
console.log('[Transformers] WebGPU not fully supported, using CPU');
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
console.log('[Transformers] WebGPU not available, using CPU');
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.warn('[Transformers] Error checking GPU support, falling back to CPU:', error);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
console.log('[Transformers] Using CPU');
|
||||||
|
}
|
||||||
|
|
||||||
// 创建嵌入管道
|
// 创建嵌入管道
|
||||||
pipeline = await pipelineFactory('feature-extraction', modelKey, pipelineOpts);
|
pipeline = await pipelineFactory('feature-extraction', modelKey, pipelineOpts);
|
||||||
|
|
||||||
// 创建分词器
|
// 创建分词器
|
||||||
tokenizer = await AutoTokenizer.from_pretrained(modelKey);
|
tokenizer = await AutoTokenizer.from_pretrained(modelKey);
|
||||||
|
|
||||||
model = {
|
model = {
|
||||||
loaded: true,
|
loaded: true,
|
||||||
model_key: modelKey,
|
model_key: modelKey,
|
||||||
use_gpu: useGpu
|
use_gpu: useGpu
|
||||||
};
|
};
|
||||||
|
|
||||||
console.log(`Model ${modelKey} loaded successfully`);
|
console.log(`Model ${modelKey} loaded successfully`);
|
||||||
return { model_loaded: true };
|
return { model_loaded: true };
|
||||||
|
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error('Error loading model:', error);
|
console.error('Error loading model:', error);
|
||||||
throw new Error(`Failed to load model: ${error}`);
|
throw new Error(`Failed to load model: ${error}`);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 卸载模型
|
// 卸载模型
|
||||||
async function unloadModel() {
|
async function unloadModel() {
|
||||||
try {
|
try {
|
||||||
console.log('Unloading model...');
|
console.log('Unloading model...');
|
||||||
|
|
||||||
if (pipeline) {
|
if (pipeline) {
|
||||||
if (pipeline.destroy) {
|
if (pipeline.destroy) {
|
||||||
pipeline.destroy();
|
pipeline.destroy();
|
||||||
}
|
}
|
||||||
pipeline = null;
|
pipeline = null;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (tokenizer) {
|
if (tokenizer) {
|
||||||
tokenizer = null;
|
tokenizer = null;
|
||||||
}
|
}
|
||||||
|
|
||||||
model = null;
|
model = null;
|
||||||
|
|
||||||
console.log('Model unloaded successfully');
|
console.log('Model unloaded successfully');
|
||||||
return { model_unloaded: true };
|
return { model_unloaded: true };
|
||||||
|
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error('Error unloading model:', error);
|
console.error('Error unloading model:', error);
|
||||||
throw new Error(`Failed to unload model: ${error}`);
|
throw new Error(`Failed to unload model: ${error}`);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 计算 token 数量
|
// 计算 token 数量
|
||||||
async function countTokens(input: string) {
|
async function countTokens(input: string) {
|
||||||
try {
|
try {
|
||||||
if (!tokenizer) {
|
if (!tokenizer) {
|
||||||
throw new Error('Tokenizer not loaded');
|
throw new Error('Tokenizer not loaded');
|
||||||
}
|
}
|
||||||
|
|
||||||
const { input_ids } = await tokenizer(input);
|
const { input_ids } = await tokenizer(input);
|
||||||
return { tokens: input_ids.data.length };
|
return { tokens: input_ids.data.length };
|
||||||
|
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error('Error counting tokens:', error);
|
console.error('Error counting tokens:', error);
|
||||||
throw new Error(`Failed to count tokens: ${error}`);
|
throw new Error(`Failed to count tokens: ${error}`);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 生成嵌入向量
|
// 生成嵌入向量
|
||||||
async function embedBatch(inputs: EmbedInput[]): Promise<EmbedResult[]> {
|
async function embedBatch(inputs: EmbedInput[]): Promise<EmbedResult[]> {
|
||||||
try {
|
try {
|
||||||
if (!pipeline || !tokenizer) {
|
if (!pipeline || !tokenizer) {
|
||||||
throw new Error('Model not loaded');
|
throw new Error('Model not loaded');
|
||||||
}
|
}
|
||||||
|
|
||||||
console.log(`Processing ${inputs.length} inputs`);
|
console.log(`Processing ${inputs.length} inputs`);
|
||||||
|
|
||||||
// 过滤空输入
|
// 过滤空输入
|
||||||
const filteredInputs = inputs.filter(item => item.embed_input && item.embed_input.length > 0);
|
const filteredInputs = inputs.filter(item => item.embed_input && item.embed_input.length > 0);
|
||||||
|
|
||||||
if (filteredInputs.length === 0) {
|
if (filteredInputs.length === 0) {
|
||||||
return [];
|
return [];
|
||||||
}
|
}
|
||||||
|
|
||||||
// 批处理大小(可以根据需要调整)
|
// 批处理大小(可以根据需要调整)
|
||||||
const batchSize = 1;
|
const batchSize = 1;
|
||||||
|
|
||||||
if (filteredInputs.length > batchSize) {
|
if (filteredInputs.length > batchSize) {
|
||||||
console.log(`Processing ${filteredInputs.length} inputs in batches of ${batchSize}`);
|
console.log(`Processing ${filteredInputs.length} inputs in batches of ${batchSize}`);
|
||||||
const results: EmbedResult[] = [];
|
const results: EmbedResult[] = [];
|
||||||
|
|
||||||
for (let i = 0; i < filteredInputs.length; i += batchSize) {
|
for (let i = 0; i < filteredInputs.length; i += batchSize) {
|
||||||
const batch = filteredInputs.slice(i, i + batchSize);
|
const batch = filteredInputs.slice(i, i + batchSize);
|
||||||
const batchResults = await processBatch(batch);
|
const batchResults = await processBatch(batch);
|
||||||
results.push(...batchResults);
|
results.push(...batchResults);
|
||||||
}
|
}
|
||||||
|
|
||||||
return results;
|
return results;
|
||||||
}
|
}
|
||||||
|
|
||||||
return await processBatch(filteredInputs);
|
return await processBatch(filteredInputs);
|
||||||
|
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error('Error in embed batch:', error);
|
console.error('Error in embed batch:', error);
|
||||||
throw new Error(`Failed to generate embeddings: ${error}`);
|
throw new Error(`Failed to generate embeddings: ${error}`);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 处理单个批次
|
// 处理单个批次
|
||||||
async function processBatch(batchInputs: EmbedInput[]): Promise<EmbedResult[]> {
|
async function processBatch(batchInputs: EmbedInput[]): Promise<EmbedResult[]> {
|
||||||
try {
|
try {
|
||||||
// 计算每个输入的 token 数量
|
// 计算每个输入的 token 数量
|
||||||
const tokens = await Promise.all(
|
const tokens = await Promise.all(
|
||||||
batchInputs.map(item => countTokens(item.embed_input))
|
batchInputs.map(item => countTokens(item.embed_input))
|
||||||
);
|
);
|
||||||
|
|
||||||
// 准备嵌入输入(处理超长文本)
|
// 准备嵌入输入(处理超长文本)
|
||||||
const maxTokens = 512; // 大多数模型的最大 token 限制
|
const maxTokens = 512; // 大多数模型的最大 token 限制
|
||||||
const embedInputs = await Promise.all(
|
const embedInputs = await Promise.all(
|
||||||
batchInputs.map(async (item, i) => {
|
batchInputs.map(async (item, i) => {
|
||||||
if (tokens[i].tokens < maxTokens) {
|
if (tokens[i].tokens < maxTokens) {
|
||||||
return item.embed_input;
|
return item.embed_input;
|
||||||
}
|
}
|
||||||
|
|
||||||
// 截断超长文本
|
// 截断超长文本
|
||||||
let tokenCt = tokens[i].tokens;
|
let tokenCt = tokens[i].tokens;
|
||||||
let truncatedInput = item.embed_input;
|
let truncatedInput = item.embed_input;
|
||||||
|
|
||||||
while (tokenCt > maxTokens) {
|
while (tokenCt > maxTokens) {
|
||||||
const pct = maxTokens / tokenCt;
|
const pct = maxTokens / tokenCt;
|
||||||
const maxChars = Math.floor(truncatedInput.length * pct * 0.9);
|
const maxChars = Math.floor(truncatedInput.length * pct * 0.9);
|
||||||
truncatedInput = truncatedInput.substring(0, maxChars) + '...';
|
truncatedInput = truncatedInput.substring(0, maxChars) + '...';
|
||||||
tokenCt = (await countTokens(truncatedInput)).tokens;
|
tokenCt = (await countTokens(truncatedInput)).tokens;
|
||||||
}
|
}
|
||||||
|
|
||||||
tokens[i].tokens = tokenCt;
|
tokens[i].tokens = tokenCt;
|
||||||
return truncatedInput;
|
return truncatedInput;
|
||||||
})
|
})
|
||||||
);
|
);
|
||||||
|
|
||||||
// 生成嵌入向量
|
// 生成嵌入向量
|
||||||
const resp = await pipeline(embedInputs, { pooling: 'mean', normalize: true });
|
const resp = await pipeline(embedInputs, { pooling: 'mean', normalize: true });
|
||||||
|
|
||||||
// 处理结果
|
// 处理结果
|
||||||
return batchInputs.map((item, i) => ({
|
return batchInputs.map((item, i) => ({
|
||||||
vec: Array.from(resp[i].data).map((val: number) => Math.round(val * 1e8) / 1e8),
|
vec: Array.from(resp[i].data).map((val: number) => Math.round(val * 1e8) / 1e8),
|
||||||
tokens: tokens[i].tokens,
|
tokens: tokens[i].tokens,
|
||||||
embed_input: item.embed_input
|
embed_input: item.embed_input
|
||||||
}));
|
}));
|
||||||
|
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error('Error processing batch:', error);
|
console.error('Error processing batch:', error);
|
||||||
|
|
||||||
// 如果批处理失败,尝试逐个处理
|
// 如果批处理失败,尝试逐个处理
|
||||||
return Promise.all(
|
return Promise.all(
|
||||||
batchInputs.map(async (item) => {
|
batchInputs.map(async (item) => {
|
||||||
try {
|
try {
|
||||||
const result = await pipeline(item.embed_input, { pooling: 'mean', normalize: true });
|
const result = await pipeline(item.embed_input, { pooling: 'mean', normalize: true });
|
||||||
const tokenCount = await countTokens(item.embed_input);
|
const tokenCount = await countTokens(item.embed_input);
|
||||||
|
|
||||||
return {
|
return {
|
||||||
vec: Array.from(result[0].data).map((val: number) => Math.round(val * 1e8) / 1e8),
|
vec: Array.from(result[0].data).map((val: number) => Math.round(val * 1e8) / 1e8),
|
||||||
tokens: tokenCount.tokens,
|
tokens: tokenCount.tokens,
|
||||||
embed_input: item.embed_input
|
embed_input: item.embed_input
|
||||||
};
|
};
|
||||||
} catch (singleError) {
|
} catch (singleError) {
|
||||||
console.error('Error processing single item:', singleError);
|
console.error('Error processing single item:', singleError);
|
||||||
return {
|
return {
|
||||||
vec: [],
|
vec: [],
|
||||||
tokens: 0,
|
tokens: 0,
|
||||||
embed_input: item.embed_input,
|
embed_input: item.embed_input,
|
||||||
error: (singleError as Error).message
|
error: (singleError as Error).message
|
||||||
} as any;
|
} as any;
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 处理消息
|
// 处理消息
|
||||||
async function processMessage(data: WorkerMessage): Promise<WorkerResponse> {
|
async function processMessage(data: WorkerMessage): Promise<WorkerResponse> {
|
||||||
const { method, params, id, worker_id } = data;
|
const { method, params, id, worker_id } = data;
|
||||||
|
|
||||||
try {
|
try {
|
||||||
let result: any;
|
let result: any;
|
||||||
|
|
||||||
switch (method) {
|
switch (method) {
|
||||||
case 'load':
|
case 'load':
|
||||||
console.log('Load method called with params:', params);
|
console.log('Load method called with params:', params);
|
||||||
result = await loadModel(params.model_key, params.use_gpu || false);
|
result = await loadModel(params.model_key, params.use_gpu || false);
|
||||||
break;
|
break;
|
||||||
|
|
||||||
case 'unload':
|
case 'unload':
|
||||||
console.log('Unload method called');
|
console.log('Unload method called');
|
||||||
result = await unloadModel();
|
result = await unloadModel();
|
||||||
break;
|
break;
|
||||||
|
|
||||||
case 'embed_batch':
|
case 'embed_batch':
|
||||||
console.log('Embed batch method called');
|
console.log('Embed batch method called');
|
||||||
if (!model) {
|
if (!model) {
|
||||||
throw new Error('Model not loaded');
|
throw new Error('Model not loaded');
|
||||||
}
|
}
|
||||||
|
|
||||||
// 等待之前的处理完成
|
// 等待之前的处理完成
|
||||||
if (processing_message) {
|
if (processing_message) {
|
||||||
while (processing_message) {
|
while (processing_message) {
|
||||||
await new Promise(resolve => setTimeout(resolve, 100));
|
await new Promise(resolve => setTimeout(resolve, 100));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
processing_message = true;
|
processing_message = true;
|
||||||
result = await embedBatch(params.inputs);
|
result = await embedBatch(params.inputs);
|
||||||
processing_message = false;
|
processing_message = false;
|
||||||
break;
|
break;
|
||||||
|
|
||||||
case 'count_tokens':
|
case 'count_tokens':
|
||||||
console.log('Count tokens method called');
|
console.log('Count tokens method called');
|
||||||
if (!model) {
|
if (!model) {
|
||||||
throw new Error('Model not loaded');
|
throw new Error('Model not loaded');
|
||||||
}
|
}
|
||||||
|
|
||||||
// 等待之前的处理完成
|
// 等待之前的处理完成
|
||||||
if (processing_message) {
|
if (processing_message) {
|
||||||
while (processing_message) {
|
while (processing_message) {
|
||||||
await new Promise(resolve => setTimeout(resolve, 100));
|
await new Promise(resolve => setTimeout(resolve, 100));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
processing_message = true;
|
processing_message = true;
|
||||||
result = await countTokens(params);
|
result = await countTokens(params);
|
||||||
processing_message = false;
|
processing_message = false;
|
||||||
break;
|
break;
|
||||||
|
|
||||||
default:
|
default:
|
||||||
throw new Error(`Unknown method: ${method}`);
|
throw new Error(`Unknown method: ${method}`);
|
||||||
}
|
}
|
||||||
|
|
||||||
return { id, result, worker_id };
|
return { id, result, worker_id };
|
||||||
|
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error('Error processing message:', error);
|
console.error('Error processing message:', error);
|
||||||
processing_message = false;
|
processing_message = false;
|
||||||
return { id, error: (error as Error).message, worker_id };
|
return { id, error: (error as Error).message, worker_id };
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 监听消息
|
// 监听消息
|
||||||
self.addEventListener('message', async (event) => {
|
self.addEventListener('message', async (event) => {
|
||||||
console.log('Worker received message:', event.data);
|
try {
|
||||||
const response = await processMessage(event.data);
|
console.log('Worker received message:', event.data);
|
||||||
console.log('Worker sending response:', response);
|
|
||||||
self.postMessage(response);
|
// 验证消息格式
|
||||||
|
if (!event.data || typeof event.data !== 'object') {
|
||||||
|
console.error('Invalid message format received');
|
||||||
|
self.postMessage({
|
||||||
|
id: -1,
|
||||||
|
error: 'Invalid message format'
|
||||||
|
});
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const response = await processMessage(event.data);
|
||||||
|
console.log('Worker sending response:', response);
|
||||||
|
self.postMessage(response);
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Unhandled error in worker message handler:', error);
|
||||||
|
self.postMessage({
|
||||||
|
id: event.data?.id || -1,
|
||||||
|
error: `Worker error: ${error.message || 'Unknown error'}`
|
||||||
|
});
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// 添加全局错误处理
|
||||||
|
self.addEventListener('error', (event) => {
|
||||||
|
console.error('Worker global error:', event);
|
||||||
|
self.postMessage({
|
||||||
|
id: -1,
|
||||||
|
error: `Worker global error: ${event.message || 'Unknown error'}`
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
// 添加未处理的 Promise 拒绝处理
|
||||||
|
self.addEventListener('unhandledrejection', (event) => {
|
||||||
|
console.error('Worker unhandled promise rejection:', event);
|
||||||
|
self.postMessage({
|
||||||
|
id: -1,
|
||||||
|
error: `Worker unhandled rejection: ${event.reason || 'Unknown error'}`
|
||||||
|
});
|
||||||
|
event.preventDefault(); // 防止默认的控制台错误
|
||||||
});
|
});
|
||||||
|
|
||||||
console.log('Embedding worker ready');
|
console.log('Embedding worker ready');
|
||||||
|
|||||||
@ -623,7 +623,7 @@ export default class InfioPlugin extends Plugin {
|
|||||||
if (!this.ragEngineInitPromise) {
|
if (!this.ragEngineInitPromise) {
|
||||||
this.ragEngineInitPromise = (async () => {
|
this.ragEngineInitPromise = (async () => {
|
||||||
const dbManager = await this.getDbManager()
|
const dbManager = await this.getDbManager()
|
||||||
this.ragEngine = new RAGEngine(this.app, this.settings, dbManager)
|
this.ragEngine = new RAGEngine(this.app, this.settings, dbManager, this.embeddingManager)
|
||||||
return this.ragEngine
|
return this.ragEngine
|
||||||
})()
|
})()
|
||||||
}
|
}
|
||||||
|
|||||||
@ -14,6 +14,7 @@ export const providerApiUrls: Record<ApiProvider, string> = {
|
|||||||
[ApiProvider.Grok]: 'https://console.x.ai/',
|
[ApiProvider.Grok]: 'https://console.x.ai/',
|
||||||
[ApiProvider.Ollama]: '', // Ollama 不需要API Key
|
[ApiProvider.Ollama]: '', // Ollama 不需要API Key
|
||||||
[ApiProvider.OpenAICompatible]: '', // 自定义兼容API,无固定URL
|
[ApiProvider.OpenAICompatible]: '', // 自定义兼容API,无固定URL
|
||||||
|
[ApiProvider.LocalProvider]: '', // 本地提供者,无固定URL
|
||||||
};
|
};
|
||||||
|
|
||||||
// 获取指定provider的API Key获取URL
|
// 获取指定provider的API Key获取URL
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user