mirror of
https://github.com/EthanMarti/infio-copilot.git
synced 2026-01-16 16:31:56 +00:00
update for custom model, ollama
This commit is contained in:
parent
76ecca0da9
commit
570e8d9564
@ -49,15 +49,35 @@ class LLMManager implements LLMManagerInterface {
|
||||
|
||||
constructor(settings: InfioSettings) {
|
||||
this.infioProvider = new InfioProvider(settings.infioProvider.apiKey)
|
||||
this.openrouterProvider = new OpenAICompatibleProvider(settings.openrouterProvider.apiKey, OPENROUTER_BASE_URL)
|
||||
this.siliconflowProvider = new OpenAICompatibleProvider(settings.siliconflowProvider.apiKey, SILICONFLOW_BASE_URL)
|
||||
this.alibabaQwenProvider = new OpenAICompatibleProvider(settings.alibabaQwenProvider.apiKey, ALIBABA_QWEN_BASE_URL)
|
||||
this.deepseekProvider = new OpenAICompatibleProvider(settings.deepseekProvider.apiKey, DEEPSEEK_BASE_URL)
|
||||
this.openrouterProvider = new OpenAICompatibleProvider(
|
||||
settings.openrouterProvider.apiKey,
|
||||
settings.openrouterProvider.baseUrl && settings.openrouterProvider.useCustomUrl ?
|
||||
settings.openrouterProvider.baseUrl
|
||||
: OPENROUTER_BASE_URL
|
||||
)
|
||||
this.siliconflowProvider = new OpenAICompatibleProvider(
|
||||
settings.siliconflowProvider.apiKey,
|
||||
settings.siliconflowProvider.baseUrl && settings.siliconflowProvider.useCustomUrl ?
|
||||
settings.siliconflowProvider.baseUrl
|
||||
: SILICONFLOW_BASE_URL
|
||||
)
|
||||
this.alibabaQwenProvider = new OpenAICompatibleProvider(
|
||||
settings.alibabaQwenProvider.apiKey,
|
||||
settings.alibabaQwenProvider.baseUrl && settings.alibabaQwenProvider.useCustomUrl ?
|
||||
settings.alibabaQwenProvider.baseUrl
|
||||
: ALIBABA_QWEN_BASE_URL
|
||||
)
|
||||
this.deepseekProvider = new OpenAICompatibleProvider(
|
||||
settings.deepseekProvider.apiKey,
|
||||
settings.deepseekProvider.baseUrl && settings.deepseekProvider.useCustomUrl ?
|
||||
settings.deepseekProvider.baseUrl
|
||||
: DEEPSEEK_BASE_URL
|
||||
)
|
||||
this.openaiProvider = new OpenAIAuthenticatedProvider(settings.openaiProvider.apiKey)
|
||||
this.anthropicProvider = new AnthropicProvider(settings.anthropicProvider.apiKey)
|
||||
this.googleProvider = new GeminiProvider(settings.googleProvider.apiKey)
|
||||
this.groqProvider = new GroqProvider(settings.groqProvider.apiKey)
|
||||
this.ollamaProvider = new OllamaProvider(settings.groqProvider.baseUrl)
|
||||
this.ollamaProvider = new OllamaProvider(settings.ollamaProvider.baseUrl)
|
||||
this.openaiCompatibleProvider = new OpenAICompatibleProvider(settings.openaicompatibleProvider.apiKey, settings.openaicompatibleProvider.baseUrl)
|
||||
this.isInfioEnabled = !!settings.infioProvider.apiKey
|
||||
}
|
||||
@ -125,6 +145,8 @@ class LLMManager implements LLMManagerInterface {
|
||||
request,
|
||||
options,
|
||||
)
|
||||
case ApiProvider.OpenAICompatible:
|
||||
return await this.openaiCompatibleProvider.generateResponse(model, request, options)
|
||||
default:
|
||||
throw new Error(`Unsupported model provider: ${model.provider}`)
|
||||
}
|
||||
|
||||
@ -68,7 +68,7 @@ export class OllamaProvider implements BaseLLMProvider {
|
||||
|
||||
const client = new NoStainlessOpenAI({
|
||||
baseURL: `${this.baseUrl}/v1`,
|
||||
apiKey: '',
|
||||
apiKey: 'ollama',
|
||||
dangerouslyAllowBrowser: true,
|
||||
})
|
||||
return this.adapter.generateResponse(client, request, options)
|
||||
@ -87,7 +87,7 @@ export class OllamaProvider implements BaseLLMProvider {
|
||||
|
||||
const client = new NoStainlessOpenAI({
|
||||
baseURL: `${this.baseUrl}/v1`,
|
||||
apiKey: '',
|
||||
apiKey: 'ollama',
|
||||
dangerouslyAllowBrowser: true,
|
||||
})
|
||||
return this.adapter.streamResponse(client, request, options)
|
||||
|
||||
@ -159,10 +159,9 @@ export const getEmbeddingModel = (
|
||||
dangerouslyAllowBrowser: true,
|
||||
baseURL: `${settings.ollamaProvider.baseUrl}/v1`,
|
||||
})
|
||||
const modelInfo = GetEmbeddingModelInfo(settings.embeddingModelProvider, settings.embeddingModelId)
|
||||
return {
|
||||
id: settings.embeddingModelId,
|
||||
dimension: modelInfo.dimensions,
|
||||
dimension: 0,
|
||||
getEmbedding: async (text: string) => {
|
||||
if (!settings.ollamaProvider.baseUrl) {
|
||||
throw new LLMBaseUrlNotSetException(
|
||||
|
||||
@ -5,128 +5,145 @@ import { DBManager } from '../../database/database-manager'
|
||||
import { VectorManager } from '../../database/modules/vector/vector-manager'
|
||||
import { SelectVector } from '../../database/schema'
|
||||
import { EmbeddingModel } from '../../types/embedding'
|
||||
import { ApiProvider } from '../../types/llm/model'
|
||||
import { InfioSettings } from '../../types/settings'
|
||||
|
||||
import { getEmbeddingModel } from './embedding'
|
||||
|
||||
export class RAGEngine {
|
||||
private app: App
|
||||
private settings: InfioSettings
|
||||
private vectorManager: VectorManager
|
||||
private app: App
|
||||
private settings: InfioSettings
|
||||
private vectorManager: VectorManager
|
||||
private embeddingModel: EmbeddingModel | null = null
|
||||
private initialized = false
|
||||
|
||||
constructor(
|
||||
app: App,
|
||||
settings: InfioSettings,
|
||||
dbManager: DBManager,
|
||||
) {
|
||||
this.app = app
|
||||
this.settings = settings
|
||||
this.vectorManager = dbManager.getVectorManager()
|
||||
constructor(
|
||||
app: App,
|
||||
settings: InfioSettings,
|
||||
dbManager: DBManager,
|
||||
) {
|
||||
this.app = app
|
||||
this.settings = settings
|
||||
this.vectorManager = dbManager.getVectorManager()
|
||||
this.embeddingModel = getEmbeddingModel(settings)
|
||||
}
|
||||
}
|
||||
|
||||
setSettings(settings: InfioSettings) {
|
||||
this.settings = settings
|
||||
this.embeddingModel = getEmbeddingModel(settings)
|
||||
}
|
||||
setSettings(settings: InfioSettings) {
|
||||
this.settings = settings
|
||||
this.embeddingModel = getEmbeddingModel(settings)
|
||||
}
|
||||
|
||||
// TODO: Implement automatic vault re-indexing when settings are changed.
|
||||
// Currently, users must manually re-index the vault.
|
||||
async updateVaultIndex(
|
||||
options: { reindexAll: boolean },
|
||||
onQueryProgressChange?: (queryProgress: QueryProgressState) => void,
|
||||
async initializeDimension(): Promise<void> {
|
||||
if (this.embeddingModel.dimension === 0 && this.settings.embeddingModelProvider === ApiProvider.Ollama) {
|
||||
this.embeddingModel.dimension = (await this.embeddingModel.getEmbedding("hello world")).length
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Implement automatic vault re-indexing when settings are changed.
|
||||
// Currently, users must manually re-index the vault.
|
||||
async updateVaultIndex(
|
||||
options: { reindexAll: boolean },
|
||||
onQueryProgressChange?: (queryProgress: QueryProgressState) => void,
|
||||
): Promise<void> {
|
||||
if (!this.embeddingModel) {
|
||||
throw new Error('Embedding model is not set')
|
||||
}
|
||||
await this.vectorManager.updateVaultIndex(
|
||||
}
|
||||
await this.initializeDimension()
|
||||
|
||||
await this.vectorManager.updateVaultIndex(
|
||||
this.embeddingModel,
|
||||
{
|
||||
chunkSize: this.settings.ragOptions.chunkSize,
|
||||
excludePatterns: this.settings.ragOptions.excludePatterns,
|
||||
includePatterns: this.settings.ragOptions.includePatterns,
|
||||
reindexAll: options.reindexAll,
|
||||
},
|
||||
(indexProgress) => {
|
||||
onQueryProgressChange?.({
|
||||
type: 'indexing',
|
||||
indexProgress,
|
||||
})
|
||||
},
|
||||
)
|
||||
{
|
||||
chunkSize: this.settings.ragOptions.chunkSize,
|
||||
excludePatterns: this.settings.ragOptions.excludePatterns,
|
||||
includePatterns: this.settings.ragOptions.includePatterns,
|
||||
reindexAll: options.reindexAll,
|
||||
},
|
||||
(indexProgress) => {
|
||||
onQueryProgressChange?.({
|
||||
type: 'indexing',
|
||||
indexProgress,
|
||||
})
|
||||
},
|
||||
)
|
||||
this.initialized = true
|
||||
}
|
||||
|
||||
async updateFileIndex(file: TFile) {
|
||||
await this.vectorManager.UpdateFileVectorIndex(
|
||||
this.embeddingModel,
|
||||
this.settings.ragOptions.chunkSize,
|
||||
file,
|
||||
)
|
||||
async updateFileIndex(file: TFile) {
|
||||
if (!this.embeddingModel) {
|
||||
throw new Error('Embedding model is not set')
|
||||
}
|
||||
|
||||
await this.initializeDimension()
|
||||
|
||||
await this.vectorManager.UpdateFileVectorIndex(
|
||||
this.embeddingModel,
|
||||
this.settings.ragOptions.chunkSize,
|
||||
file,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
async deleteFileIndex(file: TFile) {
|
||||
if (!this.embeddingModel) {
|
||||
throw new Error('Embedding model is not set')
|
||||
}
|
||||
|
||||
await this.initializeDimension()
|
||||
|
||||
await this.vectorManager.DeleteFileVectorIndex(
|
||||
this.embeddingModel,
|
||||
file,
|
||||
)
|
||||
}
|
||||
|
||||
async processQuery({
|
||||
query,
|
||||
scope,
|
||||
onQueryProgressChange,
|
||||
}: {
|
||||
query: string
|
||||
scope?: {
|
||||
files: string[]
|
||||
folders: string[]
|
||||
}
|
||||
onQueryProgressChange?: (queryProgress: QueryProgressState) => void
|
||||
}): Promise<
|
||||
(Omit<SelectVector, 'embedding'> & {
|
||||
similarity: number
|
||||
})[]
|
||||
> {
|
||||
if (!this.embeddingModel) {
|
||||
throw new Error('Embedding model is not set')
|
||||
}
|
||||
async processQuery({
|
||||
query,
|
||||
scope,
|
||||
onQueryProgressChange,
|
||||
}: {
|
||||
query: string
|
||||
scope?: {
|
||||
files: string[]
|
||||
folders: string[]
|
||||
}
|
||||
onQueryProgressChange?: (queryProgress: QueryProgressState) => void
|
||||
}): Promise<
|
||||
(Omit<SelectVector, 'embedding'> & {
|
||||
similarity: number
|
||||
})[]
|
||||
> {
|
||||
if (!this.embeddingModel) {
|
||||
throw new Error('Embedding model is not set')
|
||||
}
|
||||
|
||||
await this.initializeDimension()
|
||||
|
||||
if (!this.initialized) {
|
||||
await this.updateVaultIndex({ reindexAll: false }, onQueryProgressChange)
|
||||
}
|
||||
const queryEmbedding = await this.getEmbedding(query)
|
||||
onQueryProgressChange?.({
|
||||
type: 'querying',
|
||||
await this.updateVaultIndex({ reindexAll: false }, onQueryProgressChange)
|
||||
}
|
||||
const queryEmbedding = await this.getEmbedding(query)
|
||||
onQueryProgressChange?.({
|
||||
type: 'querying',
|
||||
})
|
||||
console.log('query, ', {
|
||||
minSimilarity: this.settings.ragOptions.minSimilarity,
|
||||
limit: this.settings.ragOptions.limit,
|
||||
scope,
|
||||
})
|
||||
const queryResult = await this.vectorManager.performSimilaritySearch(
|
||||
queryEmbedding,
|
||||
this.embeddingModel,
|
||||
{
|
||||
minSimilarity: this.settings.ragOptions.minSimilarity,
|
||||
limit: this.settings.ragOptions.limit,
|
||||
scope,
|
||||
},
|
||||
const queryResult = await this.vectorManager.performSimilaritySearch(
|
||||
queryEmbedding,
|
||||
this.embeddingModel,
|
||||
{
|
||||
minSimilarity: this.settings.ragOptions.minSimilarity,
|
||||
limit: this.settings.ragOptions.limit,
|
||||
scope,
|
||||
},
|
||||
)
|
||||
console.log('queryResult', queryResult)
|
||||
onQueryProgressChange?.({
|
||||
type: 'querying-done',
|
||||
queryResult,
|
||||
})
|
||||
return queryResult
|
||||
}
|
||||
onQueryProgressChange?.({
|
||||
type: 'querying-done',
|
||||
queryResult,
|
||||
})
|
||||
return queryResult
|
||||
}
|
||||
|
||||
async getEmbedding(query: string): Promise<number[]> {
|
||||
if (!this.embeddingModel) {
|
||||
throw new Error('Embedding model is not set')
|
||||
}
|
||||
return this.embeddingModel.getEmbedding(query)
|
||||
}
|
||||
async getEmbedding(query: string): Promise<number[]> {
|
||||
if (!this.embeddingModel) {
|
||||
throw new Error('Embedding model is not set')
|
||||
}
|
||||
return this.embeddingModel.getEmbedding(query)
|
||||
}
|
||||
}
|
||||
|
||||
@ -141,13 +141,15 @@ const CustomProviderSettings: React.FC<CustomProviderSettingsProps> = ({ plugin,
|
||||
onChange={updateProvider}
|
||||
/>
|
||||
<div className="infio-llm-setting-divider"></div>
|
||||
<TextComponent
|
||||
name={currProvider + " api key:"}
|
||||
placeholder="Enter your api key"
|
||||
value={providerSetting.apiKey || ''}
|
||||
onChange={updateProviderApiKey}
|
||||
type="password"
|
||||
/>
|
||||
{currProvider !== ApiProvider.Ollama && (
|
||||
<TextComponent
|
||||
name={currProvider + " api key:"}
|
||||
placeholder="Enter your api key"
|
||||
value={providerSetting.apiKey || ''}
|
||||
onChange={updateProviderApiKey}
|
||||
type="password"
|
||||
/>
|
||||
)}
|
||||
<div className="infio-llm-setting-divider"></div>
|
||||
<ToggleComponent
|
||||
name="Use custom base url"
|
||||
|
||||
@ -10,7 +10,6 @@ export enum ApiProvider {
|
||||
Groq = "Groq",
|
||||
Ollama = "Ollama",
|
||||
OpenAICompatible = "OpenAICompatible",
|
||||
TransformersJs = "TransformersJs",
|
||||
}
|
||||
|
||||
export type LLMModel = {
|
||||
|
||||
@ -125,14 +125,14 @@ const OpenAICompatibleProviderSchema = z.object({
|
||||
|
||||
const OllamaProviderSchema = z.object({
|
||||
name: z.literal('Ollama'),
|
||||
apiKey: z.string().catch(''),
|
||||
apiKey: z.string().catch('ollama'),
|
||||
baseUrl: z.string().catch(''),
|
||||
useCustomUrl: z.boolean().catch(false)
|
||||
}).catch({
|
||||
name: 'Ollama',
|
||||
apiKey: '',
|
||||
apiKey: 'ollama',
|
||||
baseUrl: '',
|
||||
useCustomUrl: false
|
||||
useCustomUrl: true
|
||||
})
|
||||
|
||||
const GroqProviderSchema = z.object({
|
||||
|
||||
@ -1141,7 +1141,9 @@ export const GetEmbeddingProviders = (): ApiProvider[] => {
|
||||
ApiProvider.OpenAI,
|
||||
ApiProvider.SiliconFlow,
|
||||
ApiProvider.Google,
|
||||
ApiProvider.AlibabaQwen
|
||||
ApiProvider.AlibabaQwen,
|
||||
ApiProvider.OpenAICompatible,
|
||||
ApiProvider.Ollama,
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@ -210,7 +210,6 @@ export async function webSearch(query: string, serperApiKey: string, jinaApiKey:
|
||||
}
|
||||
}
|
||||
|
||||
// todo: update
|
||||
export async function fetchUrlsContent(urls: string[], apiKey: string): Promise<string> {
|
||||
return new Promise((resolve) => {
|
||||
const results = urls.map(async (url) => {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user