feat: use last quote

This commit is contained in:
archer 2023-05-30 21:18:08 +08:00
parent 59ddf09b94
commit 0cde9a10a8
No known key found for this signature in database
GPG Key ID: 569A5660D2379E28
7 changed files with 86 additions and 81 deletions

View File

@ -5,7 +5,7 @@ import { RequestPaging } from '../types/index';
import type { ShareChatSchema } from '@/types/mongoSchema'; import type { ShareChatSchema } from '@/types/mongoSchema';
import type { ShareChatEditType } from '@/types/model'; import type { ShareChatEditType } from '@/types/model';
import { Obj2Query } from '@/utils/tools'; import { Obj2Query } from '@/utils/tools';
import { QuoteItemType } from '@/pages/api/openapi/kb/appKbSearch'; import type { QuoteItemType } from '@/pages/api/openapi/kb/appKbSearch';
import type { Props as UpdateHistoryProps } from '@/pages/api/chat/history/updateChatHistory'; import type { Props as UpdateHistoryProps } from '@/pages/api/chat/history/updateChatHistory';
/** /**

View File

@ -50,6 +50,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
// 读取对话内容 // 读取对话内容
const prompts = [...content, prompt[0]]; const prompts = [...content, prompt[0]];
const { const {
code = 200, code = 200,
systemPrompts = [], systemPrompts = [],
@ -61,7 +62,8 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
const { code, searchPrompts, rawSearch, guidePrompt } = await appKbSearch({ const { code, searchPrompts, rawSearch, guidePrompt } = await appKbSearch({
model, model,
userId, userId,
prompts, fixedQuote: content[content.length - 1]?.quote || [],
prompt: prompt[0],
similarity: ModelVectorSearchModeMap[model.chat.searchMode]?.similarity similarity: ModelVectorSearchModeMap[model.chat.searchMode]?.similarity
}); });
@ -114,7 +116,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
return res.end(response); return res.end(response);
} }
prompts.splice(prompts.length - 3, 0, ...systemPrompts); prompts.unshift(...systemPrompts);
// content check // content check
await sensitiveCheck({ await sensitiveCheck({

View File

@ -47,7 +47,8 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
const { code, searchPrompts } = await appKbSearch({ const { code, searchPrompts } = await appKbSearch({
model, model,
userId, userId,
prompts, fixedQuote: [],
prompt: prompts[prompts.length - 1],
similarity: ModelVectorSearchModeMap[model.chat.searchMode]?.similarity similarity: ModelVectorSearchModeMap[model.chat.searchMode]?.similarity
}); });
@ -74,7 +75,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
return res.send(systemPrompts[0]?.value); return res.send(systemPrompts[0]?.value);
} }
prompts.splice(prompts.length - 3, 0, ...systemPrompts); prompts.unshift(...systemPrompts);
// content check // content check
await sensitiveCheck({ await sensitiveCheck({

View File

@ -75,10 +75,11 @@ export default withNextCors(async function handler(req: NextApiRequest, res: Nex
// 使用了知识库搜索 // 使用了知识库搜索
if (model.chat.relatedKbs.length > 0) { if (model.chat.relatedKbs.length > 0) {
const { code, searchPrompts } = await appKbSearch({ const { code, searchPrompts } = await appKbSearch({
prompts,
similarity: ModelVectorSearchModeMap[model.chat.searchMode]?.similarity,
model, model,
userId userId,
fixedQuote: [],
prompt: prompts[prompts.length - 1],
similarity: ModelVectorSearchModeMap[model.chat.searchMode]?.similarity
}); });
// search result is empty // search result is empty
@ -101,7 +102,7 @@ export default withNextCors(async function handler(req: NextApiRequest, res: Nex
]; ];
} }
prompts.splice(prompts.length - 3, 0, ...systemPrompts); prompts.unshift(...systemPrompts);
// content check // content check
await sensitiveCheck({ await sensitiveCheck({

View File

@ -49,10 +49,11 @@ export default withNextCors(async function handler(req: NextApiRequest, res: Nex
}); });
const result = await appKbSearch({ const result = await appKbSearch({
model,
userId, userId,
prompts, fixedQuote: [],
similarity, prompt: prompts[prompts.length - 1],
model similarity
}); });
jsonRes<Response>(res, { jsonRes<Response>(res, {
@ -70,67 +71,53 @@ export default withNextCors(async function handler(req: NextApiRequest, res: Nex
export async function appKbSearch({ export async function appKbSearch({
model, model,
userId, userId,
prompts, fixedQuote,
prompt,
similarity similarity
}: { }: {
userId: string;
prompts: ChatItemSimpleType[];
similarity: number;
model: ModelSchema; model: ModelSchema;
userId: string;
fixedQuote: QuoteItemType[];
prompt: ChatItemSimpleType;
similarity: number;
}): Promise<Response> { }): Promise<Response> {
const modelConstantsData = ChatModelMap[model.chat.chatModel]; const modelConstantsData = ChatModelMap[model.chat.chatModel];
// search two times.
const userPrompts = prompts.filter((item) => item.obj === 'Human');
const input: string[] = [
userPrompts[userPrompts.length - 1].value,
userPrompts[userPrompts.length - 2]?.value
].filter((item) => item);
// get vector // get vector
const promptVectors = await openaiEmbedding({ const promptVector = await openaiEmbedding({
userId, userId,
input, input: [prompt.value],
type: 'chat' type: 'chat'
}); });
// search kb // search kb
const searchRes = await Promise.all( const { rows: searchRes } = await PgClient.select<QuoteItemType>('modelData', {
promptVectors.map((promptVector) => fields: ['id', 'q', 'a'],
PgClient.select<QuoteItemType>('modelData', { where: [
fields: ['id', 'q', 'a'], `kb_id IN (${model.chat.relatedKbs.map((item) => `'${item}'`).join(',')})`,
where: [ 'AND',
`kb_id IN (${model.chat.relatedKbs.map((item) => `'${item}'`).join(',')})`, `vector <=> '[${promptVector[0]}]' < ${similarity}`
'AND', ],
`vector <=> '[${promptVector}]' < ${similarity}` order: [{ field: 'vector', mode: `<=> '[${promptVector[0]}]'` }],
], limit: 8
order: [{ field: 'vector', mode: `<=> '[${promptVector}]'` }], });
limit: promptVectors.length === 1 ? 15 : 10
}).then((res) => res.rows)
)
);
// filter same search result // filter same search result
const idSet = new Set<string>(); const idSet = new Set<string>();
const filterSearch = searchRes.map((search) => const filterSearch = [
search.filter((item) => { ...searchRes.slice(0, 3),
if (idSet.has(item.id)) { ...fixedQuote.slice(0, 2),
return false; ...searchRes.slice(3),
} ...fixedQuote.slice(2, 5)
idSet.add(item.id); ].filter((item) => {
return true; if (idSet.has(item.id)) {
}) return false;
); }
idSet.add(item.id);
return true;
});
// slice search result by rate.
const sliceRateMap: Record<number, number[]> = {
1: [1],
2: [0.7, 0.3]
};
const sliceRate = sliceRateMap[searchRes.length] || sliceRateMap[0];
// 计算固定提示词的 token 数量 // 计算固定提示词的 token 数量
const guidePrompt = model.chat.systemPrompt // user system prompt const guidePrompt = model.chat.systemPrompt // user system prompt
? { ? {
obj: ChatRoleEnum.System, obj: ChatRoleEnum.System,
@ -154,24 +141,21 @@ export async function appKbSearch({
const fixedSystemTokens = modelToolMap[model.chat.chatModel].countTokens({ const fixedSystemTokens = modelToolMap[model.chat.chatModel].countTokens({
messages: [guidePrompt] messages: [guidePrompt]
}); });
const maxTokens = modelConstantsData.systemMaxToken - fixedSystemTokens; const sliceResult = modelToolMap[model.chat.chatModel]
const sliceResult = sliceRate.map((rate, i) => .tokenSlice({
modelToolMap[model.chat.chatModel] maxToken: modelConstantsData.systemMaxToken - fixedSystemTokens,
.tokenSlice({ messages: filterSearch.map((item) => ({
maxToken: Math.round(maxTokens * rate), obj: ChatRoleEnum.System,
messages: filterSearch[i].map((item) => ({ value: `${item.q}\n${item.a}`
obj: ChatRoleEnum.System, }))
value: `${item.q}\n${item.a}` })
})) .map((item) => item.value);
})
.map((item) => item.value)
);
// slice filterSearch // slice filterSearch
const sliceSearch = filterSearch.map((item, i) => item.slice(0, sliceResult[i].length)).flat(); const rawSearch = filterSearch.slice(0, sliceResult.length);
// system prompt // system prompt
const systemPrompt = sliceResult.flat().join('\n').trim(); const systemPrompt = sliceResult.join('\n').trim();
/* 高相似度+不回复 */ /* 高相似度+不回复 */
if (!systemPrompt && model.chat.searchMode === appVectorSearchModeEnum.hightSimilarity) { if (!systemPrompt && model.chat.searchMode === appVectorSearchModeEnum.hightSimilarity) {
@ -206,7 +190,7 @@ export async function appKbSearch({
return { return {
code: 200, code: 200,
rawSearch: sliceSearch, rawSearch,
guidePrompt: guidePrompt.value || '', guidePrompt: guidePrompt.value || '',
searchPrompts: [ searchPrompts: [
{ {

View File

@ -280,7 +280,8 @@ export const authChat = async ({
{ {
$project: { $project: {
obj: '$content.obj', obj: '$content.obj',
value: '$content.value' value: '$content.value',
quote: '$content.quote'
} }
} }
]); ]);

View File

@ -89,39 +89,55 @@ export const ChatContextFilter = ({
prompts: ChatItemSimpleType[]; prompts: ChatItemSimpleType[];
maxTokens: number; maxTokens: number;
}) => { }) => {
const systemPrompts: ChatItemSimpleType[] = [];
const chatPrompts: ChatItemSimpleType[] = [];
let rawTextLen = 0; let rawTextLen = 0;
const formatPrompts = prompts.map<ChatItemSimpleType>((item) => { prompts.forEach((item) => {
const val = simplifyStr(item.value); const val = simplifyStr(item.value);
rawTextLen += val.length; rawTextLen += val.length;
return {
const data = {
obj: item.obj, obj: item.obj,
value: val value: val
}; };
if (item.obj === ChatRoleEnum.System) {
systemPrompts.push(data);
} else {
chatPrompts.push(data);
}
}); });
// 长度太小时,不需要进行 token 截断 // 长度太小时,不需要进行 token 截断
if (formatPrompts.length <= 2 || rawTextLen < maxTokens * 0.5) { if (rawTextLen < maxTokens * 0.5) {
return formatPrompts; return [...systemPrompts, ...chatPrompts];
} }
// 去掉 system 的 token
maxTokens -= modelToolMap[model].countTokens({
messages: systemPrompts
});
// 根据 tokens 截断内容 // 根据 tokens 截断内容
const chats: ChatItemSimpleType[] = []; const chats: ChatItemSimpleType[] = [];
// 从后往前截取对话内容 // 从后往前截取对话内容
for (let i = formatPrompts.length - 1; i >= 0; i--) { for (let i = chatPrompts.length - 1; i >= 0; i--) {
chats.unshift(formatPrompts[i]); chats.unshift(chatPrompts[i]);
const tokens = modelToolMap[model].countTokens({ const tokens = modelToolMap[model].countTokens({
messages: chats messages: chats
}); });
/* 整体 tokens 超出范围, system必须保留 */ /* 整体 tokens 超出范围, system必须保留 */
if (tokens >= maxTokens && formatPrompts[i].obj !== ChatRoleEnum.System) { if (tokens >= maxTokens) {
return chats.slice(1); chats.shift();
break;
} }
} }
return chats; return [...systemPrompts, ...chats];
}; };
/* stream response */ /* stream response */