perf: model framwork

This commit is contained in:
archer 2023-04-29 15:55:47 +08:00
parent cd9acab938
commit 78762498eb
No known key found for this signature in database
GPG Key ID: 569A5660D2379E28
30 changed files with 649 additions and 757 deletions

View File

@ -12,8 +12,7 @@ export const getMyModels = () => GET<ModelSchema[]>('/model/list');
/** /**
* *
*/ */
export const postCreateModel = (data: { name: string; serviceModelName: string }) => export const postCreateModel = (data: { name: string }) => POST<string>('/model/create', data);
POST<ModelSchema>('/model/create', data);
/** /**
* ID * ID

View File

@ -7,7 +7,6 @@ export type InitChatResponse = {
name: string; name: string;
avatar: string; avatar: string;
intro: string; intro: string;
chatModel: ModelSchema.service.chatModel; // 对话模型名 chatModel: ModelSchema['chat']['chatModel']; // 对话模型名
modelName: ModelSchema.service.modelName; // 底层模型
history: ChatItemType[]; history: ChatItemType[];
}; };

View File

@ -1,50 +1,32 @@
import type { ModelSchema } from '@/types/mongoSchema'; import type { ModelSchema } from '@/types/mongoSchema';
export const embeddingModel = 'text-embedding-ada-002'; export const embeddingModel = 'text-embedding-ada-002';
export enum ChatModelEnum { export enum ChatModelEnum {
'GPT35' = 'gpt-3.5-turbo', 'GPT35' = 'gpt-3.5-turbo',
'GPT4' = 'gpt-4', 'GPT4' = 'gpt-4',
'GPT432k' = 'gpt-4-32k' 'GPT432k' = 'gpt-4-32k'
} }
export const ChatModelMap = {
export enum ModelNameEnum { // ui name
GPT35 = 'gpt-3.5-turbo', [ChatModelEnum.GPT35]: 'ChatGpt',
VECTOR_GPT = 'VECTOR_GPT' [ChatModelEnum.GPT4]: 'Gpt4',
} [ChatModelEnum.GPT432k]: 'Gpt4-32k'
export const Model2ChatModelMap: Record<`${ModelNameEnum}`, `${ChatModelEnum}`> = {
[ModelNameEnum.GPT35]: 'gpt-3.5-turbo',
[ModelNameEnum.VECTOR_GPT]: 'gpt-3.5-turbo'
}; };
export type ModelConstantsData = { export type ChatModelConstantType = {
icon: 'model' | 'dbModel'; chatModel: `${ChatModelEnum}`;
name: string;
model: `${ModelNameEnum}`;
trainName: string; // 空字符串代表不能训练
contextMaxToken: number; contextMaxToken: number;
maxTemperature: number; maxTemperature: number;
price: number; // 多少钱 / 1token单位: 0.00001元 price: number; // 多少钱 / 1token单位: 0.00001元
}; };
export const modelList: ModelConstantsData[] = [ export const modelList: ChatModelConstantType[] = [
{ {
icon: 'model', chatModel: ChatModelEnum.GPT35,
name: 'chatGPT',
model: ModelNameEnum.GPT35,
trainName: '',
contextMaxToken: 4096, contextMaxToken: 4096,
maxTemperature: 1.5, maxTemperature: 1.5,
price: 3 price: 3
},
{
icon: 'dbModel',
name: '知识库',
model: ModelNameEnum.VECTOR_GPT,
trainName: 'vector',
contextMaxToken: 4096,
maxTemperature: 1,
price: 3
} }
]; ];
@ -115,14 +97,16 @@ export const ModelVectorSearchModeMap: Record<
export const defaultModel: ModelSchema = { export const defaultModel: ModelSchema = {
_id: 'modelId', _id: 'modelId',
userId: 'userId', userId: 'userId',
name: 'modelName', name: '模型名称',
avatar: '/icon/logo.png', avatar: '/icon/logo.png',
status: ModelStatusEnum.pending, status: ModelStatusEnum.pending,
updateTime: Date.now(), updateTime: Date.now(),
systemPrompt: '', chat: {
temperature: 5, useKb: false,
search: { searchMode: ModelVectorSearchModeEnum.hightSimilarity,
mode: ModelVectorSearchModeEnum.hightSimilarity systemPrompt: '',
temperature: 0,
chatModel: ChatModelEnum.GPT35
}, },
share: { share: {
isShare: false, isShare: false,
@ -130,10 +114,6 @@ export const defaultModel: ModelSchema = {
intro: '', intro: '',
collection: 0 collection: 0
}, },
service: {
chatModel: ModelNameEnum.GPT35,
modelName: ModelNameEnum.GPT35
},
security: { security: {
domain: ['*'], domain: ['*'],
contextMaxLen: 1, contextMaxLen: 1,

View File

@ -1,13 +1,14 @@
import type { NextApiRequest, NextApiResponse } from 'next'; import type { NextApiRequest, NextApiResponse } from 'next';
import { connectToDatabase } from '@/service/mongo'; import { connectToDatabase } from '@/service/mongo';
import { getOpenAIApi, authChat } from '@/service/utils/auth'; import { getOpenAIApi, authChat } from '@/service/utils/auth';
import { axiosConfig, openaiChatFilter } from '@/service/utils/tools'; import { axiosConfig, openaiChatFilter, systemPromptFilter } from '@/service/utils/tools';
import { ChatItemType } from '@/types/chat'; import { ChatItemType } from '@/types/chat';
import { jsonRes } from '@/service/response'; import { jsonRes } from '@/service/response';
import { PassThrough } from 'stream'; import { PassThrough } from 'stream';
import { modelList } from '@/constants/model'; import { modelList, ModelVectorSearchModeMap, ModelVectorSearchModeEnum } from '@/constants/model';
import { pushChatBill } from '@/service/events/pushBill'; import { pushChatBill } from '@/service/events/pushBill';
import { gpt35StreamResponse } from '@/service/utils/openai'; import { gpt35StreamResponse } from '@/service/utils/openai';
import { searchKb_openai } from '@/service/tools/searchKb';
/* 发送提示词 */ /* 发送提示词 */
export default async function handler(req: NextApiRequest, res: NextApiResponse) { export default async function handler(req: NextApiRequest, res: NextApiResponse) {
@ -46,7 +47,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
authorization authorization
}); });
const modelConstantsData = modelList.find((item) => item.model === model.service.modelName); const modelConstantsData = modelList.find((item) => item.chatModel === model.chat.chatModel);
if (!modelConstantsData) { if (!modelConstantsData) {
throw new Error('模型加载异常'); throw new Error('模型加载异常');
} }
@ -54,31 +55,84 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
// 读取对话内容 // 读取对话内容
const prompts = [...content, prompt]; const prompts = [...content, prompt];
// 如果有系统提示词,自动插入 // 使用了知识库搜索
if (model.systemPrompt) { if (model.chat.useKb) {
prompts.unshift({ const { systemPrompts } = await searchKb_openai({
obj: 'SYSTEM', apiKey: userApiKey || systemKey,
value: model.systemPrompt isPay: !userApiKey,
text: prompt.value,
similarity: ModelVectorSearchModeMap[model.chat.searchMode]?.similarity || 0.22,
modelId,
userId
}); });
// filter system prompt
if (
systemPrompts.length === 0 &&
model.chat.searchMode === ModelVectorSearchModeEnum.hightSimilarity
) {
return res.send('对不起,你的问题不在知识库中。');
}
/* 高相似度+无上下文,不添加额外知识,仅用系统提示词 */
if (
systemPrompts.length === 0 &&
model.chat.searchMode === ModelVectorSearchModeEnum.noContext
) {
prompts.unshift({
obj: 'SYSTEM',
value: model.chat.systemPrompt
});
} else {
// 有匹配情况下system 添加知识库内容。
// 系统提示词过滤,最多 2500 tokens
const filterSystemPrompt = systemPromptFilter({
model: model.chat.chatModel,
prompts: systemPrompts,
maxTokens: 2500
});
prompts.unshift({
obj: 'SYSTEM',
value: `
${model.chat.systemPrompt}
${
model.chat.searchMode === ModelVectorSearchModeEnum.hightSimilarity
? `不回答知识库外的内容.`
: ''
}
知识库内容为: ${filterSystemPrompt}'
`
});
}
} else {
// 没有用知识库搜索,仅用系统提示词
if (model.chat.systemPrompt) {
prompts.unshift({
obj: 'SYSTEM',
value: model.chat.systemPrompt
});
}
} }
// 控制在 tokens 数量,防止超出 // 控制 tokens 数量,防止超出
const filterPrompts = openaiChatFilter({ const filterPrompts = openaiChatFilter({
model: model.service.chatModel, model: model.chat.chatModel,
prompts, prompts,
maxTokens: modelConstantsData.contextMaxToken - 500 maxTokens: modelConstantsData.contextMaxToken - 500
}); });
// 计算温度 // 计算温度
const temperature = modelConstantsData.maxTemperature * (model.temperature / 10); const temperature = (modelConstantsData.maxTemperature * (model.chat.temperature / 10)).toFixed(
2
);
// console.log(filterPrompts); // console.log(filterPrompts);
// 获取 chatAPI // 获取 chatAPI
const chatAPI = getOpenAIApi(userApiKey || systemKey); const chatAPI = getOpenAIApi(userApiKey || systemKey);
// 发出请求 // 发出请求
const chatResponse = await chatAPI.createChatCompletion( const chatResponse = await chatAPI.createChatCompletion(
{ {
model: model.service.chatModel, model: model.chat.chatModel,
temperature, temperature: Number(temperature) || 0,
messages: filterPrompts, messages: filterPrompts,
frequency_penalty: 0.5, // 越大,重复内容越少 frequency_penalty: 0.5, // 越大,重复内容越少
presence_penalty: -0.5, // 越大,越容易出现新内容 presence_penalty: -0.5, // 越大,越容易出现新内容
@ -105,7 +159,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
// 只有使用平台的 key 才计费 // 只有使用平台的 key 才计费
pushChatBill({ pushChatBill({
isPay: !userApiKey, isPay: !userApiKey,
modelName: model.service.modelName, chatModel: model.chat.chatModel,
userId, userId,
chatId, chatId,
messages: filterPrompts.concat({ role: 'assistant', content: responseContent }) messages: filterPrompts.concat({ role: 'assistant', content: responseContent })

View File

@ -59,8 +59,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
name: model.name, name: model.name,
avatar: model.avatar, avatar: model.avatar,
intro: model.share.intro, intro: model.share.intro,
modelName: model.service.modelName, chatModel: model.chat.chatModel,
chatModel: model.service.chatModel,
history history
} }
}); });

View File

@ -1,189 +0,0 @@
import type { NextApiRequest, NextApiResponse } from 'next';
import { connectToDatabase } from '@/service/mongo';
import { authChat } from '@/service/utils/auth';
import { axiosConfig, systemPromptFilter, openaiChatFilter } from '@/service/utils/tools';
import { ChatItemType } from '@/types/chat';
import { jsonRes } from '@/service/response';
import { PassThrough } from 'stream';
import {
modelList,
ModelVectorSearchModeMap,
ModelVectorSearchModeEnum,
ModelDataStatusEnum
} from '@/constants/model';
import { pushChatBill } from '@/service/events/pushBill';
import { openaiCreateEmbedding, gpt35StreamResponse } from '@/service/utils/openai';
import dayjs from 'dayjs';
import { PgClient } from '@/service/pg';
/* 发送提示词 */
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
let step = 0; // step=1时表示开始了流响应
const stream = new PassThrough();
stream.on('error', () => {
console.log('error: ', 'stream error');
stream.destroy();
});
res.on('close', () => {
stream.destroy();
});
res.on('error', () => {
console.log('error: ', 'request error');
stream.destroy();
});
try {
const { modelId, chatId, prompt } = req.body as {
modelId: string;
chatId: '' | string;
prompt: ChatItemType;
};
const { authorization } = req.headers;
if (!modelId || !prompt) {
throw new Error('缺少参数');
}
await connectToDatabase();
let startTime = Date.now();
const { model, content, userApiKey, systemKey, userId } = await authChat({
modelId,
chatId,
authorization
});
const modelConstantsData = modelList.find((item) => item.model === model.service.modelName);
if (!modelConstantsData) {
throw new Error('模型加载异常');
}
// 读取对话内容
const prompts = [...content, prompt];
// 获取提示词的向量
const { vector: promptVector, chatAPI } = await openaiCreateEmbedding({
isPay: !userApiKey,
apiKey: userApiKey || systemKey,
userId,
text: prompt.value
});
// 相似度搜素
const similarity = ModelVectorSearchModeMap[model.search.mode]?.similarity || 0.22;
const vectorSearch = await PgClient.select<{ id: string; q: string; a: string }>('modelData', {
fields: ['id', 'q', 'a'],
where: [
['status', ModelDataStatusEnum.ready],
'AND',
['model_id', model._id],
'AND',
`vector <=> '[${promptVector}]' < ${similarity}`
],
order: [{ field: 'vector', mode: `<=> '[${promptVector}]'` }],
limit: 20
});
const formatRedisPrompt: string[] = vectorSearch.rows.map((item) => `${item.q}\n${item.a}`);
/* 高相似度+退出,无法匹配时直接退出 */
if (
formatRedisPrompt.length === 0 &&
model.search.mode === ModelVectorSearchModeEnum.hightSimilarity
) {
return res.send('对不起,你的问题不在知识库中。');
}
/* 高相似度+无上下文,不添加额外知识 */
if (
formatRedisPrompt.length === 0 &&
model.search.mode === ModelVectorSearchModeEnum.noContext
) {
prompts.unshift({
obj: 'SYSTEM',
value: model.systemPrompt
});
} else {
// 有匹配情况下system 添加知识库内容。
// 系统提示词过滤,最多 2500 tokens
const systemPrompt = systemPromptFilter({
model: model.service.chatModel,
prompts: formatRedisPrompt,
maxTokens: 2500
});
prompts.unshift({
obj: 'SYSTEM',
value: `
${model.systemPrompt}
${
model.search.mode === ModelVectorSearchModeEnum.hightSimilarity
? `你只能从知识库选择内容回答.不在知识库内容拒绝回复`
: ''
}
知识库内容为: 当前时间为${dayjs().format('YYYY/MM/DD HH:mm:ss')}\n${systemPrompt}'
`
});
}
// 控制在 tokens 数量,防止超出
const filterPrompts = openaiChatFilter({
model: model.service.chatModel,
prompts,
maxTokens: modelConstantsData.contextMaxToken - 500
});
// console.log(filterPrompts);
// 计算温度
const temperature = modelConstantsData.maxTemperature * (model.temperature / 10);
// 发出请求
const chatResponse = await chatAPI.createChatCompletion(
{
model: model.service.chatModel,
temperature,
messages: filterPrompts,
frequency_penalty: 0.5, // 越大,重复内容越少
presence_penalty: -0.5, // 越大,越容易出现新内容
stream: true,
stop: ['.!?。']
},
{
timeout: 40000,
responseType: 'stream',
...axiosConfig()
}
);
console.log('api response time:', `${(Date.now() - startTime) / 1000}s`);
step = 1;
const { responseContent } = await gpt35StreamResponse({
res,
stream,
chatResponse
});
// 只有使用平台的 key 才计费
pushChatBill({
isPay: !userApiKey,
modelName: model.service.modelName,
userId,
chatId,
messages: filterPrompts.concat({ role: 'assistant', content: responseContent })
});
// jsonRes(res);
} catch (err: any) {
if (step === 1) {
// 直接结束流
console.log('error结束');
stream.destroy();
} else {
res.status(500);
jsonRes(res, {
code: 500,
error: err
});
}
}
}

View File

@ -3,14 +3,13 @@ import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response'; import { jsonRes } from '@/service/response';
import { connectToDatabase } from '@/service/mongo'; import { connectToDatabase } from '@/service/mongo';
import { authToken } from '@/service/utils/tools'; import { authToken } from '@/service/utils/tools';
import { ModelStatusEnum, modelList, ModelNameEnum, Model2ChatModelMap } from '@/constants/model'; import { ModelStatusEnum } from '@/constants/model';
import { Model } from '@/service/models/model'; import { Model } from '@/service/models/model';
export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) { export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
try { try {
const { name, serviceModelName } = req.body as { const { name } = req.body as {
name: string; name: string;
serviceModelName: `${ModelNameEnum}`;
}; };
const { authorization } = req.headers; const { authorization } = req.headers;
@ -18,45 +17,32 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
throw new Error('无权操作'); throw new Error('无权操作');
} }
if (!name || !serviceModelName) { if (!name) {
throw new Error('缺少参数'); throw new Error('缺少参数');
} }
// 凭证校验 // 凭证校验
const userId = await authToken(authorization); const userId = await authToken(authorization);
const modelItem = modelList.find((item) => item.model === serviceModelName);
if (!modelItem) {
throw new Error('模型不存在');
}
await connectToDatabase(); await connectToDatabase();
// 上限校验 // 上限校验
const authCount = await Model.countDocuments({ const authCount = await Model.countDocuments({
userId userId
}); });
if (authCount >= 20) { if (authCount >= 30) {
throw new Error('上限 20 个模型'); throw new Error('上限 30 个模型');
} }
// 创建模型 // 创建模型
const response = await Model.create({ const response = await Model.create({
name, name,
userId, userId,
status: ModelStatusEnum.running, status: ModelStatusEnum.running
service: {
chatModel: Model2ChatModelMap[modelItem.model], // 聊天时用的模型
modelName: modelItem.model // 最底层的模型,不会变,用于计费等核心操作
}
}); });
// 根据 id 获取模型信息
const model = await Model.findById(response._id);
jsonRes(res, { jsonRes(res, {
data: model data: response._id
}); });
} catch (err) { } catch (err) {
jsonRes(res, { jsonRes(res, {

View File

@ -9,8 +9,7 @@ import { authModel } from '@/service/utils/auth';
/* 获取我的模型 */ /* 获取我的模型 */
export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) { export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
try { try {
const { name, avatar, search, share, service, security, systemPrompt, temperature } = const { name, avatar, chat, share, security } = req.body as ModelUpdateParams;
req.body as ModelUpdateParams;
const { modelId } = req.query as { modelId: string }; const { modelId } = req.query as { modelId: string };
const { authorization } = req.headers; const { authorization } = req.headers;
@ -18,7 +17,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
throw new Error('无权操作'); throw new Error('无权操作');
} }
if (!name || !service || !security || !modelId) { if (!name || !chat || !security || !modelId) {
throw new Error('参数错误'); throw new Error('参数错误');
} }
@ -41,12 +40,10 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
{ {
name, name,
avatar, avatar,
systemPrompt, chat,
temperature,
'share.isShare': share.isShare, 'share.isShare': share.isShare,
'share.isShareDetail': share.isShareDetail, 'share.isShareDetail': share.isShareDetail,
'share.intro': share.intro, 'share.intro': share.intro,
search,
security security
} }
); );

View File

@ -0,0 +1,202 @@
import type { NextApiRequest, NextApiResponse } from 'next';
import { connectToDatabase } from '@/service/mongo';
import { getOpenAIApi, authOpenApiKey, authModel } from '@/service/utils/auth';
import { axiosConfig, openaiChatFilter, systemPromptFilter } from '@/service/utils/tools';
import { ChatItemType } from '@/types/chat';
import { jsonRes } from '@/service/response';
import { PassThrough } from 'stream';
import { modelList, ModelVectorSearchModeMap, ModelVectorSearchModeEnum } from '@/constants/model';
import { pushChatBill } from '@/service/events/pushBill';
import { gpt35StreamResponse } from '@/service/utils/openai';
import { searchKb_openai } from '@/service/tools/searchKb';
/* 发送提示词 */
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
let step = 0; // step=1时表示开始了流响应
const stream = new PassThrough();
stream.on('error', () => {
console.log('error: ', 'stream error');
stream.destroy();
});
res.on('close', () => {
stream.destroy();
});
res.on('error', () => {
console.log('error: ', 'request error');
stream.destroy();
});
try {
const {
prompts,
modelId,
isStream = true
} = req.body as {
prompts: ChatItemType[];
modelId: string;
isStream: boolean;
};
if (!prompts || !modelId) {
throw new Error('缺少参数');
}
if (!Array.isArray(prompts)) {
throw new Error('prompts is not array');
}
if (prompts.length > 30 || prompts.length === 0) {
throw new Error('prompts length range 1-30');
}
await connectToDatabase();
let startTime = Date.now();
/* 凭证校验 */
const { apiKey, userId } = await authOpenApiKey(req);
const { model } = await authModel({
userId,
modelId
});
const modelConstantsData = modelList.find((item) => item.chatModel === model.chat.chatModel);
if (!modelConstantsData) {
throw new Error('模型加载异常');
}
// 使用了知识库搜索
if (model.chat.useKb) {
const similarity = ModelVectorSearchModeMap[model.chat.searchMode]?.similarity || 0.22;
const { systemPrompts } = await searchKb_openai({
apiKey,
isPay: true,
text: prompts[prompts.length - 1].value,
similarity,
modelId,
userId
});
// filter system prompt
if (
systemPrompts.length === 0 &&
model.chat.searchMode === ModelVectorSearchModeEnum.hightSimilarity
) {
return jsonRes(res, {
code: 500,
message: '对不起,你的问题不在知识库中。',
data: '对不起,你的问题不在知识库中。'
});
}
/* 高相似度+无上下文,不添加额外知识,仅用系统提示词 */
if (
systemPrompts.length === 0 &&
model.chat.searchMode === ModelVectorSearchModeEnum.noContext
) {
prompts.unshift({
obj: 'SYSTEM',
value: model.chat.systemPrompt
});
} else {
// 有匹配情况下system 添加知识库内容。
// 系统提示词过滤,最多 2500 tokens
const filterSystemPrompt = systemPromptFilter({
model: model.chat.chatModel,
prompts: systemPrompts,
maxTokens: 2500
});
prompts.unshift({
obj: 'SYSTEM',
value: `
${model.chat.systemPrompt}
${
model.chat.searchMode === ModelVectorSearchModeEnum.hightSimilarity
? `不回答知识库外的内容.`
: ''
}
知识库内容为: ${filterSystemPrompt}'
`
});
}
} else {
// 没有用知识库搜索,仅用系统提示词
if (model.chat.systemPrompt) {
prompts.unshift({
obj: 'SYSTEM',
value: model.chat.systemPrompt
});
}
}
// 控制总 tokens 数量,防止超出
const filterPrompts = openaiChatFilter({
model: model.chat.chatModel,
prompts,
maxTokens: modelConstantsData.contextMaxToken - 500
});
// 计算温度
const temperature = (modelConstantsData.maxTemperature * (model.chat.temperature / 10)).toFixed(
2
);
// console.log(filterPrompts);
// 获取 chatAPI
const chatAPI = getOpenAIApi(apiKey);
// 发出请求
const chatResponse = await chatAPI.createChatCompletion(
{
model: model.chat.chatModel,
temperature: Number(temperature) || 0,
messages: filterPrompts,
frequency_penalty: 0.5, // 越大,重复内容越少
presence_penalty: -0.5, // 越大,越容易出现新内容
stream: isStream,
stop: ['.!?。']
},
{
timeout: 180000,
responseType: isStream ? 'stream' : 'json',
...axiosConfig()
}
);
console.log('api response time:', `${(Date.now() - startTime) / 1000}s`);
let responseContent = '';
if (isStream) {
step = 1;
const streamResponse = await gpt35StreamResponse({
res,
stream,
chatResponse
});
responseContent = streamResponse.responseContent;
} else {
responseContent = chatResponse.data.choices?.[0]?.message?.content || '';
jsonRes(res, {
data: responseContent
});
}
// 只有使用平台的 key 才计费
pushChatBill({
isPay: true,
chatModel: model.chat.chatModel,
userId,
messages: filterPrompts.concat({ role: 'assistant', content: responseContent })
});
} catch (err: any) {
if (step === 1) {
// 直接结束流
console.log('error结束');
stream.destroy();
} else {
res.status(500);
jsonRes(res, {
code: 500,
error: err
});
}
}
}

View File

@ -1,7 +1,7 @@
import type { NextApiRequest, NextApiResponse } from 'next'; import type { NextApiRequest, NextApiResponse } from 'next';
import { connectToDatabase, Model } from '@/service/mongo'; import { connectToDatabase, Model } from '@/service/mongo';
import { getOpenAIApi } from '@/service/utils/auth'; import { getOpenAIApi, authOpenApiKey } from '@/service/utils/auth';
import { axiosConfig, openaiChatFilter, authOpenApiKey } from '@/service/utils/tools'; import { axiosConfig, openaiChatFilter } from '@/service/utils/tools';
import { ChatItemType } from '@/types/chat'; import { ChatItemType } from '@/types/chat';
import { jsonRes } from '@/service/response'; import { jsonRes } from '@/service/response';
import { PassThrough } from 'stream'; import { PassThrough } from 'stream';
@ -60,37 +60,38 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
throw new Error('无权使用该模型'); throw new Error('无权使用该模型');
} }
const modelConstantsData = modelList.find((item) => item.model === model.service.modelName); const modelConstantsData = modelList.find((item) => item.chatModel === model.chat.chatModel);
if (!modelConstantsData) { if (!modelConstantsData) {
throw new Error('模型加载异常'); throw new Error('模型加载异常');
} }
// 如果有系统提示词,自动插入 // 如果有系统提示词,自动插入
if (model.systemPrompt) { if (model.chat.systemPrompt) {
prompts.unshift({ prompts.unshift({
obj: 'SYSTEM', obj: 'SYSTEM',
value: model.systemPrompt value: model.chat.systemPrompt
}); });
} }
// 控制在 tokens 数量,防止超出 // 控制在 tokens 数量,防止超出
const filterPrompts = openaiChatFilter({ const filterPrompts = openaiChatFilter({
model: model.service.chatModel, model: model.chat.chatModel,
prompts, prompts,
maxTokens: modelConstantsData.contextMaxToken - 500 maxTokens: modelConstantsData.contextMaxToken - 500
}); });
// console.log(filterPrompts); // console.log(filterPrompts);
// 计算温度 // 计算温度
const temperature = modelConstantsData.maxTemperature * (model.temperature / 10); const temperature = (modelConstantsData.maxTemperature * (model.chat.temperature / 10)).toFixed(
2
);
// 获取 chatAPI // 获取 chatAPI
const chatAPI = getOpenAIApi(apiKey); const chatAPI = getOpenAIApi(apiKey);
// 发出请求 // 发出请求
const chatResponse = await chatAPI.createChatCompletion( const chatResponse = await chatAPI.createChatCompletion(
{ {
model: model.service.chatModel, model: model.chat.chatModel,
temperature, temperature: Number(temperature) || 0,
messages: filterPrompts, messages: filterPrompts,
frequency_penalty: 0.5, // 越大,重复内容越少 frequency_penalty: 0.5, // 越大,重复内容越少
presence_penalty: -0.5, // 越大,越容易出现新内容 presence_penalty: -0.5, // 越大,越容易出现新内容
@ -126,7 +127,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
// 只有使用平台的 key 才计费 // 只有使用平台的 key 才计费
pushChatBill({ pushChatBill({
isPay: true, isPay: true,
modelName: model.service.modelName, chatModel: model.chat.chatModel,
userId, userId,
messages: filterPrompts.concat({ role: 'assistant', content: responseContent }) messages: filterPrompts.concat({ role: 'assistant', content: responseContent })
}); });

View File

@ -1,20 +1,14 @@
import type { NextApiRequest, NextApiResponse } from 'next'; import type { NextApiRequest, NextApiResponse } from 'next';
import { connectToDatabase, Model } from '@/service/mongo'; import { connectToDatabase, Model } from '@/service/mongo';
import { getOpenAIApi } from '@/service/utils/auth'; import { getOpenAIApi, authOpenApiKey } from '@/service/utils/auth';
import { authOpenApiKey } from '@/service/utils/tools';
import { axiosConfig, openaiChatFilter, systemPromptFilter } from '@/service/utils/tools'; import { axiosConfig, openaiChatFilter, systemPromptFilter } from '@/service/utils/tools';
import { ChatItemType } from '@/types/chat'; import { ChatItemType } from '@/types/chat';
import { jsonRes } from '@/service/response'; import { jsonRes } from '@/service/response';
import { PassThrough } from 'stream'; import { PassThrough } from 'stream';
import { import { modelList, ModelVectorSearchModeMap, ChatModelEnum } from '@/constants/model';
ModelNameEnum,
modelList,
ModelVectorSearchModeMap,
ChatModelEnum
} from '@/constants/model';
import { pushChatBill } from '@/service/events/pushBill'; import { pushChatBill } from '@/service/events/pushBill';
import { openaiCreateEmbedding, gpt35StreamResponse } from '@/service/utils/openai'; import { gpt35StreamResponse } from '@/service/utils/openai';
import { PgClient } from '@/service/pg'; import { searchKb_openai } from '@/service/tools/searchKb';
/* 发送提示词 */ /* 发送提示词 */
export default async function handler(req: NextApiRequest, res: NextApiResponse) { export default async function handler(req: NextApiRequest, res: NextApiResponse) {
@ -59,10 +53,11 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
throw new Error('找不到模型'); throw new Error('找不到模型');
} }
const modelConstantsData = modelList.find((item) => item.model === ModelNameEnum.VECTOR_GPT); const modelConstantsData = modelList.find((item) => item.chatModel === model.chat.chatModel);
if (!modelConstantsData) { if (!modelConstantsData) {
throw new Error('模型已下架'); throw new Error('model is undefined');
} }
console.log('laf gpt start'); console.log('laf gpt start');
// 获取 chatAPI // 获取 chatAPI
@ -132,62 +127,48 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
prompt.value += ` ${promptResolve}`; prompt.value += ` ${promptResolve}`;
console.log('prompt resolve success, time:', `${(Date.now() - startTime) / 1000}s`); console.log('prompt resolve success, time:', `${(Date.now() - startTime) / 1000}s`);
// 获取提示词的向量
const { vector: promptVector } = await openaiCreateEmbedding({
isPay: true,
apiKey,
userId,
text: prompt.value
});
// 读取对话内容 // 读取对话内容
const prompts = [prompt]; const prompts = [prompt];
// 相似度搜索 // 获取向量匹配到的提示词
const similarity = ModelVectorSearchModeMap[model.search.mode]?.similarity || 0.22; const { systemPrompts } = await searchKb_openai({
const vectorSearch = await PgClient.select<{ id: string; q: string; a: string }>('modelData', { isPay: true,
fields: ['id', 'q', 'a'], apiKey,
order: [{ field: 'vector', mode: `<=> '[${promptVector}]'` }], similarity: ModelVectorSearchModeMap[model.chat.searchMode]?.similarity || 0.22,
where: [ text: prompt.value,
['model_id', model._id], modelId,
'AND', userId
['user_id', userId],
'AND',
`vector <=> '[${promptVector}]' < ${similarity}`
],
limit: 30
}); });
const formatRedisPrompt: string[] = vectorSearch.rows.map((item) => `${item.q}\n${item.a}`);
// system 筛选,最多 2500 tokens // system 筛选,最多 2500 tokens
const systemPrompt = systemPromptFilter({ const filterSystemPrompt = systemPromptFilter({
model: model.service.chatModel, model: model.chat.chatModel,
prompts: formatRedisPrompt, prompts: systemPrompts,
maxTokens: 2500 maxTokens: 2500
}); });
prompts.unshift({ prompts.unshift({
obj: 'SYSTEM', obj: 'SYSTEM',
value: `${model.systemPrompt} 知识库是最新的,下面是知识库内容:${systemPrompt}` value: `${model.chat.systemPrompt} 知识库是最新的,下面是知识库内容:${filterSystemPrompt}`
}); });
// 控制上下文 tokens 数量,防止超出 // 控制上下文 tokens 数量,防止超出
const filterPrompts = openaiChatFilter({ const filterPrompts = openaiChatFilter({
model: model.service.chatModel, model: model.chat.chatModel,
prompts, prompts,
maxTokens: modelConstantsData.contextMaxToken - 500 maxTokens: modelConstantsData.contextMaxToken - 500
}); });
// console.log(filterPrompts); // console.log(filterPrompts);
// 计算温度 // 计算温度
const temperature = modelConstantsData.maxTemperature * (model.temperature / 10); const temperature = (modelConstantsData.maxTemperature * (model.chat.temperature / 10)).toFixed(
2
);
// 发出请求 // 发出请求
const chatResponse = await chatAPI.createChatCompletion( const chatResponse = await chatAPI.createChatCompletion(
{ {
model: model.service.chatModel, model: model.chat.chatModel,
temperature, temperature: Number(temperature) || 0,
messages: filterPrompts, messages: filterPrompts,
frequency_penalty: 0.5, // 越大,重复内容越少 frequency_penalty: 0.5, // 越大,重复内容越少
presence_penalty: -0.5, // 越大,越容易出现新内容 presence_penalty: -0.5, // 越大,越容易出现新内容
@ -223,7 +204,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
pushChatBill({ pushChatBill({
isPay: true, isPay: true,
modelName: model.service.modelName, chatModel: model.chat.chatModel,
userId, userId,
messages: filterPrompts.concat({ role: 'assistant', content: responseContent }) messages: filterPrompts.concat({ role: 'assistant', content: responseContent })
}); });

View File

@ -1,24 +1,14 @@
import type { NextApiRequest, NextApiResponse } from 'next'; import type { NextApiRequest, NextApiResponse } from 'next';
import { connectToDatabase, Model } from '@/service/mongo'; import { connectToDatabase, Model } from '@/service/mongo';
import { import { axiosConfig, systemPromptFilter, openaiChatFilter } from '@/service/utils/tools';
axiosConfig, import { getOpenAIApi, authOpenApiKey } from '@/service/utils/auth';
systemPromptFilter,
authOpenApiKey,
openaiChatFilter
} from '@/service/utils/tools';
import { ChatItemType } from '@/types/chat'; import { ChatItemType } from '@/types/chat';
import { jsonRes } from '@/service/response'; import { jsonRes } from '@/service/response';
import { PassThrough } from 'stream'; import { PassThrough } from 'stream';
import { import { modelList, ModelVectorSearchModeMap, ModelVectorSearchModeEnum } from '@/constants/model';
modelList,
ModelVectorSearchModeMap,
ModelVectorSearchModeEnum,
ModelDataStatusEnum
} from '@/constants/model';
import { pushChatBill } from '@/service/events/pushBill'; import { pushChatBill } from '@/service/events/pushBill';
import { openaiCreateEmbedding, gpt35StreamResponse } from '@/service/utils/openai'; import { gpt35StreamResponse } from '@/service/utils/openai';
import dayjs from 'dayjs'; import { searchKb_openai } from '@/service/tools/searchKb';
import { PgClient } from '@/service/pg';
/* 发送提示词 */ /* 发送提示词 */
export default async function handler(req: NextApiRequest, res: NextApiResponse) { export default async function handler(req: NextApiRequest, res: NextApiResponse) {
@ -72,96 +62,86 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
throw new Error('无权使用该模型'); throw new Error('无权使用该模型');
} }
const modelConstantsData = modelList.find((item) => item.model === model?.service?.modelName); const modelConstantsData = modelList.find((item) => item.chatModel === model.chat.chatModel);
if (!modelConstantsData) { if (!modelConstantsData) {
throw new Error('模型初始化异常'); throw new Error('模型初始化异常');
} }
// 获取提示词的向量 // 获取向量匹配到的提示词
const { vector: promptVector, chatAPI } = await openaiCreateEmbedding({ const { systemPrompts } = await searchKb_openai({
isPay: true, isPay: true,
apiKey, apiKey,
userId, similarity: ModelVectorSearchModeMap[model.chat.searchMode]?.similarity || 0.22,
text: prompts[prompts.length - 1].value // 取最后一个 text: prompts[prompts.length - 1].value,
modelId,
userId
}); });
// 相似度搜素
const similarity = ModelVectorSearchModeMap[model.search.mode]?.similarity || 0.22;
const vectorSearch = await PgClient.select<{ id: string; q: string; a: string }>('modelData', {
fields: ['id', 'q', 'a'],
where: [
['status', ModelDataStatusEnum.ready],
'AND',
['model_id', model._id],
'AND',
`vector <=> '[${promptVector}]' < ${similarity}`
],
order: [{ field: 'vector', mode: `<=> '[${promptVector}]'` }],
limit: 20
});
const formatRedisPrompt: string[] = vectorSearch.rows.map((item) => `${item.q}\n${item.a}`);
// system 合并 // system 合并
if (prompts[0].obj === 'SYSTEM') { if (prompts[0].obj === 'SYSTEM') {
formatRedisPrompt.unshift(prompts.shift()?.value || ''); systemPrompts.unshift(prompts.shift()?.value || '');
} }
/* 高相似度+退出,无法匹配时直接退出 */ /* 高相似度+退出,无法匹配时直接退出 */
if ( if (
formatRedisPrompt.length === 0 && systemPrompts.length === 0 &&
model.search.mode === ModelVectorSearchModeEnum.hightSimilarity model.chat.searchMode === ModelVectorSearchModeEnum.hightSimilarity
) { ) {
return res.send('对不起,你的问题不在知识库中。'); return jsonRes(res, {
code: 500,
message: '对不起,你的问题不在知识库中。',
data: '对不起,你的问题不在知识库中。'
});
} }
/* 高相似度+无上下文,不添加额外知识 */ /* 高相似度+无上下文,不添加额外知识 */
if ( if (
formatRedisPrompt.length === 0 && systemPrompts.length === 0 &&
model.search.mode === ModelVectorSearchModeEnum.noContext model.chat.searchMode === ModelVectorSearchModeEnum.noContext
) { ) {
prompts.unshift({ prompts.unshift({
obj: 'SYSTEM', obj: 'SYSTEM',
value: model.systemPrompt value: model.chat.systemPrompt
}); });
} else { } else {
// 有匹配或者低匹配度模式情况下,添加知识库内容。 // 有匹配或者低匹配度模式情况下,添加知识库内容。
// 系统提示词过滤,最多 2500 tokens // 系统提示词过滤,最多 2500 tokens
const systemPrompt = systemPromptFilter({ const systemPrompt = systemPromptFilter({
model: model.service.chatModel, model: model.chat.chatModel,
prompts: formatRedisPrompt, prompts: systemPrompts,
maxTokens: 2500 maxTokens: 2500
}); });
prompts.unshift({ prompts.unshift({
obj: 'SYSTEM', obj: 'SYSTEM',
value: ` value: `
${model.systemPrompt} ${model.chat.systemPrompt}
${ ${
model.search.mode === ModelVectorSearchModeEnum.hightSimilarity model.chat.searchMode === ModelVectorSearchModeEnum.hightSimilarity ? `不回答知识库外的内容.` : ''
? `你只能从知识库选择内容回答.不在知识库内容拒绝回复`
: ''
} }
知识库内容为: 当前时间为${dayjs().format('YYYY/MM/DD HH:mm:ss')}\n${systemPrompt}' 知识库内容为: ${systemPrompt}'
` `
}); });
} }
// 控制在 tokens 数量,防止超出 // 控制在 tokens 数量,防止超出
const filterPrompts = openaiChatFilter({ const filterPrompts = openaiChatFilter({
model: model.service.chatModel, model: model.chat.chatModel,
prompts, prompts,
maxTokens: modelConstantsData.contextMaxToken - 500 maxTokens: modelConstantsData.contextMaxToken - 500
}); });
// console.log(filterPrompts); // console.log(filterPrompts);
// 计算温度 // 计算温度
const temperature = modelConstantsData.maxTemperature * (model.temperature / 10); const temperature = (modelConstantsData.maxTemperature * (model.chat.temperature / 10)).toFixed(
2
);
const chatAPI = getOpenAIApi(apiKey);
// 发出请求 // 发出请求
const chatResponse = await chatAPI.createChatCompletion( const chatResponse = await chatAPI.createChatCompletion(
{ {
model: model.service.chatModel, model: model.chat.chatModel,
temperature, temperature: Number(temperature) || 0,
messages: filterPrompts, messages: filterPrompts,
frequency_penalty: 0.5, // 越大,重复内容越少 frequency_penalty: 0.5, // 越大,重复内容越少
presence_penalty: -0.5, // 越大,越容易出现新内容 presence_penalty: -0.5, // 越大,越容易出现新内容
@ -196,7 +176,7 @@ ${
pushChatBill({ pushChatBill({
isPay: true, isPay: true,
modelName: model.service.modelName, chatModel: model.chat.chatModel,
userId, userId,
messages: filterPrompts.concat({ role: 'assistant', content: responseContent }) messages: filterPrompts.concat({ role: 'assistant', content: responseContent })
}); });

View File

@ -52,7 +52,7 @@ const SlideBar = ({
const myModelList = myModels.map((item) => ({ const myModelList = myModels.map((item) => ({
id: item._id, id: item._id,
name: item.name, name: item.name,
icon: modelList.find((model) => model.model === item?.service?.modelName)?.icon || 'model' icon: 'model' as any
})); }));
const collectionList = collectionModels const collectionList = collectionModels
.map((item) => ({ .map((item) => ({

View File

@ -1,6 +1,5 @@
import React, { useCallback, useState, useRef, useMemo, useEffect } from 'react'; import React, { useCallback, useState, useRef, useMemo, useEffect } from 'react';
import { useRouter } from 'next/router'; import { useRouter } from 'next/router';
import Image from 'next/image';
import { getInitChatSiteInfo, delChatRecordByIndex, postSaveChat } from '@/api/chat'; import { getInitChatSiteInfo, delChatRecordByIndex, postSaveChat } from '@/api/chat';
import type { InitChatResponse } from '@/api/response/chat'; import type { InitChatResponse } from '@/api/response/chat';
import type { ChatItemType } from '@/types/chat'; import type { ChatItemType } from '@/types/chat';
@ -16,12 +15,13 @@ import {
Menu, Menu,
MenuButton, MenuButton,
MenuList, MenuList,
MenuItem MenuItem,
Image
} from '@chakra-ui/react'; } from '@chakra-ui/react';
import { useToast } from '@/hooks/useToast'; import { useToast } from '@/hooks/useToast';
import { useScreen } from '@/hooks/useScreen'; import { useScreen } from '@/hooks/useScreen';
import { useQuery } from '@tanstack/react-query'; import { useQuery } from '@tanstack/react-query';
import { ModelNameEnum } from '@/constants/model'; import { ChatModelEnum } from '@/constants/model';
import dynamic from 'next/dynamic'; import dynamic from 'next/dynamic';
import { useGlobalStore } from '@/store/global'; import { useGlobalStore } from '@/store/global';
import { useCopyData } from '@/utils/tools'; import { useCopyData } from '@/utils/tools';
@ -65,8 +65,7 @@ const Chat = ({ modelId, chatId }: { modelId: string; chatId: string }) => {
name: '', name: '',
avatar: '/icon/logo.png', avatar: '/icon/logo.png',
intro: '', intro: '',
chatModel: '', chatModel: ChatModelEnum.GPT35,
modelName: '',
history: [] history: []
}); // 聊天框整体数据 }); // 聊天框整体数据
@ -193,13 +192,6 @@ const Chat = ({ modelId, chatId }: { modelId: string; chatId: string }) => {
// gpt 对话 // gpt 对话
const gptChatPrompt = useCallback( const gptChatPrompt = useCallback(
async (prompts: ChatSiteItemType) => { async (prompts: ChatSiteItemType) => {
const urlMap: Record<string, string> = {
[ModelNameEnum.GPT35]: '/api/chat/chatGpt',
[ModelNameEnum.VECTOR_GPT]: '/api/chat/vectorGpt'
};
if (!urlMap[chatData.modelName]) return Promise.reject('找不到模型');
// create abort obj // create abort obj
const abortSignal = new AbortController(); const abortSignal = new AbortController();
controller.current = abortSignal; controller.current = abortSignal;
@ -212,7 +204,7 @@ const Chat = ({ modelId, chatId }: { modelId: string; chatId: string }) => {
// 流请求,获取数据 // 流请求,获取数据
const responseText = await streamFetch({ const responseText = await streamFetch({
url: urlMap[chatData.modelName], url: '/api/chat/chat',
data: { data: {
prompt, prompt,
chatId, chatId,
@ -278,7 +270,7 @@ const Chat = ({ modelId, chatId }: { modelId: string; chatId: string }) => {
}) })
})); }));
}, },
[chatData.modelName, chatId, generatingMessage, modelId, router, toast] [chatId, generatingMessage, modelId, router, toast]
); );
/** /**
@ -393,7 +385,7 @@ const Chat = ({ modelId, chatId }: { modelId: string; chatId: string }) => {
// 更新流中断对象 // 更新流中断对象
useEffect(() => { useEffect(() => {
return () => { return () => {
// eslint-disable-next-line react-hooks/exhaustive-deps isResetPage.current = true;
controller.current?.abort(); controller.current?.abort();
}; };
}, []); }, []);
@ -476,8 +468,9 @@ const Chat = ({ modelId, chatId }: { modelId: string; chatId: string }) => {
: chatData.avatar || '/icon/logo.png' : chatData.avatar || '/icon/logo.png'
} }
alt="avatar" alt="avatar"
width={media(30, 20)} w={['20px', '30px']}
height={media(30, 20)} maxH={'50px'}
objectFit={'contain'}
/> />
</MenuButton> </MenuButton>
<MenuList fontSize={'sm'}> <MenuList fontSize={'sm'}>

View File

@ -45,9 +45,10 @@ const ModelDataCard = ({ modelId, isOwner }: { modelId: string; isOwner: boolean
const [searchText, setSearchText] = useState(''); const [searchText, setSearchText] = useState('');
const tdStyles = useRef<BoxProps>({ const tdStyles = useRef<BoxProps>({
fontSize: 'xs', fontSize: 'xs',
minW: '150px',
maxW: '500px', maxW: '500px',
whiteSpace: 'pre-wrap',
maxH: '250px', maxH: '250px',
whiteSpace: 'pre-wrap',
overflowY: 'auto' overflowY: 'auto'
}); });
const { const {
@ -132,7 +133,7 @@ const ModelDataCard = ({ modelId, isOwner }: { modelId: string; isOwner: boolean
<> <>
<Flex> <Flex>
<Box fontWeight={'bold'} fontSize={'lg'} flex={1} mr={2}> <Box fontWeight={'bold'} fontSize={'lg'} flex={1} mr={2}>
: {total} : {total}
</Box> </Box>
{isOwner && ( {isOwner && (
<> <>

View File

@ -21,7 +21,7 @@ import {
import { QuestionOutlineIcon } from '@chakra-ui/icons'; import { QuestionOutlineIcon } from '@chakra-ui/icons';
import type { ModelSchema } from '@/types/mongoSchema'; import type { ModelSchema } from '@/types/mongoSchema';
import { UseFormReturn } from 'react-hook-form'; import { UseFormReturn } from 'react-hook-form';
import { modelList, ModelVectorSearchModeMap } from '@/constants/model'; import { ChatModelMap, modelList, ModelVectorSearchModeMap } from '@/constants/model';
import { formatPrice } from '@/utils/user'; import { formatPrice } from '@/utils/user';
import { useConfirm } from '@/hooks/useConfirm'; import { useConfirm } from '@/hooks/useConfirm';
import { useSelectFile } from '@/hooks/useSelectFile'; import { useSelectFile } from '@/hooks/useSelectFile';
@ -30,12 +30,10 @@ import { fileToBase64 } from '@/utils/file';
const ModelEditForm = ({ const ModelEditForm = ({
formHooks, formHooks,
canTrain,
isOwner, isOwner,
handleDelModel handleDelModel
}: { }: {
formHooks: UseFormReturn<ModelSchema>; formHooks: UseFormReturn<ModelSchema>;
canTrain: boolean;
isOwner: boolean; isOwner: boolean;
handleDelModel: () => void; handleDelModel: () => void;
}) => { }) => {
@ -73,6 +71,12 @@ const ModelEditForm = ({
<> <>
<Card p={4}> <Card p={4}>
<Box fontWeight={'bold'}></Box> <Box fontWeight={'bold'}></Box>
<Flex alignItems={'center'} mt={4}>
<Box flex={'0 0 80px'} w={0}>
modelId:
</Box>
<Box>{getValues('_id')}</Box>
</Flex>
<Flex mt={4} alignItems={'center'}> <Flex mt={4} alignItems={'center'}>
<Box flex={'0 0 80px'} w={0}> <Box flex={'0 0 80px'} w={0}>
: :
@ -101,17 +105,12 @@ const ModelEditForm = ({
></Input> ></Input>
</Flex> </Flex>
</FormControl> </FormControl>
<Flex alignItems={'center'} mt={5}> <Flex alignItems={'center'} mt={5}>
<Box flex={'0 0 80px'} w={0}> <Box flex={'0 0 80px'} w={0}>
modelId: :
</Box> </Box>
<Box>{getValues('_id')}</Box> <Box>{ChatModelMap[getValues('chat.chatModel')]}</Box>
</Flex>
<Flex alignItems={'center'} mt={5}>
<Box flex={'0 0 80px'} w={0}>
:
</Box>
<Box>{modelList.find((item) => item.model === getValues('service.modelName'))?.name}</Box>
</Flex> </Flex>
<Flex alignItems={'center'} mt={5}> <Flex alignItems={'center'} mt={5}>
<Box flex={'0 0 80px'} w={0}> <Box flex={'0 0 80px'} w={0}>
@ -119,7 +118,7 @@ const ModelEditForm = ({
</Box> </Box>
<Box> <Box>
{formatPrice( {formatPrice(
modelList.find((item) => item.model === getValues('service.modelName'))?.price || 0, modelList.find((item) => item.chatModel === getValues('chat.chatModel'))?.price || 0,
1000 1000
)} )}
/1K tokens() /1K tokens()
@ -163,15 +162,15 @@ const ModelEditForm = ({
min={0} min={0}
max={10} max={10}
step={1} step={1}
value={getValues('temperature')} value={getValues('chat.temperature')}
isDisabled={!isOwner} isDisabled={!isOwner}
onChange={(e) => { onChange={(e) => {
setValue('temperature', e); setValue('chat.temperature', e);
setRefresh(!refresh); setRefresh(!refresh);
}} }}
> >
<SliderMark <SliderMark
value={getValues('temperature')} value={getValues('chat.temperature')}
textAlign="center" textAlign="center"
bg="blue.500" bg="blue.500"
color="white" color="white"
@ -181,7 +180,7 @@ const ModelEditForm = ({
fontSize={'xs'} fontSize={'xs'}
transform={'translate(-50%, -200%)'} transform={'translate(-50%, -200%)'}
> >
{getValues('temperature')} {getValues('chat.temperature')}
</SliderMark> </SliderMark>
<SliderTrack> <SliderTrack>
<SliderFilledTrack /> <SliderFilledTrack />
@ -190,35 +189,42 @@ const ModelEditForm = ({
</Slider> </Slider>
</Flex> </Flex>
</FormControl> </FormControl>
{canTrain && ( <Flex mt={4} alignItems={'center'}>
<FormControl mt={4}> <Box mr={4}></Box>
<Flex alignItems={'center'}> <Switch
<Box flex={'0 0 70px'}></Box> isChecked={getValues('chat.useKb')}
<Select onChange={() => {
isDisabled={!isOwner} setValue('chat.useKb', !getValues('chat.useKb'));
{...register('search.mode', { required: '搜索模式不能为空' })} setRefresh(!refresh);
> }}
{Object.entries(ModelVectorSearchModeMap).map(([key, { text }]) => ( />
<option key={key} value={key}> </Flex>
{text} {getValues('chat.useKb') && (
</option> <Flex mt={4} alignItems={'center'}>
))} <Box mr={4} whiteSpace={'nowrap'}>
</Select> &emsp;
</Flex> </Box>
</FormControl> <Select
isDisabled={!isOwner}
{...register('chat.searchMode', { required: '搜索模式不能为空' })}
>
{Object.entries(ModelVectorSearchModeMap).map(([key, { text }]) => (
<option key={key} value={key}>
{text}
</option>
))}
</Select>
</Flex>
)} )}
<Box mt={4}> <Box mt={4}>
<Box mb={1}></Box> <Box mb={1}></Box>
<Textarea <Textarea
rows={8} rows={8}
maxLength={-1} maxLength={-1}
isDisabled={!isOwner} isDisabled={!isOwner}
placeholder={ placeholder={'模型默认的 prompt 词,通过调整该内容,可以引导模型聊天方向。'}
canTrain {...register('chat.systemPrompt')}
? '训练的模型会根据知识库内容,生成一部分系统提示词,因此在对话时需要消耗更多的 tokens。你可以增加提示词让效果更符合预期。例如: \n1. 请根据知识库内容回答用户问题。\n2. 知识库是电影《铃芽之旅》的内容,根据知识库内容回答。无关问题,拒绝回复!'
: '模型默认的 prompt 词,通过调整该内容,可以生成一个限定范围的模型。\n注意改功能会影响对话的整体朝向'
}
{...register('systemPrompt')}
/> />
</Box> </Box>
</Card> </Card>

View File

@ -27,11 +27,6 @@ const ModelDetail = ({ modelId }: { modelId: string }) => {
defaultValues: model defaultValues: model
}); });
const canTrain = useMemo(() => {
const openai = modelList.find((item) => item.model === model?.service.modelName);
return !!(openai && openai.trainName);
}, [model]);
const isOwner = useMemo(() => model.userId === userInfo?._id, [model.userId, userInfo?._id]); const isOwner = useMemo(() => model.userId === userInfo?._id, [model.userId, userInfo?._id]);
/* 加载模型数据 */ /* 加载模型数据 */
@ -86,11 +81,8 @@ const ModelDetail = ({ modelId }: { modelId: string }) => {
await putModelById(data._id, { await putModelById(data._id, {
name: data.name, name: data.name,
avatar: data.avatar || '/icon/logo.png', avatar: data.avatar || '/icon/logo.png',
systemPrompt: data.systemPrompt, chat: data.chat,
temperature: data.temperature,
search: data.search,
share: data.share, share: data.share,
service: data.service,
security: data.security security: data.security
}); });
toast({ toast({
@ -171,11 +163,15 @@ const ModelDetail = ({ modelId }: { modelId: string }) => {
</Tag> </Tag>
</Flex> </Flex>
<Box mt={4} textAlign={'right'}> <Box mt={4} textAlign={'right'}>
<Button variant={'outline'} onClick={handlePreviewChat}> <Button variant={'outline'} size={'sm'} onClick={handlePreviewChat}>
</Button> </Button>
{isOwner && ( {isOwner && (
<Button ml={4} onClick={formHooks.handleSubmit(saveSubmitSuccess, saveSubmitError)}> <Button
ml={4}
size={'sm'}
onClick={formHooks.handleSubmit(saveSubmitSuccess, saveSubmitError)}
>
</Button> </Button>
)} )}
@ -184,16 +180,11 @@ const ModelDetail = ({ modelId }: { modelId: string }) => {
)} )}
</Card> </Card>
<Grid mt={5} gridTemplateColumns={['1fr', '1fr 1fr']} gridGap={5}> <Grid mt={5} gridTemplateColumns={['1fr', '1fr 1fr']} gridGap={5}>
<ModelEditForm <ModelEditForm formHooks={formHooks} handleDelModel={handleDelModel} isOwner={isOwner} />
formHooks={formHooks}
handleDelModel={handleDelModel}
canTrain={canTrain}
isOwner={isOwner}
/>
{canTrain && !!model._id && ( {modelId && (
<Card p={4} gridColumnStart={[1, 1]} gridColumnEnd={[2, 3]}> <Card p={4} gridColumnStart={[1, 1]} gridColumnEnd={[2, 3]}>
<ModelDataCard modelId={model._id} isOwner={isOwner} /> <ModelDataCard modelId={modelId} isOwner={isOwner} />
</Card> </Card>
)} )}
</Grid> </Grid>

View File

@ -1,138 +0,0 @@
import React, { Dispatch, useState, useCallback, useMemo } from 'react';
import {
Modal,
ModalOverlay,
ModalContent,
ModalHeader,
ModalFooter,
ModalBody,
ModalCloseButton,
FormControl,
FormErrorMessage,
Button,
useToast,
Input,
Select,
Box
} from '@chakra-ui/react';
import { useForm } from 'react-hook-form';
import { postCreateModel } from '@/api/model';
import type { ModelSchema } from '@/types/mongoSchema';
import { modelList } from '@/constants/model';
import { formatPrice } from '@/utils/user';
interface CreateFormType {
name: string;
serviceModelName: string;
}
const CreateModel = ({
setCreateModelOpen,
onSuccess
}: {
setCreateModelOpen: Dispatch<boolean>;
onSuccess: Dispatch<ModelSchema>;
}) => {
const [requesting, setRequesting] = useState(false);
const [refresh, setRefresh] = useState(false);
const toast = useToast({
duration: 2000,
position: 'top'
});
const {
getValues,
register,
handleSubmit,
formState: { errors }
} = useForm<CreateFormType>({
defaultValues: {
serviceModelName: modelList[0].model
}
});
const handleCreateModel = useCallback(
async (data: CreateFormType) => {
setRequesting(true);
try {
const res = await postCreateModel(data);
toast({
title: '创建成功',
status: 'success'
});
onSuccess(res);
setCreateModelOpen(false);
} catch (err: any) {
toast({
title: typeof err === 'string' ? err : err.message || '出现了意外',
status: 'error'
});
}
setRequesting(false);
},
[onSuccess, setCreateModelOpen, toast]
);
return (
<>
<Modal isOpen={true} onClose={() => setCreateModelOpen(false)}>
<ModalOverlay />
<ModalContent>
<ModalHeader></ModalHeader>
<ModalCloseButton />
<ModalBody>
<FormControl mb={8} isInvalid={!!errors.name}>
<Input
placeholder="模型名称"
{...register('name', {
required: '模型名不能为空'
})}
/>
<FormErrorMessage position={'absolute'} fontSize="xs">
{!!errors.name && errors.name.message}
</FormErrorMessage>
</FormControl>
<FormControl isInvalid={!!errors.serviceModelName}>
<Select
placeholder="选择基础模型类型"
{...register('serviceModelName', {
required: '底层模型不能为空',
onChange() {
setRefresh(!refresh);
}
})}
>
{modelList.map((item) => (
<option key={item.model} value={item.model}>
{item.name}
</option>
))}
</Select>
<FormErrorMessage position={'absolute'} fontSize="xs">
{!!errors.serviceModelName && errors.serviceModelName.message}
</FormErrorMessage>
</FormControl>
<Box mt={3} textAlign={'center'} fontSize={'sm'} color={'blackAlpha.600'}>
{formatPrice(
modelList.find((item) => item.model === getValues('serviceModelName'))?.price || 0,
1000
)}
/1K tokens()
</Box>
</ModalBody>
<ModalFooter>
<Button mr={3} colorScheme={'gray'} onClick={() => setCreateModelOpen(false)}>
</Button>
<Button isLoading={requesting} onClick={handleSubmit(handleCreateModel)}>
</Button>
</ModalFooter>
</ModalContent>
</Modal>
</>
);
};
export default CreateModel;

View File

@ -2,8 +2,8 @@ import React, { useEffect } from 'react';
import { Box, Button, Flex, Tag } from '@chakra-ui/react'; import { Box, Button, Flex, Tag } from '@chakra-ui/react';
import type { ModelSchema } from '@/types/mongoSchema'; import type { ModelSchema } from '@/types/mongoSchema';
import { formatModelStatus } from '@/constants/model'; import { formatModelStatus } from '@/constants/model';
import dayjs from 'dayjs';
import { useRouter } from 'next/router'; import { useRouter } from 'next/router';
import { ChatModelMap } from '@/constants/model';
const ModelPhoneList = ({ const ModelPhoneList = ({
models, models,
@ -42,12 +42,12 @@ const ModelPhoneList = ({
</Tag> </Tag>
</Flex> </Flex>
<Flex mt={5}> <Flex mt={5}>
<Box flex={'0 0 100px'}>: </Box> <Box flex={'0 0 100px'}>: </Box>
<Box color={'blackAlpha.500'}>{dayjs(model.updateTime).format('YYYY-MM-DD HH:mm')}</Box> <Box color={'blackAlpha.500'}>{ChatModelMap[model.chat.chatModel]}</Box>
</Flex> </Flex>
<Flex mt={5}> <Flex mt={5}>
<Box flex={'0 0 100px'}>AI模型: </Box> <Box flex={'0 0 100px'}>: </Box>
<Box color={'blackAlpha.500'}>{model.service.modelName}</Box> <Box color={'blackAlpha.500'}>{model.chat.temperature}</Box>
</Flex> </Flex>
<Flex mt={5} justifyContent={'flex-end'}> <Flex mt={5} justifyContent={'flex-end'}>
<Button <Button

View File

@ -13,10 +13,9 @@ import {
Box Box
} from '@chakra-ui/react'; } from '@chakra-ui/react';
import { formatModelStatus } from '@/constants/model'; import { formatModelStatus } from '@/constants/model';
import dayjs from 'dayjs';
import type { ModelSchema } from '@/types/mongoSchema'; import type { ModelSchema } from '@/types/mongoSchema';
import { useRouter } from 'next/router'; import { useRouter } from 'next/router';
import { modelList } from '@/constants/model'; import { ChatModelMap } from '@/constants/model';
const ModelTable = ({ const ModelTable = ({
models = [], models = [],
@ -33,18 +32,18 @@ const ModelTable = ({
dataIndex: 'name' dataIndex: 'name'
}, },
{ {
title: '型类型', title: '对话模型',
key: 'service', key: 'service',
render: (model: ModelSchema) => ( render: (model: ModelSchema) => (
<Box fontWeight={'bold'} whiteSpace={'pre-wrap'} maxW={'200px'}> <Box fontWeight={'bold'} whiteSpace={'pre-wrap'} maxW={'200px'}>
{modelList.find((item) => item.model === model.service.modelName)?.name} {ChatModelMap[model.chat.chatModel]}
</Box> </Box>
) )
}, },
{ {
title: '最后更新时间', title: '温度',
key: 'updateTime', key: 'temperature',
render: (item: ModelSchema) => dayjs(item.updateTime).format('YYYY-MM-DD HH:mm') render: (model: ModelSchema) => <>{model.chat.temperature}</>
}, },
{ {
title: '状态', title: '状态',

View File

@ -1,4 +1,4 @@
import React, { useState, useCallback } from 'react'; import React, { useCallback } from 'react';
import { Box, Button, Flex, Card } from '@chakra-ui/react'; import { Box, Button, Flex, Card } from '@chakra-ui/react';
import type { ModelSchema } from '@/types/mongoSchema'; import type { ModelSchema } from '@/types/mongoSchema';
import { useRouter } from 'next/router'; import { useRouter } from 'next/router';
@ -7,30 +7,37 @@ import ModelPhoneList from './components/ModelPhoneList';
import { useScreen } from '@/hooks/useScreen'; import { useScreen } from '@/hooks/useScreen';
import { useQuery } from '@tanstack/react-query'; import { useQuery } from '@tanstack/react-query';
import { useLoading } from '@/hooks/useLoading'; import { useLoading } from '@/hooks/useLoading';
import dynamic from 'next/dynamic';
import { useToast } from '@/hooks/useToast'; import { useToast } from '@/hooks/useToast';
import { useUserStore } from '@/store/user'; import { useUserStore } from '@/store/user';
import { postCreateModel } from '@/api/model';
const CreateModel = dynamic(() => import('./components/CreateModel'));
const modelList = () => { const modelList = () => {
const { toast } = useToast(); const { toast } = useToast();
const { isPc } = useScreen(); const { isPc } = useScreen();
const router = useRouter(); const router = useRouter();
const { myModels, setMyModels, getMyModels } = useUserStore(); const { myModels, getMyModels } = useUserStore();
const [openCreateModel, setOpenCreateModel] = useState(false);
const { Loading, setIsLoading } = useLoading(); const { Loading, setIsLoading } = useLoading();
/* 加载模型 */ /* 加载模型 */
const { isLoading } = useQuery(['loadModels'], getMyModels); const { isLoading } = useQuery(['loadModels'], getMyModels);
/* 创建成功回调 */ const handleCreateModel = useCallback(async () => {
const createModelSuccess = useCallback( setIsLoading(true);
(data: ModelSchema) => { try {
setMyModels([data, ...myModels]); const id = await postCreateModel({ name: `模型${myModels.length}` });
}, toast({
[myModels, setMyModels] title: '创建成功',
); status: 'success'
});
router.push(`/model/detail?modelId=${id}`);
} catch (err: any) {
toast({
title: typeof err === 'string' ? err : err.message || '出现了意外',
status: 'error'
});
}
setIsLoading(false);
}, [myModels.length, router, setIsLoading, toast]);
/* 点前往聊天预览页 */ /* 点前往聊天预览页 */
const handlePreviewChat = useCallback( const handlePreviewChat = useCallback(
@ -61,7 +68,7 @@ const modelList = () => {
</Box> </Box>
<Button flex={'0 0 145px'} variant={'outline'} onClick={() => setOpenCreateModel(true)}> <Button flex={'0 0 145px'} variant={'outline'} onClick={handleCreateModel}>
</Button> </Button>
</Flex> </Flex>
@ -74,10 +81,6 @@ const modelList = () => {
<ModelPhoneList models={myModels} handlePreviewChat={handlePreviewChat} /> <ModelPhoneList models={myModels} handlePreviewChat={handlePreviewChat} />
)} )}
</Box> </Box>
{/* 创建弹窗 */}
{openCreateModel && (
<CreateModel setCreateModelOpen={setOpenCreateModel} onSuccess={createModelSuccess} />
)}
<Loading loading={isLoading} /> <Loading loading={isLoading} />
</Box> </Box>

View File

@ -1,23 +1,17 @@
import { connectToDatabase, Bill, User } from '../mongo'; import { connectToDatabase, Bill, User } from '../mongo';
import { import { modelList, ChatModelEnum, embeddingModel } from '@/constants/model';
modelList,
ChatModelEnum,
ModelNameEnum,
Model2ChatModelMap,
embeddingModel
} from '@/constants/model';
import { BillTypeEnum } from '@/constants/user'; import { BillTypeEnum } from '@/constants/user';
import { countChatTokens } from '@/utils/tools'; import { countChatTokens } from '@/utils/tools';
export const pushChatBill = async ({ export const pushChatBill = async ({
isPay, isPay,
modelName, chatModel,
userId, userId,
chatId, chatId,
messages messages
}: { }: {
isPay: boolean; isPay: boolean;
modelName: `${ModelNameEnum}`; chatModel: `${ChatModelEnum}`;
userId: string; userId: string;
chatId?: '' | string; chatId?: '' | string;
messages: { role: 'system' | 'user' | 'assistant'; content: string }[]; messages: { role: 'system' | 'user' | 'assistant'; content: string }[];
@ -26,7 +20,7 @@ export const pushChatBill = async ({
try { try {
// 计算 token 数量 // 计算 token 数量
const tokens = countChatTokens({ model: Model2ChatModelMap[modelName] as any, messages }); const tokens = countChatTokens({ model: chatModel, messages });
const text = messages.map((item) => item.content).join(''); const text = messages.map((item) => item.content).join('');
console.log( console.log(
@ -37,7 +31,7 @@ export const pushChatBill = async ({
await connectToDatabase(); await connectToDatabase();
// 获取模型单价格 // 获取模型单价格
const modelItem = modelList.find((item) => item.model === modelName); const modelItem = modelList.find((item) => item.chatModel === chatModel);
// 计算价格 // 计算价格
const unitPrice = modelItem?.price || 5; const unitPrice = modelItem?.price || 5;
const price = unitPrice * tokens; const price = unitPrice * tokens;
@ -47,7 +41,7 @@ export const pushChatBill = async ({
const res = await Bill.create({ const res = await Bill.create({
userId, userId,
type: 'chat', type: 'chat',
modelName, modelName: chatModel,
chatId: chatId ? chatId : undefined, chatId: chatId ? chatId : undefined,
textLen: text.length, textLen: text.length,
tokenLen: tokens, tokenLen: tokens,
@ -94,7 +88,7 @@ export const pushSplitDataBill = async ({
if (isPay) { if (isPay) {
try { try {
// 获取模型单价格, 都是用 gpt35 拆分 // 获取模型单价格, 都是用 gpt35 拆分
const modelItem = modelList.find((item) => item.model === ChatModelEnum.GPT35); const modelItem = modelList.find((item) => item.chatModel === ChatModelEnum.GPT35);
const unitPrice = modelItem?.price || 3; const unitPrice = modelItem?.price || 3;
// 计算价格 // 计算价格
const price = unitPrice * tokenLen; const price = unitPrice * tokenLen;

View File

@ -1,5 +1,5 @@
import { Schema, model, models, Model } from 'mongoose'; import { Schema, model, models, Model } from 'mongoose';
import { modelList } from '@/constants/model'; import { ChatModelMap } from '@/constants/model';
import { BillSchema as BillType } from '@/types/mongoSchema'; import { BillSchema as BillType } from '@/types/mongoSchema';
import { BillTypeMap } from '@/constants/user'; import { BillTypeMap } from '@/constants/user';
@ -16,7 +16,7 @@ const BillSchema = new Schema({
}, },
modelName: { modelName: {
type: String, type: String,
enum: [...modelList.map((item) => item.model), 'text-embedding-ada-002'], enum: [...Object.keys(ChatModelMap), 'text-embedding-ada-002'],
required: true required: true
}, },
chatId: { chatId: {

View File

@ -1,6 +1,11 @@
import { Schema, model, models, Model as MongoModel } from 'mongoose'; import { Schema, model, models, Model as MongoModel } from 'mongoose';
import { ModelSchema as ModelType } from '@/types/mongoSchema'; import { ModelSchema as ModelType } from '@/types/mongoSchema';
import { ModelVectorSearchModeMap, ModelVectorSearchModeEnum } from '@/constants/model'; import {
ModelVectorSearchModeMap,
ModelVectorSearchModeEnum,
ChatModelMap,
ChatModelEnum
} from '@/constants/model';
const ModelSchema = new Schema({ const ModelSchema = new Schema({
userId: { userId: {
@ -16,11 +21,6 @@ const ModelSchema = new Schema({
type: String, type: String,
default: '/icon/logo.png' default: '/icon/logo.png'
}, },
systemPrompt: {
// 系统提示词
type: String,
default: ''
},
status: { status: {
type: String, type: String,
required: true, required: true,
@ -30,17 +30,34 @@ const ModelSchema = new Schema({
type: Date, type: Date,
default: () => new Date() default: () => new Date()
}, },
temperature: { chat: {
type: Number, useKb: {
min: 0, // use knowledge base to search
max: 10, type: Boolean,
default: 4 default: false
}, },
search: { searchMode: {
mode: { // knowledge base search mode
type: String, type: String,
enum: Object.keys(ModelVectorSearchModeMap), enum: Object.keys(ModelVectorSearchModeMap),
default: ModelVectorSearchModeEnum.hightSimilarity default: ModelVectorSearchModeEnum.hightSimilarity
},
systemPrompt: {
// 系统提示词
type: String,
default: ''
},
temperature: {
type: Number,
min: 0,
max: 10,
default: 0
},
chatModel: {
// 聊天时使用的模型
type: String,
enum: Object.keys(ChatModelMap),
default: ChatModelEnum.GPT35
} }
}, },
share: { share: {
@ -63,18 +80,6 @@ const ModelSchema = new Schema({
default: 0 default: 0
} }
}, },
service: {
chatModel: {
// 聊天时使用的模型
type: String,
required: true
},
modelName: {
// 底层模型的名称
type: String,
required: true
}
},
security: { security: {
type: { type: {
domain: { domain: {
@ -100,8 +105,7 @@ const ModelSchema = new Schema({
default: -1 default: -1
} }
}, },
default: {}, default: {}
required: true
} }
}); });

View File

@ -0,0 +1,47 @@
import { openaiCreateEmbedding } from '../utils/openai';
import { PgClient } from '@/service/pg';
import { ModelDataStatusEnum } from '@/constants/model';
/**
* use openai embedding search kb
*/
export const searchKb_openai = async ({
apiKey,
isPay,
text,
similarity,
modelId,
userId
}: {
apiKey: string;
isPay: boolean;
text: string;
modelId: string;
userId: string;
similarity: number;
}) => {
// 获取提示词的向量
const { vector: promptVector } = await openaiCreateEmbedding({
isPay,
apiKey,
userId,
text
});
const vectorSearch = await PgClient.select<{ id: string; q: string; a: string }>('modelData', {
fields: ['id', 'q', 'a'],
where: [
['status', ModelDataStatusEnum.ready],
'AND',
['model_id', modelId],
'AND',
`vector <=> '[${promptVector}]' < ${similarity}`
],
order: [{ field: 'vector', mode: `<=> '[${promptVector}]'` }],
limit: 20
});
const systemPrompts: string[] = vectorSearch.rows.map((item) => `${item.q}\n${item.a}`);
return { systemPrompts };
};

View File

@ -1,10 +1,33 @@
import { Configuration, OpenAIApi } from 'openai'; import { Configuration, OpenAIApi } from 'openai';
import { Chat, Model } from '../mongo'; import type { NextApiRequest } from 'next';
import jwt from 'jsonwebtoken';
import { Chat, Model, OpenApi, User } from '../mongo';
import type { ModelSchema } from '@/types/mongoSchema'; import type { ModelSchema } from '@/types/mongoSchema';
import { authToken } from './tools';
import { getOpenApiKey } from './openai'; import { getOpenApiKey } from './openai';
import type { ChatItemType } from '@/types/chat'; import type { ChatItemType } from '@/types/chat';
import mongoose from 'mongoose'; import mongoose from 'mongoose';
import { defaultModel } from '@/constants/model';
import { formatPrice } from '@/utils/user';
import { ERROR_ENUM } from '../errorCode';
/* 校验 token */
export const authToken = (token?: string): Promise<string> => {
return new Promise((resolve, reject) => {
if (!token) {
reject('缺少登录凭证');
return;
}
const key = process.env.TOKEN_KEY as string;
jwt.verify(token, key, function (err, decoded: any) {
if (err || !decoded?.userId) {
reject('凭证无效');
return;
}
resolve(decoded.userId);
});
});
};
export const getOpenAIApi = (apiKey: string) => { export const getOpenAIApi = (apiKey: string) => {
const configuration = new Configuration({ const configuration = new Configuration({
@ -20,12 +43,14 @@ export const authModel = async ({
modelId, modelId,
userId, userId,
authUser = true, authUser = true,
authOwner = true authOwner = true,
reserveDetail = false
}: { }: {
modelId: string; modelId: string;
userId: string; userId: string;
authUser?: boolean; authUser?: boolean;
authOwner?: boolean; authOwner?: boolean;
reserveDetail?: boolean; // focus reserve detail
}) => { }) => {
// 获取 model 数据 // 获取 model 数据
const model = await Model.findById<ModelSchema>(modelId); const model = await Model.findById<ModelSchema>(modelId);
@ -33,15 +58,21 @@ export const authModel = async ({
return Promise.reject('模型不存在'); return Promise.reject('模型不存在');
} }
// 使用权限校验 /*
Access verification
1. authOwner=true or authUser = true , just owner can use
2. authUser = false and share, anyone can use
*/
if ((authOwner || (authUser && !model.share.isShare)) && userId !== String(model.userId)) { if ((authOwner || (authUser && !model.share.isShare)) && userId !== String(model.userId)) {
return Promise.reject('无权操作该模型'); return Promise.reject('无权操作该模型');
} }
// detail 内容去除 // do not share detail info
if (!model.share.isShareDetail && userId !== String(model.userId)) { if (!reserveDetail && !model.share.isShareDetail && userId !== String(model.userId)) {
model.systemPrompt = ''; model.chat = {
model.temperature = 0; ...defaultModel.chat,
chatModel: model.chat.chatModel
};
} }
return { model }; return { model };
@ -60,7 +91,7 @@ export const authChat = async ({
const userId = await authToken(authorization); const userId = await authToken(authorization);
// 获取 model 数据 // 获取 model 数据
const { model } = await authModel({ modelId, userId, authOwner: false }); const { model } = await authModel({ modelId, userId, authOwner: false, reserveDetail: true });
// 聊天内容 // 聊天内容
let content: ChatItemType[] = []; let content: ChatItemType[] = [];
@ -91,3 +122,41 @@ export const authChat = async ({
model model
}; };
}; };
/* 校验 open api key */
export const authOpenApiKey = async (req: NextApiRequest) => {
const { apikey: apiKey } = req.headers;
if (!apiKey) {
return Promise.reject(ERROR_ENUM.unAuthorization);
}
try {
const openApi = await OpenApi.findOne({ apiKey });
if (!openApi) {
return Promise.reject(ERROR_ENUM.unAuthorization);
}
const userId = String(openApi.userId);
// 余额校验
const user = await User.findById(userId);
if (!user) {
return Promise.reject(ERROR_ENUM.unAuthorization);
}
if (formatPrice(user.balance) <= 0) {
return Promise.reject(ERROR_ENUM.insufficientQuota);
}
// 更新使用的时间
await OpenApi.findByIdAndUpdate(openApi._id, {
lastUsedTime: new Date()
});
return {
apiKey: process.env.OPENAIKEY as string,
userId
};
} catch (error) {
return Promise.reject(error);
}
};

View File

@ -1,6 +1,5 @@
import * as nodemailer from 'nodemailer'; import * as nodemailer from 'nodemailer';
import { UserAuthTypeEnum } from '@/constants/common'; import { UserAuthTypeEnum } from '@/constants/common';
import dayjs from 'dayjs';
import Dysmsapi, * as dysmsapi from '@alicloud/dysmsapi20170525'; import Dysmsapi, * as dysmsapi from '@alicloud/dysmsapi20170525';
// @ts-ignore // @ts-ignore
import * as OpenApi from '@alicloud/openapi-client'; import * as OpenApi from '@alicloud/openapi-client';
@ -48,25 +47,6 @@ export const sendEmailCode = (email: string, code: string, type: `${UserAuthType
}); });
}; };
export const sendTrainSucceed = (email: string, modelName: string) => {
return new Promise((resolve, reject) => {
const options = {
from: `"FastGPT" ${myEmail}`,
to: email,
subject: '模型训练完成通知',
html: `你的模型 ${modelName} 已于 ${dayjs().format('YYYY-MM-DD HH:mm')} 训练完成!`
};
mailTransport.sendMail(options, function (err, msg) {
if (err) {
console.log('send email error->', err);
reject('邮箱异常');
} else {
resolve('');
}
});
});
};
export const sendPhoneCode = async (phone: string, code: string) => { export const sendPhoneCode = async (phone: string, code: string) => {
const accessKeyId = process.env.aliAccessKeyId; const accessKeyId = process.env.aliAccessKeyId;
const accessKeySecret = process.env.aliAccessKeySecret; const accessKeySecret = process.env.aliAccessKeySecret;

View File

@ -1,10 +1,6 @@
import type { NextApiRequest } from 'next';
import crypto from 'crypto'; import crypto from 'crypto';
import jwt from 'jsonwebtoken'; import jwt from 'jsonwebtoken';
import { ChatItemType } from '@/types/chat'; import { ChatItemType } from '@/types/chat';
import { OpenApi, User } from '../mongo';
import { formatPrice } from '@/utils/user';
import { ERROR_ENUM } from '../errorCode';
import { countChatTokens } from '@/utils/tools'; import { countChatTokens } from '@/utils/tools';
import { ChatCompletionRequestMessageRoleEnum, ChatCompletionRequestMessage } from 'openai'; import { ChatCompletionRequestMessageRoleEnum, ChatCompletionRequestMessage } from 'openai';
import { ChatModelEnum } from '@/constants/model'; import { ChatModelEnum } from '@/constants/model';
@ -46,44 +42,6 @@ export const authToken = (token?: string): Promise<string> => {
}); });
}; };
/* 校验 open api key */
export const authOpenApiKey = async (req: NextApiRequest) => {
const { apikey: apiKey } = req.headers;
if (!apiKey) {
return Promise.reject(ERROR_ENUM.unAuthorization);
}
try {
const openApi = await OpenApi.findOne({ apiKey });
if (!openApi) {
return Promise.reject(ERROR_ENUM.unAuthorization);
}
const userId = String(openApi.userId);
// 余额校验
const user = await User.findById(userId);
if (!user) {
return Promise.reject(ERROR_ENUM.unAuthorization);
}
if (formatPrice(user.balance) <= 0) {
return Promise.reject('Insufficient account balance');
}
// 更新使用的时间
await OpenApi.findByIdAndUpdate(openApi._id, {
lastUsedTime: new Date()
});
return {
apiKey: process.env.OPENAIKEY as string,
userId
};
} catch (error) {
return Promise.reject(error);
}
};
/* openai axios config */ /* openai axios config */
export const axiosConfig = () => ({ export const axiosConfig = () => ({
httpsAgent: global.httpsAgent, httpsAgent: global.httpsAgent,

View File

@ -1,13 +1,11 @@
import { ModelStatusEnum } from '@/constants/model'; import { ModelStatusEnum } from '@/constants/model';
import type { ModelSchema } from './mongoSchema'; import type { ModelSchema } from './mongoSchema';
export interface ModelUpdateParams { export interface ModelUpdateParams {
name: string; name: string;
avatar: string; avatar: string;
systemPrompt: string; chat: ModelSchema['chat'];
temperature: number;
search: ModelSchema['search'];
share: ModelSchema['share']; share: ModelSchema['share'];
service: ModelSchema['service'];
security: ModelSchema['security']; security: ModelSchema['security'];
} }

View File

@ -31,15 +31,17 @@ export interface AuthCodeSchema {
export interface ModelSchema { export interface ModelSchema {
_id: string; _id: string;
userId: string;
name: string; name: string;
avatar: string; avatar: string;
systemPrompt: string;
userId: string;
status: `${ModelStatusEnum}`; status: `${ModelStatusEnum}`;
updateTime: number; updateTime: number;
temperature: number; chat: {
search: { useKb: boolean;
mode: `${ModelVectorSearchModeEnum}`; searchMode: `${ModelVectorSearchModeEnum}`;
systemPrompt: string;
temperature: number;
chatModel: `${ChatModelEnum}`; // 聊天时用的模型,训练后就是训练的模型
}; };
share: { share: {
isShare: boolean; isShare: boolean;
@ -47,10 +49,6 @@ export interface ModelSchema {
intro: string; intro: string;
collection: number; collection: number;
}; };
service: {
chatModel: `${ChatModelEnum}`; // 聊天时用的模型,训练后就是训练的模型
modelName: `${ModelNameEnum}`; // 底层模型名称,不会变
};
security: { security: {
domain: string[]; domain: string[];
contextMaxLen: number; contextMaxLen: number;