From 6014a56e5432d75e70a6f482c0bb7d6bf7176fd7 Mon Sep 17 00:00:00 2001 From: archer <545436317@qq.com> Date: Tue, 23 May 2023 19:13:01 +0800 Subject: [PATCH] feat: system prompt --- src/api/fetch.ts | 9 ++-- src/constants/chat.ts | 1 + src/pages/api/chat/chat.ts | 19 +++++--- src/pages/api/chat/init.ts | 3 +- src/pages/api/chat/saveChat.ts | 6 +-- src/pages/api/openapi/kb/appKbSearch.ts | 47 +++++++++--------- src/pages/chat/index.tsx | 64 ++++++++++++++++++------- src/service/models/chat.ts | 29 +++++++++-- src/types/chat.d.ts | 1 + 9 files changed, 118 insertions(+), 61 deletions(-) diff --git a/src/api/fetch.ts b/src/api/fetch.ts index 38f3b185d..62de6dedc 100644 --- a/src/api/fetch.ts +++ b/src/api/fetch.ts @@ -1,4 +1,4 @@ -import { NEW_CHATID_HEADER, QUOTE_LEN_HEADER } from '@/constants/chat'; +import { GUIDE_PROMPT_HEADER, NEW_CHATID_HEADER, QUOTE_LEN_HEADER } from '@/constants/chat'; interface StreamFetchProps { url: string; @@ -7,7 +7,7 @@ interface StreamFetchProps { abortSignal: AbortController; } export const streamFetch = ({ url, data, onMessage, abortSignal }: StreamFetchProps) => - new Promise<{ responseText: string; newChatId: string; quoteLen: number }>( + new Promise<{ responseText: string; newChatId: string; systemPrompt: string; quoteLen: number }>( async (resolve, reject) => { try { const res = await fetch(url, { @@ -24,6 +24,7 @@ export const streamFetch = ({ url, data, onMessage, abortSignal }: StreamFetchPr const decoder = new TextDecoder(); const newChatId = decodeURIComponent(res.headers.get(NEW_CHATID_HEADER) || ''); + const systemPrompt = decodeURIComponent(res.headers.get(GUIDE_PROMPT_HEADER) || '').trim(); const quoteLen = res.headers.get(QUOTE_LEN_HEADER) ? Number(res.headers.get(QUOTE_LEN_HEADER)) : 0; @@ -35,7 +36,7 @@ export const streamFetch = ({ url, data, onMessage, abortSignal }: StreamFetchPr const { done, value } = await reader?.read(); if (done) { if (res.status === 200) { - resolve({ responseText, newChatId, quoteLen }); + resolve({ responseText, newChatId, quoteLen, systemPrompt }); } else { const parseError = JSON.parse(responseText); reject(parseError?.message || '请求异常'); @@ -49,7 +50,7 @@ export const streamFetch = ({ url, data, onMessage, abortSignal }: StreamFetchPr read(); } catch (err: any) { if (err?.message === 'The user aborted a request.') { - return resolve({ responseText, newChatId, quoteLen: 0 }); + return resolve({ responseText, newChatId, quoteLen: 0, systemPrompt: '' }); } reject(typeof err === 'string' ? err : err?.message || '请求异常'); } diff --git a/src/constants/chat.ts b/src/constants/chat.ts index 59b57ec4a..a816143a7 100644 --- a/src/constants/chat.ts +++ b/src/constants/chat.ts @@ -1,5 +1,6 @@ export const NEW_CHATID_HEADER = 'response-new-chat-id'; export const QUOTE_LEN_HEADER = 'response-quote-len'; +export const GUIDE_PROMPT_HEADER = 'response-guide-prompt'; export enum ChatRoleEnum { System = 'System', diff --git a/src/pages/api/chat/chat.ts b/src/pages/api/chat/chat.ts index d4ad11e71..923d328d7 100644 --- a/src/pages/api/chat/chat.ts +++ b/src/pages/api/chat/chat.ts @@ -8,7 +8,7 @@ import { ChatModelMap, ModelVectorSearchModeMap } from '@/constants/model'; import { pushChatBill } from '@/service/events/pushBill'; import { resStreamResponse } from '@/service/utils/chat'; import { appKbSearch } from '../openapi/kb/appKbSearch'; -import { ChatRoleEnum, QUOTE_LEN_HEADER } from '@/constants/chat'; +import { ChatRoleEnum, QUOTE_LEN_HEADER, GUIDE_PROMPT_HEADER } from '@/constants/chat'; import { BillTypeEnum } from '@/constants/user'; import { sensitiveCheck } from '@/service/api/text'; import { NEW_CHATID_HEADER } from '@/constants/chat'; @@ -53,11 +53,12 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) const { code = 200, systemPrompts = [], - quote = [] + quote = [], + guidePrompt = '' } = await (async () => { // 使用了知识库搜索 if (model.chat.relatedKbs.length > 0) { - const { code, searchPrompts, rawSearch } = await appKbSearch({ + const { code, searchPrompts, rawSearch, guidePrompt } = await appKbSearch({ model, userId, prompts, @@ -67,11 +68,13 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) return { code, quote: rawSearch, - systemPrompts: searchPrompts + systemPrompts: searchPrompts, + guidePrompt }; } if (model.chat.systemPrompt) { return { + guidePrompt: model.chat.systemPrompt, systemPrompts: [ { obj: ChatRoleEnum.System, @@ -86,7 +89,10 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) // get conversationId. create a newId if it is null const conversationId = chatId || String(new Types.ObjectId()); !chatId && res.setHeader(NEW_CHATID_HEADER, conversationId); - res.setHeader(QUOTE_LEN_HEADER, quote.length); + if (showModelDetail) { + guidePrompt && res.setHeader(GUIDE_PROMPT_HEADER, encodeURIComponent(guidePrompt)); + res.setHeader(QUOTE_LEN_HEADER, quote.length); + } // search result is empty if (code === 201) { @@ -151,8 +157,9 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) prompt[0], { ...prompt[1], + value: responseContent, quote: showModelDetail ? quote : [], - value: responseContent + systemPrompt: showModelDetail ? guidePrompt : '' } ], userId diff --git a/src/pages/api/chat/init.ts b/src/pages/api/chat/init.ts index c985b3692..62f8ce45b 100644 --- a/src/pages/api/chat/init.ts +++ b/src/pages/api/chat/init.ts @@ -73,7 +73,8 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) _id: '$content._id', obj: '$content.obj', value: '$content.value', - quoteLen: { $size: '$content.quote' } + systemPrompt: '$content.systemPrompt', + quoteLen: { $size: { $ifNull: ['$content.quote', []] } } } } ]); diff --git a/src/pages/api/chat/saveChat.ts b/src/pages/api/chat/saveChat.ts index 9b96949b6..11fa68d12 100644 --- a/src/pages/api/chat/saveChat.ts +++ b/src/pages/api/chat/saveChat.ts @@ -57,10 +57,8 @@ export async function saveChat({ _id: item._id ? new mongoose.Types.ObjectId(item._id) : undefined, obj: item.obj, value: item.value, - quote: item.quote?.map((item) => ({ - ...item, - isEdit: false - })) + systemPrompt: item.systemPrompt, + quote: item.quote || [] })); // 没有 chatId, 创建一个对话 diff --git a/src/pages/api/openapi/kb/appKbSearch.ts b/src/pages/api/openapi/kb/appKbSearch.ts index f74e16f7d..97f01e577 100644 --- a/src/pages/api/openapi/kb/appKbSearch.ts +++ b/src/pages/api/openapi/kb/appKbSearch.ts @@ -22,6 +22,7 @@ type Props = { type Response = { code: 200 | 201; rawSearch: QuoteItemType[]; + guidePrompt: string; searchPrompts: { obj: ChatRoleEnum; value: string; @@ -131,36 +132,29 @@ export async function appKbSearch({ }; const sliceRate = sliceRateMap[searchRes.length] || sliceRateMap[0]; // 计算固定提示词的 token 数量 - const fixedPrompts = [ - // user system prompt - ...(model.chat.systemPrompt - ? [ - { - obj: ChatRoleEnum.System, - value: model.chat.systemPrompt - } - ] - : model.chat.searchMode === ModelVectorSearchModeEnum.noContext - ? [ - { - obj: ChatRoleEnum.System, - value: `知识库是关于"${model.name}"的内容,根据知识库内容回答问题.` - } - ] - : [ - { - obj: ChatRoleEnum.System, - value: `玩一个问答游戏,规则为: + + const guidePrompt = model.chat.systemPrompt // user system prompt + ? { + obj: ChatRoleEnum.System, + value: model.chat.systemPrompt + } + : model.chat.searchMode === ModelVectorSearchModeEnum.noContext + ? { + obj: ChatRoleEnum.System, + value: `知识库是关于"${model.name}"的内容,根据知识库内容回答问题.` + } + : { + obj: ChatRoleEnum.System, + value: `玩一个问答游戏,规则为: 1.你完全忘记你已有的知识 2.你只回答关于"${model.name}"的问题 3.你只从知识库中选择内容进行回答 4.如果问题不在知识库中,你会回答:"我不知道。" 请务必遵守规则` - } - ]) - ]; + }; + const fixedSystemTokens = modelToolMap[model.chat.chatModel].countTokens({ - messages: fixedPrompts + messages: [guidePrompt] }); const maxTokens = modelConstantsData.systemMaxToken - fixedSystemTokens; const sliceResult = sliceRate.map((rate, i) => @@ -186,6 +180,7 @@ export async function appKbSearch({ return { code: 201, rawSearch: [], + guidePrompt: '', searchPrompts: [ { obj: ChatRoleEnum.System, @@ -199,6 +194,7 @@ export async function appKbSearch({ return { code: 200, rawSearch: [], + guidePrompt: model.chat.systemPrompt || '', searchPrompts: model.chat.systemPrompt ? [ { @@ -213,12 +209,13 @@ export async function appKbSearch({ return { code: 200, rawSearch: sliceSearch, + guidePrompt: guidePrompt.value || '', searchPrompts: [ { obj: ChatRoleEnum.System, value: `知识库:${systemPrompt}` }, - ...fixedPrompts + guidePrompt ] }; } diff --git a/src/pages/chat/index.tsx b/src/pages/chat/index.tsx index 4aabb01f8..bd544c4a9 100644 --- a/src/pages/chat/index.tsx +++ b/src/pages/chat/index.tsx @@ -76,6 +76,7 @@ const Chat = ({ modelId, chatId }: { modelId: string; chatId: string }) => { const isLeavePage = useRef(false); const [showHistoryQuote, setShowHistoryQuote] = useState(); + const [showSystemPrompt, setShowSystemPrompt] = useState(''); const [messageContextMenuData, setMessageContextMenuData] = useState<{ // message messageContextMenuData left: number; @@ -177,7 +178,7 @@ const Chat = ({ modelId, chatId }: { modelId: string; chatId: string }) => { })); // 流请求,获取数据 - const { newChatId, quoteLen } = await streamFetch({ + const { newChatId, quoteLen, systemPrompt } = await streamFetch({ url: '/api/chat/chat', data: { prompt, @@ -221,14 +222,15 @@ const Chat = ({ modelId, chatId }: { modelId: string; chatId: string }) => { return { ...item, status: 'finish', - quoteLen + quoteLen, + systemPrompt }; }) })); // refresh history - loadHistory({ pageNum: 1, init: true }); setTimeout(() => { + loadHistory({ pageNum: 1, init: true }); generatingMessage(); }, 100); }, @@ -699,6 +701,7 @@ const Chat = ({ modelId, chatId }: { modelId: string; chatId: string }) => { })} > { isChatting={isChatting && index === chatData.history.length - 1} formatLink /> - {!!item.quoteLen && ( - - )} + + {!!item.systemPrompt && ( + + )} + {!!item.quoteLen && ( + + )} + ) : ( @@ -876,6 +895,19 @@ const Chat = ({ modelId, chatId }: { modelId: string; chatId: string }) => { onClose={() => setShowHistoryQuote(undefined)} /> )} + {/* system prompt show modal */} + { + setShowSystemPrompt('')}> + + + + 提示词 + + {showSystemPrompt} + + + + } {/* context menu */} {messageContextMenuData && (