From 2f34794a3c501bfde45b2910afbff10ad741167d Mon Sep 17 00:00:00 2001 From: duanfuxiang Date: Thu, 12 Jun 2025 16:01:37 +0800 Subject: [PATCH] update models settings test --- src/core/llm/openai-compatible.ts | 25 ++++- src/core/llm/openai-message-adapter.ts | 4 +- .../components/ModelProviderSettings.tsx | 104 ++++++++++++++++-- 3 files changed, 122 insertions(+), 11 deletions(-) diff --git a/src/core/llm/openai-compatible.ts b/src/core/llm/openai-compatible.ts index f946115..9ce940b 100644 --- a/src/core/llm/openai-compatible.ts +++ b/src/core/llm/openai-compatible.ts @@ -1,5 +1,6 @@ import OpenAI from 'openai' +import { ALIBABA_QWEN_BASE_URL } from '../../constants' import { LLMModel } from '../../types/llm/model' import { LLMOptions, @@ -32,6 +33,24 @@ export class OpenAICompatibleProvider implements BaseLLMProvider { this.baseURL = baseURL } + // 检查是否为阿里云Qwen API + private isAlibabaQwen(): boolean { + return this.baseURL === ALIBABA_QWEN_BASE_URL || + this.baseURL?.includes('dashscope.aliyuncs.com') + } + + // 获取提供商特定的额外参数 + private getExtraParams(isStreaming: boolean): Record { + const extraParams: Record = {} + + // 阿里云Qwen API需要在非流式调用中设置 enable_thinking: false + if (this.isAlibabaQwen() && !isStreaming) { + extraParams.enable_thinking = false + } + + return extraParams + } + async generateResponse( model: LLMModel, request: LLMRequestNonStreaming, @@ -43,7 +62,8 @@ export class OpenAICompatibleProvider implements BaseLLMProvider { ) } - return this.adapter.generateResponse(this.client, request, options) + const extraParams = this.getExtraParams(false) // 非流式调用 + return this.adapter.generateResponse(this.client, request, options, extraParams) } async streamResponse( @@ -57,6 +77,7 @@ export class OpenAICompatibleProvider implements BaseLLMProvider { ) } - return this.adapter.streamResponse(this.client, request, options) + const extraParams = this.getExtraParams(true) // 流式调用 + return this.adapter.streamResponse(this.client, request, options, extraParams) } } diff --git a/src/core/llm/openai-message-adapter.ts b/src/core/llm/openai-message-adapter.ts index 43d0539..78251cb 100644 --- a/src/core/llm/openai-message-adapter.ts +++ b/src/core/llm/openai-message-adapter.ts @@ -21,7 +21,8 @@ export class OpenAIMessageAdapter { async generateResponse( client: OpenAI, request: LLMRequestNonStreaming, - options?: LLMOptions, + options?: LLMOptions, + extraParams?: Record, ): Promise { const response = await client.chat.completions.create( { @@ -36,6 +37,7 @@ export class OpenAIMessageAdapter { presence_penalty: request.presence_penalty, logit_bias: request.logit_bias, prediction: request.prediction, + ...extraParams, }, { signal: options?.signal, diff --git a/src/settings/components/ModelProviderSettings.tsx b/src/settings/components/ModelProviderSettings.tsx index cba3050..704998f 100644 --- a/src/settings/components/ModelProviderSettings.tsx +++ b/src/settings/components/ModelProviderSettings.tsx @@ -172,16 +172,104 @@ const CustomProviderSettings: React.FC = ({ plugin, }; const testApiConnection = async (provider: ApiProvider) => { - // TODO: 实现API连接测试逻辑 - // 这里应该根据provider类型调用对应的API测试接口 console.log(`Testing connection for ${provider}...`); - // 模拟延迟 - await new Promise(resolve => setTimeout(resolve, 1000)); - - // 模拟随机成功/失败(用于演示) - if (Math.random() > 0.5) { - throw new Error('Connection test failed'); + try { + // 动态导入LLMManager以避免循环依赖 + const { default: LLMManager } = await import('../../core/llm/manager'); + const { GetDefaultModelId } = await import('../../utils/api'); + + // 对于Ollama和OpenAICompatible,不支持测试API连接 + if (provider === ApiProvider.Ollama || provider === ApiProvider.OpenAICompatible) { + throw new Error(`不支持测试 ${provider} 的API连接`); + } + + // 创建LLM管理器实例 + const llmManager = new LLMManager(settings); + + // 获取提供商的默认聊天模型 + const defaultModels = GetDefaultModelId(provider); + const testModelId = defaultModels.chat; + + // 对于没有默认模型的提供商,使用通用的测试模型 + if (!testModelId) { + throw new Error(`No default chat model available for ${provider}`); + } + + // 构造测试模型对象 + const testModel = { + provider: provider, + modelId: testModelId + }; + + // 构造简单的测试请求 + const testRequest = { + messages: [ + { + role: 'user' as const, + content: 'echo hi' + } + ], + model: testModelId, + max_tokens: 10, + temperature: 0 + }; + + // 设置超时选项 + const abortController = new AbortController(); + const timeoutId = setTimeout(() => abortController.abort(), 10000); // 10秒超时 + + try { + // 发起API调用测试 + const response = await llmManager.generateResponse( + testModel, + testRequest, + { signal: abortController.signal } + ); + + clearTimeout(timeoutId); + + // 检查响应是否有效 + if (response && response.choices && response.choices.length > 0) { + console.log(`✅ ${provider} connection test successful:`, response.choices[0]?.message?.content); + // ApiKeyComponent expects no return value on success, just no thrown error + return; + } else { + throw new Error('Invalid response format'); + } + } catch (apiError) { + clearTimeout(timeoutId); + throw apiError; + } + + } catch (error) { + console.error(`❌ ${provider} connection test failed:`, error); + + // 根据错误类型提供更具体的错误信息 + let errorMessage = '连接测试失败'; + + if (error.message?.includes('API key')) { + errorMessage = 'API Key 无效或缺失'; + } else if (error.message?.includes('base URL') || error.message?.includes('baseURL')) { + errorMessage = '基础URL设置错误'; + } else if (error.message?.includes('timeout') || error.name === 'AbortError') { + errorMessage = '请求超时,请检查网络连接'; + } else if (error.message?.includes('fetch')) { + errorMessage = '网络连接失败'; + } else if (error.message?.includes('401')) { + errorMessage = 'API Key 授权失败'; + } else if (error.message?.includes('403')) { + errorMessage = '访问被拒绝,请检查API Key权限'; + } else if (error.message?.includes('429')) { + errorMessage = '请求频率过高,请稍后重试'; + } else if (error.message?.includes('500')) { + errorMessage = '服务器内部错误'; + } else if (error.message) { + errorMessage = error.message; + } + alert(errorMessage); + // 必须抛出错误,这样ApiKeyComponent才能正确显示失败状态 + throw new Error(errorMessage); } };