perf: chat framwork
This commit is contained in:
parent
91decc3683
commit
00a99261ae
File diff suppressed because one or more lines are too long
@ -1,6 +1,7 @@
|
|||||||
import type { ModelSchema } from '@/types/mongoSchema';
|
import type { ModelSchema } from '@/types/mongoSchema';
|
||||||
|
|
||||||
export const embeddingModel = 'text-embedding-ada-002';
|
export const embeddingModel = 'text-embedding-ada-002';
|
||||||
|
export type EmbeddingModelType = 'text-embedding-ada-002';
|
||||||
|
|
||||||
export enum OpenAiChatEnum {
|
export enum OpenAiChatEnum {
|
||||||
'GPT35' = 'gpt-3.5-turbo',
|
'GPT35' = 'gpt-3.5-turbo',
|
||||||
@ -25,7 +26,7 @@ export const ChatModelMap = {
|
|||||||
},
|
},
|
||||||
[OpenAiChatEnum.GPT432k]: {
|
[OpenAiChatEnum.GPT432k]: {
|
||||||
name: 'Gpt4-32k',
|
name: 'Gpt4-32k',
|
||||||
contextMaxToken: 8000,
|
contextMaxToken: 32000,
|
||||||
maxTemperature: 1.5,
|
maxTemperature: 1.5,
|
||||||
price: 30
|
price: 30
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,14 +1,15 @@
|
|||||||
import type { NextApiRequest, NextApiResponse } from 'next';
|
import type { NextApiRequest, NextApiResponse } from 'next';
|
||||||
import { connectToDatabase } from '@/service/mongo';
|
import { connectToDatabase } from '@/service/mongo';
|
||||||
import { getOpenAIApi, authChat } from '@/service/utils/auth';
|
import { authChat } from '@/service/utils/auth';
|
||||||
import { axiosConfig, openaiChatFilter } from '@/service/utils/tools';
|
import { modelServiceToolMap } from '@/service/utils/chat';
|
||||||
import { ChatItemSimpleType } from '@/types/chat';
|
import { ChatItemSimpleType } from '@/types/chat';
|
||||||
import { jsonRes } from '@/service/response';
|
import { jsonRes } from '@/service/response';
|
||||||
import { PassThrough } from 'stream';
|
import { PassThrough } from 'stream';
|
||||||
import { ChatModelMap, ModelVectorSearchModeMap } from '@/constants/model';
|
import { ChatModelMap, ModelVectorSearchModeMap } from '@/constants/model';
|
||||||
import { pushChatBill } from '@/service/events/pushBill';
|
import { pushChatBill } from '@/service/events/pushBill';
|
||||||
import { gpt35StreamResponse } from '@/service/utils/openai';
|
import { resStreamResponse } from '@/service/utils/chat';
|
||||||
import { searchKb_openai } from '@/service/tools/searchKb';
|
import { searchKb } from '@/service/plugins/searchKb';
|
||||||
|
import { ChatRoleEnum } from '@/constants/chat';
|
||||||
|
|
||||||
/* 发送提示词 */
|
/* 发送提示词 */
|
||||||
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
|
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
|
||||||
@ -41,7 +42,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
|
|||||||
await connectToDatabase();
|
await connectToDatabase();
|
||||||
let startTime = Date.now();
|
let startTime = Date.now();
|
||||||
|
|
||||||
const { model, showModelDetail, content, userApiKey, systemKey, userId } = await authChat({
|
const { model, showModelDetail, content, userApiKey, systemApiKey, userId } = await authChat({
|
||||||
modelId,
|
modelId,
|
||||||
chatId,
|
chatId,
|
||||||
authorization
|
authorization
|
||||||
@ -54,9 +55,9 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
|
|||||||
|
|
||||||
// 使用了知识库搜索
|
// 使用了知识库搜索
|
||||||
if (model.chat.useKb) {
|
if (model.chat.useKb) {
|
||||||
const { code, searchPrompt } = await searchKb_openai({
|
const { code, searchPrompt } = await searchKb({
|
||||||
apiKey: userApiKey || systemKey,
|
userApiKey,
|
||||||
isPay: !userApiKey,
|
systemApiKey,
|
||||||
text: prompt.value,
|
text: prompt.value,
|
||||||
similarity: ModelVectorSearchModeMap[model.chat.searchMode]?.similarity,
|
similarity: ModelVectorSearchModeMap[model.chat.searchMode]?.similarity,
|
||||||
model,
|
model,
|
||||||
@ -73,53 +74,37 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
|
|||||||
// 没有用知识库搜索,仅用系统提示词
|
// 没有用知识库搜索,仅用系统提示词
|
||||||
model.chat.systemPrompt &&
|
model.chat.systemPrompt &&
|
||||||
prompts.unshift({
|
prompts.unshift({
|
||||||
obj: 'SYSTEM',
|
obj: ChatRoleEnum.System,
|
||||||
value: model.chat.systemPrompt
|
value: model.chat.systemPrompt
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// 控制总 tokens 数量,防止超出
|
|
||||||
const filterPrompts = openaiChatFilter({
|
|
||||||
model: model.chat.chatModel,
|
|
||||||
prompts,
|
|
||||||
maxTokens: modelConstantsData.contextMaxToken - 300
|
|
||||||
});
|
|
||||||
|
|
||||||
// 计算温度
|
// 计算温度
|
||||||
const temperature = (modelConstantsData.maxTemperature * (model.chat.temperature / 10)).toFixed(
|
const temperature = (modelConstantsData.maxTemperature * (model.chat.temperature / 10)).toFixed(
|
||||||
2
|
2
|
||||||
);
|
);
|
||||||
// console.log(filterPrompts);
|
// console.log(filterPrompts);
|
||||||
// 获取 chatAPI
|
|
||||||
const chatAPI = getOpenAIApi(userApiKey || systemKey);
|
|
||||||
// 发出请求
|
// 发出请求
|
||||||
const chatResponse = await chatAPI.createChatCompletion(
|
const { streamResponse } = await modelServiceToolMap[model.chat.chatModel].chatCompletion({
|
||||||
{
|
apiKey: userApiKey || systemApiKey,
|
||||||
model: model.chat.chatModel,
|
temperature: +temperature,
|
||||||
temperature: Number(temperature) || 0,
|
messages: prompts,
|
||||||
messages: filterPrompts,
|
stream: true
|
||||||
frequency_penalty: 0.5, // 越大,重复内容越少
|
});
|
||||||
presence_penalty: -0.5, // 越大,越容易出现新内容
|
|
||||||
stream: true,
|
|
||||||
stop: ['.!?。']
|
|
||||||
},
|
|
||||||
{
|
|
||||||
timeout: 40000,
|
|
||||||
responseType: 'stream',
|
|
||||||
...axiosConfig()
|
|
||||||
}
|
|
||||||
);
|
|
||||||
|
|
||||||
console.log('api response time:', `${(Date.now() - startTime) / 1000}s`);
|
console.log('api response time:', `${(Date.now() - startTime) / 1000}s`);
|
||||||
|
|
||||||
step = 1;
|
step = 1;
|
||||||
|
|
||||||
const { responseContent } = await gpt35StreamResponse({
|
const { totalTokens, finishMessages } = await resStreamResponse({
|
||||||
|
model: model.chat.chatModel,
|
||||||
res,
|
res,
|
||||||
stream,
|
stream,
|
||||||
chatResponse,
|
chatResponse: streamResponse,
|
||||||
|
prompts,
|
||||||
systemPrompt:
|
systemPrompt:
|
||||||
showModelDetail && filterPrompts[0].role === 'system' ? filterPrompts[0].content : ''
|
showModelDetail && prompts[0].obj === ChatRoleEnum.System ? prompts[0].value : ''
|
||||||
});
|
});
|
||||||
|
|
||||||
// 只有使用平台的 key 才计费
|
// 只有使用平台的 key 才计费
|
||||||
@ -128,7 +113,8 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
|
|||||||
chatModel: model.chat.chatModel,
|
chatModel: model.chat.chatModel,
|
||||||
userId,
|
userId,
|
||||||
chatId,
|
chatId,
|
||||||
messages: filterPrompts.concat({ role: 'assistant', content: responseContent })
|
textLen: finishMessages.map((item) => item.value).join('').length,
|
||||||
|
tokens: totalTokens
|
||||||
});
|
});
|
||||||
} catch (err: any) {
|
} catch (err: any) {
|
||||||
if (step === 1) {
|
if (step === 1) {
|
||||||
|
|||||||
@ -1,14 +1,14 @@
|
|||||||
import type { NextApiRequest, NextApiResponse } from 'next';
|
import type { NextApiRequest, NextApiResponse } from 'next';
|
||||||
import { connectToDatabase } from '@/service/mongo';
|
import { connectToDatabase } from '@/service/mongo';
|
||||||
import { getOpenAIApi, authOpenApiKey, authModel } from '@/service/utils/auth';
|
import { authOpenApiKey, authModel } from '@/service/utils/auth';
|
||||||
import { axiosConfig, openaiChatFilter } from '@/service/utils/tools';
|
import { modelServiceToolMap, resStreamResponse } from '@/service/utils/chat';
|
||||||
import { ChatItemSimpleType } from '@/types/chat';
|
import { ChatItemSimpleType } from '@/types/chat';
|
||||||
import { jsonRes } from '@/service/response';
|
import { jsonRes } from '@/service/response';
|
||||||
import { PassThrough } from 'stream';
|
import { PassThrough } from 'stream';
|
||||||
import { ChatModelMap, ModelVectorSearchModeMap } from '@/constants/model';
|
import { ChatModelMap, ModelVectorSearchModeMap } from '@/constants/model';
|
||||||
import { pushChatBill } from '@/service/events/pushBill';
|
import { pushChatBill } from '@/service/events/pushBill';
|
||||||
import { gpt35StreamResponse } from '@/service/utils/openai';
|
import { searchKb } from '@/service/plugins/searchKb';
|
||||||
import { searchKb_openai } from '@/service/tools/searchKb';
|
import { ChatRoleEnum } from '@/constants/chat';
|
||||||
|
|
||||||
/* 发送提示词 */
|
/* 发送提示词 */
|
||||||
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
|
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
|
||||||
@ -64,9 +64,8 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
|
|||||||
if (model.chat.useKb) {
|
if (model.chat.useKb) {
|
||||||
const similarity = ModelVectorSearchModeMap[model.chat.searchMode]?.similarity || 0.22;
|
const similarity = ModelVectorSearchModeMap[model.chat.searchMode]?.similarity || 0.22;
|
||||||
|
|
||||||
const { code, searchPrompt } = await searchKb_openai({
|
const { code, searchPrompt } = await searchKb({
|
||||||
apiKey,
|
systemApiKey: apiKey,
|
||||||
isPay: true,
|
|
||||||
text: prompts[prompts.length - 1].value,
|
text: prompts[prompts.length - 1].value,
|
||||||
similarity,
|
similarity,
|
||||||
model,
|
model,
|
||||||
@ -83,69 +82,55 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
|
|||||||
// 没有用知识库搜索,仅用系统提示词
|
// 没有用知识库搜索,仅用系统提示词
|
||||||
if (model.chat.systemPrompt) {
|
if (model.chat.systemPrompt) {
|
||||||
prompts.unshift({
|
prompts.unshift({
|
||||||
obj: 'SYSTEM',
|
obj: ChatRoleEnum.System,
|
||||||
value: model.chat.systemPrompt
|
value: model.chat.systemPrompt
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 控制总 tokens 数量,防止超出
|
|
||||||
const filterPrompts = openaiChatFilter({
|
|
||||||
model: model.chat.chatModel,
|
|
||||||
prompts,
|
|
||||||
maxTokens: modelConstantsData.contextMaxToken - 300
|
|
||||||
});
|
|
||||||
|
|
||||||
// 计算温度
|
// 计算温度
|
||||||
const temperature = (modelConstantsData.maxTemperature * (model.chat.temperature / 10)).toFixed(
|
const temperature = (modelConstantsData.maxTemperature * (model.chat.temperature / 10)).toFixed(
|
||||||
2
|
2
|
||||||
);
|
);
|
||||||
// console.log(filterPrompts);
|
|
||||||
// 获取 chatAPI
|
|
||||||
const chatAPI = getOpenAIApi(apiKey);
|
|
||||||
// 发出请求
|
// 发出请求
|
||||||
const chatResponse = await chatAPI.createChatCompletion(
|
const { streamResponse, responseMessages, responseText, totalTokens } =
|
||||||
{
|
await modelServiceToolMap[model.chat.chatModel].chatCompletion({
|
||||||
model: model.chat.chatModel,
|
apiKey,
|
||||||
temperature: Number(temperature) || 0,
|
temperature: +temperature,
|
||||||
messages: filterPrompts,
|
messages: prompts,
|
||||||
frequency_penalty: 0.5, // 越大,重复内容越少
|
stream: isStream
|
||||||
presence_penalty: -0.5, // 越大,越容易出现新内容
|
});
|
||||||
stream: isStream,
|
|
||||||
stop: ['.!?。']
|
|
||||||
},
|
|
||||||
{
|
|
||||||
timeout: 180000,
|
|
||||||
responseType: isStream ? 'stream' : 'json',
|
|
||||||
...axiosConfig()
|
|
||||||
}
|
|
||||||
);
|
|
||||||
|
|
||||||
console.log('api response time:', `${(Date.now() - startTime) / 1000}s`);
|
console.log('api response time:', `${(Date.now() - startTime) / 1000}s`);
|
||||||
|
|
||||||
let responseContent = '';
|
let textLen = 0;
|
||||||
|
let tokens = totalTokens;
|
||||||
|
|
||||||
if (isStream) {
|
if (isStream) {
|
||||||
step = 1;
|
step = 1;
|
||||||
const streamResponse = await gpt35StreamResponse({
|
const { finishMessages, totalTokens } = await resStreamResponse({
|
||||||
|
model: model.chat.chatModel,
|
||||||
res,
|
res,
|
||||||
stream,
|
stream,
|
||||||
chatResponse
|
chatResponse: streamResponse,
|
||||||
|
prompts
|
||||||
});
|
});
|
||||||
responseContent = streamResponse.responseContent;
|
textLen = finishMessages.map((item) => item.value).join('').length;
|
||||||
|
tokens = totalTokens;
|
||||||
} else {
|
} else {
|
||||||
responseContent = chatResponse.data.choices?.[0]?.message?.content || '';
|
textLen = responseMessages.map((item) => item.value).join('').length;
|
||||||
jsonRes(res, {
|
jsonRes(res, {
|
||||||
data: responseContent
|
data: responseText
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// 只有使用平台的 key 才计费
|
|
||||||
pushChatBill({
|
pushChatBill({
|
||||||
isPay: true,
|
isPay: true,
|
||||||
chatModel: model.chat.chatModel,
|
chatModel: model.chat.chatModel,
|
||||||
userId,
|
userId,
|
||||||
messages: filterPrompts.concat({ role: 'assistant', content: responseContent })
|
textLen,
|
||||||
|
tokens
|
||||||
});
|
});
|
||||||
} catch (err: any) {
|
} catch (err: any) {
|
||||||
if (step === 1) {
|
if (step === 1) {
|
||||||
|
|||||||
@ -1,144 +0,0 @@
|
|||||||
import type { NextApiRequest, NextApiResponse } from 'next';
|
|
||||||
import { connectToDatabase, Model } from '@/service/mongo';
|
|
||||||
import { getOpenAIApi, authOpenApiKey } from '@/service/utils/auth';
|
|
||||||
import { axiosConfig, openaiChatFilter } from '@/service/utils/tools';
|
|
||||||
import { ChatItemSimpleType } from '@/types/chat';
|
|
||||||
import { jsonRes } from '@/service/response';
|
|
||||||
import { PassThrough } from 'stream';
|
|
||||||
import { ChatModelMap } from '@/constants/model';
|
|
||||||
import { pushChatBill } from '@/service/events/pushBill';
|
|
||||||
import { gpt35StreamResponse } from '@/service/utils/openai';
|
|
||||||
|
|
||||||
/* 发送提示词 */
|
|
||||||
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
|
|
||||||
let step = 0; // step=1时,表示开始了流响应
|
|
||||||
const stream = new PassThrough();
|
|
||||||
stream.on('error', () => {
|
|
||||||
console.log('error: ', 'stream error');
|
|
||||||
stream.destroy();
|
|
||||||
});
|
|
||||||
res.on('close', () => {
|
|
||||||
stream.destroy();
|
|
||||||
});
|
|
||||||
res.on('error', () => {
|
|
||||||
console.log('error: ', 'request error');
|
|
||||||
stream.destroy();
|
|
||||||
});
|
|
||||||
|
|
||||||
try {
|
|
||||||
const {
|
|
||||||
prompts,
|
|
||||||
modelId,
|
|
||||||
isStream = true
|
|
||||||
} = req.body as {
|
|
||||||
prompts: ChatItemSimpleType[];
|
|
||||||
modelId: string;
|
|
||||||
isStream: boolean;
|
|
||||||
};
|
|
||||||
|
|
||||||
if (!prompts || !modelId) {
|
|
||||||
throw new Error('缺少参数');
|
|
||||||
}
|
|
||||||
if (!Array.isArray(prompts)) {
|
|
||||||
throw new Error('prompts is not array');
|
|
||||||
}
|
|
||||||
if (prompts.length > 30 || prompts.length === 0) {
|
|
||||||
throw new Error('prompts length range 1-30');
|
|
||||||
}
|
|
||||||
|
|
||||||
await connectToDatabase();
|
|
||||||
let startTime = Date.now();
|
|
||||||
|
|
||||||
const { apiKey, userId } = await authOpenApiKey(req);
|
|
||||||
|
|
||||||
const model = await Model.findOne({
|
|
||||||
_id: modelId,
|
|
||||||
userId
|
|
||||||
});
|
|
||||||
|
|
||||||
if (!model) {
|
|
||||||
throw new Error('无权使用该模型');
|
|
||||||
}
|
|
||||||
|
|
||||||
const modelConstantsData = ChatModelMap[model.chat.chatModel];
|
|
||||||
|
|
||||||
// 如果有系统提示词,自动插入
|
|
||||||
if (model.chat.systemPrompt) {
|
|
||||||
prompts.unshift({
|
|
||||||
obj: 'SYSTEM',
|
|
||||||
value: model.chat.systemPrompt
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
// 控制在 tokens 数量,防止超出
|
|
||||||
const filterPrompts = openaiChatFilter({
|
|
||||||
model: model.chat.chatModel,
|
|
||||||
prompts,
|
|
||||||
maxTokens: modelConstantsData.contextMaxToken - 300
|
|
||||||
});
|
|
||||||
|
|
||||||
// console.log(filterPrompts);
|
|
||||||
// 计算温度
|
|
||||||
const temperature = (modelConstantsData.maxTemperature * (model.chat.temperature / 10)).toFixed(
|
|
||||||
2
|
|
||||||
);
|
|
||||||
// 获取 chatAPI
|
|
||||||
const chatAPI = getOpenAIApi(apiKey);
|
|
||||||
// 发出请求
|
|
||||||
const chatResponse = await chatAPI.createChatCompletion(
|
|
||||||
{
|
|
||||||
model: model.chat.chatModel,
|
|
||||||
temperature: Number(temperature) || 0,
|
|
||||||
messages: filterPrompts,
|
|
||||||
frequency_penalty: 0.5, // 越大,重复内容越少
|
|
||||||
presence_penalty: -0.5, // 越大,越容易出现新内容
|
|
||||||
stream: isStream,
|
|
||||||
stop: ['.!?。']
|
|
||||||
},
|
|
||||||
{
|
|
||||||
timeout: 40000,
|
|
||||||
responseType: isStream ? 'stream' : 'json',
|
|
||||||
...axiosConfig()
|
|
||||||
}
|
|
||||||
);
|
|
||||||
|
|
||||||
console.log('api response time:', `${(Date.now() - startTime) / 1000}s`);
|
|
||||||
|
|
||||||
let responseContent = '';
|
|
||||||
|
|
||||||
if (isStream) {
|
|
||||||
step = 1;
|
|
||||||
const streamResponse = await gpt35StreamResponse({
|
|
||||||
res,
|
|
||||||
stream,
|
|
||||||
chatResponse
|
|
||||||
});
|
|
||||||
responseContent = streamResponse.responseContent;
|
|
||||||
} else {
|
|
||||||
responseContent = chatResponse.data.choices?.[0]?.message?.content || '';
|
|
||||||
jsonRes(res, {
|
|
||||||
data: responseContent
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
// 只有使用平台的 key 才计费
|
|
||||||
pushChatBill({
|
|
||||||
isPay: true,
|
|
||||||
chatModel: model.chat.chatModel,
|
|
||||||
userId,
|
|
||||||
messages: filterPrompts.concat({ role: 'assistant', content: responseContent })
|
|
||||||
});
|
|
||||||
} catch (err: any) {
|
|
||||||
if (step === 1) {
|
|
||||||
// 直接结束流
|
|
||||||
console.log('error,结束');
|
|
||||||
stream.destroy();
|
|
||||||
} else {
|
|
||||||
res.status(500);
|
|
||||||
jsonRes(res, {
|
|
||||||
code: 500,
|
|
||||||
error: err
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -1,14 +1,14 @@
|
|||||||
import type { NextApiRequest, NextApiResponse } from 'next';
|
import type { NextApiRequest, NextApiResponse } from 'next';
|
||||||
import { connectToDatabase, Model } from '@/service/mongo';
|
import { connectToDatabase, Model } from '@/service/mongo';
|
||||||
import { getOpenAIApi, authOpenApiKey } from '@/service/utils/auth';
|
import { authOpenApiKey } from '@/service/utils/auth';
|
||||||
import { axiosConfig, openaiChatFilter } from '@/service/utils/tools';
|
import { resStreamResponse, modelServiceToolMap } from '@/service/utils/chat';
|
||||||
import { ChatItemSimpleType } from '@/types/chat';
|
import { ChatItemSimpleType } from '@/types/chat';
|
||||||
import { jsonRes } from '@/service/response';
|
import { jsonRes } from '@/service/response';
|
||||||
import { PassThrough } from 'stream';
|
import { PassThrough } from 'stream';
|
||||||
import { ChatModelMap, ModelVectorSearchModeMap, OpenAiChatEnum } from '@/constants/model';
|
import { ChatModelMap, ModelVectorSearchModeMap } from '@/constants/model';
|
||||||
import { pushChatBill } from '@/service/events/pushBill';
|
import { pushChatBill } from '@/service/events/pushBill';
|
||||||
import { gpt35StreamResponse } from '@/service/utils/openai';
|
import { searchKb } from '@/service/plugins/searchKb';
|
||||||
import { searchKb_openai } from '@/service/tools/searchKb';
|
import { ChatRoleEnum } from '@/constants/chat';
|
||||||
|
|
||||||
/* 发送提示词 */
|
/* 发送提示词 */
|
||||||
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
|
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
|
||||||
@ -57,20 +57,16 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
|
|||||||
|
|
||||||
console.log('laf gpt start');
|
console.log('laf gpt start');
|
||||||
|
|
||||||
// 获取 chatAPI
|
|
||||||
const chatAPI = getOpenAIApi(apiKey);
|
|
||||||
|
|
||||||
// 请求一次 chatgpt 拆解需求
|
// 请求一次 chatgpt 拆解需求
|
||||||
const promptResponse = await chatAPI.createChatCompletion(
|
const { responseText: resolveText, totalTokens: resolveTokens } = await modelServiceToolMap[
|
||||||
{
|
model.chat.chatModel
|
||||||
model: OpenAiChatEnum.GPT35,
|
].chatCompletion({
|
||||||
temperature: 0,
|
apiKey,
|
||||||
frequency_penalty: 0.5, // 越大,重复内容越少
|
temperature: 0,
|
||||||
presence_penalty: -0.5, // 越大,越容易出现新内容
|
messages: [
|
||||||
messages: [
|
{
|
||||||
{
|
obj: ChatRoleEnum.System,
|
||||||
role: 'system',
|
value: `服务端逻辑生成器.根据用户输入的需求,拆解成 laf 云函数实现的步骤,只返回步骤,按格式返回步骤: 1.\n2.\n3.\n ......
|
||||||
content: `服务端逻辑生成器.根据用户输入的需求,拆解成 laf 云函数实现的步骤,只返回步骤,按格式返回步骤: 1.\n2.\n3.\n ......
|
|
||||||
下面是一些例子:
|
下面是一些例子:
|
||||||
一个 hello world 例子
|
一个 hello world 例子
|
||||||
1. 返回字符串: "hello world"
|
1. 返回字符串: "hello world"
|
||||||
@ -103,35 +99,25 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
|
|||||||
5. 获取当前时间,记录为 updateTime.
|
5. 获取当前时间,记录为 updateTime.
|
||||||
6. 更新数据库数据,表为"blogs",更新符合 blogId 的记录的内容为{blogText, tags, updateTime}.
|
6. 更新数据库数据,表为"blogs",更新符合 blogId 的记录的内容为{blogText, tags, updateTime}.
|
||||||
7. 返回结果 "更新博客记录成功"`
|
7. 返回结果 "更新博客记录成功"`
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
role: 'user',
|
obj: ChatRoleEnum.Human,
|
||||||
content: prompt.value
|
value: prompt.value
|
||||||
}
|
}
|
||||||
]
|
],
|
||||||
},
|
stream: false
|
||||||
{
|
});
|
||||||
timeout: 180000,
|
|
||||||
...axiosConfig()
|
|
||||||
}
|
|
||||||
);
|
|
||||||
|
|
||||||
const promptResolve = promptResponse.data.choices?.[0]?.message?.content || '';
|
prompt.value += ` ${resolveText}`;
|
||||||
if (!promptResolve) {
|
|
||||||
throw new Error('gpt 异常');
|
|
||||||
}
|
|
||||||
|
|
||||||
prompt.value += ` ${promptResolve}`;
|
|
||||||
console.log('prompt resolve success, time:', `${(Date.now() - startTime) / 1000}s`);
|
console.log('prompt resolve success, time:', `${(Date.now() - startTime) / 1000}s`);
|
||||||
|
|
||||||
// 读取对话内容
|
// 读取对话内容
|
||||||
const prompts = [prompt];
|
const prompts = [prompt];
|
||||||
|
|
||||||
// 获取向量匹配到的提示词
|
// 获取向量匹配到的提示词
|
||||||
const { searchPrompt } = await searchKb_openai({
|
const { searchPrompt } = await searchKb({
|
||||||
isPay: true,
|
systemApiKey: apiKey,
|
||||||
apiKey,
|
similarity: ModelVectorSearchModeMap[model.chat.searchMode]?.similarity,
|
||||||
similarity: ModelVectorSearchModeMap[model.chat.searchMode]?.similarity || 0.22,
|
|
||||||
text: prompt.value,
|
text: prompt.value,
|
||||||
model,
|
model,
|
||||||
userId
|
userId
|
||||||
@ -139,49 +125,41 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
|
|||||||
|
|
||||||
searchPrompt && prompts.unshift(searchPrompt);
|
searchPrompt && prompts.unshift(searchPrompt);
|
||||||
|
|
||||||
// 控制上下文 tokens 数量,防止超出
|
|
||||||
const filterPrompts = openaiChatFilter({
|
|
||||||
model: model.chat.chatModel,
|
|
||||||
prompts,
|
|
||||||
maxTokens: modelConstantsData.contextMaxToken - 300
|
|
||||||
});
|
|
||||||
|
|
||||||
// console.log(filterPrompts);
|
|
||||||
// 计算温度
|
// 计算温度
|
||||||
const temperature = (modelConstantsData.maxTemperature * (model.chat.temperature / 10)).toFixed(
|
const temperature = (modelConstantsData.maxTemperature * (model.chat.temperature / 10)).toFixed(
|
||||||
2
|
2
|
||||||
);
|
);
|
||||||
// 发出请求
|
|
||||||
const chatResponse = await chatAPI.createChatCompletion(
|
|
||||||
{
|
|
||||||
model: model.chat.chatModel,
|
|
||||||
temperature: Number(temperature) || 0,
|
|
||||||
messages: filterPrompts,
|
|
||||||
frequency_penalty: 0.5, // 越大,重复内容越少
|
|
||||||
presence_penalty: -0.5, // 越大,越容易出现新内容
|
|
||||||
stream: isStream
|
|
||||||
},
|
|
||||||
{
|
|
||||||
timeout: 180000,
|
|
||||||
responseType: isStream ? 'stream' : 'json',
|
|
||||||
...axiosConfig()
|
|
||||||
}
|
|
||||||
);
|
|
||||||
|
|
||||||
let responseContent = '';
|
// 发出请求
|
||||||
|
const { streamResponse, responseMessages, responseText, totalTokens } =
|
||||||
|
await modelServiceToolMap[model.chat.chatModel].chatCompletion({
|
||||||
|
apiKey,
|
||||||
|
temperature: +temperature,
|
||||||
|
messages: prompts,
|
||||||
|
stream: isStream
|
||||||
|
});
|
||||||
|
|
||||||
|
console.log('api response time:', `${(Date.now() - startTime) / 1000}s`);
|
||||||
|
|
||||||
|
let textLen = resolveText.length;
|
||||||
|
let tokens = resolveTokens;
|
||||||
|
|
||||||
if (isStream) {
|
if (isStream) {
|
||||||
step = 1;
|
step = 1;
|
||||||
const streamResponse = await gpt35StreamResponse({
|
const { finishMessages, totalTokens } = await resStreamResponse({
|
||||||
|
model: model.chat.chatModel,
|
||||||
res,
|
res,
|
||||||
stream,
|
stream,
|
||||||
chatResponse
|
chatResponse: streamResponse,
|
||||||
|
prompts
|
||||||
});
|
});
|
||||||
responseContent = streamResponse.responseContent;
|
textLen += finishMessages.map((item) => item.value).join('').length;
|
||||||
|
tokens += totalTokens;
|
||||||
} else {
|
} else {
|
||||||
responseContent = chatResponse.data.choices?.[0]?.message?.content || '';
|
textLen += responseMessages.map((item) => item.value).join('').length;
|
||||||
|
tokens += totalTokens;
|
||||||
jsonRes(res, {
|
jsonRes(res, {
|
||||||
data: responseContent
|
data: responseText
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -191,7 +169,8 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
|
|||||||
isPay: true,
|
isPay: true,
|
||||||
chatModel: model.chat.chatModel,
|
chatModel: model.chat.chatModel,
|
||||||
userId,
|
userId,
|
||||||
messages: filterPrompts.concat({ role: 'assistant', content: responseContent })
|
textLen,
|
||||||
|
tokens
|
||||||
});
|
});
|
||||||
} catch (err: any) {
|
} catch (err: any) {
|
||||||
if (step === 1) {
|
if (step === 1) {
|
||||||
|
|||||||
@ -1,159 +0,0 @@
|
|||||||
import type { NextApiRequest, NextApiResponse } from 'next';
|
|
||||||
import { connectToDatabase, Model } from '@/service/mongo';
|
|
||||||
import { axiosConfig, openaiChatFilter } from '@/service/utils/tools';
|
|
||||||
import { getOpenAIApi, authOpenApiKey } from '@/service/utils/auth';
|
|
||||||
import { ChatItemSimpleType } from '@/types/chat';
|
|
||||||
import { jsonRes } from '@/service/response';
|
|
||||||
import { PassThrough } from 'stream';
|
|
||||||
import {
|
|
||||||
ChatModelMap,
|
|
||||||
ModelVectorSearchModeMap,
|
|
||||||
ModelVectorSearchModeEnum
|
|
||||||
} from '@/constants/model';
|
|
||||||
import { pushChatBill } from '@/service/events/pushBill';
|
|
||||||
import { gpt35StreamResponse } from '@/service/utils/openai';
|
|
||||||
import { searchKb_openai } from '@/service/tools/searchKb';
|
|
||||||
|
|
||||||
/* 发送提示词 */
|
|
||||||
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
|
|
||||||
let step = 0; // step=1时,表示开始了流响应
|
|
||||||
const stream = new PassThrough();
|
|
||||||
stream.on('error', () => {
|
|
||||||
console.log('error: ', 'stream error');
|
|
||||||
stream.destroy();
|
|
||||||
});
|
|
||||||
res.on('close', () => {
|
|
||||||
stream.destroy();
|
|
||||||
});
|
|
||||||
res.on('error', () => {
|
|
||||||
console.log('error: ', 'request error');
|
|
||||||
stream.destroy();
|
|
||||||
});
|
|
||||||
|
|
||||||
try {
|
|
||||||
const {
|
|
||||||
prompts,
|
|
||||||
modelId,
|
|
||||||
isStream = true
|
|
||||||
} = req.body as {
|
|
||||||
prompts: ChatItemSimpleType[];
|
|
||||||
modelId: string;
|
|
||||||
isStream: boolean;
|
|
||||||
};
|
|
||||||
|
|
||||||
if (!prompts || !modelId) {
|
|
||||||
throw new Error('缺少参数');
|
|
||||||
}
|
|
||||||
if (!Array.isArray(prompts)) {
|
|
||||||
throw new Error('prompts is not array');
|
|
||||||
}
|
|
||||||
if (prompts.length > 30 || prompts.length === 0) {
|
|
||||||
throw new Error('prompts length range 1-30');
|
|
||||||
}
|
|
||||||
|
|
||||||
await connectToDatabase();
|
|
||||||
let startTime = Date.now();
|
|
||||||
|
|
||||||
/* 凭证校验 */
|
|
||||||
const { apiKey, userId } = await authOpenApiKey(req);
|
|
||||||
|
|
||||||
const model = await Model.findOne({
|
|
||||||
_id: modelId,
|
|
||||||
userId
|
|
||||||
});
|
|
||||||
|
|
||||||
if (!model) {
|
|
||||||
throw new Error('无权使用该模型');
|
|
||||||
}
|
|
||||||
|
|
||||||
const modelConstantsData = ChatModelMap[model.chat.chatModel];
|
|
||||||
|
|
||||||
// 获取向量匹配到的提示词
|
|
||||||
const { code, searchPrompt } = await searchKb_openai({
|
|
||||||
isPay: true,
|
|
||||||
apiKey,
|
|
||||||
similarity: ModelVectorSearchModeMap[model.chat.searchMode]?.similarity || 0.22,
|
|
||||||
text: prompts[prompts.length - 1].value,
|
|
||||||
model,
|
|
||||||
userId
|
|
||||||
});
|
|
||||||
|
|
||||||
// search result is empty
|
|
||||||
if (code === 201) {
|
|
||||||
return res.send(searchPrompt?.value);
|
|
||||||
}
|
|
||||||
|
|
||||||
searchPrompt && prompts.unshift(searchPrompt);
|
|
||||||
|
|
||||||
// 控制在 tokens 数量,防止超出
|
|
||||||
const filterPrompts = openaiChatFilter({
|
|
||||||
model: model.chat.chatModel,
|
|
||||||
prompts,
|
|
||||||
maxTokens: modelConstantsData.contextMaxToken - 300
|
|
||||||
});
|
|
||||||
|
|
||||||
// console.log(filterPrompts);
|
|
||||||
// 计算温度
|
|
||||||
const temperature = (modelConstantsData.maxTemperature * (model.chat.temperature / 10)).toFixed(
|
|
||||||
2
|
|
||||||
);
|
|
||||||
const chatAPI = getOpenAIApi(apiKey);
|
|
||||||
|
|
||||||
// 发出请求
|
|
||||||
const chatResponse = await chatAPI.createChatCompletion(
|
|
||||||
{
|
|
||||||
model: model.chat.chatModel,
|
|
||||||
temperature: Number(temperature) || 0,
|
|
||||||
messages: filterPrompts,
|
|
||||||
frequency_penalty: 0.5, // 越大,重复内容越少
|
|
||||||
presence_penalty: -0.5, // 越大,越容易出现新内容
|
|
||||||
stream: isStream,
|
|
||||||
stop: ['.!?。']
|
|
||||||
},
|
|
||||||
{
|
|
||||||
timeout: 180000,
|
|
||||||
responseType: isStream ? 'stream' : 'json',
|
|
||||||
...axiosConfig()
|
|
||||||
}
|
|
||||||
);
|
|
||||||
|
|
||||||
console.log('api response time:', `${(Date.now() - startTime) / 1000}s`);
|
|
||||||
|
|
||||||
let responseContent = '';
|
|
||||||
|
|
||||||
if (isStream) {
|
|
||||||
step = 1;
|
|
||||||
const streamResponse = await gpt35StreamResponse({
|
|
||||||
res,
|
|
||||||
stream,
|
|
||||||
chatResponse
|
|
||||||
});
|
|
||||||
responseContent = streamResponse.responseContent;
|
|
||||||
} else {
|
|
||||||
responseContent = chatResponse.data.choices?.[0]?.message?.content || '';
|
|
||||||
jsonRes(res, {
|
|
||||||
data: responseContent
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
pushChatBill({
|
|
||||||
isPay: true,
|
|
||||||
chatModel: model.chat.chatModel,
|
|
||||||
userId,
|
|
||||||
messages: filterPrompts.concat({ role: 'assistant', content: responseContent })
|
|
||||||
});
|
|
||||||
// jsonRes(res);
|
|
||||||
} catch (err: any) {
|
|
||||||
if (step === 1) {
|
|
||||||
// 直接结束流
|
|
||||||
console.log('error,结束');
|
|
||||||
stream.destroy();
|
|
||||||
} else {
|
|
||||||
res.status(500);
|
|
||||||
jsonRes(res, {
|
|
||||||
code: 500,
|
|
||||||
error: err
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -1,5 +0,0 @@
|
|||||||
export enum OpenAiTuneStatusEnum {
|
|
||||||
cancelled = 'cancelled',
|
|
||||||
succeeded = 'succeeded',
|
|
||||||
pending = 'pending'
|
|
||||||
}
|
|
||||||
@ -1,14 +1,13 @@
|
|||||||
import { SplitData } from '@/service/mongo';
|
import { SplitData } from '@/service/mongo';
|
||||||
import { getOpenAIApi } from '@/service/utils/auth';
|
import { getApiKey } from '../utils/auth';
|
||||||
import { axiosConfig } from '@/service/utils/tools';
|
|
||||||
import { getOpenApiKey } from '../utils/openai';
|
|
||||||
import type { ChatCompletionRequestMessage } from 'openai';
|
|
||||||
import { OpenAiChatEnum } from '@/constants/model';
|
import { OpenAiChatEnum } from '@/constants/model';
|
||||||
import { pushSplitDataBill } from '@/service/events/pushBill';
|
import { pushSplitDataBill } from '@/service/events/pushBill';
|
||||||
import { generateVector } from './generateVector';
|
import { generateVector } from './generateVector';
|
||||||
import { openaiError2 } from '../errorCode';
|
import { openaiError2 } from '../errorCode';
|
||||||
import { PgClient } from '@/service/pg';
|
import { PgClient } from '@/service/pg';
|
||||||
import { ModelSplitDataSchema } from '@/types/mongoSchema';
|
import { ModelSplitDataSchema } from '@/types/mongoSchema';
|
||||||
|
import { modelServiceToolMap } from '../utils/chat';
|
||||||
|
import { ChatRoleEnum } from '@/constants/chat';
|
||||||
|
|
||||||
export async function generateQA(next = false): Promise<any> {
|
export async function generateQA(next = false): Promise<any> {
|
||||||
if (process.env.queueTask !== '1') {
|
if (process.env.queueTask !== '1') {
|
||||||
@ -47,11 +46,11 @@ export async function generateQA(next = false): Promise<any> {
|
|||||||
|
|
||||||
// 获取 openapi Key
|
// 获取 openapi Key
|
||||||
let userApiKey = '',
|
let userApiKey = '',
|
||||||
systemKey = '';
|
systemApiKey = '';
|
||||||
try {
|
try {
|
||||||
const key = await getOpenApiKey(dataItem.userId);
|
const key = await getApiKey({ model: OpenAiChatEnum.GPT35, userId: dataItem.userId });
|
||||||
userApiKey = key.userApiKey;
|
userApiKey = key.userApiKey;
|
||||||
systemKey = key.systemKey;
|
systemApiKey = key.systemApiKey;
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
if (error?.code === 501) {
|
if (error?.code === 501) {
|
||||||
// 余额不够了, 清空该记录
|
// 余额不够了, 清空该记录
|
||||||
@ -69,55 +68,44 @@ export async function generateQA(next = false): Promise<any> {
|
|||||||
|
|
||||||
const startTime = Date.now();
|
const startTime = Date.now();
|
||||||
|
|
||||||
// 获取 openai 请求实例
|
|
||||||
const chatAPI = getOpenAIApi(userApiKey || systemKey);
|
|
||||||
const systemPrompt: ChatCompletionRequestMessage = {
|
|
||||||
role: 'system',
|
|
||||||
content: `你是出题人
|
|
||||||
${dataItem.prompt || '下面是"一段长文本"'}
|
|
||||||
从中选出5至20个题目和答案.答案详细.按格式返回: Q1:
|
|
||||||
A1:
|
|
||||||
Q2:
|
|
||||||
A2:
|
|
||||||
...`
|
|
||||||
};
|
|
||||||
|
|
||||||
// 请求 chatgpt 获取回答
|
// 请求 chatgpt 获取回答
|
||||||
const response = await Promise.allSettled(
|
const response = await Promise.allSettled(
|
||||||
textList.map((text) =>
|
textList.map((text) =>
|
||||||
chatAPI
|
modelServiceToolMap[OpenAiChatEnum.GPT35]
|
||||||
.createChatCompletion(
|
.chatCompletion({
|
||||||
{
|
apiKey: userApiKey || systemApiKey,
|
||||||
model: OpenAiChatEnum.GPT35,
|
temperature: 0.8,
|
||||||
temperature: 0.8,
|
messages: [
|
||||||
n: 1,
|
{
|
||||||
messages: [
|
obj: ChatRoleEnum.System,
|
||||||
systemPrompt,
|
value: `你是出题人
|
||||||
{
|
${dataItem.prompt || '下面是"一段长文本"'}
|
||||||
role: 'user',
|
从中选出5至20个题目和答案.答案详细.按格式返回: Q1:
|
||||||
content: text
|
A1:
|
||||||
}
|
Q2:
|
||||||
]
|
A2:
|
||||||
},
|
...`
|
||||||
{
|
},
|
||||||
timeout: 180000,
|
{
|
||||||
...axiosConfig()
|
obj: 'Human',
|
||||||
}
|
value: text
|
||||||
)
|
}
|
||||||
.then((res) => {
|
],
|
||||||
const rawContent = res?.data.choices[0].message?.content || ''; // chatgpt 原本的回复
|
stream: false
|
||||||
const result = formatSplitText(res?.data.choices[0].message?.content || ''); // 格式化后的QA对
|
})
|
||||||
|
.then(({ totalTokens, responseText, responseMessages }) => {
|
||||||
|
const result = formatSplitText(responseText); // 格式化后的QA对
|
||||||
console.log(`split result length: `, result.length);
|
console.log(`split result length: `, result.length);
|
||||||
// 计费
|
// 计费
|
||||||
pushSplitDataBill({
|
pushSplitDataBill({
|
||||||
isPay: !userApiKey && result.length > 0,
|
isPay: !userApiKey && result.length > 0,
|
||||||
userId: dataItem.userId,
|
userId: dataItem.userId,
|
||||||
type: 'QA',
|
type: 'QA',
|
||||||
text: systemPrompt.content + text + rawContent,
|
textLen: responseMessages.map((item) => item.value).join('').length,
|
||||||
tokenLen: res.data.usage?.total_tokens || 0
|
totalTokens
|
||||||
});
|
});
|
||||||
return {
|
return {
|
||||||
rawContent,
|
rawContent: responseText,
|
||||||
result
|
result
|
||||||
};
|
};
|
||||||
})
|
})
|
||||||
|
|||||||
@ -1,6 +1,8 @@
|
|||||||
import { openaiCreateEmbedding, getOpenApiKey } from '../utils/openai';
|
import { openaiCreateEmbedding } from '../utils/chat/openai';
|
||||||
|
import { getApiKey } from '../utils/auth';
|
||||||
import { openaiError2 } from '../errorCode';
|
import { openaiError2 } from '../errorCode';
|
||||||
import { PgClient } from '@/service/pg';
|
import { PgClient } from '@/service/pg';
|
||||||
|
import { embeddingModel } from '@/constants/model';
|
||||||
|
|
||||||
export async function generateVector(next = false): Promise<any> {
|
export async function generateVector(next = false): Promise<any> {
|
||||||
if (process.env.queueTask !== '1') {
|
if (process.env.queueTask !== '1') {
|
||||||
@ -40,11 +42,11 @@ export async function generateVector(next = false): Promise<any> {
|
|||||||
dataId = dataItem.id;
|
dataId = dataItem.id;
|
||||||
|
|
||||||
// 获取 openapi Key
|
// 获取 openapi Key
|
||||||
let userApiKey, systemKey;
|
let userApiKey, systemApiKey;
|
||||||
try {
|
try {
|
||||||
const res = await getOpenApiKey(dataItem.userId);
|
const res = await getApiKey({ model: embeddingModel, userId: dataItem.userId });
|
||||||
userApiKey = res.userApiKey;
|
userApiKey = res.userApiKey;
|
||||||
systemKey = res.systemKey;
|
systemApiKey = res.systemApiKey;
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
if (error?.code === 501) {
|
if (error?.code === 501) {
|
||||||
await PgClient.delete('modelData', {
|
await PgClient.delete('modelData', {
|
||||||
@ -61,8 +63,8 @@ export async function generateVector(next = false): Promise<any> {
|
|||||||
const { vector } = await openaiCreateEmbedding({
|
const { vector } = await openaiCreateEmbedding({
|
||||||
text: dataItem.q,
|
text: dataItem.q,
|
||||||
userId: dataItem.userId,
|
userId: dataItem.userId,
|
||||||
isPay: !userApiKey,
|
userApiKey,
|
||||||
apiKey: userApiKey || systemKey
|
systemApiKey
|
||||||
});
|
});
|
||||||
|
|
||||||
// 更新 pg 向量和状态数据
|
// 更新 pg 向量和状态数据
|
||||||
|
|||||||
@ -1,60 +1,54 @@
|
|||||||
import { connectToDatabase, Bill, User } from '../mongo';
|
import { connectToDatabase, Bill, User } from '../mongo';
|
||||||
import { ChatModelMap, OpenAiChatEnum, ChatModelType, embeddingModel } from '@/constants/model';
|
import { ChatModelMap, OpenAiChatEnum, ChatModelType, embeddingModel } from '@/constants/model';
|
||||||
import { BillTypeEnum } from '@/constants/user';
|
import { BillTypeEnum } from '@/constants/user';
|
||||||
import { countChatTokens } from '@/utils/tools';
|
|
||||||
|
|
||||||
export const pushChatBill = async ({
|
export const pushChatBill = async ({
|
||||||
isPay,
|
isPay,
|
||||||
chatModel,
|
chatModel,
|
||||||
userId,
|
userId,
|
||||||
chatId,
|
chatId,
|
||||||
messages
|
textLen,
|
||||||
|
tokens
|
||||||
}: {
|
}: {
|
||||||
isPay: boolean;
|
isPay: boolean;
|
||||||
chatModel: ChatModelType;
|
chatModel: ChatModelType;
|
||||||
userId: string;
|
userId: string;
|
||||||
chatId?: '' | string;
|
chatId?: '' | string;
|
||||||
messages: { role: 'system' | 'user' | 'assistant'; content: string }[];
|
textLen: number;
|
||||||
|
tokens: number;
|
||||||
}) => {
|
}) => {
|
||||||
|
console.log(`chat generate success. text len: ${textLen}. token len: ${tokens}. pay:${isPay}`);
|
||||||
|
if (!isPay) return;
|
||||||
|
|
||||||
let billId = '';
|
let billId = '';
|
||||||
|
|
||||||
try {
|
try {
|
||||||
// 计算 token 数量
|
await connectToDatabase();
|
||||||
const tokens = countChatTokens({ model: chatModel, messages });
|
|
||||||
const text = messages.map((item) => item.content).join('');
|
|
||||||
|
|
||||||
console.log(
|
// 计算价格
|
||||||
`chat generate success. text len: ${text.length}. token len: ${tokens}. pay:${isPay}`
|
const unitPrice = ChatModelMap[chatModel]?.price || 5;
|
||||||
);
|
const price = unitPrice * tokens;
|
||||||
|
|
||||||
if (isPay) {
|
try {
|
||||||
await connectToDatabase();
|
// 插入 Bill 记录
|
||||||
|
const res = await Bill.create({
|
||||||
|
userId,
|
||||||
|
type: 'chat',
|
||||||
|
modelName: chatModel,
|
||||||
|
chatId: chatId ? chatId : undefined,
|
||||||
|
textLen,
|
||||||
|
tokenLen: tokens,
|
||||||
|
price
|
||||||
|
});
|
||||||
|
billId = res._id;
|
||||||
|
|
||||||
// 计算价格
|
// 账号扣费
|
||||||
const unitPrice = ChatModelMap[chatModel]?.price || 5;
|
await User.findByIdAndUpdate(userId, {
|
||||||
const price = unitPrice * tokens;
|
$inc: { balance: -price }
|
||||||
|
});
|
||||||
try {
|
} catch (error) {
|
||||||
// 插入 Bill 记录
|
console.log('创建账单失败:', error);
|
||||||
const res = await Bill.create({
|
billId && Bill.findByIdAndDelete(billId);
|
||||||
userId,
|
|
||||||
type: 'chat',
|
|
||||||
modelName: chatModel,
|
|
||||||
chatId: chatId ? chatId : undefined,
|
|
||||||
textLen: text.length,
|
|
||||||
tokenLen: tokens,
|
|
||||||
price
|
|
||||||
});
|
|
||||||
billId = res._id;
|
|
||||||
|
|
||||||
// 账号扣费
|
|
||||||
await User.findByIdAndUpdate(userId, {
|
|
||||||
$inc: { balance: -price }
|
|
||||||
});
|
|
||||||
} catch (error) {
|
|
||||||
console.log('创建账单失败:', error);
|
|
||||||
billId && Bill.findByIdAndDelete(billId);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.log(error);
|
console.log(error);
|
||||||
@ -64,54 +58,49 @@ export const pushChatBill = async ({
|
|||||||
export const pushSplitDataBill = async ({
|
export const pushSplitDataBill = async ({
|
||||||
isPay,
|
isPay,
|
||||||
userId,
|
userId,
|
||||||
tokenLen,
|
totalTokens,
|
||||||
text,
|
textLen,
|
||||||
type
|
type
|
||||||
}: {
|
}: {
|
||||||
isPay: boolean;
|
isPay: boolean;
|
||||||
userId: string;
|
userId: string;
|
||||||
tokenLen: number;
|
totalTokens: number;
|
||||||
text: string;
|
textLen: number;
|
||||||
type: `${BillTypeEnum}`;
|
type: `${BillTypeEnum}`;
|
||||||
}) => {
|
}) => {
|
||||||
await connectToDatabase();
|
console.log(
|
||||||
|
`splitData generate success. text len: ${textLen}. token len: ${totalTokens}. pay:${isPay}`
|
||||||
|
);
|
||||||
|
if (!isPay) return;
|
||||||
|
|
||||||
let billId;
|
let billId;
|
||||||
|
|
||||||
try {
|
try {
|
||||||
console.log(
|
await connectToDatabase();
|
||||||
`splitData generate success. text len: ${text.length}. token len: ${tokenLen}. pay:${isPay}`
|
|
||||||
);
|
|
||||||
|
|
||||||
if (isPay) {
|
// 获取模型单价格, 都是用 gpt35 拆分
|
||||||
try {
|
const unitPrice = ChatModelMap[OpenAiChatEnum.GPT35].price || 3;
|
||||||
// 获取模型单价格, 都是用 gpt35 拆分
|
// 计算价格
|
||||||
const unitPrice = ChatModelMap[OpenAiChatEnum.GPT35]?.price || 3;
|
const price = unitPrice * totalTokens;
|
||||||
// 计算价格
|
|
||||||
const price = unitPrice * tokenLen;
|
|
||||||
|
|
||||||
// 插入 Bill 记录
|
// 插入 Bill 记录
|
||||||
const res = await Bill.create({
|
const res = await Bill.create({
|
||||||
userId,
|
userId,
|
||||||
type,
|
type,
|
||||||
modelName: OpenAiChatEnum.GPT35,
|
modelName: OpenAiChatEnum.GPT35,
|
||||||
textLen: text.length,
|
textLen,
|
||||||
tokenLen,
|
tokenLen: totalTokens,
|
||||||
price
|
price
|
||||||
});
|
});
|
||||||
billId = res._id;
|
billId = res._id;
|
||||||
|
|
||||||
// 账号扣费
|
// 账号扣费
|
||||||
await User.findByIdAndUpdate(userId, {
|
await User.findByIdAndUpdate(userId, {
|
||||||
$inc: { balance: -price }
|
$inc: { balance: -price }
|
||||||
});
|
});
|
||||||
} catch (error) {
|
|
||||||
console.log('创建账单失败:', error);
|
|
||||||
billId && Bill.findByIdAndDelete(billId);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.log(error);
|
console.log('创建账单失败:', error);
|
||||||
|
billId && Bill.findByIdAndDelete(billId);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -126,41 +115,40 @@ export const pushGenerateVectorBill = async ({
|
|||||||
text: string;
|
text: string;
|
||||||
tokenLen: number;
|
tokenLen: number;
|
||||||
}) => {
|
}) => {
|
||||||
await connectToDatabase();
|
console.log(
|
||||||
|
`vector generate success. text len: ${text.length}. token len: ${tokenLen}. pay:${isPay}`
|
||||||
|
);
|
||||||
|
if (!isPay) return;
|
||||||
|
|
||||||
let billId;
|
let billId;
|
||||||
|
|
||||||
try {
|
try {
|
||||||
console.log(
|
await connectToDatabase();
|
||||||
`vector generate success. text len: ${text.length}. token len: ${tokenLen}. pay:${isPay}`
|
|
||||||
);
|
|
||||||
|
|
||||||
if (isPay) {
|
try {
|
||||||
try {
|
const unitPrice = 0.4;
|
||||||
const unitPrice = 0.4;
|
// 计算价格. 至少为1
|
||||||
// 计算价格. 至少为1
|
let price = unitPrice * tokenLen;
|
||||||
let price = unitPrice * tokenLen;
|
price = price > 1 ? price : 1;
|
||||||
price = price > 1 ? price : 1;
|
|
||||||
|
|
||||||
// 插入 Bill 记录
|
// 插入 Bill 记录
|
||||||
const res = await Bill.create({
|
const res = await Bill.create({
|
||||||
userId,
|
userId,
|
||||||
type: BillTypeEnum.vector,
|
type: BillTypeEnum.vector,
|
||||||
modelName: embeddingModel,
|
modelName: embeddingModel,
|
||||||
textLen: text.length,
|
textLen: text.length,
|
||||||
tokenLen,
|
tokenLen,
|
||||||
price
|
price
|
||||||
});
|
});
|
||||||
billId = res._id;
|
billId = res._id;
|
||||||
|
|
||||||
// 账号扣费
|
// 账号扣费
|
||||||
await User.findByIdAndUpdate(userId, {
|
await User.findByIdAndUpdate(userId, {
|
||||||
$inc: { balance: -price }
|
$inc: { balance: -price }
|
||||||
});
|
});
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.log('创建账单失败:', error);
|
console.log('创建账单失败:', error);
|
||||||
billId && Bill.findByIdAndDelete(billId);
|
billId && Bill.findByIdAndDelete(billId);
|
||||||
}
|
|
||||||
}
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.log(error);
|
console.log(error);
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
import { Schema, model, models, Model } from 'mongoose';
|
import { Schema, model, models, Model } from 'mongoose';
|
||||||
import { ChatSchema as ChatType } from '@/types/mongoSchema';
|
import { ChatSchema as ChatType } from '@/types/mongoSchema';
|
||||||
|
import { ChatRoleMap } from '@/constants/chat';
|
||||||
|
|
||||||
const ChatSchema = new Schema({
|
const ChatSchema = new Schema({
|
||||||
userId: {
|
userId: {
|
||||||
@ -36,7 +37,7 @@ const ChatSchema = new Schema({
|
|||||||
obj: {
|
obj: {
|
||||||
type: String,
|
type: String,
|
||||||
required: true,
|
required: true,
|
||||||
enum: ['Human', 'AI', 'SYSTEM']
|
enum: Object.keys(ChatRoleMap)
|
||||||
},
|
},
|
||||||
value: {
|
value: {
|
||||||
type: String,
|
type: String,
|
||||||
|
|||||||
@ -1,22 +1,23 @@
|
|||||||
import { openaiCreateEmbedding } from '../utils/openai';
|
|
||||||
import { PgClient } from '@/service/pg';
|
import { PgClient } from '@/service/pg';
|
||||||
import { ModelDataStatusEnum, ModelVectorSearchModeEnum, ChatModelMap } from '@/constants/model';
|
import { ModelDataStatusEnum, ModelVectorSearchModeEnum, ChatModelMap } from '@/constants/model';
|
||||||
import { ModelSchema } from '@/types/mongoSchema';
|
import { ModelSchema } from '@/types/mongoSchema';
|
||||||
import { systemPromptFilter } from '../utils/tools';
|
import { openaiCreateEmbedding } from '../utils/chat/openai';
|
||||||
|
import { ChatRoleEnum } from '@/constants/chat';
|
||||||
|
import { sliceTextByToken } from '@/utils/chat';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* use openai embedding search kb
|
* use openai embedding search kb
|
||||||
*/
|
*/
|
||||||
export const searchKb_openai = async ({
|
export const searchKb = async ({
|
||||||
apiKey,
|
userApiKey,
|
||||||
isPay = true,
|
systemApiKey,
|
||||||
text,
|
text,
|
||||||
similarity = 0.2,
|
similarity = 0.2,
|
||||||
model,
|
model,
|
||||||
userId
|
userId
|
||||||
}: {
|
}: {
|
||||||
apiKey: string;
|
userApiKey?: string;
|
||||||
isPay: boolean;
|
systemApiKey: string;
|
||||||
text: string;
|
text: string;
|
||||||
model: ModelSchema;
|
model: ModelSchema;
|
||||||
userId: string;
|
userId: string;
|
||||||
@ -24,7 +25,7 @@ export const searchKb_openai = async ({
|
|||||||
}): Promise<{
|
}): Promise<{
|
||||||
code: 200 | 201;
|
code: 200 | 201;
|
||||||
searchPrompt?: {
|
searchPrompt?: {
|
||||||
obj: 'Human' | 'AI' | 'SYSTEM';
|
obj: `${ChatRoleEnum}`;
|
||||||
value: string;
|
value: string;
|
||||||
};
|
};
|
||||||
}> => {
|
}> => {
|
||||||
@ -32,8 +33,8 @@ export const searchKb_openai = async ({
|
|||||||
|
|
||||||
// 获取提示词的向量
|
// 获取提示词的向量
|
||||||
const { vector: promptVector } = await openaiCreateEmbedding({
|
const { vector: promptVector } = await openaiCreateEmbedding({
|
||||||
isPay,
|
userApiKey,
|
||||||
apiKey,
|
systemApiKey,
|
||||||
userId,
|
userId,
|
||||||
text
|
text
|
||||||
});
|
});
|
||||||
@ -61,7 +62,7 @@ export const searchKb_openai = async ({
|
|||||||
return {
|
return {
|
||||||
code: 201,
|
code: 201,
|
||||||
searchPrompt: {
|
searchPrompt: {
|
||||||
obj: 'AI',
|
obj: ChatRoleEnum.AI,
|
||||||
value: '对不起,你的问题不在知识库中。'
|
value: '对不起,你的问题不在知识库中。'
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -72,7 +73,7 @@ export const searchKb_openai = async ({
|
|||||||
code: 200,
|
code: 200,
|
||||||
searchPrompt: model.chat.systemPrompt
|
searchPrompt: model.chat.systemPrompt
|
||||||
? {
|
? {
|
||||||
obj: 'SYSTEM',
|
obj: ChatRoleEnum.System,
|
||||||
value: model.chat.systemPrompt
|
value: model.chat.systemPrompt
|
||||||
}
|
}
|
||||||
: undefined
|
: undefined
|
||||||
@ -81,16 +82,16 @@ export const searchKb_openai = async ({
|
|||||||
|
|
||||||
// 有匹配情况下,system 添加知识库内容。
|
// 有匹配情况下,system 添加知识库内容。
|
||||||
// 系统提示词过滤,最多 65% tokens
|
// 系统提示词过滤,最多 65% tokens
|
||||||
const filterSystemPrompt = systemPromptFilter({
|
const filterSystemPrompt = sliceTextByToken({
|
||||||
model: model.chat.chatModel,
|
model: model.chat.chatModel,
|
||||||
prompts: systemPrompts,
|
text: systemPrompts.join('\n'),
|
||||||
maxTokens: Math.floor(modelConstantsData.contextMaxToken * 0.65)
|
length: Math.floor(modelConstantsData.contextMaxToken * 0.65)
|
||||||
});
|
});
|
||||||
|
|
||||||
return {
|
return {
|
||||||
code: 200,
|
code: 200,
|
||||||
searchPrompt: {
|
searchPrompt: {
|
||||||
obj: 'SYSTEM',
|
obj: ChatRoleEnum.System,
|
||||||
value: `
|
value: `
|
||||||
${model.chat.systemPrompt}
|
${model.chat.systemPrompt}
|
||||||
${
|
${
|
||||||
@ -1,14 +1,18 @@
|
|||||||
import { Configuration, OpenAIApi } from 'openai';
|
|
||||||
import type { NextApiRequest } from 'next';
|
import type { NextApiRequest } from 'next';
|
||||||
import jwt from 'jsonwebtoken';
|
import jwt from 'jsonwebtoken';
|
||||||
import { Chat, Model, OpenApi, User } from '../mongo';
|
import { Chat, Model, OpenApi, User } from '../mongo';
|
||||||
import type { ModelSchema } from '@/types/mongoSchema';
|
import type { ModelSchema } from '@/types/mongoSchema';
|
||||||
import { getOpenApiKey } from './openai';
|
|
||||||
import type { ChatItemSimpleType } from '@/types/chat';
|
import type { ChatItemSimpleType } from '@/types/chat';
|
||||||
import mongoose from 'mongoose';
|
import mongoose from 'mongoose';
|
||||||
import { defaultModel } from '@/constants/model';
|
import { defaultModel } from '@/constants/model';
|
||||||
import { formatPrice } from '@/utils/user';
|
import { formatPrice } from '@/utils/user';
|
||||||
import { ERROR_ENUM } from '../errorCode';
|
import { ERROR_ENUM } from '../errorCode';
|
||||||
|
import {
|
||||||
|
ChatModelType,
|
||||||
|
OpenAiChatEnum,
|
||||||
|
embeddingModel,
|
||||||
|
EmbeddingModelType
|
||||||
|
} from '@/constants/model';
|
||||||
|
|
||||||
/* 校验 token */
|
/* 校验 token */
|
||||||
export const authToken = (token?: string): Promise<string> => {
|
export const authToken = (token?: string): Promise<string> => {
|
||||||
@ -29,13 +33,63 @@ export const authToken = (token?: string): Promise<string> => {
|
|||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
|
||||||
export const getOpenAIApi = (apiKey: string) => {
|
/* 获取 api 请求的 key */
|
||||||
const configuration = new Configuration({
|
export const getApiKey = async ({
|
||||||
apiKey,
|
model,
|
||||||
basePath: process.env.OPENAI_BASE_URL
|
userId
|
||||||
});
|
}: {
|
||||||
|
model: ChatModelType | EmbeddingModelType;
|
||||||
|
userId: string;
|
||||||
|
}) => {
|
||||||
|
const user = await User.findById(userId);
|
||||||
|
if (!user) {
|
||||||
|
return Promise.reject({
|
||||||
|
code: 501,
|
||||||
|
message: '找不到用户'
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
return new OpenAIApi(configuration);
|
const keyMap = {
|
||||||
|
[OpenAiChatEnum.GPT35]: {
|
||||||
|
userApiKey: user.openaiKey || '',
|
||||||
|
systemApiKey: process.env.OPENAIKEY as string
|
||||||
|
},
|
||||||
|
[OpenAiChatEnum.GPT4]: {
|
||||||
|
userApiKey: user.openaiKey || '',
|
||||||
|
systemApiKey: process.env.OPENAIKEY as string
|
||||||
|
},
|
||||||
|
[OpenAiChatEnum.GPT432k]: {
|
||||||
|
userApiKey: user.openaiKey || '',
|
||||||
|
systemApiKey: process.env.OPENAIKEY as string
|
||||||
|
},
|
||||||
|
[embeddingModel]: {
|
||||||
|
userApiKey: user.openaiKey || '',
|
||||||
|
systemApiKey: process.env.OPENAIKEY as string
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// 有自己的key
|
||||||
|
if (keyMap[model].userApiKey) {
|
||||||
|
return {
|
||||||
|
user,
|
||||||
|
userApiKey: keyMap[model].userApiKey,
|
||||||
|
systemApiKey: ''
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
// 平台账号余额校验
|
||||||
|
if (formatPrice(user.balance) <= 0) {
|
||||||
|
return Promise.reject({
|
||||||
|
code: 501,
|
||||||
|
message: '账号余额不足'
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
user,
|
||||||
|
userApiKey: '',
|
||||||
|
systemApiKey: keyMap[model].systemApiKey
|
||||||
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
// 模型使用权校验
|
// 模型使用权校验
|
||||||
@ -122,11 +176,11 @@ export const authChat = async ({
|
|||||||
]);
|
]);
|
||||||
}
|
}
|
||||||
// 获取 user 的 apiKey
|
// 获取 user 的 apiKey
|
||||||
const { userApiKey, systemKey } = await getOpenApiKey(userId);
|
const { userApiKey, systemApiKey } = await getApiKey({ model: model.chat.chatModel, userId });
|
||||||
|
|
||||||
return {
|
return {
|
||||||
userApiKey,
|
userApiKey,
|
||||||
systemKey,
|
systemApiKey,
|
||||||
content,
|
content,
|
||||||
userId,
|
userId,
|
||||||
model,
|
model,
|
||||||
|
|||||||
155
src/service/utils/chat/index.ts
Normal file
155
src/service/utils/chat/index.ts
Normal file
@ -0,0 +1,155 @@
|
|||||||
|
import { ChatItemSimpleType } from '@/types/chat';
|
||||||
|
import { modelToolMap } from '@/utils/chat';
|
||||||
|
import type { ChatModelType } from '@/constants/model';
|
||||||
|
import { ChatRoleEnum, SYSTEM_PROMPT_PREFIX } from '@/constants/chat';
|
||||||
|
import { OpenAiChatEnum } from '@/constants/model';
|
||||||
|
import { chatResponse, openAiStreamResponse } from './openai';
|
||||||
|
import type { NextApiResponse } from 'next';
|
||||||
|
import type { PassThrough } from 'stream';
|
||||||
|
|
||||||
|
export type ChatCompletionType = {
|
||||||
|
apiKey: string;
|
||||||
|
temperature: number;
|
||||||
|
messages: ChatItemSimpleType[];
|
||||||
|
stream: boolean;
|
||||||
|
};
|
||||||
|
export type StreamResponseType = {
|
||||||
|
stream: PassThrough;
|
||||||
|
chatResponse: any;
|
||||||
|
prompts: ChatItemSimpleType[];
|
||||||
|
};
|
||||||
|
|
||||||
|
export const modelServiceToolMap = {
|
||||||
|
[OpenAiChatEnum.GPT35]: {
|
||||||
|
chatCompletion: (data: ChatCompletionType) =>
|
||||||
|
chatResponse({ model: OpenAiChatEnum.GPT35, ...data }),
|
||||||
|
streamResponse: (data: StreamResponseType) =>
|
||||||
|
openAiStreamResponse({
|
||||||
|
model: OpenAiChatEnum.GPT35,
|
||||||
|
...data
|
||||||
|
})
|
||||||
|
},
|
||||||
|
[OpenAiChatEnum.GPT4]: {
|
||||||
|
chatCompletion: (data: ChatCompletionType) =>
|
||||||
|
chatResponse({ model: OpenAiChatEnum.GPT4, ...data }),
|
||||||
|
streamResponse: (data: StreamResponseType) =>
|
||||||
|
openAiStreamResponse({
|
||||||
|
model: OpenAiChatEnum.GPT4,
|
||||||
|
...data
|
||||||
|
})
|
||||||
|
},
|
||||||
|
[OpenAiChatEnum.GPT432k]: {
|
||||||
|
chatCompletion: (data: ChatCompletionType) =>
|
||||||
|
chatResponse({ model: OpenAiChatEnum.GPT432k, ...data }),
|
||||||
|
streamResponse: (data: StreamResponseType) =>
|
||||||
|
openAiStreamResponse({
|
||||||
|
model: OpenAiChatEnum.GPT432k,
|
||||||
|
...data
|
||||||
|
})
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/* delete invalid symbol */
|
||||||
|
const simplifyStr = (str: string) =>
|
||||||
|
str
|
||||||
|
.replace(/\n+/g, '\n') // 连续空行
|
||||||
|
.replace(/[^\S\r\n]+/g, ' ') // 连续空白内容
|
||||||
|
.trim();
|
||||||
|
|
||||||
|
/* 聊天上下文 tokens 截断 */
|
||||||
|
export const ChatContextFilter = ({
|
||||||
|
model,
|
||||||
|
prompts,
|
||||||
|
maxTokens
|
||||||
|
}: {
|
||||||
|
model: ChatModelType;
|
||||||
|
prompts: ChatItemSimpleType[];
|
||||||
|
maxTokens: number;
|
||||||
|
}) => {
|
||||||
|
let rawTextLen = 0;
|
||||||
|
const formatPrompts = prompts.map<ChatItemSimpleType>((item) => {
|
||||||
|
const val = simplifyStr(item.value);
|
||||||
|
rawTextLen += val.length;
|
||||||
|
return {
|
||||||
|
obj: item.obj,
|
||||||
|
value: val
|
||||||
|
};
|
||||||
|
});
|
||||||
|
|
||||||
|
// 长度太小时,不需要进行 token 截断
|
||||||
|
if (formatPrompts.length <= 2 || rawTextLen < maxTokens * 0.5) {
|
||||||
|
return formatPrompts;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 根据 tokens 截断内容
|
||||||
|
const chats: ChatItemSimpleType[] = [];
|
||||||
|
let systemPrompt: ChatItemSimpleType | null = null;
|
||||||
|
|
||||||
|
// System 词保留
|
||||||
|
if (formatPrompts[0].obj === ChatRoleEnum.System) {
|
||||||
|
const prompt = formatPrompts.shift();
|
||||||
|
if (prompt) {
|
||||||
|
systemPrompt = prompt;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let messages: ChatItemSimpleType[] = [];
|
||||||
|
|
||||||
|
// 从后往前截取对话内容
|
||||||
|
for (let i = formatPrompts.length - 1; i >= 0; i--) {
|
||||||
|
chats.unshift(formatPrompts[i]);
|
||||||
|
|
||||||
|
messages = systemPrompt ? [systemPrompt, ...chats] : chats;
|
||||||
|
|
||||||
|
const tokens = modelToolMap[model].countTokens({
|
||||||
|
messages
|
||||||
|
});
|
||||||
|
|
||||||
|
/* 整体 tokens 超出范围 */
|
||||||
|
if (tokens >= maxTokens) {
|
||||||
|
return systemPrompt ? [systemPrompt, ...chats.slice(1)] : chats.slice(1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return messages;
|
||||||
|
};
|
||||||
|
|
||||||
|
/* stream response */
|
||||||
|
export const resStreamResponse = async ({
|
||||||
|
model,
|
||||||
|
res,
|
||||||
|
stream,
|
||||||
|
chatResponse,
|
||||||
|
systemPrompt,
|
||||||
|
prompts
|
||||||
|
}: StreamResponseType & {
|
||||||
|
model: ChatModelType;
|
||||||
|
res: NextApiResponse;
|
||||||
|
systemPrompt?: string;
|
||||||
|
}) => {
|
||||||
|
// 创建响应流
|
||||||
|
res.setHeader('Content-Type', 'text/event-stream;charset-utf-8');
|
||||||
|
res.setHeader('Access-Control-Allow-Origin', '*');
|
||||||
|
res.setHeader('X-Accel-Buffering', 'no');
|
||||||
|
res.setHeader('Cache-Control', 'no-cache, no-transform');
|
||||||
|
stream.pipe(res);
|
||||||
|
|
||||||
|
const { responseContent, totalTokens, finishMessages } = await modelServiceToolMap[
|
||||||
|
model
|
||||||
|
].streamResponse({
|
||||||
|
chatResponse,
|
||||||
|
stream,
|
||||||
|
prompts
|
||||||
|
});
|
||||||
|
|
||||||
|
// push system prompt
|
||||||
|
!stream.destroyed &&
|
||||||
|
systemPrompt &&
|
||||||
|
stream.push(`${SYSTEM_PROMPT_PREFIX}${systemPrompt.replace(/\n/g, '<br/>')}`);
|
||||||
|
|
||||||
|
// close stream
|
||||||
|
!stream.destroyed && stream.push(null);
|
||||||
|
stream.destroy();
|
||||||
|
|
||||||
|
return { responseContent, totalTokens, finishMessages };
|
||||||
|
};
|
||||||
174
src/service/utils/chat/openai.ts
Normal file
174
src/service/utils/chat/openai.ts
Normal file
@ -0,0 +1,174 @@
|
|||||||
|
import { Configuration, OpenAIApi } from 'openai';
|
||||||
|
import { createParser, ParsedEvent, ReconnectInterval } from 'eventsource-parser';
|
||||||
|
import { axiosConfig } from '../tools';
|
||||||
|
import { ChatModelMap, embeddingModel, OpenAiChatEnum } from '@/constants/model';
|
||||||
|
import { pushGenerateVectorBill } from '../../events/pushBill';
|
||||||
|
import { adaptChatItem_openAI } from '@/utils/chat/openai';
|
||||||
|
import { modelToolMap } from '@/utils/chat';
|
||||||
|
import { ChatCompletionType, ChatContextFilter, StreamResponseType } from './index';
|
||||||
|
import { ChatRoleEnum } from '@/constants/chat';
|
||||||
|
|
||||||
|
export const getOpenAIApi = (apiKey: string) => {
|
||||||
|
const configuration = new Configuration({
|
||||||
|
apiKey,
|
||||||
|
basePath: process.env.OPENAI_BASE_URL
|
||||||
|
});
|
||||||
|
|
||||||
|
return new OpenAIApi(configuration);
|
||||||
|
};
|
||||||
|
|
||||||
|
/* 获取向量 */
|
||||||
|
export const openaiCreateEmbedding = async ({
|
||||||
|
userApiKey,
|
||||||
|
systemApiKey,
|
||||||
|
userId,
|
||||||
|
text
|
||||||
|
}: {
|
||||||
|
userApiKey?: string;
|
||||||
|
systemApiKey: string;
|
||||||
|
userId: string;
|
||||||
|
text: string;
|
||||||
|
}) => {
|
||||||
|
// 获取 chatAPI
|
||||||
|
const chatAPI = getOpenAIApi(userApiKey || systemApiKey);
|
||||||
|
|
||||||
|
// 把输入的内容转成向量
|
||||||
|
const res = await chatAPI
|
||||||
|
.createEmbedding(
|
||||||
|
{
|
||||||
|
model: embeddingModel,
|
||||||
|
input: text
|
||||||
|
},
|
||||||
|
{
|
||||||
|
timeout: 60000,
|
||||||
|
...axiosConfig()
|
||||||
|
}
|
||||||
|
)
|
||||||
|
.then((res) => ({
|
||||||
|
tokenLen: res.data.usage.total_tokens || 0,
|
||||||
|
vector: res.data.data?.[0]?.embedding || []
|
||||||
|
}));
|
||||||
|
|
||||||
|
pushGenerateVectorBill({
|
||||||
|
isPay: !userApiKey,
|
||||||
|
userId,
|
||||||
|
text,
|
||||||
|
tokenLen: res.tokenLen
|
||||||
|
});
|
||||||
|
|
||||||
|
return {
|
||||||
|
vector: res.vector,
|
||||||
|
chatAPI
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
/* 模型对话 */
|
||||||
|
export const chatResponse = async ({
|
||||||
|
model,
|
||||||
|
apiKey,
|
||||||
|
temperature,
|
||||||
|
messages,
|
||||||
|
stream
|
||||||
|
}: ChatCompletionType & { model: `${OpenAiChatEnum}` }) => {
|
||||||
|
const filterMessages = ChatContextFilter({
|
||||||
|
model,
|
||||||
|
prompts: messages,
|
||||||
|
maxTokens: Math.ceil(ChatModelMap[model].contextMaxToken * 0.9)
|
||||||
|
});
|
||||||
|
|
||||||
|
const adaptMessages = adaptChatItem_openAI({ messages: filterMessages });
|
||||||
|
const chatAPI = getOpenAIApi(apiKey);
|
||||||
|
|
||||||
|
const response = await chatAPI.createChatCompletion(
|
||||||
|
{
|
||||||
|
model,
|
||||||
|
temperature: Number(temperature) || 0,
|
||||||
|
messages: adaptMessages,
|
||||||
|
frequency_penalty: 0.5, // 越大,重复内容越少
|
||||||
|
presence_penalty: -0.5, // 越大,越容易出现新内容
|
||||||
|
stream,
|
||||||
|
stop: ['.!?。']
|
||||||
|
},
|
||||||
|
{
|
||||||
|
timeout: stream ? 40000 : 240000,
|
||||||
|
responseType: stream ? 'stream' : 'json',
|
||||||
|
...axiosConfig()
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
let responseText = '';
|
||||||
|
let totalTokens = 0;
|
||||||
|
|
||||||
|
// adapt data
|
||||||
|
if (!stream) {
|
||||||
|
responseText = response.data.choices[0].message?.content || '';
|
||||||
|
totalTokens = response.data.usage?.total_tokens || 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
streamResponse: response,
|
||||||
|
responseMessages: filterMessages.concat({ obj: 'AI', value: responseText }),
|
||||||
|
responseText,
|
||||||
|
totalTokens
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
/* openai stream response */
|
||||||
|
export const openAiStreamResponse = async ({
|
||||||
|
model,
|
||||||
|
stream,
|
||||||
|
chatResponse,
|
||||||
|
prompts
|
||||||
|
}: StreamResponseType & {
|
||||||
|
model: `${OpenAiChatEnum}`;
|
||||||
|
}) => {
|
||||||
|
try {
|
||||||
|
let responseContent = '';
|
||||||
|
|
||||||
|
const onParse = async (event: ParsedEvent | ReconnectInterval) => {
|
||||||
|
if (event.type !== 'event') return;
|
||||||
|
const data = event.data;
|
||||||
|
if (data === '[DONE]') return;
|
||||||
|
try {
|
||||||
|
const json = JSON.parse(data);
|
||||||
|
const content: string = json?.choices?.[0].delta.content || '';
|
||||||
|
responseContent += content;
|
||||||
|
|
||||||
|
!stream.destroyed && content && stream.push(content.replace(/\n/g, '<br/>'));
|
||||||
|
} catch (error) {
|
||||||
|
error;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
try {
|
||||||
|
const decoder = new TextDecoder();
|
||||||
|
const parser = createParser(onParse);
|
||||||
|
for await (const chunk of chatResponse.data as any) {
|
||||||
|
if (stream.destroyed) {
|
||||||
|
// 流被中断了,直接忽略后面的内容
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
parser.feed(decoder.decode(chunk, { stream: true }));
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.log('pipe error', error);
|
||||||
|
}
|
||||||
|
|
||||||
|
// count tokens
|
||||||
|
const finishMessages = prompts.concat({
|
||||||
|
obj: ChatRoleEnum.AI,
|
||||||
|
value: responseContent
|
||||||
|
});
|
||||||
|
const totalTokens = modelToolMap[model].countTokens({
|
||||||
|
messages: finishMessages
|
||||||
|
});
|
||||||
|
|
||||||
|
return {
|
||||||
|
responseContent,
|
||||||
|
totalTokens,
|
||||||
|
finishMessages
|
||||||
|
};
|
||||||
|
} catch (error) {
|
||||||
|
return Promise.reject(error);
|
||||||
|
}
|
||||||
|
};
|
||||||
@ -1,179 +0,0 @@
|
|||||||
import type { NextApiResponse } from 'next';
|
|
||||||
import type { PassThrough } from 'stream';
|
|
||||||
import { createParser, ParsedEvent, ReconnectInterval } from 'eventsource-parser';
|
|
||||||
import { getOpenAIApi } from '@/service/utils/auth';
|
|
||||||
import { axiosConfig } from './tools';
|
|
||||||
import { User } from '../models/user';
|
|
||||||
import { formatPrice } from '@/utils/user';
|
|
||||||
import { embeddingModel } from '@/constants/model';
|
|
||||||
import { pushGenerateVectorBill } from '../events/pushBill';
|
|
||||||
import { SYSTEM_PROMPT_PREFIX } from '@/constants/chat';
|
|
||||||
|
|
||||||
/* 获取用户 api 的 openai 信息 */
|
|
||||||
export const getUserApiOpenai = async (userId: string) => {
|
|
||||||
const user = await User.findById(userId);
|
|
||||||
|
|
||||||
const userApiKey = user?.openaiKey;
|
|
||||||
|
|
||||||
if (!userApiKey) {
|
|
||||||
return Promise.reject('缺少ApiKey, 无法请求');
|
|
||||||
}
|
|
||||||
|
|
||||||
return {
|
|
||||||
user,
|
|
||||||
openai: getOpenAIApi(userApiKey),
|
|
||||||
apiKey: userApiKey
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
/* 获取 open api key,如果用户没有自己的key,就用平台的,用平台记得加账单 */
|
|
||||||
export const getOpenApiKey = async (userId: string) => {
|
|
||||||
const user = await User.findById(userId);
|
|
||||||
if (!user) {
|
|
||||||
return Promise.reject({
|
|
||||||
code: 501,
|
|
||||||
message: '找不到用户'
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
const userApiKey = user?.openaiKey;
|
|
||||||
|
|
||||||
// 有自己的key
|
|
||||||
if (userApiKey) {
|
|
||||||
return {
|
|
||||||
user,
|
|
||||||
userApiKey,
|
|
||||||
systemKey: ''
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
// 平台账号余额校验
|
|
||||||
if (formatPrice(user.balance) <= 0) {
|
|
||||||
return Promise.reject({
|
|
||||||
code: 501,
|
|
||||||
message: '账号余额不足'
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
return {
|
|
||||||
user,
|
|
||||||
userApiKey: '',
|
|
||||||
systemKey: process.env.OPENAIKEY as string
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
/* 获取向量 */
|
|
||||||
export const openaiCreateEmbedding = async ({
|
|
||||||
isPay,
|
|
||||||
userId,
|
|
||||||
apiKey,
|
|
||||||
text
|
|
||||||
}: {
|
|
||||||
isPay: boolean;
|
|
||||||
userId: string;
|
|
||||||
apiKey: string;
|
|
||||||
text: string;
|
|
||||||
}) => {
|
|
||||||
// 获取 chatAPI
|
|
||||||
const chatAPI = getOpenAIApi(apiKey);
|
|
||||||
|
|
||||||
// 把输入的内容转成向量
|
|
||||||
const res = await chatAPI
|
|
||||||
.createEmbedding(
|
|
||||||
{
|
|
||||||
model: embeddingModel,
|
|
||||||
input: text
|
|
||||||
},
|
|
||||||
{
|
|
||||||
timeout: 60000,
|
|
||||||
...axiosConfig()
|
|
||||||
}
|
|
||||||
)
|
|
||||||
.then((res) => ({
|
|
||||||
tokenLen: res.data.usage.total_tokens || 0,
|
|
||||||
vector: res.data.data?.[0]?.embedding || []
|
|
||||||
}));
|
|
||||||
|
|
||||||
pushGenerateVectorBill({
|
|
||||||
isPay,
|
|
||||||
userId,
|
|
||||||
text,
|
|
||||||
tokenLen: res.tokenLen
|
|
||||||
});
|
|
||||||
|
|
||||||
return {
|
|
||||||
vector: res.vector,
|
|
||||||
chatAPI
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
/* gpt35 响应 */
|
|
||||||
export const gpt35StreamResponse = ({
|
|
||||||
res,
|
|
||||||
stream,
|
|
||||||
chatResponse,
|
|
||||||
systemPrompt = ''
|
|
||||||
}: {
|
|
||||||
res: NextApiResponse;
|
|
||||||
stream: PassThrough;
|
|
||||||
chatResponse: any;
|
|
||||||
systemPrompt?: string;
|
|
||||||
}) =>
|
|
||||||
new Promise<{ responseContent: string }>(async (resolve, reject) => {
|
|
||||||
try {
|
|
||||||
// 创建响应流
|
|
||||||
res.setHeader('Content-Type', 'text/event-stream;charset-utf-8');
|
|
||||||
res.setHeader('Access-Control-Allow-Origin', '*');
|
|
||||||
res.setHeader('X-Accel-Buffering', 'no');
|
|
||||||
res.setHeader('Cache-Control', 'no-cache, no-transform');
|
|
||||||
stream.pipe(res);
|
|
||||||
|
|
||||||
let responseContent = '';
|
|
||||||
|
|
||||||
const onParse = async (event: ParsedEvent | ReconnectInterval) => {
|
|
||||||
if (event.type !== 'event') return;
|
|
||||||
const data = event.data;
|
|
||||||
if (data === '[DONE]') return;
|
|
||||||
try {
|
|
||||||
const json = JSON.parse(data);
|
|
||||||
const content: string = json?.choices?.[0].delta.content || '';
|
|
||||||
responseContent += content;
|
|
||||||
|
|
||||||
if (!stream.destroyed && content) {
|
|
||||||
stream.push(content.replace(/\n/g, '<br/>'));
|
|
||||||
}
|
|
||||||
} catch (error) {
|
|
||||||
error;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
try {
|
|
||||||
const decoder = new TextDecoder();
|
|
||||||
const parser = createParser(onParse);
|
|
||||||
for await (const chunk of chatResponse.data as any) {
|
|
||||||
if (stream.destroyed) {
|
|
||||||
// 流被中断了,直接忽略后面的内容
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
parser.feed(decoder.decode(chunk, { stream: true }));
|
|
||||||
}
|
|
||||||
} catch (error) {
|
|
||||||
console.log('pipe error', error);
|
|
||||||
}
|
|
||||||
|
|
||||||
// push system prompt
|
|
||||||
!stream.destroyed &&
|
|
||||||
systemPrompt &&
|
|
||||||
stream.push(`${SYSTEM_PROMPT_PREFIX}${systemPrompt.replace(/\n/g, '<br/>')}`);
|
|
||||||
|
|
||||||
// close stream
|
|
||||||
!stream.destroyed && stream.push(null);
|
|
||||||
stream.destroy();
|
|
||||||
|
|
||||||
resolve({
|
|
||||||
responseContent
|
|
||||||
});
|
|
||||||
} catch (error) {
|
|
||||||
reject(error);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
@ -1,9 +1,5 @@
|
|||||||
import crypto from 'crypto';
|
import crypto from 'crypto';
|
||||||
import jwt from 'jsonwebtoken';
|
import jwt from 'jsonwebtoken';
|
||||||
import { ChatItemSimpleType } from '@/types/chat';
|
|
||||||
import { countChatTokens, sliceTextByToken } from '@/utils/tools';
|
|
||||||
import { ChatCompletionRequestMessageRoleEnum, ChatCompletionRequestMessage } from 'openai';
|
|
||||||
import type { ChatModelType } from '@/constants/model';
|
|
||||||
|
|
||||||
/* 密码加密 */
|
/* 密码加密 */
|
||||||
export const hashPassword = (psw: string) => {
|
export const hashPassword = (psw: string) => {
|
||||||
@ -30,92 +26,3 @@ export const axiosConfig = () => ({
|
|||||||
auth: process.env.OPENAI_BASE_URL_AUTH || ''
|
auth: process.env.OPENAI_BASE_URL_AUTH || ''
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
/* delete invalid symbol */
|
|
||||||
const simplifyStr = (str: string) =>
|
|
||||||
str
|
|
||||||
.replace(/\n+/g, '\n') // 连续空行
|
|
||||||
.replace(/[^\S\r\n]+/g, ' ') // 连续空白内容
|
|
||||||
.trim();
|
|
||||||
|
|
||||||
/* 聊天内容 tokens 截断 */
|
|
||||||
export const openaiChatFilter = ({
|
|
||||||
model,
|
|
||||||
prompts,
|
|
||||||
maxTokens
|
|
||||||
}: {
|
|
||||||
model: ChatModelType;
|
|
||||||
prompts: ChatItemSimpleType[];
|
|
||||||
maxTokens: number;
|
|
||||||
}) => {
|
|
||||||
// role map
|
|
||||||
const map = {
|
|
||||||
Human: ChatCompletionRequestMessageRoleEnum.User,
|
|
||||||
AI: ChatCompletionRequestMessageRoleEnum.Assistant,
|
|
||||||
SYSTEM: ChatCompletionRequestMessageRoleEnum.System
|
|
||||||
};
|
|
||||||
|
|
||||||
let rawTextLen = 0;
|
|
||||||
const formatPrompts = prompts.map((item) => {
|
|
||||||
const val = simplifyStr(item.value);
|
|
||||||
rawTextLen += val.length;
|
|
||||||
return {
|
|
||||||
role: map[item.obj],
|
|
||||||
content: val
|
|
||||||
};
|
|
||||||
});
|
|
||||||
|
|
||||||
// 长度太小时,不需要进行 token 截断
|
|
||||||
if (rawTextLen < maxTokens * 0.5) {
|
|
||||||
return formatPrompts;
|
|
||||||
}
|
|
||||||
|
|
||||||
// 根据 tokens 截断内容
|
|
||||||
const chats: ChatCompletionRequestMessage[] = [];
|
|
||||||
let systemPrompt: ChatCompletionRequestMessage | null = null;
|
|
||||||
|
|
||||||
// System 词保留
|
|
||||||
if (formatPrompts[0]?.role === 'system') {
|
|
||||||
systemPrompt = formatPrompts.shift() as ChatCompletionRequestMessage;
|
|
||||||
}
|
|
||||||
|
|
||||||
let messages: { role: ChatCompletionRequestMessageRoleEnum; content: string }[] = [];
|
|
||||||
|
|
||||||
// 从后往前截取对话内容
|
|
||||||
for (let i = formatPrompts.length - 1; i >= 0; i--) {
|
|
||||||
chats.unshift(formatPrompts[i]);
|
|
||||||
|
|
||||||
messages = systemPrompt ? [systemPrompt, ...chats] : chats;
|
|
||||||
|
|
||||||
const tokens = countChatTokens({
|
|
||||||
model,
|
|
||||||
messages
|
|
||||||
});
|
|
||||||
|
|
||||||
/* 整体 tokens 超出范围 */
|
|
||||||
if (tokens >= maxTokens) {
|
|
||||||
return systemPrompt ? [systemPrompt, ...chats.slice(1)] : chats.slice(1);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return messages;
|
|
||||||
};
|
|
||||||
|
|
||||||
/* system 内容截断. 相似度从高到低 */
|
|
||||||
export const systemPromptFilter = ({
|
|
||||||
model,
|
|
||||||
prompts,
|
|
||||||
maxTokens
|
|
||||||
}: {
|
|
||||||
model: 'gpt-4' | 'gpt-4-32k' | 'gpt-3.5-turbo';
|
|
||||||
prompts: string[];
|
|
||||||
maxTokens: number;
|
|
||||||
}) => {
|
|
||||||
const systemPrompt = prompts.join('\n');
|
|
||||||
|
|
||||||
return sliceTextByToken({
|
|
||||||
model,
|
|
||||||
text: systemPrompt,
|
|
||||||
length: maxTokens
|
|
||||||
});
|
|
||||||
};
|
|
||||||
|
|||||||
4
src/types/chat.d.ts
vendored
4
src/types/chat.d.ts
vendored
@ -1,5 +1,7 @@
|
|||||||
|
import { ChatRoleEnum } from '@/constants/chat';
|
||||||
|
|
||||||
export type ChatItemSimpleType = {
|
export type ChatItemSimpleType = {
|
||||||
obj: 'Human' | 'AI' | 'SYSTEM';
|
obj: `${ChatRoleEnum}`;
|
||||||
value: string;
|
value: string;
|
||||||
systemPrompt?: string;
|
systemPrompt?: string;
|
||||||
};
|
};
|
||||||
|
|||||||
39
src/utils/chat/index.ts
Normal file
39
src/utils/chat/index.ts
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
import { OpenAiChatEnum } from '@/constants/model';
|
||||||
|
import type { ChatModelType } from '@/constants/model';
|
||||||
|
import type { ChatItemSimpleType } from '@/types/chat';
|
||||||
|
import { countOpenAIToken, getOpenAiEncMap, adaptChatItem_openAI } from './openai';
|
||||||
|
|
||||||
|
export type CountTokenType = { messages: ChatItemSimpleType[] };
|
||||||
|
|
||||||
|
export const modelToolMap = {
|
||||||
|
[OpenAiChatEnum.GPT35]: {
|
||||||
|
countTokens: ({ messages }: CountTokenType) =>
|
||||||
|
countOpenAIToken({ model: OpenAiChatEnum.GPT35, messages }),
|
||||||
|
adaptChatMessages: adaptChatItem_openAI
|
||||||
|
},
|
||||||
|
[OpenAiChatEnum.GPT4]: {
|
||||||
|
countTokens: ({ messages }: CountTokenType) =>
|
||||||
|
countOpenAIToken({ model: OpenAiChatEnum.GPT4, messages }),
|
||||||
|
adaptChatMessages: adaptChatItem_openAI
|
||||||
|
},
|
||||||
|
[OpenAiChatEnum.GPT432k]: {
|
||||||
|
countTokens: ({ messages }: CountTokenType) =>
|
||||||
|
countOpenAIToken({ model: OpenAiChatEnum.GPT432k, messages }),
|
||||||
|
adaptChatMessages: adaptChatItem_openAI
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
export const sliceTextByToken = ({
|
||||||
|
model = 'gpt-3.5-turbo',
|
||||||
|
text,
|
||||||
|
length
|
||||||
|
}: {
|
||||||
|
model: ChatModelType;
|
||||||
|
text: string;
|
||||||
|
length: number;
|
||||||
|
}) => {
|
||||||
|
const enc = getOpenAiEncMap()[model];
|
||||||
|
const encodeText = enc.encode(text);
|
||||||
|
const decoder = new TextDecoder();
|
||||||
|
return decoder.decode(enc.decode(encodeText.slice(0, length)));
|
||||||
|
};
|
||||||
106
src/utils/chat/openai.ts
Normal file
106
src/utils/chat/openai.ts
Normal file
@ -0,0 +1,106 @@
|
|||||||
|
import { encoding_for_model, type Tiktoken } from '@dqbd/tiktoken';
|
||||||
|
import type { ChatItemSimpleType } from '@/types/chat';
|
||||||
|
import { ChatRoleEnum } from '@/constants/chat';
|
||||||
|
import { ChatCompletionRequestMessage, ChatCompletionRequestMessageRoleEnum } from 'openai';
|
||||||
|
|
||||||
|
import Graphemer from 'graphemer';
|
||||||
|
|
||||||
|
const textDecoder = new TextDecoder();
|
||||||
|
const graphemer = new Graphemer();
|
||||||
|
|
||||||
|
export const adaptChatItem_openAI = ({
|
||||||
|
messages
|
||||||
|
}: {
|
||||||
|
messages: ChatItemSimpleType[];
|
||||||
|
}): ChatCompletionRequestMessage[] => {
|
||||||
|
const map = {
|
||||||
|
[ChatRoleEnum.AI]: ChatCompletionRequestMessageRoleEnum.Assistant,
|
||||||
|
[ChatRoleEnum.Human]: ChatCompletionRequestMessageRoleEnum.User,
|
||||||
|
[ChatRoleEnum.System]: ChatCompletionRequestMessageRoleEnum.System
|
||||||
|
};
|
||||||
|
return messages.map((item) => ({
|
||||||
|
role: map[item.obj] || ChatCompletionRequestMessageRoleEnum.System,
|
||||||
|
content: item.value || ''
|
||||||
|
}));
|
||||||
|
};
|
||||||
|
|
||||||
|
/* count openai chat token*/
|
||||||
|
let OpenAiEncMap: Record<string, Tiktoken>;
|
||||||
|
export const getOpenAiEncMap = () => {
|
||||||
|
if (OpenAiEncMap) return OpenAiEncMap;
|
||||||
|
OpenAiEncMap = {
|
||||||
|
'gpt-3.5-turbo': encoding_for_model('gpt-3.5-turbo', {
|
||||||
|
'<|im_start|>': 100264,
|
||||||
|
'<|im_end|>': 100265,
|
||||||
|
'<|im_sep|>': 100266
|
||||||
|
}),
|
||||||
|
'gpt-4': encoding_for_model('gpt-4', {
|
||||||
|
'<|im_start|>': 100264,
|
||||||
|
'<|im_end|>': 100265,
|
||||||
|
'<|im_sep|>': 100266
|
||||||
|
}),
|
||||||
|
'gpt-4-32k': encoding_for_model('gpt-4-32k', {
|
||||||
|
'<|im_start|>': 100264,
|
||||||
|
'<|im_end|>': 100265,
|
||||||
|
'<|im_sep|>': 100266
|
||||||
|
})
|
||||||
|
};
|
||||||
|
return OpenAiEncMap;
|
||||||
|
};
|
||||||
|
export function countOpenAIToken({
|
||||||
|
messages,
|
||||||
|
model
|
||||||
|
}: {
|
||||||
|
messages: ChatItemSimpleType[];
|
||||||
|
model: 'gpt-3.5-turbo' | 'gpt-4' | 'gpt-4-32k';
|
||||||
|
}) {
|
||||||
|
function getChatGPTEncodingText(
|
||||||
|
messages: { role: 'system' | 'user' | 'assistant'; content: string; name?: string }[],
|
||||||
|
model: 'gpt-3.5-turbo' | 'gpt-4' | 'gpt-4-32k'
|
||||||
|
) {
|
||||||
|
const isGpt3 = model === 'gpt-3.5-turbo';
|
||||||
|
|
||||||
|
const msgSep = isGpt3 ? '\n' : '';
|
||||||
|
const roleSep = isGpt3 ? '\n' : '<|im_sep|>';
|
||||||
|
|
||||||
|
return [
|
||||||
|
messages
|
||||||
|
.map(({ name = '', role, content }) => {
|
||||||
|
return `<|im_start|>${name || role}${roleSep}${content}<|im_end|>`;
|
||||||
|
})
|
||||||
|
.join(msgSep),
|
||||||
|
`<|im_start|>assistant${roleSep}`
|
||||||
|
].join(msgSep);
|
||||||
|
}
|
||||||
|
function text2TokensLen(encoder: Tiktoken, inputText: string) {
|
||||||
|
const encoding = encoder.encode(inputText, 'all');
|
||||||
|
const segments: { text: string; tokens: { id: number; idx: number }[] }[] = [];
|
||||||
|
|
||||||
|
let byteAcc: number[] = [];
|
||||||
|
let tokenAcc: { id: number; idx: number }[] = [];
|
||||||
|
let inputGraphemes = graphemer.splitGraphemes(inputText);
|
||||||
|
|
||||||
|
for (let idx = 0; idx < encoding.length; idx++) {
|
||||||
|
const token = encoding[idx]!;
|
||||||
|
byteAcc.push(...encoder.decode_single_token_bytes(token));
|
||||||
|
tokenAcc.push({ id: token, idx });
|
||||||
|
|
||||||
|
const segmentText = textDecoder.decode(new Uint8Array(byteAcc));
|
||||||
|
const graphemes = graphemer.splitGraphemes(segmentText);
|
||||||
|
|
||||||
|
if (graphemes.every((item, idx) => inputGraphemes[idx] === item)) {
|
||||||
|
segments.push({ text: segmentText, tokens: tokenAcc });
|
||||||
|
|
||||||
|
byteAcc = [];
|
||||||
|
tokenAcc = [];
|
||||||
|
inputGraphemes = inputGraphemes.slice(graphemes.length);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return segments.reduce((memo, i) => memo + i.tokens.length, 0) ?? 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
const adaptMessages = adaptChatItem_openAI({ messages });
|
||||||
|
|
||||||
|
return text2TokensLen(getOpenAiEncMap()[model], getChatGPTEncodingText(adaptMessages, model));
|
||||||
|
}
|
||||||
@ -1,6 +1,6 @@
|
|||||||
import mammoth from 'mammoth';
|
import mammoth from 'mammoth';
|
||||||
import Papa from 'papaparse';
|
import Papa from 'papaparse';
|
||||||
import { getEncMap } from './tools';
|
import { getOpenAiEncMap } from './chat/openai';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 读取 txt 文件内容
|
* 读取 txt 文件内容
|
||||||
@ -154,7 +154,7 @@ export const splitText_token = ({
|
|||||||
maxLen: number;
|
maxLen: number;
|
||||||
slideLen: number;
|
slideLen: number;
|
||||||
}) => {
|
}) => {
|
||||||
const enc = getEncMap()['gpt-3.5-turbo'];
|
const enc = getOpenAiEncMap()['gpt-3.5-turbo'];
|
||||||
// filter empty text. encode sentence
|
// filter empty text. encode sentence
|
||||||
const encodeText = enc.encode(text);
|
const encodeText = enc.encode(text);
|
||||||
|
|
||||||
|
|||||||
@ -1,33 +1,5 @@
|
|||||||
import crypto from 'crypto';
|
import crypto from 'crypto';
|
||||||
import { useToast } from '@/hooks/useToast';
|
import { useToast } from '@/hooks/useToast';
|
||||||
import { encoding_for_model, type Tiktoken } from '@dqbd/tiktoken';
|
|
||||||
import Graphemer from 'graphemer';
|
|
||||||
import type { ChatModelType } from '@/constants/model';
|
|
||||||
|
|
||||||
const textDecoder = new TextDecoder();
|
|
||||||
const graphemer = new Graphemer();
|
|
||||||
let encMap: Record<string, Tiktoken>;
|
|
||||||
export const getEncMap = () => {
|
|
||||||
if (encMap) return encMap;
|
|
||||||
encMap = {
|
|
||||||
'gpt-3.5-turbo': encoding_for_model('gpt-3.5-turbo', {
|
|
||||||
'<|im_start|>': 100264,
|
|
||||||
'<|im_end|>': 100265,
|
|
||||||
'<|im_sep|>': 100266
|
|
||||||
}),
|
|
||||||
'gpt-4': encoding_for_model('gpt-4', {
|
|
||||||
'<|im_start|>': 100264,
|
|
||||||
'<|im_end|>': 100265,
|
|
||||||
'<|im_sep|>': 100266
|
|
||||||
}),
|
|
||||||
'gpt-4-32k': encoding_for_model('gpt-4-32k', {
|
|
||||||
'<|im_start|>': 100264,
|
|
||||||
'<|im_end|>': 100265,
|
|
||||||
'<|im_sep|>': 100266
|
|
||||||
})
|
|
||||||
};
|
|
||||||
return encMap;
|
|
||||||
};
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* copy text data
|
* copy text data
|
||||||
@ -79,75 +51,3 @@ export const Obj2Query = (obj: Record<string, string | number>) => {
|
|||||||
}
|
}
|
||||||
return queryParams.toString();
|
return queryParams.toString();
|
||||||
};
|
};
|
||||||
|
|
||||||
/* 格式化 chat 聊天内容 */
|
|
||||||
function getChatGPTEncodingText(
|
|
||||||
messages: { role: 'system' | 'user' | 'assistant'; content: string; name?: string }[],
|
|
||||||
model: 'gpt-3.5-turbo' | 'gpt-4' | 'gpt-4-32k'
|
|
||||||
) {
|
|
||||||
const isGpt3 = model === 'gpt-3.5-turbo';
|
|
||||||
|
|
||||||
const msgSep = isGpt3 ? '\n' : '';
|
|
||||||
const roleSep = isGpt3 ? '\n' : '<|im_sep|>';
|
|
||||||
|
|
||||||
return [
|
|
||||||
messages
|
|
||||||
.map(({ name = '', role, content }) => {
|
|
||||||
return `<|im_start|>${name || role}${roleSep}${content}<|im_end|>`;
|
|
||||||
})
|
|
||||||
.join(msgSep),
|
|
||||||
`<|im_start|>assistant${roleSep}`
|
|
||||||
].join(msgSep);
|
|
||||||
}
|
|
||||||
function text2TokensLen(encoder: Tiktoken, inputText: string) {
|
|
||||||
const encoding = encoder.encode(inputText, 'all');
|
|
||||||
const segments: { text: string; tokens: { id: number; idx: number }[] }[] = [];
|
|
||||||
|
|
||||||
let byteAcc: number[] = [];
|
|
||||||
let tokenAcc: { id: number; idx: number }[] = [];
|
|
||||||
let inputGraphemes = graphemer.splitGraphemes(inputText);
|
|
||||||
|
|
||||||
for (let idx = 0; idx < encoding.length; idx++) {
|
|
||||||
const token = encoding[idx]!;
|
|
||||||
byteAcc.push(...encoder.decode_single_token_bytes(token));
|
|
||||||
tokenAcc.push({ id: token, idx });
|
|
||||||
|
|
||||||
const segmentText = textDecoder.decode(new Uint8Array(byteAcc));
|
|
||||||
const graphemes = graphemer.splitGraphemes(segmentText);
|
|
||||||
|
|
||||||
if (graphemes.every((item, idx) => inputGraphemes[idx] === item)) {
|
|
||||||
segments.push({ text: segmentText, tokens: tokenAcc });
|
|
||||||
|
|
||||||
byteAcc = [];
|
|
||||||
tokenAcc = [];
|
|
||||||
inputGraphemes = inputGraphemes.slice(graphemes.length);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return segments.reduce((memo, i) => memo + i.tokens.length, 0) ?? 0;
|
|
||||||
}
|
|
||||||
export const countChatTokens = ({
|
|
||||||
model = 'gpt-3.5-turbo',
|
|
||||||
messages
|
|
||||||
}: {
|
|
||||||
model?: ChatModelType;
|
|
||||||
messages: { role: 'system' | 'user' | 'assistant'; content: string }[];
|
|
||||||
}) => {
|
|
||||||
const text = getChatGPTEncodingText(messages, model);
|
|
||||||
return text2TokensLen(getEncMap()[model], text);
|
|
||||||
};
|
|
||||||
|
|
||||||
export const sliceTextByToken = ({
|
|
||||||
model = 'gpt-3.5-turbo',
|
|
||||||
text,
|
|
||||||
length
|
|
||||||
}: {
|
|
||||||
model?: ChatModelType;
|
|
||||||
text: string;
|
|
||||||
length: number;
|
|
||||||
}) => {
|
|
||||||
const enc = getEncMap()[model];
|
|
||||||
const encodeText = enc.encode(text);
|
|
||||||
const decoder = new TextDecoder();
|
|
||||||
return decoder.decode(enc.decode(encodeText.slice(0, length)));
|
|
||||||
};
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user