perf: not cut text when little text
This commit is contained in:
parent
3294be5e7f
commit
ce68791c3c
@ -6,7 +6,7 @@ import { OpenApi, User } from '../mongo';
|
|||||||
import { formatPrice } from '@/utils/user';
|
import { formatPrice } from '@/utils/user';
|
||||||
import { ERROR_ENUM } from '../errorCode';
|
import { ERROR_ENUM } from '../errorCode';
|
||||||
import { countChatTokens } from '@/utils/tools';
|
import { countChatTokens } from '@/utils/tools';
|
||||||
import { ChatCompletionRequestMessageRoleEnum } from 'openai';
|
import { ChatCompletionRequestMessageRoleEnum, ChatCompletionRequestMessage } from 'openai';
|
||||||
import { ChatModelEnum } from '@/constants/model';
|
import { ChatModelEnum } from '@/constants/model';
|
||||||
|
|
||||||
/* 密码加密 */
|
/* 密码加密 */
|
||||||
@ -88,6 +88,13 @@ export const authOpenApiKey = async (req: NextApiRequest) => {
|
|||||||
export const httpsAgent = (fast: boolean) =>
|
export const httpsAgent = (fast: boolean) =>
|
||||||
fast ? global.httpsAgentFast : global.httpsAgentNormal;
|
fast ? global.httpsAgentFast : global.httpsAgentNormal;
|
||||||
|
|
||||||
|
/* delete invalid symbol */
|
||||||
|
const simplifyStr = (str: string) =>
|
||||||
|
str
|
||||||
|
.replace(/\n+/g, '\n') // 连续空行
|
||||||
|
.replace(/[^\S\r\n]+/g, ' ') // 连续空白内容
|
||||||
|
.trim();
|
||||||
|
|
||||||
/* 聊天内容 tokens 截断 */
|
/* 聊天内容 tokens 截断 */
|
||||||
export const openaiChatFilter = ({
|
export const openaiChatFilter = ({
|
||||||
model,
|
model,
|
||||||
@ -98,40 +105,44 @@ export const openaiChatFilter = ({
|
|||||||
prompts: ChatItemType[];
|
prompts: ChatItemType[];
|
||||||
maxTokens: number;
|
maxTokens: number;
|
||||||
}) => {
|
}) => {
|
||||||
const formatPrompts = prompts.map((item) => ({
|
// role map
|
||||||
obj: item.obj,
|
|
||||||
value: item.value
|
|
||||||
// .replace(/[\u3000\u3001\uff01-\uff5e\u3002]/g, ' ') // 中文标点改空格
|
|
||||||
.replace(/\n+/g, '\n') // 连续空行
|
|
||||||
.replace(/[^\S\r\n]+/g, ' ') // 连续空白内容
|
|
||||||
.trim()
|
|
||||||
}));
|
|
||||||
|
|
||||||
let chats: ChatItemType[] = [];
|
|
||||||
let systemPrompt: ChatItemType | null = null;
|
|
||||||
|
|
||||||
// System 词保留
|
|
||||||
if (formatPrompts[0]?.obj === 'SYSTEM') {
|
|
||||||
systemPrompt = formatPrompts.shift() as ChatItemType;
|
|
||||||
}
|
|
||||||
|
|
||||||
// 格式化文本内容成 chatgpt 格式
|
|
||||||
const map = {
|
const map = {
|
||||||
Human: ChatCompletionRequestMessageRoleEnum.User,
|
Human: ChatCompletionRequestMessageRoleEnum.User,
|
||||||
AI: ChatCompletionRequestMessageRoleEnum.Assistant,
|
AI: ChatCompletionRequestMessageRoleEnum.Assistant,
|
||||||
SYSTEM: ChatCompletionRequestMessageRoleEnum.System
|
SYSTEM: ChatCompletionRequestMessageRoleEnum.System
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let rawTextLen = 0;
|
||||||
|
const formatPrompts = prompts.map((item) => {
|
||||||
|
const val = simplifyStr(item.value);
|
||||||
|
rawTextLen += val.length;
|
||||||
|
return {
|
||||||
|
role: map[item.obj],
|
||||||
|
content: val
|
||||||
|
};
|
||||||
|
});
|
||||||
|
|
||||||
|
// 长度太小时,不需要进行 token 截断
|
||||||
|
if (rawTextLen < maxTokens * 0.5) {
|
||||||
|
return formatPrompts;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 根据 tokens 截断内容
|
||||||
|
const chats: ChatCompletionRequestMessage[] = [];
|
||||||
|
let systemPrompt: ChatCompletionRequestMessage | null = null;
|
||||||
|
|
||||||
|
// System 词保留
|
||||||
|
if (formatPrompts[0]?.role === 'system') {
|
||||||
|
systemPrompt = formatPrompts.shift() as ChatCompletionRequestMessage;
|
||||||
|
}
|
||||||
|
|
||||||
let messages: { role: ChatCompletionRequestMessageRoleEnum; content: string }[] = [];
|
let messages: { role: ChatCompletionRequestMessageRoleEnum; content: string }[] = [];
|
||||||
|
|
||||||
// 从后往前截取对话内容
|
// 从后往前截取对话内容
|
||||||
for (let i = formatPrompts.length - 1; i >= 0; i--) {
|
for (let i = formatPrompts.length - 1; i >= 0; i--) {
|
||||||
chats.unshift(formatPrompts[i]);
|
chats.unshift(formatPrompts[i]);
|
||||||
|
|
||||||
messages = (systemPrompt ? [systemPrompt, ...chats] : chats).map((item) => ({
|
messages = systemPrompt ? [systemPrompt, ...chats] : chats;
|
||||||
role: map[item.obj],
|
|
||||||
content: item.value
|
|
||||||
}));
|
|
||||||
|
|
||||||
const tokens = countChatTokens({
|
const tokens = countChatTokens({
|
||||||
model,
|
model,
|
||||||
@ -147,7 +158,7 @@ export const openaiChatFilter = ({
|
|||||||
return messages;
|
return messages;
|
||||||
};
|
};
|
||||||
|
|
||||||
/* system 内容截断 */
|
/* system 内容截断. 相似度从高到低 */
|
||||||
export const systemPromptFilter = ({
|
export const systemPromptFilter = ({
|
||||||
model,
|
model,
|
||||||
prompts,
|
prompts,
|
||||||
@ -161,7 +172,7 @@ export const systemPromptFilter = ({
|
|||||||
|
|
||||||
// 从前往前截取
|
// 从前往前截取
|
||||||
for (let i = 0; i < prompts.length; i++) {
|
for (let i = 0; i < prompts.length; i++) {
|
||||||
const prompt = prompts[i].replace(/\n+/g, '\n');
|
const prompt = simplifyStr(prompts[i]);
|
||||||
|
|
||||||
splitText += `${prompt}\n`;
|
splitText += `${prompt}\n`;
|
||||||
const tokens = countChatTokens({ model, messages: [{ role: 'system', content: splitText }] });
|
const tokens = countChatTokens({ model, messages: [{ role: 'system', content: splitText }] });
|
||||||
@ -170,5 +181,5 @@ export const systemPromptFilter = ({
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return splitText.slice(0, splitText.length - 1).replace(/\n+/g, '\n');
|
return splitText.slice(0, splitText.length - 1);
|
||||||
};
|
};
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user