mirror of
https://github.com/EthanMarti/infio-copilot.git
synced 2026-01-16 16:31:56 +00:00
更新 RAGEngine 和嵌入管理器以支持嵌入管理器的传递,添加本地提供者的嵌入模型加载逻辑,优化错误处理和消息处理机制。
This commit is contained in:
parent
bed96a5233
commit
4e139ecc4f
@ -16,10 +16,67 @@ import {
|
||||
} from '../llm/exception'
|
||||
import { NoStainlessOpenAI } from '../llm/ollama'
|
||||
|
||||
// EmbeddingManager 类型定义
|
||||
type EmbeddingManager = {
|
||||
modelLoaded: boolean
|
||||
currentModel: string | null
|
||||
loadModel(modelId: string, useGpu: boolean): Promise<any>
|
||||
embed(text: string): Promise<{ vec: number[] }>
|
||||
embedBatch(texts: string[]): Promise<{ vec: number[] }[]>
|
||||
}
|
||||
|
||||
export const getEmbeddingModel = (
|
||||
settings: InfioSettings,
|
||||
embeddingManager?: EmbeddingManager,
|
||||
): EmbeddingModel => {
|
||||
switch (settings.embeddingModelProvider) {
|
||||
case ApiProvider.LocalProvider: {
|
||||
if (!embeddingManager) {
|
||||
throw new Error('EmbeddingManager is required for LocalProvider')
|
||||
}
|
||||
|
||||
const modelInfo = GetEmbeddingModelInfo(settings.embeddingModelProvider, settings.embeddingModelId)
|
||||
if (!modelInfo) {
|
||||
throw new Error(`Embedding model ${settings.embeddingModelId} not found for provider ${settings.embeddingModelProvider}`)
|
||||
}
|
||||
|
||||
return {
|
||||
id: settings.embeddingModelId,
|
||||
dimension: modelInfo.dimensions,
|
||||
supportsBatch: true,
|
||||
getEmbedding: async (text: string) => {
|
||||
try {
|
||||
// 确保模型已加载
|
||||
if (!embeddingManager.modelLoaded || embeddingManager.currentModel !== settings.embeddingModelId) {
|
||||
console.log(`Loading model: ${settings.embeddingModelId}`)
|
||||
await embeddingManager.loadModel(settings.embeddingModelId, true)
|
||||
}
|
||||
|
||||
const result = await embeddingManager.embed(text)
|
||||
return result.vec
|
||||
} catch (error) {
|
||||
console.error('LocalProvider embedding error:', error)
|
||||
throw new Error(`LocalProvider embedding failed: ${error.message}`)
|
||||
}
|
||||
},
|
||||
getBatchEmbeddings: async (texts: string[]) => {
|
||||
try {
|
||||
// 确保模型已加载
|
||||
if (!embeddingManager.modelLoaded || embeddingManager.currentModel !== settings.embeddingModelId) {
|
||||
console.log(`Loading model: ${settings.embeddingModelId}`)
|
||||
await embeddingManager.loadModel(settings.embeddingModelId, false)
|
||||
}
|
||||
|
||||
const results = await embeddingManager.embedBatch(texts)
|
||||
console.log('results', results)
|
||||
return results.map(result => result.vec)
|
||||
} catch (error) {
|
||||
console.error('LocalProvider batch embedding error:', error)
|
||||
throw new Error(`LocalProvider batch embedding failed: ${error.message}`)
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
case ApiProvider.Infio: {
|
||||
const openai = new OpenAI({
|
||||
apiKey: settings.infioProvider.apiKey,
|
||||
|
||||
@ -10,9 +10,19 @@ import { InfioSettings } from '../../types/settings'
|
||||
|
||||
import { getEmbeddingModel } from './embedding'
|
||||
|
||||
// EmbeddingManager 类型定义
|
||||
type EmbeddingManager = {
|
||||
modelLoaded: boolean
|
||||
currentModel: string | null
|
||||
loadModel(modelId: string, useGpu: boolean): Promise<any>
|
||||
embed(text: string): Promise<{ vec: number[] }>
|
||||
embedBatch(texts: string[]): Promise<{ vec: number[] }[]>
|
||||
}
|
||||
|
||||
export class RAGEngine {
|
||||
private app: App
|
||||
private settings: InfioSettings
|
||||
private embeddingManager?: EmbeddingManager
|
||||
private vectorManager: VectorManager | null = null
|
||||
private embeddingModel: EmbeddingModel | null = null
|
||||
private initialized = false
|
||||
@ -21,13 +31,15 @@ export class RAGEngine {
|
||||
app: App,
|
||||
settings: InfioSettings,
|
||||
dbManager: DBManager,
|
||||
embeddingManager?: EmbeddingManager,
|
||||
) {
|
||||
this.app = app
|
||||
this.settings = settings
|
||||
this.embeddingManager = embeddingManager
|
||||
this.vectorManager = dbManager.getVectorManager()
|
||||
if (settings.embeddingModelId && settings.embeddingModelId.trim() !== '') {
|
||||
try {
|
||||
this.embeddingModel = getEmbeddingModel(settings)
|
||||
this.embeddingModel = getEmbeddingModel(settings, embeddingManager)
|
||||
} catch (error) {
|
||||
console.warn('Failed to initialize embedding model:', error)
|
||||
this.embeddingModel = null
|
||||
@ -46,7 +58,7 @@ export class RAGEngine {
|
||||
this.settings = settings
|
||||
if (settings.embeddingModelId && settings.embeddingModelId.trim() !== '') {
|
||||
try {
|
||||
this.embeddingModel = getEmbeddingModel(settings)
|
||||
this.embeddingModel = getEmbeddingModel(settings, this.embeddingManager)
|
||||
} catch (error) {
|
||||
console.warn('Failed to initialize embedding model:', error)
|
||||
this.embeddingModel = null
|
||||
|
||||
@ -113,7 +113,7 @@ export class VectorManager {
|
||||
|
||||
const textSplitter = new MarkdownTextSplitter({
|
||||
chunkSize: options.chunkSize,
|
||||
chunkOverlap: Math.floor(options.chunkSize * 0.15)
|
||||
// chunkOverlap: Math.floor(options.chunkSize * 0.15)
|
||||
})
|
||||
|
||||
const skippedFiles: string[] = []
|
||||
|
||||
@ -34,6 +34,7 @@ export class EmbeddingManager {
|
||||
|
||||
// 统一监听来自 Worker 的所有消息
|
||||
this.worker.onmessage = (event) => {
|
||||
try {
|
||||
const { id, result, error } = event.data;
|
||||
|
||||
// 根据返回的 id 找到对应的 Promise 回调
|
||||
@ -48,15 +49,27 @@ export class EmbeddingManager {
|
||||
// 完成后从 Map 中删除
|
||||
this.requests.delete(id);
|
||||
}
|
||||
} catch (err) {
|
||||
console.error("Error processing worker message:", err);
|
||||
// 拒绝所有待处理的请求
|
||||
this.requests.forEach(request => {
|
||||
request.reject(new Error(`Worker message processing error: ${err.message}`));
|
||||
});
|
||||
this.requests.clear();
|
||||
}
|
||||
};
|
||||
|
||||
this.worker.onerror = (error) => {
|
||||
console.error("EmbeddingWorker error:", error);
|
||||
// 拒绝所有待处理的请求
|
||||
this.requests.forEach(request => {
|
||||
request.reject(error);
|
||||
request.reject(new Error(`Worker error: ${error.message || 'Unknown worker error'}`));
|
||||
});
|
||||
this.requests.clear();
|
||||
|
||||
// 重置状态
|
||||
this.isModelLoaded = false;
|
||||
this.currentModelId = null;
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@ -47,8 +47,8 @@ async function loadTransformers() {
|
||||
env.allowLocalModels = false;
|
||||
env.allowRemoteModels = true;
|
||||
|
||||
// 配置 WASM 后端
|
||||
env.backends.onnx.wasm.numThreads = 2; // 在 Worker 中使用单线程
|
||||
// 配置 WASM 后端 - 修复线程配置
|
||||
env.backends.onnx.wasm.numThreads = 4; // 在 Worker 中使用单线程,避免竞态条件
|
||||
env.backends.onnx.wasm.simd = true;
|
||||
|
||||
// 禁用 Node.js 特定功能
|
||||
@ -83,18 +83,38 @@ async function loadModel(modelKey: string, useGpu: boolean = false) {
|
||||
// 配置管道选项
|
||||
const pipelineOpts: any = {
|
||||
quantized: true,
|
||||
// 修复进度回调,添加错误处理
|
||||
progress_callback: (progress: any) => {
|
||||
try {
|
||||
if (progress && typeof progress === 'object') {
|
||||
console.log('Model loading progress:', progress);
|
||||
}
|
||||
} catch (error) {
|
||||
// 忽略进度回调错误,避免中断模型加载
|
||||
console.warn('Progress callback error (ignored):', error);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if (useGpu && typeof navigator !== 'undefined' && 'gpu' in navigator) {
|
||||
console.log('[Transformers] Attempting to use GPU');
|
||||
// GPU 配置更加谨慎
|
||||
if (useGpu) {
|
||||
try {
|
||||
// 检查 WebGPU 支持
|
||||
console.log("useGpu", useGpu)
|
||||
if (typeof navigator !== 'undefined' && 'gpu' in navigator) {
|
||||
const gpu = (navigator as any).gpu;
|
||||
if (gpu && typeof gpu.requestAdapter === 'function') {
|
||||
console.log('[Transformers] Attempting to use GPU');
|
||||
pipelineOpts.device = 'webgpu';
|
||||
pipelineOpts.dtype = 'fp32';
|
||||
} else {
|
||||
console.log('[Transformers] WebGPU not fully supported, using CPU');
|
||||
}
|
||||
} else {
|
||||
console.log('[Transformers] WebGPU not available, using CPU');
|
||||
}
|
||||
} catch (error) {
|
||||
console.warn('[Transformers] GPU not available, falling back to CPU');
|
||||
console.warn('[Transformers] Error checking GPU support, falling back to CPU:', error);
|
||||
}
|
||||
} else {
|
||||
console.log('[Transformers] Using CPU');
|
||||
@ -344,10 +364,48 @@ async function processMessage(data: WorkerMessage): Promise<WorkerResponse> {
|
||||
|
||||
// 监听消息
|
||||
self.addEventListener('message', async (event) => {
|
||||
try {
|
||||
console.log('Worker received message:', event.data);
|
||||
|
||||
// 验证消息格式
|
||||
if (!event.data || typeof event.data !== 'object') {
|
||||
console.error('Invalid message format received');
|
||||
self.postMessage({
|
||||
id: -1,
|
||||
error: 'Invalid message format'
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
const response = await processMessage(event.data);
|
||||
console.log('Worker sending response:', response);
|
||||
self.postMessage(response);
|
||||
} catch (error) {
|
||||
console.error('Unhandled error in worker message handler:', error);
|
||||
self.postMessage({
|
||||
id: event.data?.id || -1,
|
||||
error: `Worker error: ${error.message || 'Unknown error'}`
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
// 添加全局错误处理
|
||||
self.addEventListener('error', (event) => {
|
||||
console.error('Worker global error:', event);
|
||||
self.postMessage({
|
||||
id: -1,
|
||||
error: `Worker global error: ${event.message || 'Unknown error'}`
|
||||
});
|
||||
});
|
||||
|
||||
// 添加未处理的 Promise 拒绝处理
|
||||
self.addEventListener('unhandledrejection', (event) => {
|
||||
console.error('Worker unhandled promise rejection:', event);
|
||||
self.postMessage({
|
||||
id: -1,
|
||||
error: `Worker unhandled rejection: ${event.reason || 'Unknown error'}`
|
||||
});
|
||||
event.preventDefault(); // 防止默认的控制台错误
|
||||
});
|
||||
|
||||
console.log('Embedding worker ready');
|
||||
|
||||
@ -623,7 +623,7 @@ export default class InfioPlugin extends Plugin {
|
||||
if (!this.ragEngineInitPromise) {
|
||||
this.ragEngineInitPromise = (async () => {
|
||||
const dbManager = await this.getDbManager()
|
||||
this.ragEngine = new RAGEngine(this.app, this.settings, dbManager)
|
||||
this.ragEngine = new RAGEngine(this.app, this.settings, dbManager, this.embeddingManager)
|
||||
return this.ragEngine
|
||||
})()
|
||||
}
|
||||
|
||||
@ -14,6 +14,7 @@ export const providerApiUrls: Record<ApiProvider, string> = {
|
||||
[ApiProvider.Grok]: 'https://console.x.ai/',
|
||||
[ApiProvider.Ollama]: '', // Ollama 不需要API Key
|
||||
[ApiProvider.OpenAICompatible]: '', // 自定义兼容API,无固定URL
|
||||
[ApiProvider.LocalProvider]: '', // 本地提供者,无固定URL
|
||||
};
|
||||
|
||||
// 获取指定provider的API Key获取URL
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user