update for custom model, ollama
This commit is contained in:
parent
76ecca0da9
commit
570e8d9564
@ -49,15 +49,35 @@ class LLMManager implements LLMManagerInterface {
|
|||||||
|
|
||||||
constructor(settings: InfioSettings) {
|
constructor(settings: InfioSettings) {
|
||||||
this.infioProvider = new InfioProvider(settings.infioProvider.apiKey)
|
this.infioProvider = new InfioProvider(settings.infioProvider.apiKey)
|
||||||
this.openrouterProvider = new OpenAICompatibleProvider(settings.openrouterProvider.apiKey, OPENROUTER_BASE_URL)
|
this.openrouterProvider = new OpenAICompatibleProvider(
|
||||||
this.siliconflowProvider = new OpenAICompatibleProvider(settings.siliconflowProvider.apiKey, SILICONFLOW_BASE_URL)
|
settings.openrouterProvider.apiKey,
|
||||||
this.alibabaQwenProvider = new OpenAICompatibleProvider(settings.alibabaQwenProvider.apiKey, ALIBABA_QWEN_BASE_URL)
|
settings.openrouterProvider.baseUrl && settings.openrouterProvider.useCustomUrl ?
|
||||||
this.deepseekProvider = new OpenAICompatibleProvider(settings.deepseekProvider.apiKey, DEEPSEEK_BASE_URL)
|
settings.openrouterProvider.baseUrl
|
||||||
|
: OPENROUTER_BASE_URL
|
||||||
|
)
|
||||||
|
this.siliconflowProvider = new OpenAICompatibleProvider(
|
||||||
|
settings.siliconflowProvider.apiKey,
|
||||||
|
settings.siliconflowProvider.baseUrl && settings.siliconflowProvider.useCustomUrl ?
|
||||||
|
settings.siliconflowProvider.baseUrl
|
||||||
|
: SILICONFLOW_BASE_URL
|
||||||
|
)
|
||||||
|
this.alibabaQwenProvider = new OpenAICompatibleProvider(
|
||||||
|
settings.alibabaQwenProvider.apiKey,
|
||||||
|
settings.alibabaQwenProvider.baseUrl && settings.alibabaQwenProvider.useCustomUrl ?
|
||||||
|
settings.alibabaQwenProvider.baseUrl
|
||||||
|
: ALIBABA_QWEN_BASE_URL
|
||||||
|
)
|
||||||
|
this.deepseekProvider = new OpenAICompatibleProvider(
|
||||||
|
settings.deepseekProvider.apiKey,
|
||||||
|
settings.deepseekProvider.baseUrl && settings.deepseekProvider.useCustomUrl ?
|
||||||
|
settings.deepseekProvider.baseUrl
|
||||||
|
: DEEPSEEK_BASE_URL
|
||||||
|
)
|
||||||
this.openaiProvider = new OpenAIAuthenticatedProvider(settings.openaiProvider.apiKey)
|
this.openaiProvider = new OpenAIAuthenticatedProvider(settings.openaiProvider.apiKey)
|
||||||
this.anthropicProvider = new AnthropicProvider(settings.anthropicProvider.apiKey)
|
this.anthropicProvider = new AnthropicProvider(settings.anthropicProvider.apiKey)
|
||||||
this.googleProvider = new GeminiProvider(settings.googleProvider.apiKey)
|
this.googleProvider = new GeminiProvider(settings.googleProvider.apiKey)
|
||||||
this.groqProvider = new GroqProvider(settings.groqProvider.apiKey)
|
this.groqProvider = new GroqProvider(settings.groqProvider.apiKey)
|
||||||
this.ollamaProvider = new OllamaProvider(settings.groqProvider.baseUrl)
|
this.ollamaProvider = new OllamaProvider(settings.ollamaProvider.baseUrl)
|
||||||
this.openaiCompatibleProvider = new OpenAICompatibleProvider(settings.openaicompatibleProvider.apiKey, settings.openaicompatibleProvider.baseUrl)
|
this.openaiCompatibleProvider = new OpenAICompatibleProvider(settings.openaicompatibleProvider.apiKey, settings.openaicompatibleProvider.baseUrl)
|
||||||
this.isInfioEnabled = !!settings.infioProvider.apiKey
|
this.isInfioEnabled = !!settings.infioProvider.apiKey
|
||||||
}
|
}
|
||||||
@ -125,6 +145,8 @@ class LLMManager implements LLMManagerInterface {
|
|||||||
request,
|
request,
|
||||||
options,
|
options,
|
||||||
)
|
)
|
||||||
|
case ApiProvider.OpenAICompatible:
|
||||||
|
return await this.openaiCompatibleProvider.generateResponse(model, request, options)
|
||||||
default:
|
default:
|
||||||
throw new Error(`Unsupported model provider: ${model.provider}`)
|
throw new Error(`Unsupported model provider: ${model.provider}`)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -68,7 +68,7 @@ export class OllamaProvider implements BaseLLMProvider {
|
|||||||
|
|
||||||
const client = new NoStainlessOpenAI({
|
const client = new NoStainlessOpenAI({
|
||||||
baseURL: `${this.baseUrl}/v1`,
|
baseURL: `${this.baseUrl}/v1`,
|
||||||
apiKey: '',
|
apiKey: 'ollama',
|
||||||
dangerouslyAllowBrowser: true,
|
dangerouslyAllowBrowser: true,
|
||||||
})
|
})
|
||||||
return this.adapter.generateResponse(client, request, options)
|
return this.adapter.generateResponse(client, request, options)
|
||||||
@ -87,7 +87,7 @@ export class OllamaProvider implements BaseLLMProvider {
|
|||||||
|
|
||||||
const client = new NoStainlessOpenAI({
|
const client = new NoStainlessOpenAI({
|
||||||
baseURL: `${this.baseUrl}/v1`,
|
baseURL: `${this.baseUrl}/v1`,
|
||||||
apiKey: '',
|
apiKey: 'ollama',
|
||||||
dangerouslyAllowBrowser: true,
|
dangerouslyAllowBrowser: true,
|
||||||
})
|
})
|
||||||
return this.adapter.streamResponse(client, request, options)
|
return this.adapter.streamResponse(client, request, options)
|
||||||
|
|||||||
@ -159,10 +159,9 @@ export const getEmbeddingModel = (
|
|||||||
dangerouslyAllowBrowser: true,
|
dangerouslyAllowBrowser: true,
|
||||||
baseURL: `${settings.ollamaProvider.baseUrl}/v1`,
|
baseURL: `${settings.ollamaProvider.baseUrl}/v1`,
|
||||||
})
|
})
|
||||||
const modelInfo = GetEmbeddingModelInfo(settings.embeddingModelProvider, settings.embeddingModelId)
|
|
||||||
return {
|
return {
|
||||||
id: settings.embeddingModelId,
|
id: settings.embeddingModelId,
|
||||||
dimension: modelInfo.dimensions,
|
dimension: 0,
|
||||||
getEmbedding: async (text: string) => {
|
getEmbedding: async (text: string) => {
|
||||||
if (!settings.ollamaProvider.baseUrl) {
|
if (!settings.ollamaProvider.baseUrl) {
|
||||||
throw new LLMBaseUrlNotSetException(
|
throw new LLMBaseUrlNotSetException(
|
||||||
|
|||||||
@ -5,6 +5,7 @@ import { DBManager } from '../../database/database-manager'
|
|||||||
import { VectorManager } from '../../database/modules/vector/vector-manager'
|
import { VectorManager } from '../../database/modules/vector/vector-manager'
|
||||||
import { SelectVector } from '../../database/schema'
|
import { SelectVector } from '../../database/schema'
|
||||||
import { EmbeddingModel } from '../../types/embedding'
|
import { EmbeddingModel } from '../../types/embedding'
|
||||||
|
import { ApiProvider } from '../../types/llm/model'
|
||||||
import { InfioSettings } from '../../types/settings'
|
import { InfioSettings } from '../../types/settings'
|
||||||
|
|
||||||
import { getEmbeddingModel } from './embedding'
|
import { getEmbeddingModel } from './embedding'
|
||||||
@ -32,6 +33,12 @@ export class RAGEngine {
|
|||||||
this.embeddingModel = getEmbeddingModel(settings)
|
this.embeddingModel = getEmbeddingModel(settings)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async initializeDimension(): Promise<void> {
|
||||||
|
if (this.embeddingModel.dimension === 0 && this.settings.embeddingModelProvider === ApiProvider.Ollama) {
|
||||||
|
this.embeddingModel.dimension = (await this.embeddingModel.getEmbedding("hello world")).length
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// TODO: Implement automatic vault re-indexing when settings are changed.
|
// TODO: Implement automatic vault re-indexing when settings are changed.
|
||||||
// Currently, users must manually re-index the vault.
|
// Currently, users must manually re-index the vault.
|
||||||
async updateVaultIndex(
|
async updateVaultIndex(
|
||||||
@ -41,6 +48,8 @@ export class RAGEngine {
|
|||||||
if (!this.embeddingModel) {
|
if (!this.embeddingModel) {
|
||||||
throw new Error('Embedding model is not set')
|
throw new Error('Embedding model is not set')
|
||||||
}
|
}
|
||||||
|
await this.initializeDimension()
|
||||||
|
|
||||||
await this.vectorManager.updateVaultIndex(
|
await this.vectorManager.updateVaultIndex(
|
||||||
this.embeddingModel,
|
this.embeddingModel,
|
||||||
{
|
{
|
||||||
@ -60,6 +69,12 @@ export class RAGEngine {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async updateFileIndex(file: TFile) {
|
async updateFileIndex(file: TFile) {
|
||||||
|
if (!this.embeddingModel) {
|
||||||
|
throw new Error('Embedding model is not set')
|
||||||
|
}
|
||||||
|
|
||||||
|
await this.initializeDimension()
|
||||||
|
|
||||||
await this.vectorManager.UpdateFileVectorIndex(
|
await this.vectorManager.UpdateFileVectorIndex(
|
||||||
this.embeddingModel,
|
this.embeddingModel,
|
||||||
this.settings.ragOptions.chunkSize,
|
this.settings.ragOptions.chunkSize,
|
||||||
@ -68,6 +83,12 @@ export class RAGEngine {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async deleteFileIndex(file: TFile) {
|
async deleteFileIndex(file: TFile) {
|
||||||
|
if (!this.embeddingModel) {
|
||||||
|
throw new Error('Embedding model is not set')
|
||||||
|
}
|
||||||
|
|
||||||
|
await this.initializeDimension()
|
||||||
|
|
||||||
await this.vectorManager.DeleteFileVectorIndex(
|
await this.vectorManager.DeleteFileVectorIndex(
|
||||||
this.embeddingModel,
|
this.embeddingModel,
|
||||||
file,
|
file,
|
||||||
@ -94,6 +115,8 @@ export class RAGEngine {
|
|||||||
throw new Error('Embedding model is not set')
|
throw new Error('Embedding model is not set')
|
||||||
}
|
}
|
||||||
|
|
||||||
|
await this.initializeDimension()
|
||||||
|
|
||||||
if (!this.initialized) {
|
if (!this.initialized) {
|
||||||
await this.updateVaultIndex({ reindexAll: false }, onQueryProgressChange)
|
await this.updateVaultIndex({ reindexAll: false }, onQueryProgressChange)
|
||||||
}
|
}
|
||||||
@ -101,11 +124,6 @@ export class RAGEngine {
|
|||||||
onQueryProgressChange?.({
|
onQueryProgressChange?.({
|
||||||
type: 'querying',
|
type: 'querying',
|
||||||
})
|
})
|
||||||
console.log('query, ', {
|
|
||||||
minSimilarity: this.settings.ragOptions.minSimilarity,
|
|
||||||
limit: this.settings.ragOptions.limit,
|
|
||||||
scope,
|
|
||||||
})
|
|
||||||
const queryResult = await this.vectorManager.performSimilaritySearch(
|
const queryResult = await this.vectorManager.performSimilaritySearch(
|
||||||
queryEmbedding,
|
queryEmbedding,
|
||||||
this.embeddingModel,
|
this.embeddingModel,
|
||||||
@ -115,7 +133,6 @@ export class RAGEngine {
|
|||||||
scope,
|
scope,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
console.log('queryResult', queryResult)
|
|
||||||
onQueryProgressChange?.({
|
onQueryProgressChange?.({
|
||||||
type: 'querying-done',
|
type: 'querying-done',
|
||||||
queryResult,
|
queryResult,
|
||||||
|
|||||||
@ -141,6 +141,7 @@ const CustomProviderSettings: React.FC<CustomProviderSettingsProps> = ({ plugin,
|
|||||||
onChange={updateProvider}
|
onChange={updateProvider}
|
||||||
/>
|
/>
|
||||||
<div className="infio-llm-setting-divider"></div>
|
<div className="infio-llm-setting-divider"></div>
|
||||||
|
{currProvider !== ApiProvider.Ollama && (
|
||||||
<TextComponent
|
<TextComponent
|
||||||
name={currProvider + " api key:"}
|
name={currProvider + " api key:"}
|
||||||
placeholder="Enter your api key"
|
placeholder="Enter your api key"
|
||||||
@ -148,6 +149,7 @@ const CustomProviderSettings: React.FC<CustomProviderSettingsProps> = ({ plugin,
|
|||||||
onChange={updateProviderApiKey}
|
onChange={updateProviderApiKey}
|
||||||
type="password"
|
type="password"
|
||||||
/>
|
/>
|
||||||
|
)}
|
||||||
<div className="infio-llm-setting-divider"></div>
|
<div className="infio-llm-setting-divider"></div>
|
||||||
<ToggleComponent
|
<ToggleComponent
|
||||||
name="Use custom base url"
|
name="Use custom base url"
|
||||||
|
|||||||
@ -10,7 +10,6 @@ export enum ApiProvider {
|
|||||||
Groq = "Groq",
|
Groq = "Groq",
|
||||||
Ollama = "Ollama",
|
Ollama = "Ollama",
|
||||||
OpenAICompatible = "OpenAICompatible",
|
OpenAICompatible = "OpenAICompatible",
|
||||||
TransformersJs = "TransformersJs",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
export type LLMModel = {
|
export type LLMModel = {
|
||||||
|
|||||||
@ -125,14 +125,14 @@ const OpenAICompatibleProviderSchema = z.object({
|
|||||||
|
|
||||||
const OllamaProviderSchema = z.object({
|
const OllamaProviderSchema = z.object({
|
||||||
name: z.literal('Ollama'),
|
name: z.literal('Ollama'),
|
||||||
apiKey: z.string().catch(''),
|
apiKey: z.string().catch('ollama'),
|
||||||
baseUrl: z.string().catch(''),
|
baseUrl: z.string().catch(''),
|
||||||
useCustomUrl: z.boolean().catch(false)
|
useCustomUrl: z.boolean().catch(false)
|
||||||
}).catch({
|
}).catch({
|
||||||
name: 'Ollama',
|
name: 'Ollama',
|
||||||
apiKey: '',
|
apiKey: 'ollama',
|
||||||
baseUrl: '',
|
baseUrl: '',
|
||||||
useCustomUrl: false
|
useCustomUrl: true
|
||||||
})
|
})
|
||||||
|
|
||||||
const GroqProviderSchema = z.object({
|
const GroqProviderSchema = z.object({
|
||||||
|
|||||||
@ -1141,7 +1141,9 @@ export const GetEmbeddingProviders = (): ApiProvider[] => {
|
|||||||
ApiProvider.OpenAI,
|
ApiProvider.OpenAI,
|
||||||
ApiProvider.SiliconFlow,
|
ApiProvider.SiliconFlow,
|
||||||
ApiProvider.Google,
|
ApiProvider.Google,
|
||||||
ApiProvider.AlibabaQwen
|
ApiProvider.AlibabaQwen,
|
||||||
|
ApiProvider.OpenAICompatible,
|
||||||
|
ApiProvider.Ollama,
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -210,7 +210,6 @@ export async function webSearch(query: string, serperApiKey: string, jinaApiKey:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// todo: update
|
|
||||||
export async function fetchUrlsContent(urls: string[], apiKey: string): Promise<string> {
|
export async function fetchUrlsContent(urls: string[], apiKey: string): Promise<string> {
|
||||||
return new Promise((resolve) => {
|
return new Promise((resolve) => {
|
||||||
const results = urls.map(async (url) => {
|
const results = urls.map(async (url) => {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user