update models settings test

This commit is contained in:
duanfuxiang 2025-06-12 16:01:37 +08:00
parent 6501132d80
commit 2f34794a3c
3 changed files with 122 additions and 11 deletions

View File

@ -1,5 +1,6 @@
import OpenAI from 'openai' import OpenAI from 'openai'
import { ALIBABA_QWEN_BASE_URL } from '../../constants'
import { LLMModel } from '../../types/llm/model' import { LLMModel } from '../../types/llm/model'
import { import {
LLMOptions, LLMOptions,
@ -32,6 +33,24 @@ export class OpenAICompatibleProvider implements BaseLLMProvider {
this.baseURL = baseURL this.baseURL = baseURL
} }
// 检查是否为阿里云Qwen API
private isAlibabaQwen(): boolean {
return this.baseURL === ALIBABA_QWEN_BASE_URL ||
this.baseURL?.includes('dashscope.aliyuncs.com')
}
// 获取提供商特定的额外参数
private getExtraParams(isStreaming: boolean): Record<string, any> {
const extraParams: Record<string, any> = {}
// 阿里云Qwen API需要在非流式调用中设置 enable_thinking: false
if (this.isAlibabaQwen() && !isStreaming) {
extraParams.enable_thinking = false
}
return extraParams
}
async generateResponse( async generateResponse(
model: LLMModel, model: LLMModel,
request: LLMRequestNonStreaming, request: LLMRequestNonStreaming,
@ -43,7 +62,8 @@ export class OpenAICompatibleProvider implements BaseLLMProvider {
) )
} }
return this.adapter.generateResponse(this.client, request, options) const extraParams = this.getExtraParams(false) // 非流式调用
return this.adapter.generateResponse(this.client, request, options, extraParams)
} }
async streamResponse( async streamResponse(
@ -57,6 +77,7 @@ export class OpenAICompatibleProvider implements BaseLLMProvider {
) )
} }
return this.adapter.streamResponse(this.client, request, options) const extraParams = this.getExtraParams(true) // 流式调用
return this.adapter.streamResponse(this.client, request, options, extraParams)
} }
} }

View File

@ -21,7 +21,8 @@ export class OpenAIMessageAdapter {
async generateResponse( async generateResponse(
client: OpenAI, client: OpenAI,
request: LLMRequestNonStreaming, request: LLMRequestNonStreaming,
options?: LLMOptions, options?: LLMOptions,
extraParams?: Record<string, any>,
): Promise<LLMResponseNonStreaming> { ): Promise<LLMResponseNonStreaming> {
const response = await client.chat.completions.create( const response = await client.chat.completions.create(
{ {
@ -36,6 +37,7 @@ export class OpenAIMessageAdapter {
presence_penalty: request.presence_penalty, presence_penalty: request.presence_penalty,
logit_bias: request.logit_bias, logit_bias: request.logit_bias,
prediction: request.prediction, prediction: request.prediction,
...extraParams,
}, },
{ {
signal: options?.signal, signal: options?.signal,

View File

@ -172,16 +172,104 @@ const CustomProviderSettings: React.FC<CustomProviderSettingsProps> = ({ plugin,
}; };
const testApiConnection = async (provider: ApiProvider) => { const testApiConnection = async (provider: ApiProvider) => {
// TODO: 实现API连接测试逻辑
// 这里应该根据provider类型调用对应的API测试接口
console.log(`Testing connection for ${provider}...`); console.log(`Testing connection for ${provider}...`);
// 模拟延迟 try {
await new Promise(resolve => setTimeout(resolve, 1000)); // 动态导入LLMManager以避免循环依赖
const { default: LLMManager } = await import('../../core/llm/manager');
// 模拟随机成功/失败(用于演示) const { GetDefaultModelId } = await import('../../utils/api');
if (Math.random() > 0.5) {
throw new Error('Connection test failed'); // 对于Ollama和OpenAICompatible不支持测试API连接
if (provider === ApiProvider.Ollama || provider === ApiProvider.OpenAICompatible) {
throw new Error(`不支持测试 ${provider} 的API连接`);
}
// 创建LLM管理器实例
const llmManager = new LLMManager(settings);
// 获取提供商的默认聊天模型
const defaultModels = GetDefaultModelId(provider);
const testModelId = defaultModels.chat;
// 对于没有默认模型的提供商,使用通用的测试模型
if (!testModelId) {
throw new Error(`No default chat model available for ${provider}`);
}
// 构造测试模型对象
const testModel = {
provider: provider,
modelId: testModelId
};
// 构造简单的测试请求
const testRequest = {
messages: [
{
role: 'user' as const,
content: 'echo hi'
}
],
model: testModelId,
max_tokens: 10,
temperature: 0
};
// 设置超时选项
const abortController = new AbortController();
const timeoutId = setTimeout(() => abortController.abort(), 10000); // 10秒超时
try {
// 发起API调用测试
const response = await llmManager.generateResponse(
testModel,
testRequest,
{ signal: abortController.signal }
);
clearTimeout(timeoutId);
// 检查响应是否有效
if (response && response.choices && response.choices.length > 0) {
console.log(`${provider} connection test successful:`, response.choices[0]?.message?.content);
// ApiKeyComponent expects no return value on success, just no thrown error
return;
} else {
throw new Error('Invalid response format');
}
} catch (apiError) {
clearTimeout(timeoutId);
throw apiError;
}
} catch (error) {
console.error(`${provider} connection test failed:`, error);
// 根据错误类型提供更具体的错误信息
let errorMessage = '连接测试失败';
if (error.message?.includes('API key')) {
errorMessage = 'API Key 无效或缺失';
} else if (error.message?.includes('base URL') || error.message?.includes('baseURL')) {
errorMessage = '基础URL设置错误';
} else if (error.message?.includes('timeout') || error.name === 'AbortError') {
errorMessage = '请求超时,请检查网络连接';
} else if (error.message?.includes('fetch')) {
errorMessage = '网络连接失败';
} else if (error.message?.includes('401')) {
errorMessage = 'API Key 授权失败';
} else if (error.message?.includes('403')) {
errorMessage = '访问被拒绝请检查API Key权限';
} else if (error.message?.includes('429')) {
errorMessage = '请求频率过高,请稍后重试';
} else if (error.message?.includes('500')) {
errorMessage = '服务器内部错误';
} else if (error.message) {
errorMessage = error.message;
}
alert(errorMessage);
// 必须抛出错误这样ApiKeyComponent才能正确显示失败状态
throw new Error(errorMessage);
} }
}; };