feat: use last quote
This commit is contained in:
parent
59ddf09b94
commit
0cde9a10a8
@ -5,7 +5,7 @@ import { RequestPaging } from '../types/index';
|
|||||||
import type { ShareChatSchema } from '@/types/mongoSchema';
|
import type { ShareChatSchema } from '@/types/mongoSchema';
|
||||||
import type { ShareChatEditType } from '@/types/model';
|
import type { ShareChatEditType } from '@/types/model';
|
||||||
import { Obj2Query } from '@/utils/tools';
|
import { Obj2Query } from '@/utils/tools';
|
||||||
import { QuoteItemType } from '@/pages/api/openapi/kb/appKbSearch';
|
import type { QuoteItemType } from '@/pages/api/openapi/kb/appKbSearch';
|
||||||
import type { Props as UpdateHistoryProps } from '@/pages/api/chat/history/updateChatHistory';
|
import type { Props as UpdateHistoryProps } from '@/pages/api/chat/history/updateChatHistory';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@ -50,6 +50,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
|
|||||||
|
|
||||||
// 读取对话内容
|
// 读取对话内容
|
||||||
const prompts = [...content, prompt[0]];
|
const prompts = [...content, prompt[0]];
|
||||||
|
|
||||||
const {
|
const {
|
||||||
code = 200,
|
code = 200,
|
||||||
systemPrompts = [],
|
systemPrompts = [],
|
||||||
@ -61,7 +62,8 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
|
|||||||
const { code, searchPrompts, rawSearch, guidePrompt } = await appKbSearch({
|
const { code, searchPrompts, rawSearch, guidePrompt } = await appKbSearch({
|
||||||
model,
|
model,
|
||||||
userId,
|
userId,
|
||||||
prompts,
|
fixedQuote: content[content.length - 1]?.quote || [],
|
||||||
|
prompt: prompt[0],
|
||||||
similarity: ModelVectorSearchModeMap[model.chat.searchMode]?.similarity
|
similarity: ModelVectorSearchModeMap[model.chat.searchMode]?.similarity
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -114,7 +116,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
|
|||||||
return res.end(response);
|
return res.end(response);
|
||||||
}
|
}
|
||||||
|
|
||||||
prompts.splice(prompts.length - 3, 0, ...systemPrompts);
|
prompts.unshift(...systemPrompts);
|
||||||
|
|
||||||
// content check
|
// content check
|
||||||
await sensitiveCheck({
|
await sensitiveCheck({
|
||||||
|
|||||||
@ -47,7 +47,8 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
|
|||||||
const { code, searchPrompts } = await appKbSearch({
|
const { code, searchPrompts } = await appKbSearch({
|
||||||
model,
|
model,
|
||||||
userId,
|
userId,
|
||||||
prompts,
|
fixedQuote: [],
|
||||||
|
prompt: prompts[prompts.length - 1],
|
||||||
similarity: ModelVectorSearchModeMap[model.chat.searchMode]?.similarity
|
similarity: ModelVectorSearchModeMap[model.chat.searchMode]?.similarity
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -74,7 +75,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
|
|||||||
return res.send(systemPrompts[0]?.value);
|
return res.send(systemPrompts[0]?.value);
|
||||||
}
|
}
|
||||||
|
|
||||||
prompts.splice(prompts.length - 3, 0, ...systemPrompts);
|
prompts.unshift(...systemPrompts);
|
||||||
|
|
||||||
// content check
|
// content check
|
||||||
await sensitiveCheck({
|
await sensitiveCheck({
|
||||||
|
|||||||
@ -75,10 +75,11 @@ export default withNextCors(async function handler(req: NextApiRequest, res: Nex
|
|||||||
// 使用了知识库搜索
|
// 使用了知识库搜索
|
||||||
if (model.chat.relatedKbs.length > 0) {
|
if (model.chat.relatedKbs.length > 0) {
|
||||||
const { code, searchPrompts } = await appKbSearch({
|
const { code, searchPrompts } = await appKbSearch({
|
||||||
prompts,
|
|
||||||
similarity: ModelVectorSearchModeMap[model.chat.searchMode]?.similarity,
|
|
||||||
model,
|
model,
|
||||||
userId
|
userId,
|
||||||
|
fixedQuote: [],
|
||||||
|
prompt: prompts[prompts.length - 1],
|
||||||
|
similarity: ModelVectorSearchModeMap[model.chat.searchMode]?.similarity
|
||||||
});
|
});
|
||||||
|
|
||||||
// search result is empty
|
// search result is empty
|
||||||
@ -101,7 +102,7 @@ export default withNextCors(async function handler(req: NextApiRequest, res: Nex
|
|||||||
];
|
];
|
||||||
}
|
}
|
||||||
|
|
||||||
prompts.splice(prompts.length - 3, 0, ...systemPrompts);
|
prompts.unshift(...systemPrompts);
|
||||||
|
|
||||||
// content check
|
// content check
|
||||||
await sensitiveCheck({
|
await sensitiveCheck({
|
||||||
|
|||||||
@ -49,10 +49,11 @@ export default withNextCors(async function handler(req: NextApiRequest, res: Nex
|
|||||||
});
|
});
|
||||||
|
|
||||||
const result = await appKbSearch({
|
const result = await appKbSearch({
|
||||||
|
model,
|
||||||
userId,
|
userId,
|
||||||
prompts,
|
fixedQuote: [],
|
||||||
similarity,
|
prompt: prompts[prompts.length - 1],
|
||||||
model
|
similarity
|
||||||
});
|
});
|
||||||
|
|
||||||
jsonRes<Response>(res, {
|
jsonRes<Response>(res, {
|
||||||
@ -70,67 +71,53 @@ export default withNextCors(async function handler(req: NextApiRequest, res: Nex
|
|||||||
export async function appKbSearch({
|
export async function appKbSearch({
|
||||||
model,
|
model,
|
||||||
userId,
|
userId,
|
||||||
prompts,
|
fixedQuote,
|
||||||
|
prompt,
|
||||||
similarity
|
similarity
|
||||||
}: {
|
}: {
|
||||||
userId: string;
|
|
||||||
prompts: ChatItemSimpleType[];
|
|
||||||
similarity: number;
|
|
||||||
model: ModelSchema;
|
model: ModelSchema;
|
||||||
|
userId: string;
|
||||||
|
fixedQuote: QuoteItemType[];
|
||||||
|
prompt: ChatItemSimpleType;
|
||||||
|
similarity: number;
|
||||||
}): Promise<Response> {
|
}): Promise<Response> {
|
||||||
const modelConstantsData = ChatModelMap[model.chat.chatModel];
|
const modelConstantsData = ChatModelMap[model.chat.chatModel];
|
||||||
|
|
||||||
// search two times.
|
|
||||||
const userPrompts = prompts.filter((item) => item.obj === 'Human');
|
|
||||||
|
|
||||||
const input: string[] = [
|
|
||||||
userPrompts[userPrompts.length - 1].value,
|
|
||||||
userPrompts[userPrompts.length - 2]?.value
|
|
||||||
].filter((item) => item);
|
|
||||||
|
|
||||||
// get vector
|
// get vector
|
||||||
const promptVectors = await openaiEmbedding({
|
const promptVector = await openaiEmbedding({
|
||||||
userId,
|
userId,
|
||||||
input,
|
input: [prompt.value],
|
||||||
type: 'chat'
|
type: 'chat'
|
||||||
});
|
});
|
||||||
|
|
||||||
// search kb
|
// search kb
|
||||||
const searchRes = await Promise.all(
|
const { rows: searchRes } = await PgClient.select<QuoteItemType>('modelData', {
|
||||||
promptVectors.map((promptVector) =>
|
fields: ['id', 'q', 'a'],
|
||||||
PgClient.select<QuoteItemType>('modelData', {
|
where: [
|
||||||
fields: ['id', 'q', 'a'],
|
`kb_id IN (${model.chat.relatedKbs.map((item) => `'${item}'`).join(',')})`,
|
||||||
where: [
|
'AND',
|
||||||
`kb_id IN (${model.chat.relatedKbs.map((item) => `'${item}'`).join(',')})`,
|
`vector <=> '[${promptVector[0]}]' < ${similarity}`
|
||||||
'AND',
|
],
|
||||||
`vector <=> '[${promptVector}]' < ${similarity}`
|
order: [{ field: 'vector', mode: `<=> '[${promptVector[0]}]'` }],
|
||||||
],
|
limit: 8
|
||||||
order: [{ field: 'vector', mode: `<=> '[${promptVector}]'` }],
|
});
|
||||||
limit: promptVectors.length === 1 ? 15 : 10
|
|
||||||
}).then((res) => res.rows)
|
|
||||||
)
|
|
||||||
);
|
|
||||||
|
|
||||||
// filter same search result
|
// filter same search result
|
||||||
const idSet = new Set<string>();
|
const idSet = new Set<string>();
|
||||||
const filterSearch = searchRes.map((search) =>
|
const filterSearch = [
|
||||||
search.filter((item) => {
|
...searchRes.slice(0, 3),
|
||||||
if (idSet.has(item.id)) {
|
...fixedQuote.slice(0, 2),
|
||||||
return false;
|
...searchRes.slice(3),
|
||||||
}
|
...fixedQuote.slice(2, 5)
|
||||||
idSet.add(item.id);
|
].filter((item) => {
|
||||||
return true;
|
if (idSet.has(item.id)) {
|
||||||
})
|
return false;
|
||||||
);
|
}
|
||||||
|
idSet.add(item.id);
|
||||||
|
return true;
|
||||||
|
});
|
||||||
|
|
||||||
// slice search result by rate.
|
|
||||||
const sliceRateMap: Record<number, number[]> = {
|
|
||||||
1: [1],
|
|
||||||
2: [0.7, 0.3]
|
|
||||||
};
|
|
||||||
const sliceRate = sliceRateMap[searchRes.length] || sliceRateMap[0];
|
|
||||||
// 计算固定提示词的 token 数量
|
// 计算固定提示词的 token 数量
|
||||||
|
|
||||||
const guidePrompt = model.chat.systemPrompt // user system prompt
|
const guidePrompt = model.chat.systemPrompt // user system prompt
|
||||||
? {
|
? {
|
||||||
obj: ChatRoleEnum.System,
|
obj: ChatRoleEnum.System,
|
||||||
@ -154,24 +141,21 @@ export async function appKbSearch({
|
|||||||
const fixedSystemTokens = modelToolMap[model.chat.chatModel].countTokens({
|
const fixedSystemTokens = modelToolMap[model.chat.chatModel].countTokens({
|
||||||
messages: [guidePrompt]
|
messages: [guidePrompt]
|
||||||
});
|
});
|
||||||
const maxTokens = modelConstantsData.systemMaxToken - fixedSystemTokens;
|
const sliceResult = modelToolMap[model.chat.chatModel]
|
||||||
const sliceResult = sliceRate.map((rate, i) =>
|
.tokenSlice({
|
||||||
modelToolMap[model.chat.chatModel]
|
maxToken: modelConstantsData.systemMaxToken - fixedSystemTokens,
|
||||||
.tokenSlice({
|
messages: filterSearch.map((item) => ({
|
||||||
maxToken: Math.round(maxTokens * rate),
|
obj: ChatRoleEnum.System,
|
||||||
messages: filterSearch[i].map((item) => ({
|
value: `${item.q}\n${item.a}`
|
||||||
obj: ChatRoleEnum.System,
|
}))
|
||||||
value: `${item.q}\n${item.a}`
|
})
|
||||||
}))
|
.map((item) => item.value);
|
||||||
})
|
|
||||||
.map((item) => item.value)
|
|
||||||
);
|
|
||||||
|
|
||||||
// slice filterSearch
|
// slice filterSearch
|
||||||
const sliceSearch = filterSearch.map((item, i) => item.slice(0, sliceResult[i].length)).flat();
|
const rawSearch = filterSearch.slice(0, sliceResult.length);
|
||||||
|
|
||||||
// system prompt
|
// system prompt
|
||||||
const systemPrompt = sliceResult.flat().join('\n').trim();
|
const systemPrompt = sliceResult.join('\n').trim();
|
||||||
|
|
||||||
/* 高相似度+不回复 */
|
/* 高相似度+不回复 */
|
||||||
if (!systemPrompt && model.chat.searchMode === appVectorSearchModeEnum.hightSimilarity) {
|
if (!systemPrompt && model.chat.searchMode === appVectorSearchModeEnum.hightSimilarity) {
|
||||||
@ -206,7 +190,7 @@ export async function appKbSearch({
|
|||||||
|
|
||||||
return {
|
return {
|
||||||
code: 200,
|
code: 200,
|
||||||
rawSearch: sliceSearch,
|
rawSearch,
|
||||||
guidePrompt: guidePrompt.value || '',
|
guidePrompt: guidePrompt.value || '',
|
||||||
searchPrompts: [
|
searchPrompts: [
|
||||||
{
|
{
|
||||||
|
|||||||
@ -280,7 +280,8 @@ export const authChat = async ({
|
|||||||
{
|
{
|
||||||
$project: {
|
$project: {
|
||||||
obj: '$content.obj',
|
obj: '$content.obj',
|
||||||
value: '$content.value'
|
value: '$content.value',
|
||||||
|
quote: '$content.quote'
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
]);
|
]);
|
||||||
|
|||||||
@ -89,39 +89,55 @@ export const ChatContextFilter = ({
|
|||||||
prompts: ChatItemSimpleType[];
|
prompts: ChatItemSimpleType[];
|
||||||
maxTokens: number;
|
maxTokens: number;
|
||||||
}) => {
|
}) => {
|
||||||
|
const systemPrompts: ChatItemSimpleType[] = [];
|
||||||
|
const chatPrompts: ChatItemSimpleType[] = [];
|
||||||
|
|
||||||
let rawTextLen = 0;
|
let rawTextLen = 0;
|
||||||
const formatPrompts = prompts.map<ChatItemSimpleType>((item) => {
|
prompts.forEach((item) => {
|
||||||
const val = simplifyStr(item.value);
|
const val = simplifyStr(item.value);
|
||||||
rawTextLen += val.length;
|
rawTextLen += val.length;
|
||||||
return {
|
|
||||||
|
const data = {
|
||||||
obj: item.obj,
|
obj: item.obj,
|
||||||
value: val
|
value: val
|
||||||
};
|
};
|
||||||
|
|
||||||
|
if (item.obj === ChatRoleEnum.System) {
|
||||||
|
systemPrompts.push(data);
|
||||||
|
} else {
|
||||||
|
chatPrompts.push(data);
|
||||||
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
// 长度太小时,不需要进行 token 截断
|
// 长度太小时,不需要进行 token 截断
|
||||||
if (formatPrompts.length <= 2 || rawTextLen < maxTokens * 0.5) {
|
if (rawTextLen < maxTokens * 0.5) {
|
||||||
return formatPrompts;
|
return [...systemPrompts, ...chatPrompts];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 去掉 system 的 token
|
||||||
|
maxTokens -= modelToolMap[model].countTokens({
|
||||||
|
messages: systemPrompts
|
||||||
|
});
|
||||||
|
|
||||||
// 根据 tokens 截断内容
|
// 根据 tokens 截断内容
|
||||||
const chats: ChatItemSimpleType[] = [];
|
const chats: ChatItemSimpleType[] = [];
|
||||||
|
|
||||||
// 从后往前截取对话内容
|
// 从后往前截取对话内容
|
||||||
for (let i = formatPrompts.length - 1; i >= 0; i--) {
|
for (let i = chatPrompts.length - 1; i >= 0; i--) {
|
||||||
chats.unshift(formatPrompts[i]);
|
chats.unshift(chatPrompts[i]);
|
||||||
|
|
||||||
const tokens = modelToolMap[model].countTokens({
|
const tokens = modelToolMap[model].countTokens({
|
||||||
messages: chats
|
messages: chats
|
||||||
});
|
});
|
||||||
|
|
||||||
/* 整体 tokens 超出范围, system必须保留 */
|
/* 整体 tokens 超出范围, system必须保留 */
|
||||||
if (tokens >= maxTokens && formatPrompts[i].obj !== ChatRoleEnum.System) {
|
if (tokens >= maxTokens) {
|
||||||
return chats.slice(1);
|
chats.shift();
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return chats;
|
return [...systemPrompts, ...chats];
|
||||||
};
|
};
|
||||||
|
|
||||||
/* stream response */
|
/* stream response */
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user