From 812aa763760c316ad41bd5feaf8dd38da1eff3d7 Mon Sep 17 00:00:00 2001 From: duanfuxiang Date: Sun, 13 Apr 2025 15:04:27 +0800 Subject: [PATCH] add grok --- src/constants.ts | 1 + src/core/llm/manager.ts | 21 ++- src/core/llm/openai-compatible.ts | 4 +- .../components/ModelProviderSettings.tsx | 4 +- src/types/llm/model.ts | 1 + src/types/settings.test.ts | 6 + src/types/settings.ts | 13 ++ src/utils/api.ts | 129 +++++++++++++----- 8 files changed, 139 insertions(+), 40 deletions(-) diff --git a/src/constants.ts b/src/constants.ts index 0cdaf4a..ebbf0e1 100644 --- a/src/constants.ts +++ b/src/constants.ts @@ -29,6 +29,7 @@ export const SUPPORT_EMBEDDING_SIMENTION: number[] = [ export const OPENAI_BASE_URL = 'https://api.openai.com/v1' export const DEEPSEEK_BASE_URL = 'https://api.deepseek.com' export const OPENROUTER_BASE_URL = 'https://openrouter.ai/api/v1' +export const GROK_BASE_URL = 'https://api.x.ai/v1' export const SILICONFLOW_BASE_URL = 'https://api.siliconflow.cn/v1' export const ALIBABA_QWEN_BASE_URL = 'https://dashscope.aliyuncs.com/compatible-mode/v1' export const INFIO_BASE_URL = 'https://api.infio.com/api/raw_message' diff --git a/src/core/llm/manager.ts b/src/core/llm/manager.ts index 76aa57f..74dc4ee 100644 --- a/src/core/llm/manager.ts +++ b/src/core/llm/manager.ts @@ -1,4 +1,4 @@ -import { ALIBABA_QWEN_BASE_URL, DEEPSEEK_BASE_URL, OPENROUTER_BASE_URL, SILICONFLOW_BASE_URL } from '../../constants' +import { ALIBABA_QWEN_BASE_URL, DEEPSEEK_BASE_URL, GROK_BASE_URL, OPENROUTER_BASE_URL, SILICONFLOW_BASE_URL } from '../../constants' import { ApiProvider, LLMModel } from '../../types/llm/model' import { LLMOptions, @@ -39,6 +39,7 @@ class LLMManager implements LLMManagerInterface { private anthropicProvider: AnthropicProvider private googleProvider: GeminiProvider private groqProvider: GroqProvider + private grokProvider: OpenAICompatibleProvider private infioProvider: InfioProvider private openrouterProvider: OpenAICompatibleProvider private siliconflowProvider: OpenAICompatibleProvider @@ -77,6 +78,16 @@ class LLMManager implements LLMManagerInterface { this.anthropicProvider = new AnthropicProvider(settings.anthropicProvider.apiKey) this.googleProvider = new GeminiProvider(settings.googleProvider.apiKey) this.groqProvider = new GroqProvider(settings.groqProvider.apiKey) + console.log('GrokProvider', + settings.grokProvider.apiKey, + settings.grokProvider.baseUrl, + settings.grokProvider.useCustomUrl + ) + this.grokProvider = new OpenAICompatibleProvider(settings.grokProvider.apiKey, + settings.grokProvider.baseUrl && settings.grokProvider.useCustomUrl ? + settings.grokProvider.baseUrl + : GROK_BASE_URL + ) this.ollamaProvider = new OllamaProvider(settings.ollamaProvider.baseUrl) this.openaiCompatibleProvider = new OpenAICompatibleProvider(settings.openaicompatibleProvider.apiKey, settings.openaicompatibleProvider.baseUrl) this.isInfioEnabled = !!settings.infioProvider.apiKey @@ -145,6 +156,12 @@ class LLMManager implements LLMManagerInterface { request, options, ) + case ApiProvider.Grok: + return await this.grokProvider.generateResponse( + model, + request, + options, + ) case ApiProvider.OpenAICompatible: return await this.openaiCompatibleProvider.generateResponse(model, request, options) default: @@ -182,6 +199,8 @@ class LLMManager implements LLMManagerInterface { return await this.googleProvider.streamResponse(model, request, options) case ApiProvider.Groq: return await this.groqProvider.streamResponse(model, request, options) + case ApiProvider.Grok: + return await this.grokProvider.streamResponse(model, request, options) case ApiProvider.Ollama: return await this.ollamaProvider.streamResponse(model, request, options) case ApiProvider.OpenAICompatible: diff --git a/src/core/llm/openai-compatible.ts b/src/core/llm/openai-compatible.ts index 6802602..320b95e 100644 --- a/src/core/llm/openai-compatible.ts +++ b/src/core/llm/openai-compatible.ts @@ -22,7 +22,8 @@ export class OpenAICompatibleProvider implements BaseLLMProvider { private baseURL: string constructor(apiKey: string, baseURL: string) { - this.adapter = new OpenAIMessageAdapter() + console.log('OpenAICompatibleProvider constructor', apiKey, baseURL) + this.adapter = new OpenAIMessageAdapter() this.client = new OpenAI({ apiKey: apiKey, baseURL: baseURL, @@ -37,6 +38,7 @@ export class OpenAICompatibleProvider implements BaseLLMProvider { request: LLMRequestNonStreaming, options?: LLMOptions, ): Promise { + console.log('OpenAICompatibleProvider generateResponse', this.baseURL, this.apiKey) if (!this.baseURL || !this.apiKey) { throw new LLMBaseUrlNotSetException( 'OpenAI Compatible base URL or API key is missing. Please set it in settings menu.', diff --git a/src/settings/components/ModelProviderSettings.tsx b/src/settings/components/ModelProviderSettings.tsx index 0819bca..318dd41 100644 --- a/src/settings/components/ModelProviderSettings.tsx +++ b/src/settings/components/ModelProviderSettings.tsx @@ -23,6 +23,7 @@ type ProviderSettingKey = | 'deepseekProvider' | 'googleProvider' | 'groqProvider' + | 'grokProvider' | 'ollamaProvider' | 'openaicompatibleProvider'; @@ -36,8 +37,9 @@ const keyMap: Record = { 'Deepseek': 'deepseekProvider', 'Google': 'googleProvider', 'Groq': 'groqProvider', + 'Grok': 'grokProvider', 'Ollama': 'ollamaProvider', - 'OpenAICompatible': 'openaicompatibleProvider' + 'OpenAICompatible': 'openaicompatibleProvider', }; const getProviderSettingKey = (provider: ApiProvider): ProviderSettingKey => { diff --git a/src/types/llm/model.ts b/src/types/llm/model.ts index b345544..23bab4d 100644 --- a/src/types/llm/model.ts +++ b/src/types/llm/model.ts @@ -8,6 +8,7 @@ export enum ApiProvider { OpenAI = "OpenAI", Google = "Google", Groq = "Groq", + Grok = "Grok", Ollama = "Ollama", OpenAICompatible = "OpenAICompatible", } diff --git a/src/types/settings.test.ts b/src/types/settings.test.ts index 6eaa37c..9e97743 100644 --- a/src/types/settings.test.ts +++ b/src/types/settings.test.ts @@ -296,6 +296,12 @@ describe('settings migration', () => { baseUrl: '', useCustomUrl: false, }, + grokProvider: { + name: 'Grok', + apiKey: '', + baseUrl: '', + useCustomUrl: false, + }, infioProvider: { name: 'Infio', apiKey: '', diff --git a/src/types/settings.ts b/src/types/settings.ts index 3fe4223..a6a9a3b 100644 --- a/src/types/settings.ts +++ b/src/types/settings.ts @@ -147,6 +147,18 @@ const GroqProviderSchema = z.object({ useCustomUrl: false }) +const GrokProviderSchema = z.object({ + name: z.literal('Grok'), + apiKey: z.string().catch(''), + baseUrl: z.string().catch(''), + useCustomUrl: z.boolean().catch(false) +}).catch({ + name: 'Grok', + apiKey: '', + baseUrl: '', + useCustomUrl: false +}) + const ollamaModelSchema = z.object({ baseUrl: z.string().catch(''), model: z.string().catch(''), @@ -205,6 +217,7 @@ export const InfioSettingsSchema = z.object({ googleProvider: GoogleProviderSchema, ollamaProvider: OllamaProviderSchema, groqProvider: GroqProviderSchema, + grokProvider: GrokProviderSchema, openaicompatibleProvider: OpenAICompatibleProviderSchema, // Chat Model diff --git a/src/utils/api.ts b/src/utils/api.ts index 92970a7..9061b2d 100644 --- a/src/utils/api.ts +++ b/src/utils/api.ts @@ -1,4 +1,4 @@ -import { OPENROUTER_BASE_URL } from '../constants' +import { GROK_BASE_URL, OPENROUTER_BASE_URL } from '../constants' import { ApiProvider } from '../types/llm/model' export interface ModelInfo { @@ -159,6 +159,40 @@ export const openRouterDefaultModelInfo: ModelInfo = { description: "The new Claude 3.5 Sonnet delivers better-than-Opus capabilities, faster-than-Sonnet speeds, at the same Sonnet prices. Sonnet is particularly good at:\n\n- Coding: New Sonnet scores ~49% on SWE-Bench Verified, higher than the last best score, and without any fancy prompt scaffolding\n- Data science: Augments human data science expertise; navigates unstructured data while using multiple tools for insights\n- Visual processing: excelling at interpreting charts, graphs, and images, accurately transcribing text to derive insights beyond just the text alone\n- Agentic tasks: exceptional tool use, making it great at agentic tasks (i.e. complex, multi-step problem solving tasks that require engaging with other systems)\n\n#multimodal", } +let openRouterModelsCache: Record | null = null; +async function fetchOpenRouterModels(): Promise> { + if (openRouterModelsCache) { + return openRouterModelsCache; + } + + try { + const response = await fetch(OPENROUTER_BASE_URL + "/models"); + const data = await response.json(); + const models: Record = {}; + + if (data?.data) { + for (const model of data.data) { + models[model.id] = { + maxTokens: model.top_provider?.max_completion_tokens ?? model.context_length, + contextWindow: model.context_length, + supportsImages: model.architecture?.modality?.includes("image") ?? false, + supportsPromptCache: false, + inputPrice: model.pricing?.prompt ?? 0, + outputPrice: model.pricing?.completion ?? 0, + description: model.description, + }; + } + } + + openRouterModelsCache = models; + return models; + } catch (error) { + console.error('Failed to fetch OpenRouter models:', error); + return { + [openRouterDefaultModelId]: openRouterDefaultModelInfo + }; + } +} // Gemini // https://ai.google.dev/gemini-api/docs/models/gemini @@ -1130,6 +1164,61 @@ export const groqModels = { }, } as const satisfies Record +// Grok +// https://docs.x.ai/docs/models +export type GrokModelId = keyof typeof grokModels +export const grokDefaultModelId: GrokModelId = "grok-3" +export const grokModels = { + "grok-3": { + maxTokens: 8192, + contextWindow: 131072, + supportsImages: false, + supportsPromptCache: true, + inputPrice: 0, + outputPrice: 0, + }, + "grok-3-fast": { + maxTokens: 8192, + contextWindow: 131072, + supportsImages: false, + supportsPromptCache: true, + inputPrice: 0, + outputPrice: 0, + }, + "grok-3-mini": { + maxTokens: 8192, + contextWindow: 131072, + supportsImages: false, + supportsPromptCache: true, + inputPrice: 0, + outputPrice: 0, + }, + "grok-3-mini-fast": { + maxTokens: 8192, + contextWindow: 131072, + supportsImages: false, + supportsPromptCache: true, + inputPrice: 0, + outputPrice: 0, + }, + "grok-2-vision": { + maxTokens: 8192, + contextWindow: 131072, + supportsImages: true, + supportsPromptCache: true, + inputPrice: 0, + outputPrice: 0, + }, + "grok-2-image": { + maxTokens: 8192, + contextWindow: 131072, + supportsImages: true, + supportsPromptCache: true, + inputPrice: 0, + outputPrice: 0, + } +} as const satisfies Record + /// helper functions // get all providers export const GetAllProviders = (): ApiProvider[] => { @@ -1147,42 +1236,6 @@ export const GetEmbeddingProviders = (): ApiProvider[] => { ] } -let openRouterModelsCache: Record | null = null; - -async function fetchOpenRouterModels(): Promise> { - if (openRouterModelsCache) { - return openRouterModelsCache; - } - - try { - const response = await fetch(OPENROUTER_BASE_URL + "/models"); - const data = await response.json(); - const models: Record = {}; - - if (data?.data) { - for (const model of data.data) { - models[model.id] = { - maxTokens: model.top_provider?.max_completion_tokens ?? model.context_length, - contextWindow: model.context_length, - supportsImages: model.architecture?.modality?.includes("image") ?? false, - supportsPromptCache: false, - inputPrice: model.pricing?.prompt ?? 0, - outputPrice: model.pricing?.completion ?? 0, - description: model.description, - }; - } - } - - openRouterModelsCache = models; - return models; - } catch (error) { - console.error('Failed to fetch OpenRouter models:', error); - return { - [openRouterDefaultModelId]: openRouterDefaultModelInfo - }; - } -} - // Get all models for a provider export const GetProviderModels = async (provider: ApiProvider): Promise> => { switch (provider) { @@ -1204,6 +1257,8 @@ export const GetProviderModels = async (provider: ApiProvider): Promise