From c3cc81624fc64280594df2630e1bb4c9dea55beb Mon Sep 17 00:00:00 2001 From: duanfuxiang Date: Wed, 18 Jun 2025 08:02:32 +0800 Subject: [PATCH] update models select --- .../components/ModelProviderSettings.tsx | 27 ++++++++--- .../components/ProviderModelsPicker.tsx | 10 ++-- src/types/settings.ts | 48 +++++++++---------- 3 files changed, 51 insertions(+), 34 deletions(-) diff --git a/src/settings/components/ModelProviderSettings.tsx b/src/settings/components/ModelProviderSettings.tsx index 751c536..ad0cc67 100644 --- a/src/settings/components/ModelProviderSettings.tsx +++ b/src/settings/components/ModelProviderSettings.tsx @@ -293,7 +293,12 @@ const CustomProviderSettings: React.FC = ({ plugin, console.log(`updateChatModelId: ${provider} -> ${modelId}, isCustom: ${isCustom}`) const providerSettingKey = getProviderSettingKey(provider); const providerSettings = settings[providerSettingKey] || {}; - const currentModels = providerSettings.models || new Set(); + const currentModels = providerSettings.models || []; + + // 如果是自定义模型且不在列表中,则添加 + const updatedModels = isCustom && !currentModels.includes(modelId) + ? [...currentModels, modelId] + : currentModels; handleSettingsUpdate({ ...settings, @@ -301,7 +306,7 @@ const CustomProviderSettings: React.FC = ({ plugin, chatModelId: modelId, [providerSettingKey]: { ...providerSettings, - models: isCustom ? new Set([...currentModels, modelId]) : currentModels + models: updatedModels } }); }; @@ -310,7 +315,12 @@ const CustomProviderSettings: React.FC = ({ plugin, console.log(`updateApplyModelId: ${provider} -> ${modelId}, isCustom: ${isCustom}`) const providerSettingKey = getProviderSettingKey(provider); const providerSettings = settings[providerSettingKey] || {}; - const currentModels = providerSettings.models || new Set(); + const currentModels = providerSettings.models || []; + + // 如果是自定义模型且不在列表中,则添加 + const updatedModels = isCustom && !currentModels.includes(modelId) + ? [...currentModels, modelId] + : currentModels; handleSettingsUpdate({ ...settings, @@ -318,7 +328,7 @@ const CustomProviderSettings: React.FC = ({ plugin, applyModelId: modelId, [providerSettingKey]: { ...providerSettings, - models: isCustom ? new Set([...currentModels, modelId]) : currentModels + models: updatedModels } }); }; @@ -327,7 +337,12 @@ const CustomProviderSettings: React.FC = ({ plugin, console.log(`updateEmbeddingModelId: ${provider} -> ${modelId}, isCustom: ${isCustom}`) const providerSettingKey = getProviderSettingKey(provider); const providerSettings = settings[providerSettingKey] || {}; - const currentModels = providerSettings.models || new Set(); + const currentModels = providerSettings.models || []; + + // 如果是自定义模型且不在列表中,则添加 + const updatedModels = isCustom && !currentModels.includes(modelId) + ? [...currentModels, modelId] + : currentModels; handleSettingsUpdate({ ...settings, @@ -335,7 +350,7 @@ const CustomProviderSettings: React.FC = ({ plugin, embeddingModelId: modelId, [providerSettingKey]: { ...providerSettings, - models: isCustom ? new Set([...currentModels, modelId]) : currentModels + models: updatedModels } }); }; diff --git a/src/settings/components/ProviderModelsPicker.tsx b/src/settings/components/ProviderModelsPicker.tsx index fa7e340..4879b76 100644 --- a/src/settings/components/ProviderModelsPicker.tsx +++ b/src/settings/components/ProviderModelsPicker.tsx @@ -207,13 +207,13 @@ export const ComboBoxComponent: React.FC = ({ const combinedModelIds = useMemo(() => { const providerKey = getProviderSettingKey(modelProvider); const providerModels = settings?.[providerKey]?.models; - console.log(`🔍 Custom models in settings for ${modelProvider}:`, providerModels ? Array.from(providerModels) : 'none') - // Ensure providerModels is a Set of strings - if (!providerModels || !(providerModels instanceof Set)) { + console.log(`🔍 Custom models in settings for ${modelProvider}:`, providerModels || 'none') + // Ensure providerModels is an array of strings + if (!providerModels || !Array.isArray(providerModels)) { console.log(`📋 Using only official models (${modelIds.length}):`, modelIds); return modelIds; } - const additionalModels = Array.from(providerModels).filter((model): model is string => typeof model === 'string'); + const additionalModels = providerModels.filter((model): model is string => typeof model === 'string'); console.log(`📋 Combined models: ${modelIds.length} official + ${additionalModels.length} custom`); return [...modelIds, ...additionalModels]; }, [modelIds, settings, modelProvider]); @@ -286,6 +286,8 @@ export const ComboBoxComponent: React.FC = ({ if (isValidProvider(newProvider)) { setModelProvider(newProvider); + // 当提供商变更时,清空模型选择并通知父组件 + updateModel(newProvider, '', false); } }; diff --git a/src/types/settings.ts b/src/types/settings.ts index aa589fd..a95c0c8 100644 --- a/src/types/settings.ts +++ b/src/types/settings.ts @@ -20,13 +20,13 @@ const InfioProviderSchema = z.object({ apiKey: z.string().catch(''), baseUrl: z.string().catch(''), useCustomUrl: z.boolean().catch(false), - models: z.set(z.string()).catch(new Set()) + models: z.array(z.string()).catch([]) }).catch({ name: 'Infio', apiKey: '', baseUrl: '', useCustomUrl: false, - models: new Set() + models: [] }) const OpenRouterProviderSchema = z.object({ @@ -34,13 +34,13 @@ const OpenRouterProviderSchema = z.object({ apiKey: z.string().catch(''), baseUrl: z.string().catch(''), useCustomUrl: z.boolean().catch(false), - models: z.set(z.string()).catch(new Set()) + models: z.array(z.string()).catch([]) }).catch({ name: 'OpenRouter', apiKey: '', baseUrl: '', useCustomUrl: false, - models: new Set() + models: [] }) const SiliconFlowProviderSchema = z.object({ @@ -48,13 +48,13 @@ const SiliconFlowProviderSchema = z.object({ apiKey: z.string().catch(''), baseUrl: z.string().catch(''), useCustomUrl: z.boolean().catch(false), - models: z.set(z.string()).catch(new Set()) + models: z.array(z.string()).catch([]) }).catch({ name: 'SiliconFlow', apiKey: '', baseUrl: '', useCustomUrl: false, - models: new Set() + models: [] }) const AlibabaQwenProviderSchema = z.object({ @@ -62,13 +62,13 @@ const AlibabaQwenProviderSchema = z.object({ apiKey: z.string().catch(''), baseUrl: z.string().catch(''), useCustomUrl: z.boolean().catch(false), - models: z.set(z.string()).catch(new Set()) + models: z.array(z.string()).catch([]) }).catch({ name: 'AlibabaQwen', apiKey: '', baseUrl: '', useCustomUrl: false, - models: new Set() + models: [] }) const AnthropicProviderSchema = z.object({ @@ -76,13 +76,13 @@ const AnthropicProviderSchema = z.object({ apiKey: z.string().catch(''), baseUrl: z.string().optional(), useCustomUrl: z.boolean().catch(false), - models: z.set(z.string()).catch(new Set()) + models: z.array(z.string()).catch([]) }).catch({ name: 'Anthropic', apiKey: '', baseUrl: '', useCustomUrl: false, - models: new Set() + models: [] }) const DeepSeekProviderSchema = z.object({ @@ -90,13 +90,13 @@ const DeepSeekProviderSchema = z.object({ apiKey: z.string().catch(''), baseUrl: z.string().catch(''), useCustomUrl: z.boolean().catch(false), - models: z.set(z.string()).catch(new Set()) + models: z.array(z.string()).catch([]) }).catch({ name: 'DeepSeek', apiKey: '', baseUrl: '', useCustomUrl: false, - models: new Set() + models: [] }) const GoogleProviderSchema = z.object({ @@ -104,13 +104,13 @@ const GoogleProviderSchema = z.object({ apiKey: z.string().catch(''), baseUrl: z.string().catch(''), useCustomUrl: z.boolean().catch(false), - models: z.set(z.string()).catch(new Set()) + models: z.array(z.string()).catch([]) }).catch({ name: 'Google', apiKey: '', baseUrl: '', useCustomUrl: false, - models: new Set() + models: [] }) const OpenAIProviderSchema = z.object({ @@ -118,13 +118,13 @@ const OpenAIProviderSchema = z.object({ apiKey: z.string().catch(''), baseUrl: z.string().optional(), useCustomUrl: z.boolean().catch(false), - models: z.set(z.string()).catch(new Set()) + models: z.array(z.string()).catch([]) }).catch({ name: 'OpenAI', apiKey: '', baseUrl: '', useCustomUrl: false, - models: new Set() + models: [] }) const OpenAICompatibleProviderSchema = z.object({ @@ -132,13 +132,13 @@ const OpenAICompatibleProviderSchema = z.object({ apiKey: z.string().catch(''), baseUrl: z.string().optional(), useCustomUrl: z.boolean().catch(true), - models: z.set(z.string()).catch(new Set()) + models: z.array(z.string()).catch([]) }).catch({ name: 'OpenAICompatible', apiKey: '', baseUrl: '', useCustomUrl: true, - models: new Set() + models: [] }) const OllamaProviderSchema = z.object({ @@ -146,13 +146,13 @@ const OllamaProviderSchema = z.object({ apiKey: z.string().catch('ollama'), baseUrl: z.string().catch(''), useCustomUrl: z.boolean().catch(false), - models: z.set(z.string()).catch(new Set()) + models: z.array(z.string()).catch([]) }).catch({ name: 'Ollama', apiKey: 'ollama', baseUrl: '', useCustomUrl: true, - models: new Set() + models: [] }) const GroqProviderSchema = z.object({ @@ -160,13 +160,13 @@ const GroqProviderSchema = z.object({ apiKey: z.string().catch(''), baseUrl: z.string().catch(''), useCustomUrl: z.boolean().catch(false), - models: z.set(z.string()).catch(new Set()) + models: z.array(z.string()).catch([]) }).catch({ name: 'Groq', apiKey: '', baseUrl: '', useCustomUrl: false, - models: new Set() + models: [] }) const GrokProviderSchema = z.object({ @@ -174,13 +174,13 @@ const GrokProviderSchema = z.object({ apiKey: z.string().catch(''), baseUrl: z.string().catch(''), useCustomUrl: z.boolean().catch(false), - models: z.set(z.string()).catch(new Set()) + models: z.array(z.string()).catch([]) }).catch({ name: 'Grok', apiKey: '', baseUrl: '', useCustomUrl: false, - models: new Set() + models: [] }) const ollamaModelSchema = z.object({