update for custom model, ollama

This commit is contained in:
duanfuxiang 2025-03-20 12:44:53 +08:00
parent 76ecca0da9
commit 570e8d9564
9 changed files with 155 additions and 115 deletions

View File

@ -49,15 +49,35 @@ class LLMManager implements LLMManagerInterface {
constructor(settings: InfioSettings) { constructor(settings: InfioSettings) {
this.infioProvider = new InfioProvider(settings.infioProvider.apiKey) this.infioProvider = new InfioProvider(settings.infioProvider.apiKey)
this.openrouterProvider = new OpenAICompatibleProvider(settings.openrouterProvider.apiKey, OPENROUTER_BASE_URL) this.openrouterProvider = new OpenAICompatibleProvider(
this.siliconflowProvider = new OpenAICompatibleProvider(settings.siliconflowProvider.apiKey, SILICONFLOW_BASE_URL) settings.openrouterProvider.apiKey,
this.alibabaQwenProvider = new OpenAICompatibleProvider(settings.alibabaQwenProvider.apiKey, ALIBABA_QWEN_BASE_URL) settings.openrouterProvider.baseUrl && settings.openrouterProvider.useCustomUrl ?
this.deepseekProvider = new OpenAICompatibleProvider(settings.deepseekProvider.apiKey, DEEPSEEK_BASE_URL) settings.openrouterProvider.baseUrl
: OPENROUTER_BASE_URL
)
this.siliconflowProvider = new OpenAICompatibleProvider(
settings.siliconflowProvider.apiKey,
settings.siliconflowProvider.baseUrl && settings.siliconflowProvider.useCustomUrl ?
settings.siliconflowProvider.baseUrl
: SILICONFLOW_BASE_URL
)
this.alibabaQwenProvider = new OpenAICompatibleProvider(
settings.alibabaQwenProvider.apiKey,
settings.alibabaQwenProvider.baseUrl && settings.alibabaQwenProvider.useCustomUrl ?
settings.alibabaQwenProvider.baseUrl
: ALIBABA_QWEN_BASE_URL
)
this.deepseekProvider = new OpenAICompatibleProvider(
settings.deepseekProvider.apiKey,
settings.deepseekProvider.baseUrl && settings.deepseekProvider.useCustomUrl ?
settings.deepseekProvider.baseUrl
: DEEPSEEK_BASE_URL
)
this.openaiProvider = new OpenAIAuthenticatedProvider(settings.openaiProvider.apiKey) this.openaiProvider = new OpenAIAuthenticatedProvider(settings.openaiProvider.apiKey)
this.anthropicProvider = new AnthropicProvider(settings.anthropicProvider.apiKey) this.anthropicProvider = new AnthropicProvider(settings.anthropicProvider.apiKey)
this.googleProvider = new GeminiProvider(settings.googleProvider.apiKey) this.googleProvider = new GeminiProvider(settings.googleProvider.apiKey)
this.groqProvider = new GroqProvider(settings.groqProvider.apiKey) this.groqProvider = new GroqProvider(settings.groqProvider.apiKey)
this.ollamaProvider = new OllamaProvider(settings.groqProvider.baseUrl) this.ollamaProvider = new OllamaProvider(settings.ollamaProvider.baseUrl)
this.openaiCompatibleProvider = new OpenAICompatibleProvider(settings.openaicompatibleProvider.apiKey, settings.openaicompatibleProvider.baseUrl) this.openaiCompatibleProvider = new OpenAICompatibleProvider(settings.openaicompatibleProvider.apiKey, settings.openaicompatibleProvider.baseUrl)
this.isInfioEnabled = !!settings.infioProvider.apiKey this.isInfioEnabled = !!settings.infioProvider.apiKey
} }
@ -125,6 +145,8 @@ class LLMManager implements LLMManagerInterface {
request, request,
options, options,
) )
case ApiProvider.OpenAICompatible:
return await this.openaiCompatibleProvider.generateResponse(model, request, options)
default: default:
throw new Error(`Unsupported model provider: ${model.provider}`) throw new Error(`Unsupported model provider: ${model.provider}`)
} }

View File

@ -68,7 +68,7 @@ export class OllamaProvider implements BaseLLMProvider {
const client = new NoStainlessOpenAI({ const client = new NoStainlessOpenAI({
baseURL: `${this.baseUrl}/v1`, baseURL: `${this.baseUrl}/v1`,
apiKey: '', apiKey: 'ollama',
dangerouslyAllowBrowser: true, dangerouslyAllowBrowser: true,
}) })
return this.adapter.generateResponse(client, request, options) return this.adapter.generateResponse(client, request, options)
@ -87,7 +87,7 @@ export class OllamaProvider implements BaseLLMProvider {
const client = new NoStainlessOpenAI({ const client = new NoStainlessOpenAI({
baseURL: `${this.baseUrl}/v1`, baseURL: `${this.baseUrl}/v1`,
apiKey: '', apiKey: 'ollama',
dangerouslyAllowBrowser: true, dangerouslyAllowBrowser: true,
}) })
return this.adapter.streamResponse(client, request, options) return this.adapter.streamResponse(client, request, options)

View File

@ -159,10 +159,9 @@ export const getEmbeddingModel = (
dangerouslyAllowBrowser: true, dangerouslyAllowBrowser: true,
baseURL: `${settings.ollamaProvider.baseUrl}/v1`, baseURL: `${settings.ollamaProvider.baseUrl}/v1`,
}) })
const modelInfo = GetEmbeddingModelInfo(settings.embeddingModelProvider, settings.embeddingModelId)
return { return {
id: settings.embeddingModelId, id: settings.embeddingModelId,
dimension: modelInfo.dimensions, dimension: 0,
getEmbedding: async (text: string) => { getEmbedding: async (text: string) => {
if (!settings.ollamaProvider.baseUrl) { if (!settings.ollamaProvider.baseUrl) {
throw new LLMBaseUrlNotSetException( throw new LLMBaseUrlNotSetException(

View File

@ -5,6 +5,7 @@ import { DBManager } from '../../database/database-manager'
import { VectorManager } from '../../database/modules/vector/vector-manager' import { VectorManager } from '../../database/modules/vector/vector-manager'
import { SelectVector } from '../../database/schema' import { SelectVector } from '../../database/schema'
import { EmbeddingModel } from '../../types/embedding' import { EmbeddingModel } from '../../types/embedding'
import { ApiProvider } from '../../types/llm/model'
import { InfioSettings } from '../../types/settings' import { InfioSettings } from '../../types/settings'
import { getEmbeddingModel } from './embedding' import { getEmbeddingModel } from './embedding'
@ -32,6 +33,12 @@ export class RAGEngine {
this.embeddingModel = getEmbeddingModel(settings) this.embeddingModel = getEmbeddingModel(settings)
} }
async initializeDimension(): Promise<void> {
if (this.embeddingModel.dimension === 0 && this.settings.embeddingModelProvider === ApiProvider.Ollama) {
this.embeddingModel.dimension = (await this.embeddingModel.getEmbedding("hello world")).length
}
}
// TODO: Implement automatic vault re-indexing when settings are changed. // TODO: Implement automatic vault re-indexing when settings are changed.
// Currently, users must manually re-index the vault. // Currently, users must manually re-index the vault.
async updateVaultIndex( async updateVaultIndex(
@ -41,6 +48,8 @@ export class RAGEngine {
if (!this.embeddingModel) { if (!this.embeddingModel) {
throw new Error('Embedding model is not set') throw new Error('Embedding model is not set')
} }
await this.initializeDimension()
await this.vectorManager.updateVaultIndex( await this.vectorManager.updateVaultIndex(
this.embeddingModel, this.embeddingModel,
{ {
@ -60,6 +69,12 @@ export class RAGEngine {
} }
async updateFileIndex(file: TFile) { async updateFileIndex(file: TFile) {
if (!this.embeddingModel) {
throw new Error('Embedding model is not set')
}
await this.initializeDimension()
await this.vectorManager.UpdateFileVectorIndex( await this.vectorManager.UpdateFileVectorIndex(
this.embeddingModel, this.embeddingModel,
this.settings.ragOptions.chunkSize, this.settings.ragOptions.chunkSize,
@ -68,6 +83,12 @@ export class RAGEngine {
} }
async deleteFileIndex(file: TFile) { async deleteFileIndex(file: TFile) {
if (!this.embeddingModel) {
throw new Error('Embedding model is not set')
}
await this.initializeDimension()
await this.vectorManager.DeleteFileVectorIndex( await this.vectorManager.DeleteFileVectorIndex(
this.embeddingModel, this.embeddingModel,
file, file,
@ -94,6 +115,8 @@ export class RAGEngine {
throw new Error('Embedding model is not set') throw new Error('Embedding model is not set')
} }
await this.initializeDimension()
if (!this.initialized) { if (!this.initialized) {
await this.updateVaultIndex({ reindexAll: false }, onQueryProgressChange) await this.updateVaultIndex({ reindexAll: false }, onQueryProgressChange)
} }
@ -101,11 +124,6 @@ export class RAGEngine {
onQueryProgressChange?.({ onQueryProgressChange?.({
type: 'querying', type: 'querying',
}) })
console.log('query, ', {
minSimilarity: this.settings.ragOptions.minSimilarity,
limit: this.settings.ragOptions.limit,
scope,
})
const queryResult = await this.vectorManager.performSimilaritySearch( const queryResult = await this.vectorManager.performSimilaritySearch(
queryEmbedding, queryEmbedding,
this.embeddingModel, this.embeddingModel,
@ -115,7 +133,6 @@ export class RAGEngine {
scope, scope,
}, },
) )
console.log('queryResult', queryResult)
onQueryProgressChange?.({ onQueryProgressChange?.({
type: 'querying-done', type: 'querying-done',
queryResult, queryResult,

View File

@ -141,6 +141,7 @@ const CustomProviderSettings: React.FC<CustomProviderSettingsProps> = ({ plugin,
onChange={updateProvider} onChange={updateProvider}
/> />
<div className="infio-llm-setting-divider"></div> <div className="infio-llm-setting-divider"></div>
{currProvider !== ApiProvider.Ollama && (
<TextComponent <TextComponent
name={currProvider + " api key:"} name={currProvider + " api key:"}
placeholder="Enter your api key" placeholder="Enter your api key"
@ -148,6 +149,7 @@ const CustomProviderSettings: React.FC<CustomProviderSettingsProps> = ({ plugin,
onChange={updateProviderApiKey} onChange={updateProviderApiKey}
type="password" type="password"
/> />
)}
<div className="infio-llm-setting-divider"></div> <div className="infio-llm-setting-divider"></div>
<ToggleComponent <ToggleComponent
name="Use custom base url" name="Use custom base url"

View File

@ -10,7 +10,6 @@ export enum ApiProvider {
Groq = "Groq", Groq = "Groq",
Ollama = "Ollama", Ollama = "Ollama",
OpenAICompatible = "OpenAICompatible", OpenAICompatible = "OpenAICompatible",
TransformersJs = "TransformersJs",
} }
export type LLMModel = { export type LLMModel = {

View File

@ -125,14 +125,14 @@ const OpenAICompatibleProviderSchema = z.object({
const OllamaProviderSchema = z.object({ const OllamaProviderSchema = z.object({
name: z.literal('Ollama'), name: z.literal('Ollama'),
apiKey: z.string().catch(''), apiKey: z.string().catch('ollama'),
baseUrl: z.string().catch(''), baseUrl: z.string().catch(''),
useCustomUrl: z.boolean().catch(false) useCustomUrl: z.boolean().catch(false)
}).catch({ }).catch({
name: 'Ollama', name: 'Ollama',
apiKey: '', apiKey: 'ollama',
baseUrl: '', baseUrl: '',
useCustomUrl: false useCustomUrl: true
}) })
const GroqProviderSchema = z.object({ const GroqProviderSchema = z.object({

View File

@ -1141,7 +1141,9 @@ export const GetEmbeddingProviders = (): ApiProvider[] => {
ApiProvider.OpenAI, ApiProvider.OpenAI,
ApiProvider.SiliconFlow, ApiProvider.SiliconFlow,
ApiProvider.Google, ApiProvider.Google,
ApiProvider.AlibabaQwen ApiProvider.AlibabaQwen,
ApiProvider.OpenAICompatible,
ApiProvider.Ollama,
] ]
} }

View File

@ -210,7 +210,6 @@ export async function webSearch(query: string, serperApiKey: string, jinaApiKey:
} }
} }
// todo: update
export async function fetchUrlsContent(urls: string[], apiKey: string): Promise<string> { export async function fetchUrlsContent(urls: string[], apiKey: string): Promise<string> {
return new Promise((resolve) => { return new Promise((resolve) => {
const results = urls.map(async (url) => { const results = urls.map(async (url) => {