From 570e8d9564018366808b0508e16fe7d0cdb0110d Mon Sep 17 00:00:00 2001 From: duanfuxiang Date: Thu, 20 Mar 2025 12:44:53 +0800 Subject: [PATCH] update for custom model, ollama --- src/core/llm/manager.ts | 32 ++- src/core/llm/ollama.ts | 4 +- src/core/rag/embedding.ts | 3 +- src/core/rag/rag-engine.ts | 203 ++++++++++-------- .../components/ModelProviderSettings.tsx | 16 +- src/types/llm/model.ts | 1 - src/types/settings.ts | 6 +- src/utils/api.ts | 4 +- src/utils/web-search.ts | 1 - 9 files changed, 155 insertions(+), 115 deletions(-) diff --git a/src/core/llm/manager.ts b/src/core/llm/manager.ts index fdf488c..15f6152 100644 --- a/src/core/llm/manager.ts +++ b/src/core/llm/manager.ts @@ -49,15 +49,35 @@ class LLMManager implements LLMManagerInterface { constructor(settings: InfioSettings) { this.infioProvider = new InfioProvider(settings.infioProvider.apiKey) - this.openrouterProvider = new OpenAICompatibleProvider(settings.openrouterProvider.apiKey, OPENROUTER_BASE_URL) - this.siliconflowProvider = new OpenAICompatibleProvider(settings.siliconflowProvider.apiKey, SILICONFLOW_BASE_URL) - this.alibabaQwenProvider = new OpenAICompatibleProvider(settings.alibabaQwenProvider.apiKey, ALIBABA_QWEN_BASE_URL) - this.deepseekProvider = new OpenAICompatibleProvider(settings.deepseekProvider.apiKey, DEEPSEEK_BASE_URL) + this.openrouterProvider = new OpenAICompatibleProvider( + settings.openrouterProvider.apiKey, + settings.openrouterProvider.baseUrl && settings.openrouterProvider.useCustomUrl ? + settings.openrouterProvider.baseUrl + : OPENROUTER_BASE_URL + ) + this.siliconflowProvider = new OpenAICompatibleProvider( + settings.siliconflowProvider.apiKey, + settings.siliconflowProvider.baseUrl && settings.siliconflowProvider.useCustomUrl ? + settings.siliconflowProvider.baseUrl + : SILICONFLOW_BASE_URL + ) + this.alibabaQwenProvider = new OpenAICompatibleProvider( + settings.alibabaQwenProvider.apiKey, + settings.alibabaQwenProvider.baseUrl && settings.alibabaQwenProvider.useCustomUrl ? + settings.alibabaQwenProvider.baseUrl + : ALIBABA_QWEN_BASE_URL + ) + this.deepseekProvider = new OpenAICompatibleProvider( + settings.deepseekProvider.apiKey, + settings.deepseekProvider.baseUrl && settings.deepseekProvider.useCustomUrl ? + settings.deepseekProvider.baseUrl + : DEEPSEEK_BASE_URL + ) this.openaiProvider = new OpenAIAuthenticatedProvider(settings.openaiProvider.apiKey) this.anthropicProvider = new AnthropicProvider(settings.anthropicProvider.apiKey) this.googleProvider = new GeminiProvider(settings.googleProvider.apiKey) this.groqProvider = new GroqProvider(settings.groqProvider.apiKey) - this.ollamaProvider = new OllamaProvider(settings.groqProvider.baseUrl) + this.ollamaProvider = new OllamaProvider(settings.ollamaProvider.baseUrl) this.openaiCompatibleProvider = new OpenAICompatibleProvider(settings.openaicompatibleProvider.apiKey, settings.openaicompatibleProvider.baseUrl) this.isInfioEnabled = !!settings.infioProvider.apiKey } @@ -125,6 +145,8 @@ class LLMManager implements LLMManagerInterface { request, options, ) + case ApiProvider.OpenAICompatible: + return await this.openaiCompatibleProvider.generateResponse(model, request, options) default: throw new Error(`Unsupported model provider: ${model.provider}`) } diff --git a/src/core/llm/ollama.ts b/src/core/llm/ollama.ts index 09daf35..bce4757 100644 --- a/src/core/llm/ollama.ts +++ b/src/core/llm/ollama.ts @@ -68,7 +68,7 @@ export class OllamaProvider implements BaseLLMProvider { const client = new NoStainlessOpenAI({ baseURL: `${this.baseUrl}/v1`, - apiKey: '', + apiKey: 'ollama', dangerouslyAllowBrowser: true, }) return this.adapter.generateResponse(client, request, options) @@ -87,7 +87,7 @@ export class OllamaProvider implements BaseLLMProvider { const client = new NoStainlessOpenAI({ baseURL: `${this.baseUrl}/v1`, - apiKey: '', + apiKey: 'ollama', dangerouslyAllowBrowser: true, }) return this.adapter.streamResponse(client, request, options) diff --git a/src/core/rag/embedding.ts b/src/core/rag/embedding.ts index 1d1527d..b1015da 100644 --- a/src/core/rag/embedding.ts +++ b/src/core/rag/embedding.ts @@ -159,10 +159,9 @@ export const getEmbeddingModel = ( dangerouslyAllowBrowser: true, baseURL: `${settings.ollamaProvider.baseUrl}/v1`, }) - const modelInfo = GetEmbeddingModelInfo(settings.embeddingModelProvider, settings.embeddingModelId) return { id: settings.embeddingModelId, - dimension: modelInfo.dimensions, + dimension: 0, getEmbedding: async (text: string) => { if (!settings.ollamaProvider.baseUrl) { throw new LLMBaseUrlNotSetException( diff --git a/src/core/rag/rag-engine.ts b/src/core/rag/rag-engine.ts index 5cee91d..a54eb74 100644 --- a/src/core/rag/rag-engine.ts +++ b/src/core/rag/rag-engine.ts @@ -5,128 +5,145 @@ import { DBManager } from '../../database/database-manager' import { VectorManager } from '../../database/modules/vector/vector-manager' import { SelectVector } from '../../database/schema' import { EmbeddingModel } from '../../types/embedding' +import { ApiProvider } from '../../types/llm/model' import { InfioSettings } from '../../types/settings' import { getEmbeddingModel } from './embedding' export class RAGEngine { - private app: App - private settings: InfioSettings - private vectorManager: VectorManager + private app: App + private settings: InfioSettings + private vectorManager: VectorManager private embeddingModel: EmbeddingModel | null = null private initialized = false - constructor( - app: App, - settings: InfioSettings, - dbManager: DBManager, - ) { - this.app = app - this.settings = settings - this.vectorManager = dbManager.getVectorManager() + constructor( + app: App, + settings: InfioSettings, + dbManager: DBManager, + ) { + this.app = app + this.settings = settings + this.vectorManager = dbManager.getVectorManager() this.embeddingModel = getEmbeddingModel(settings) - } + } - setSettings(settings: InfioSettings) { - this.settings = settings - this.embeddingModel = getEmbeddingModel(settings) - } + setSettings(settings: InfioSettings) { + this.settings = settings + this.embeddingModel = getEmbeddingModel(settings) + } - // TODO: Implement automatic vault re-indexing when settings are changed. - // Currently, users must manually re-index the vault. - async updateVaultIndex( - options: { reindexAll: boolean }, - onQueryProgressChange?: (queryProgress: QueryProgressState) => void, + async initializeDimension(): Promise { + if (this.embeddingModel.dimension === 0 && this.settings.embeddingModelProvider === ApiProvider.Ollama) { + this.embeddingModel.dimension = (await this.embeddingModel.getEmbedding("hello world")).length + } + } + + // TODO: Implement automatic vault re-indexing when settings are changed. + // Currently, users must manually re-index the vault. + async updateVaultIndex( + options: { reindexAll: boolean }, + onQueryProgressChange?: (queryProgress: QueryProgressState) => void, ): Promise { if (!this.embeddingModel) { throw new Error('Embedding model is not set') - } - await this.vectorManager.updateVaultIndex( + } + await this.initializeDimension() + + await this.vectorManager.updateVaultIndex( this.embeddingModel, - { - chunkSize: this.settings.ragOptions.chunkSize, - excludePatterns: this.settings.ragOptions.excludePatterns, - includePatterns: this.settings.ragOptions.includePatterns, - reindexAll: options.reindexAll, - }, - (indexProgress) => { - onQueryProgressChange?.({ - type: 'indexing', - indexProgress, - }) - }, - ) + { + chunkSize: this.settings.ragOptions.chunkSize, + excludePatterns: this.settings.ragOptions.excludePatterns, + includePatterns: this.settings.ragOptions.includePatterns, + reindexAll: options.reindexAll, + }, + (indexProgress) => { + onQueryProgressChange?.({ + type: 'indexing', + indexProgress, + }) + }, + ) this.initialized = true } - async updateFileIndex(file: TFile) { - await this.vectorManager.UpdateFileVectorIndex( - this.embeddingModel, - this.settings.ragOptions.chunkSize, - file, - ) + async updateFileIndex(file: TFile) { + if (!this.embeddingModel) { + throw new Error('Embedding model is not set') + } + + await this.initializeDimension() + + await this.vectorManager.UpdateFileVectorIndex( + this.embeddingModel, + this.settings.ragOptions.chunkSize, + file, + ) } - + async deleteFileIndex(file: TFile) { + if (!this.embeddingModel) { + throw new Error('Embedding model is not set') + } + + await this.initializeDimension() + await this.vectorManager.DeleteFileVectorIndex( this.embeddingModel, file, ) } - async processQuery({ - query, - scope, - onQueryProgressChange, - }: { - query: string - scope?: { - files: string[] - folders: string[] - } - onQueryProgressChange?: (queryProgress: QueryProgressState) => void - }): Promise< - (Omit & { - similarity: number - })[] - > { - if (!this.embeddingModel) { - throw new Error('Embedding model is not set') - } + async processQuery({ + query, + scope, + onQueryProgressChange, + }: { + query: string + scope?: { + files: string[] + folders: string[] + } + onQueryProgressChange?: (queryProgress: QueryProgressState) => void + }): Promise< + (Omit & { + similarity: number + })[] + > { + if (!this.embeddingModel) { + throw new Error('Embedding model is not set') + } + + await this.initializeDimension() if (!this.initialized) { - await this.updateVaultIndex({ reindexAll: false }, onQueryProgressChange) - } - const queryEmbedding = await this.getEmbedding(query) - onQueryProgressChange?.({ - type: 'querying', + await this.updateVaultIndex({ reindexAll: false }, onQueryProgressChange) + } + const queryEmbedding = await this.getEmbedding(query) + onQueryProgressChange?.({ + type: 'querying', }) - console.log('query, ', { - minSimilarity: this.settings.ragOptions.minSimilarity, - limit: this.settings.ragOptions.limit, - scope, - }) - const queryResult = await this.vectorManager.performSimilaritySearch( - queryEmbedding, - this.embeddingModel, - { - minSimilarity: this.settings.ragOptions.minSimilarity, - limit: this.settings.ragOptions.limit, - scope, - }, + const queryResult = await this.vectorManager.performSimilaritySearch( + queryEmbedding, + this.embeddingModel, + { + minSimilarity: this.settings.ragOptions.minSimilarity, + limit: this.settings.ragOptions.limit, + scope, + }, ) - console.log('queryResult', queryResult) - onQueryProgressChange?.({ - type: 'querying-done', - queryResult, - }) - return queryResult - } + onQueryProgressChange?.({ + type: 'querying-done', + queryResult, + }) + return queryResult + } - async getEmbedding(query: string): Promise { - if (!this.embeddingModel) { - throw new Error('Embedding model is not set') - } - return this.embeddingModel.getEmbedding(query) - } + async getEmbedding(query: string): Promise { + if (!this.embeddingModel) { + throw new Error('Embedding model is not set') + } + return this.embeddingModel.getEmbedding(query) + } } diff --git a/src/settings/components/ModelProviderSettings.tsx b/src/settings/components/ModelProviderSettings.tsx index ddbae16..9cc4633 100644 --- a/src/settings/components/ModelProviderSettings.tsx +++ b/src/settings/components/ModelProviderSettings.tsx @@ -141,13 +141,15 @@ const CustomProviderSettings: React.FC = ({ plugin, onChange={updateProvider} />
- + {currProvider !== ApiProvider.Ollama && ( + + )}
{ ApiProvider.OpenAI, ApiProvider.SiliconFlow, ApiProvider.Google, - ApiProvider.AlibabaQwen + ApiProvider.AlibabaQwen, + ApiProvider.OpenAICompatible, + ApiProvider.Ollama, ] } diff --git a/src/utils/web-search.ts b/src/utils/web-search.ts index cfe0146..4659604 100644 --- a/src/utils/web-search.ts +++ b/src/utils/web-search.ts @@ -210,7 +210,6 @@ export async function webSearch(query: string, serperApiKey: string, jinaApiKey: } } -// todo: update export async function fetchUrlsContent(urls: string[], apiKey: string): Promise { return new Promise((resolve) => { const results = urls.map(async (url) => {