更新 RAGEngine 和嵌入管理器以支持嵌入管理器的传递,添加本地提供者的嵌入模型加载逻辑,优化错误处理和消息处理机制。

This commit is contained in:
duanfuxiang 2025-07-04 15:52:00 +08:00
parent bed96a5233
commit 4e139ecc4f
7 changed files with 453 additions and 312 deletions

View File

@ -16,10 +16,67 @@ import {
} from '../llm/exception' } from '../llm/exception'
import { NoStainlessOpenAI } from '../llm/ollama' import { NoStainlessOpenAI } from '../llm/ollama'
// EmbeddingManager 类型定义
type EmbeddingManager = {
modelLoaded: boolean
currentModel: string | null
loadModel(modelId: string, useGpu: boolean): Promise<any>
embed(text: string): Promise<{ vec: number[] }>
embedBatch(texts: string[]): Promise<{ vec: number[] }[]>
}
export const getEmbeddingModel = ( export const getEmbeddingModel = (
settings: InfioSettings, settings: InfioSettings,
embeddingManager?: EmbeddingManager,
): EmbeddingModel => { ): EmbeddingModel => {
switch (settings.embeddingModelProvider) { switch (settings.embeddingModelProvider) {
case ApiProvider.LocalProvider: {
if (!embeddingManager) {
throw new Error('EmbeddingManager is required for LocalProvider')
}
const modelInfo = GetEmbeddingModelInfo(settings.embeddingModelProvider, settings.embeddingModelId)
if (!modelInfo) {
throw new Error(`Embedding model ${settings.embeddingModelId} not found for provider ${settings.embeddingModelProvider}`)
}
return {
id: settings.embeddingModelId,
dimension: modelInfo.dimensions,
supportsBatch: true,
getEmbedding: async (text: string) => {
try {
// 确保模型已加载
if (!embeddingManager.modelLoaded || embeddingManager.currentModel !== settings.embeddingModelId) {
console.log(`Loading model: ${settings.embeddingModelId}`)
await embeddingManager.loadModel(settings.embeddingModelId, true)
}
const result = await embeddingManager.embed(text)
return result.vec
} catch (error) {
console.error('LocalProvider embedding error:', error)
throw new Error(`LocalProvider embedding failed: ${error.message}`)
}
},
getBatchEmbeddings: async (texts: string[]) => {
try {
// 确保模型已加载
if (!embeddingManager.modelLoaded || embeddingManager.currentModel !== settings.embeddingModelId) {
console.log(`Loading model: ${settings.embeddingModelId}`)
await embeddingManager.loadModel(settings.embeddingModelId, false)
}
const results = await embeddingManager.embedBatch(texts)
console.log('results', results)
return results.map(result => result.vec)
} catch (error) {
console.error('LocalProvider batch embedding error:', error)
throw new Error(`LocalProvider batch embedding failed: ${error.message}`)
}
},
}
}
case ApiProvider.Infio: { case ApiProvider.Infio: {
const openai = new OpenAI({ const openai = new OpenAI({
apiKey: settings.infioProvider.apiKey, apiKey: settings.infioProvider.apiKey,

View File

@ -10,9 +10,19 @@ import { InfioSettings } from '../../types/settings'
import { getEmbeddingModel } from './embedding' import { getEmbeddingModel } from './embedding'
// EmbeddingManager 类型定义
type EmbeddingManager = {
modelLoaded: boolean
currentModel: string | null
loadModel(modelId: string, useGpu: boolean): Promise<any>
embed(text: string): Promise<{ vec: number[] }>
embedBatch(texts: string[]): Promise<{ vec: number[] }[]>
}
export class RAGEngine { export class RAGEngine {
private app: App private app: App
private settings: InfioSettings private settings: InfioSettings
private embeddingManager?: EmbeddingManager
private vectorManager: VectorManager | null = null private vectorManager: VectorManager | null = null
private embeddingModel: EmbeddingModel | null = null private embeddingModel: EmbeddingModel | null = null
private initialized = false private initialized = false
@ -21,13 +31,15 @@ export class RAGEngine {
app: App, app: App,
settings: InfioSettings, settings: InfioSettings,
dbManager: DBManager, dbManager: DBManager,
embeddingManager?: EmbeddingManager,
) { ) {
this.app = app this.app = app
this.settings = settings this.settings = settings
this.embeddingManager = embeddingManager
this.vectorManager = dbManager.getVectorManager() this.vectorManager = dbManager.getVectorManager()
if (settings.embeddingModelId && settings.embeddingModelId.trim() !== '') { if (settings.embeddingModelId && settings.embeddingModelId.trim() !== '') {
try { try {
this.embeddingModel = getEmbeddingModel(settings) this.embeddingModel = getEmbeddingModel(settings, embeddingManager)
} catch (error) { } catch (error) {
console.warn('Failed to initialize embedding model:', error) console.warn('Failed to initialize embedding model:', error)
this.embeddingModel = null this.embeddingModel = null
@ -46,7 +58,7 @@ export class RAGEngine {
this.settings = settings this.settings = settings
if (settings.embeddingModelId && settings.embeddingModelId.trim() !== '') { if (settings.embeddingModelId && settings.embeddingModelId.trim() !== '') {
try { try {
this.embeddingModel = getEmbeddingModel(settings) this.embeddingModel = getEmbeddingModel(settings, this.embeddingManager)
} catch (error) { } catch (error) {
console.warn('Failed to initialize embedding model:', error) console.warn('Failed to initialize embedding model:', error)
this.embeddingModel = null this.embeddingModel = null

View File

@ -113,7 +113,7 @@ export class VectorManager {
const textSplitter = new MarkdownTextSplitter({ const textSplitter = new MarkdownTextSplitter({
chunkSize: options.chunkSize, chunkSize: options.chunkSize,
chunkOverlap: Math.floor(options.chunkSize * 0.15) // chunkOverlap: Math.floor(options.chunkSize * 0.15)
}) })
const skippedFiles: string[] = [] const skippedFiles: string[] = []

View File

@ -34,19 +34,28 @@ export class EmbeddingManager {
// 统一监听来自 Worker 的所有消息 // 统一监听来自 Worker 的所有消息
this.worker.onmessage = (event) => { this.worker.onmessage = (event) => {
const { id, result, error } = event.data; try {
const { id, result, error } = event.data;
// 根据返回的 id 找到对应的 Promise 回调 // 根据返回的 id 找到对应的 Promise 回调
const request = this.requests.get(id); const request = this.requests.get(id);
if (request) { if (request) {
if (error) { if (error) {
request.reject(new Error(error)); request.reject(new Error(error));
} else { } else {
request.resolve(result); request.resolve(result);
}
// 完成后从 Map 中删除
this.requests.delete(id);
} }
// 完成后从 Map 中删除 } catch (err) {
this.requests.delete(id); console.error("Error processing worker message:", err);
// 拒绝所有待处理的请求
this.requests.forEach(request => {
request.reject(new Error(`Worker message processing error: ${err.message}`));
});
this.requests.clear();
} }
}; };
@ -54,9 +63,13 @@ export class EmbeddingManager {
console.error("EmbeddingWorker error:", error); console.error("EmbeddingWorker error:", error);
// 拒绝所有待处理的请求 // 拒绝所有待处理的请求
this.requests.forEach(request => { this.requests.forEach(request => {
request.reject(error); request.reject(new Error(`Worker error: ${error.message || 'Unknown worker error'}`));
}); });
this.requests.clear(); this.requests.clear();
// 重置状态
this.isModelLoaded = false;
this.currentModelId = null;
}; };
} }

View File

@ -3,27 +3,27 @@ console.log('Embedding worker loaded');
// 类型定义 // 类型定义
interface EmbedInput { interface EmbedInput {
embed_input: string; embed_input: string;
} }
interface EmbedResult { interface EmbedResult {
vec: number[]; vec: number[];
tokens: number; tokens: number;
embed_input?: string; embed_input?: string;
} }
interface WorkerMessage { interface WorkerMessage {
method: string; method: string;
params: any; params: any;
id: number; id: number;
worker_id?: string; worker_id?: string;
} }
interface WorkerResponse { interface WorkerResponse {
id: number; id: number;
result?: any; result?: any;
error?: string; error?: string;
worker_id?: string; worker_id?: string;
} }
// 全局变量 // 全局变量
@ -35,319 +35,377 @@ let transformersLoaded = false;
// 动态导入 Transformers.js // 动态导入 Transformers.js
async function loadTransformers() { async function loadTransformers() {
if (transformersLoaded) return; if (transformersLoaded) return;
try { try {
console.log('Loading Transformers.js...'); console.log('Loading Transformers.js...');
// 尝试使用旧版本的 Transformers.js它在 Worker 中更稳定 // 尝试使用旧版本的 Transformers.js它在 Worker 中更稳定
const { pipeline: pipelineFactory, env, AutoTokenizer } = await import('@xenova/transformers'); const { pipeline: pipelineFactory, env, AutoTokenizer } = await import('@xenova/transformers');
// 配置环境以适应浏览器 Worker // 配置环境以适应浏览器 Worker
env.allowLocalModels = false; env.allowLocalModels = false;
env.allowRemoteModels = true; env.allowRemoteModels = true;
// 配置 WASM 后端 // 配置 WASM 后端 - 修复线程配置
env.backends.onnx.wasm.numThreads = 2; // 在 Worker 中使用单线程 env.backends.onnx.wasm.numThreads = 4; // 在 Worker 中使用单线程,避免竞态条件
env.backends.onnx.wasm.simd = true; env.backends.onnx.wasm.simd = true;
// 禁用 Node.js 特定功能 // 禁用 Node.js 特定功能
env.useFS = false; env.useFS = false;
env.useBrowserCache = true; env.useBrowserCache = true;
// 存储导入的函数 // 存储导入的函数
(globalThis as any).pipelineFactory = pipelineFactory; (globalThis as any).pipelineFactory = pipelineFactory;
(globalThis as any).AutoTokenizer = AutoTokenizer; (globalThis as any).AutoTokenizer = AutoTokenizer;
(globalThis as any).env = env; (globalThis as any).env = env;
transformersLoaded = true; transformersLoaded = true;
console.log('Transformers.js loaded successfully'); console.log('Transformers.js loaded successfully');
} catch (error) { } catch (error) {
console.error('Failed to load Transformers.js:', error); console.error('Failed to load Transformers.js:', error);
throw new Error(`Failed to load Transformers.js: ${error}`); throw new Error(`Failed to load Transformers.js: ${error}`);
} }
} }
// 加载模型 // 加载模型
async function loadModel(modelKey: string, useGpu: boolean = false) { async function loadModel(modelKey: string, useGpu: boolean = false) {
try { try {
console.log(`Loading model: ${modelKey}, GPU: ${useGpu}`); console.log(`Loading model: ${modelKey}, GPU: ${useGpu}`);
// 确保 Transformers.js 已加载 // 确保 Transformers.js 已加载
await loadTransformers(); await loadTransformers();
const pipelineFactory = (globalThis as any).pipelineFactory; const pipelineFactory = (globalThis as any).pipelineFactory;
const AutoTokenizer = (globalThis as any).AutoTokenizer; const AutoTokenizer = (globalThis as any).AutoTokenizer;
const env = (globalThis as any).env; const env = (globalThis as any).env;
// 配置管道选项 // 配置管道选项
const pipelineOpts: any = { const pipelineOpts: any = {
quantized: true, quantized: true,
progress_callback: (progress: any) => { // 修复进度回调,添加错误处理
console.log('Model loading progress:', progress); progress_callback: (progress: any) => {
} try {
}; if (progress && typeof progress === 'object') {
console.log('Model loading progress:', progress);
if (useGpu && typeof navigator !== 'undefined' && 'gpu' in navigator) { }
console.log('[Transformers] Attempting to use GPU'); } catch (error) {
try { // 忽略进度回调错误,避免中断模型加载
pipelineOpts.device = 'webgpu'; console.warn('Progress callback error (ignored):', error);
pipelineOpts.dtype = 'fp32'; }
} catch (error) { }
console.warn('[Transformers] GPU not available, falling back to CPU'); };
}
} else { // GPU 配置更加谨慎
console.log('[Transformers] Using CPU'); if (useGpu) {
} try {
// 检查 WebGPU 支持
// 创建嵌入管道 console.log("useGpu", useGpu)
pipeline = await pipelineFactory('feature-extraction', modelKey, pipelineOpts); if (typeof navigator !== 'undefined' && 'gpu' in navigator) {
const gpu = (navigator as any).gpu;
// 创建分词器 if (gpu && typeof gpu.requestAdapter === 'function') {
tokenizer = await AutoTokenizer.from_pretrained(modelKey); console.log('[Transformers] Attempting to use GPU');
pipelineOpts.device = 'webgpu';
model = { pipelineOpts.dtype = 'fp32';
loaded: true, } else {
model_key: modelKey, console.log('[Transformers] WebGPU not fully supported, using CPU');
use_gpu: useGpu }
}; } else {
console.log('[Transformers] WebGPU not available, using CPU');
console.log(`Model ${modelKey} loaded successfully`); }
return { model_loaded: true }; } catch (error) {
console.warn('[Transformers] Error checking GPU support, falling back to CPU:', error);
} catch (error) { }
console.error('Error loading model:', error); } else {
throw new Error(`Failed to load model: ${error}`); console.log('[Transformers] Using CPU');
} }
// 创建嵌入管道
pipeline = await pipelineFactory('feature-extraction', modelKey, pipelineOpts);
// 创建分词器
tokenizer = await AutoTokenizer.from_pretrained(modelKey);
model = {
loaded: true,
model_key: modelKey,
use_gpu: useGpu
};
console.log(`Model ${modelKey} loaded successfully`);
return { model_loaded: true };
} catch (error) {
console.error('Error loading model:', error);
throw new Error(`Failed to load model: ${error}`);
}
} }
// 卸载模型 // 卸载模型
async function unloadModel() { async function unloadModel() {
try { try {
console.log('Unloading model...'); console.log('Unloading model...');
if (pipeline) { if (pipeline) {
if (pipeline.destroy) { if (pipeline.destroy) {
pipeline.destroy(); pipeline.destroy();
} }
pipeline = null; pipeline = null;
} }
if (tokenizer) { if (tokenizer) {
tokenizer = null; tokenizer = null;
} }
model = null; model = null;
console.log('Model unloaded successfully'); console.log('Model unloaded successfully');
return { model_unloaded: true }; return { model_unloaded: true };
} catch (error) { } catch (error) {
console.error('Error unloading model:', error); console.error('Error unloading model:', error);
throw new Error(`Failed to unload model: ${error}`); throw new Error(`Failed to unload model: ${error}`);
} }
} }
// 计算 token 数量 // 计算 token 数量
async function countTokens(input: string) { async function countTokens(input: string) {
try { try {
if (!tokenizer) { if (!tokenizer) {
throw new Error('Tokenizer not loaded'); throw new Error('Tokenizer not loaded');
} }
const { input_ids } = await tokenizer(input); const { input_ids } = await tokenizer(input);
return { tokens: input_ids.data.length }; return { tokens: input_ids.data.length };
} catch (error) { } catch (error) {
console.error('Error counting tokens:', error); console.error('Error counting tokens:', error);
throw new Error(`Failed to count tokens: ${error}`); throw new Error(`Failed to count tokens: ${error}`);
} }
} }
// 生成嵌入向量 // 生成嵌入向量
async function embedBatch(inputs: EmbedInput[]): Promise<EmbedResult[]> { async function embedBatch(inputs: EmbedInput[]): Promise<EmbedResult[]> {
try { try {
if (!pipeline || !tokenizer) { if (!pipeline || !tokenizer) {
throw new Error('Model not loaded'); throw new Error('Model not loaded');
} }
console.log(`Processing ${inputs.length} inputs`); console.log(`Processing ${inputs.length} inputs`);
// 过滤空输入 // 过滤空输入
const filteredInputs = inputs.filter(item => item.embed_input && item.embed_input.length > 0); const filteredInputs = inputs.filter(item => item.embed_input && item.embed_input.length > 0);
if (filteredInputs.length === 0) { if (filteredInputs.length === 0) {
return []; return [];
} }
// 批处理大小(可以根据需要调整) // 批处理大小(可以根据需要调整)
const batchSize = 1; const batchSize = 1;
if (filteredInputs.length > batchSize) { if (filteredInputs.length > batchSize) {
console.log(`Processing ${filteredInputs.length} inputs in batches of ${batchSize}`); console.log(`Processing ${filteredInputs.length} inputs in batches of ${batchSize}`);
const results: EmbedResult[] = []; const results: EmbedResult[] = [];
for (let i = 0; i < filteredInputs.length; i += batchSize) { for (let i = 0; i < filteredInputs.length; i += batchSize) {
const batch = filteredInputs.slice(i, i + batchSize); const batch = filteredInputs.slice(i, i + batchSize);
const batchResults = await processBatch(batch); const batchResults = await processBatch(batch);
results.push(...batchResults); results.push(...batchResults);
} }
return results; return results;
} }
return await processBatch(filteredInputs); return await processBatch(filteredInputs);
} catch (error) { } catch (error) {
console.error('Error in embed batch:', error); console.error('Error in embed batch:', error);
throw new Error(`Failed to generate embeddings: ${error}`); throw new Error(`Failed to generate embeddings: ${error}`);
} }
} }
// 处理单个批次 // 处理单个批次
async function processBatch(batchInputs: EmbedInput[]): Promise<EmbedResult[]> { async function processBatch(batchInputs: EmbedInput[]): Promise<EmbedResult[]> {
try { try {
// 计算每个输入的 token 数量 // 计算每个输入的 token 数量
const tokens = await Promise.all( const tokens = await Promise.all(
batchInputs.map(item => countTokens(item.embed_input)) batchInputs.map(item => countTokens(item.embed_input))
); );
// 准备嵌入输入(处理超长文本) // 准备嵌入输入(处理超长文本)
const maxTokens = 512; // 大多数模型的最大 token 限制 const maxTokens = 512; // 大多数模型的最大 token 限制
const embedInputs = await Promise.all( const embedInputs = await Promise.all(
batchInputs.map(async (item, i) => { batchInputs.map(async (item, i) => {
if (tokens[i].tokens < maxTokens) { if (tokens[i].tokens < maxTokens) {
return item.embed_input; return item.embed_input;
} }
// 截断超长文本 // 截断超长文本
let tokenCt = tokens[i].tokens; let tokenCt = tokens[i].tokens;
let truncatedInput = item.embed_input; let truncatedInput = item.embed_input;
while (tokenCt > maxTokens) { while (tokenCt > maxTokens) {
const pct = maxTokens / tokenCt; const pct = maxTokens / tokenCt;
const maxChars = Math.floor(truncatedInput.length * pct * 0.9); const maxChars = Math.floor(truncatedInput.length * pct * 0.9);
truncatedInput = truncatedInput.substring(0, maxChars) + '...'; truncatedInput = truncatedInput.substring(0, maxChars) + '...';
tokenCt = (await countTokens(truncatedInput)).tokens; tokenCt = (await countTokens(truncatedInput)).tokens;
} }
tokens[i].tokens = tokenCt; tokens[i].tokens = tokenCt;
return truncatedInput; return truncatedInput;
}) })
); );
// 生成嵌入向量 // 生成嵌入向量
const resp = await pipeline(embedInputs, { pooling: 'mean', normalize: true }); const resp = await pipeline(embedInputs, { pooling: 'mean', normalize: true });
// 处理结果 // 处理结果
return batchInputs.map((item, i) => ({ return batchInputs.map((item, i) => ({
vec: Array.from(resp[i].data).map((val: number) => Math.round(val * 1e8) / 1e8), vec: Array.from(resp[i].data).map((val: number) => Math.round(val * 1e8) / 1e8),
tokens: tokens[i].tokens, tokens: tokens[i].tokens,
embed_input: item.embed_input embed_input: item.embed_input
})); }));
} catch (error) { } catch (error) {
console.error('Error processing batch:', error); console.error('Error processing batch:', error);
// 如果批处理失败,尝试逐个处理 // 如果批处理失败,尝试逐个处理
return Promise.all( return Promise.all(
batchInputs.map(async (item) => { batchInputs.map(async (item) => {
try { try {
const result = await pipeline(item.embed_input, { pooling: 'mean', normalize: true }); const result = await pipeline(item.embed_input, { pooling: 'mean', normalize: true });
const tokenCount = await countTokens(item.embed_input); const tokenCount = await countTokens(item.embed_input);
return { return {
vec: Array.from(result[0].data).map((val: number) => Math.round(val * 1e8) / 1e8), vec: Array.from(result[0].data).map((val: number) => Math.round(val * 1e8) / 1e8),
tokens: tokenCount.tokens, tokens: tokenCount.tokens,
embed_input: item.embed_input embed_input: item.embed_input
}; };
} catch (singleError) { } catch (singleError) {
console.error('Error processing single item:', singleError); console.error('Error processing single item:', singleError);
return { return {
vec: [], vec: [],
tokens: 0, tokens: 0,
embed_input: item.embed_input, embed_input: item.embed_input,
error: (singleError as Error).message error: (singleError as Error).message
} as any; } as any;
} }
}) })
); );
} }
} }
// 处理消息 // 处理消息
async function processMessage(data: WorkerMessage): Promise<WorkerResponse> { async function processMessage(data: WorkerMessage): Promise<WorkerResponse> {
const { method, params, id, worker_id } = data; const { method, params, id, worker_id } = data;
try { try {
let result: any; let result: any;
switch (method) { switch (method) {
case 'load': case 'load':
console.log('Load method called with params:', params); console.log('Load method called with params:', params);
result = await loadModel(params.model_key, params.use_gpu || false); result = await loadModel(params.model_key, params.use_gpu || false);
break; break;
case 'unload': case 'unload':
console.log('Unload method called'); console.log('Unload method called');
result = await unloadModel(); result = await unloadModel();
break; break;
case 'embed_batch': case 'embed_batch':
console.log('Embed batch method called'); console.log('Embed batch method called');
if (!model) { if (!model) {
throw new Error('Model not loaded'); throw new Error('Model not loaded');
} }
// 等待之前的处理完成 // 等待之前的处理完成
if (processing_message) { if (processing_message) {
while (processing_message) { while (processing_message) {
await new Promise(resolve => setTimeout(resolve, 100)); await new Promise(resolve => setTimeout(resolve, 100));
} }
} }
processing_message = true; processing_message = true;
result = await embedBatch(params.inputs); result = await embedBatch(params.inputs);
processing_message = false; processing_message = false;
break; break;
case 'count_tokens': case 'count_tokens':
console.log('Count tokens method called'); console.log('Count tokens method called');
if (!model) { if (!model) {
throw new Error('Model not loaded'); throw new Error('Model not loaded');
} }
// 等待之前的处理完成 // 等待之前的处理完成
if (processing_message) { if (processing_message) {
while (processing_message) { while (processing_message) {
await new Promise(resolve => setTimeout(resolve, 100)); await new Promise(resolve => setTimeout(resolve, 100));
} }
} }
processing_message = true; processing_message = true;
result = await countTokens(params); result = await countTokens(params);
processing_message = false; processing_message = false;
break; break;
default: default:
throw new Error(`Unknown method: ${method}`); throw new Error(`Unknown method: ${method}`);
} }
return { id, result, worker_id }; return { id, result, worker_id };
} catch (error) { } catch (error) {
console.error('Error processing message:', error); console.error('Error processing message:', error);
processing_message = false; processing_message = false;
return { id, error: (error as Error).message, worker_id }; return { id, error: (error as Error).message, worker_id };
} }
} }
// 监听消息 // 监听消息
self.addEventListener('message', async (event) => { self.addEventListener('message', async (event) => {
console.log('Worker received message:', event.data); try {
const response = await processMessage(event.data); console.log('Worker received message:', event.data);
console.log('Worker sending response:', response);
self.postMessage(response); // 验证消息格式
if (!event.data || typeof event.data !== 'object') {
console.error('Invalid message format received');
self.postMessage({
id: -1,
error: 'Invalid message format'
});
return;
}
const response = await processMessage(event.data);
console.log('Worker sending response:', response);
self.postMessage(response);
} catch (error) {
console.error('Unhandled error in worker message handler:', error);
self.postMessage({
id: event.data?.id || -1,
error: `Worker error: ${error.message || 'Unknown error'}`
});
}
});
// 添加全局错误处理
self.addEventListener('error', (event) => {
console.error('Worker global error:', event);
self.postMessage({
id: -1,
error: `Worker global error: ${event.message || 'Unknown error'}`
});
});
// 添加未处理的 Promise 拒绝处理
self.addEventListener('unhandledrejection', (event) => {
console.error('Worker unhandled promise rejection:', event);
self.postMessage({
id: -1,
error: `Worker unhandled rejection: ${event.reason || 'Unknown error'}`
});
event.preventDefault(); // 防止默认的控制台错误
}); });
console.log('Embedding worker ready'); console.log('Embedding worker ready');

View File

@ -623,7 +623,7 @@ export default class InfioPlugin extends Plugin {
if (!this.ragEngineInitPromise) { if (!this.ragEngineInitPromise) {
this.ragEngineInitPromise = (async () => { this.ragEngineInitPromise = (async () => {
const dbManager = await this.getDbManager() const dbManager = await this.getDbManager()
this.ragEngine = new RAGEngine(this.app, this.settings, dbManager) this.ragEngine = new RAGEngine(this.app, this.settings, dbManager, this.embeddingManager)
return this.ragEngine return this.ragEngine
})() })()
} }

View File

@ -14,6 +14,7 @@ export const providerApiUrls: Record<ApiProvider, string> = {
[ApiProvider.Grok]: 'https://console.x.ai/', [ApiProvider.Grok]: 'https://console.x.ai/',
[ApiProvider.Ollama]: '', // Ollama 不需要API Key [ApiProvider.Ollama]: '', // Ollama 不需要API Key
[ApiProvider.OpenAICompatible]: '', // 自定义兼容API无固定URL [ApiProvider.OpenAICompatible]: '', // 自定义兼容API无固定URL
[ApiProvider.LocalProvider]: '', // 本地提供者无固定URL
}; };
// 获取指定provider的API Key获取URL // 获取指定provider的API Key获取URL