From 0cde9a10a827e3113dd8b29668abda62a2a3b102 Mon Sep 17 00:00:00 2001 From: archer <545436317@qq.com> Date: Tue, 30 May 2023 21:18:08 +0800 Subject: [PATCH] feat: use last quote --- src/api/chat.ts | 2 +- src/pages/api/chat/chat.ts | 6 +- src/pages/api/chat/shareChat/chat.ts | 5 +- src/pages/api/openapi/chat/chat.ts | 9 +- src/pages/api/openapi/kb/appKbSearch.ts | 108 ++++++++++-------------- src/service/utils/auth.ts | 3 +- src/service/utils/chat/index.ts | 34 ++++++-- 7 files changed, 86 insertions(+), 81 deletions(-) diff --git a/src/api/chat.ts b/src/api/chat.ts index d47fa17b5..2abc47c98 100644 --- a/src/api/chat.ts +++ b/src/api/chat.ts @@ -5,7 +5,7 @@ import { RequestPaging } from '../types/index'; import type { ShareChatSchema } from '@/types/mongoSchema'; import type { ShareChatEditType } from '@/types/model'; import { Obj2Query } from '@/utils/tools'; -import { QuoteItemType } from '@/pages/api/openapi/kb/appKbSearch'; +import type { QuoteItemType } from '@/pages/api/openapi/kb/appKbSearch'; import type { Props as UpdateHistoryProps } from '@/pages/api/chat/history/updateChatHistory'; /** diff --git a/src/pages/api/chat/chat.ts b/src/pages/api/chat/chat.ts index 923d328d7..95314a059 100644 --- a/src/pages/api/chat/chat.ts +++ b/src/pages/api/chat/chat.ts @@ -50,6 +50,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) // 读取对话内容 const prompts = [...content, prompt[0]]; + const { code = 200, systemPrompts = [], @@ -61,7 +62,8 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) const { code, searchPrompts, rawSearch, guidePrompt } = await appKbSearch({ model, userId, - prompts, + fixedQuote: content[content.length - 1]?.quote || [], + prompt: prompt[0], similarity: ModelVectorSearchModeMap[model.chat.searchMode]?.similarity }); @@ -114,7 +116,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) return res.end(response); } - prompts.splice(prompts.length - 3, 0, ...systemPrompts); + prompts.unshift(...systemPrompts); // content check await sensitiveCheck({ diff --git a/src/pages/api/chat/shareChat/chat.ts b/src/pages/api/chat/shareChat/chat.ts index b32bd8ccc..8d3d4a31e 100644 --- a/src/pages/api/chat/shareChat/chat.ts +++ b/src/pages/api/chat/shareChat/chat.ts @@ -47,7 +47,8 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) const { code, searchPrompts } = await appKbSearch({ model, userId, - prompts, + fixedQuote: [], + prompt: prompts[prompts.length - 1], similarity: ModelVectorSearchModeMap[model.chat.searchMode]?.similarity }); @@ -74,7 +75,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) return res.send(systemPrompts[0]?.value); } - prompts.splice(prompts.length - 3, 0, ...systemPrompts); + prompts.unshift(...systemPrompts); // content check await sensitiveCheck({ diff --git a/src/pages/api/openapi/chat/chat.ts b/src/pages/api/openapi/chat/chat.ts index a03853cd6..3af123876 100644 --- a/src/pages/api/openapi/chat/chat.ts +++ b/src/pages/api/openapi/chat/chat.ts @@ -75,10 +75,11 @@ export default withNextCors(async function handler(req: NextApiRequest, res: Nex // 使用了知识库搜索 if (model.chat.relatedKbs.length > 0) { const { code, searchPrompts } = await appKbSearch({ - prompts, - similarity: ModelVectorSearchModeMap[model.chat.searchMode]?.similarity, model, - userId + userId, + fixedQuote: [], + prompt: prompts[prompts.length - 1], + similarity: ModelVectorSearchModeMap[model.chat.searchMode]?.similarity }); // search result is empty @@ -101,7 +102,7 @@ export default withNextCors(async function handler(req: NextApiRequest, res: Nex ]; } - prompts.splice(prompts.length - 3, 0, ...systemPrompts); + prompts.unshift(...systemPrompts); // content check await sensitiveCheck({ diff --git a/src/pages/api/openapi/kb/appKbSearch.ts b/src/pages/api/openapi/kb/appKbSearch.ts index 8eaeb981d..125e8007d 100644 --- a/src/pages/api/openapi/kb/appKbSearch.ts +++ b/src/pages/api/openapi/kb/appKbSearch.ts @@ -49,10 +49,11 @@ export default withNextCors(async function handler(req: NextApiRequest, res: Nex }); const result = await appKbSearch({ + model, userId, - prompts, - similarity, - model + fixedQuote: [], + prompt: prompts[prompts.length - 1], + similarity }); jsonRes(res, { @@ -70,67 +71,53 @@ export default withNextCors(async function handler(req: NextApiRequest, res: Nex export async function appKbSearch({ model, userId, - prompts, + fixedQuote, + prompt, similarity }: { - userId: string; - prompts: ChatItemSimpleType[]; - similarity: number; model: ModelSchema; + userId: string; + fixedQuote: QuoteItemType[]; + prompt: ChatItemSimpleType; + similarity: number; }): Promise { const modelConstantsData = ChatModelMap[model.chat.chatModel]; - // search two times. - const userPrompts = prompts.filter((item) => item.obj === 'Human'); - - const input: string[] = [ - userPrompts[userPrompts.length - 1].value, - userPrompts[userPrompts.length - 2]?.value - ].filter((item) => item); - // get vector - const promptVectors = await openaiEmbedding({ + const promptVector = await openaiEmbedding({ userId, - input, + input: [prompt.value], type: 'chat' }); // search kb - const searchRes = await Promise.all( - promptVectors.map((promptVector) => - PgClient.select('modelData', { - fields: ['id', 'q', 'a'], - where: [ - `kb_id IN (${model.chat.relatedKbs.map((item) => `'${item}'`).join(',')})`, - 'AND', - `vector <=> '[${promptVector}]' < ${similarity}` - ], - order: [{ field: 'vector', mode: `<=> '[${promptVector}]'` }], - limit: promptVectors.length === 1 ? 15 : 10 - }).then((res) => res.rows) - ) - ); + const { rows: searchRes } = await PgClient.select('modelData', { + fields: ['id', 'q', 'a'], + where: [ + `kb_id IN (${model.chat.relatedKbs.map((item) => `'${item}'`).join(',')})`, + 'AND', + `vector <=> '[${promptVector[0]}]' < ${similarity}` + ], + order: [{ field: 'vector', mode: `<=> '[${promptVector[0]}]'` }], + limit: 8 + }); // filter same search result const idSet = new Set(); - const filterSearch = searchRes.map((search) => - search.filter((item) => { - if (idSet.has(item.id)) { - return false; - } - idSet.add(item.id); - return true; - }) - ); + const filterSearch = [ + ...searchRes.slice(0, 3), + ...fixedQuote.slice(0, 2), + ...searchRes.slice(3), + ...fixedQuote.slice(2, 5) + ].filter((item) => { + if (idSet.has(item.id)) { + return false; + } + idSet.add(item.id); + return true; + }); - // slice search result by rate. - const sliceRateMap: Record = { - 1: [1], - 2: [0.7, 0.3] - }; - const sliceRate = sliceRateMap[searchRes.length] || sliceRateMap[0]; // 计算固定提示词的 token 数量 - const guidePrompt = model.chat.systemPrompt // user system prompt ? { obj: ChatRoleEnum.System, @@ -154,24 +141,21 @@ export async function appKbSearch({ const fixedSystemTokens = modelToolMap[model.chat.chatModel].countTokens({ messages: [guidePrompt] }); - const maxTokens = modelConstantsData.systemMaxToken - fixedSystemTokens; - const sliceResult = sliceRate.map((rate, i) => - modelToolMap[model.chat.chatModel] - .tokenSlice({ - maxToken: Math.round(maxTokens * rate), - messages: filterSearch[i].map((item) => ({ - obj: ChatRoleEnum.System, - value: `${item.q}\n${item.a}` - })) - }) - .map((item) => item.value) - ); + const sliceResult = modelToolMap[model.chat.chatModel] + .tokenSlice({ + maxToken: modelConstantsData.systemMaxToken - fixedSystemTokens, + messages: filterSearch.map((item) => ({ + obj: ChatRoleEnum.System, + value: `${item.q}\n${item.a}` + })) + }) + .map((item) => item.value); // slice filterSearch - const sliceSearch = filterSearch.map((item, i) => item.slice(0, sliceResult[i].length)).flat(); + const rawSearch = filterSearch.slice(0, sliceResult.length); // system prompt - const systemPrompt = sliceResult.flat().join('\n').trim(); + const systemPrompt = sliceResult.join('\n').trim(); /* 高相似度+不回复 */ if (!systemPrompt && model.chat.searchMode === appVectorSearchModeEnum.hightSimilarity) { @@ -206,7 +190,7 @@ export async function appKbSearch({ return { code: 200, - rawSearch: sliceSearch, + rawSearch, guidePrompt: guidePrompt.value || '', searchPrompts: [ { diff --git a/src/service/utils/auth.ts b/src/service/utils/auth.ts index d3160d1e5..6f97dad59 100644 --- a/src/service/utils/auth.ts +++ b/src/service/utils/auth.ts @@ -280,7 +280,8 @@ export const authChat = async ({ { $project: { obj: '$content.obj', - value: '$content.value' + value: '$content.value', + quote: '$content.quote' } } ]); diff --git a/src/service/utils/chat/index.ts b/src/service/utils/chat/index.ts index 7b6073798..51219572d 100644 --- a/src/service/utils/chat/index.ts +++ b/src/service/utils/chat/index.ts @@ -89,39 +89,55 @@ export const ChatContextFilter = ({ prompts: ChatItemSimpleType[]; maxTokens: number; }) => { + const systemPrompts: ChatItemSimpleType[] = []; + const chatPrompts: ChatItemSimpleType[] = []; + let rawTextLen = 0; - const formatPrompts = prompts.map((item) => { + prompts.forEach((item) => { const val = simplifyStr(item.value); rawTextLen += val.length; - return { + + const data = { obj: item.obj, value: val }; + + if (item.obj === ChatRoleEnum.System) { + systemPrompts.push(data); + } else { + chatPrompts.push(data); + } }); // 长度太小时,不需要进行 token 截断 - if (formatPrompts.length <= 2 || rawTextLen < maxTokens * 0.5) { - return formatPrompts; + if (rawTextLen < maxTokens * 0.5) { + return [...systemPrompts, ...chatPrompts]; } + // 去掉 system 的 token + maxTokens -= modelToolMap[model].countTokens({ + messages: systemPrompts + }); + // 根据 tokens 截断内容 const chats: ChatItemSimpleType[] = []; // 从后往前截取对话内容 - for (let i = formatPrompts.length - 1; i >= 0; i--) { - chats.unshift(formatPrompts[i]); + for (let i = chatPrompts.length - 1; i >= 0; i--) { + chats.unshift(chatPrompts[i]); const tokens = modelToolMap[model].countTokens({ messages: chats }); /* 整体 tokens 超出范围, system必须保留 */ - if (tokens >= maxTokens && formatPrompts[i].obj !== ChatRoleEnum.System) { - return chats.slice(1); + if (tokens >= maxTokens) { + chats.shift(); + break; } } - return chats; + return [...systemPrompts, ...chats]; }; /* stream response */