From 90456301d263c3ec39c6afca0034cc0f1c9c08c1 Mon Sep 17 00:00:00 2001 From: archer <545436317@qq.com> Date: Tue, 2 May 2023 14:06:10 +0800 Subject: [PATCH] feat: save system prompt --- src/api/fetch.ts | 26 +++++++++++--- src/constants/chat.ts | 1 + src/pages/api/chat/chat.ts | 6 ++-- src/pages/api/chat/init.ts | 3 +- src/pages/api/chat/saveChat.ts | 3 +- src/pages/chat/index.tsx | 62 +++++++++++++++++++++++++++------- src/service/models/chat.ts | 4 +++ src/service/utils/auth.ts | 12 +++++-- src/service/utils/openai.ts | 13 +++++-- src/types/chat.d.ts | 1 + 10 files changed, 104 insertions(+), 27 deletions(-) create mode 100644 src/constants/chat.ts diff --git a/src/api/fetch.ts b/src/api/fetch.ts index a7822015b..9e06a27b3 100644 --- a/src/api/fetch.ts +++ b/src/api/fetch.ts @@ -1,4 +1,6 @@ import { getToken } from '../utils/user'; +import { SYSTEM_PROMPT_PREFIX } from '@/constants/chat'; + interface StreamFetchProps { url: string; data: any; @@ -6,7 +8,7 @@ interface StreamFetchProps { abortSignal: AbortController; } export const streamFetch = ({ url, data, onMessage, abortSignal }: StreamFetchProps) => - new Promise(async (resolve, reject) => { + new Promise<{ responseText: string; systemPrompt: string }>(async (resolve, reject) => { try { const res = await fetch(url, { method: 'POST', @@ -19,15 +21,22 @@ export const streamFetch = ({ url, data, onMessage, abortSignal }: StreamFetchPr }); const reader = res.body?.getReader(); if (!reader) return; + + if (res.status !== 200) { + console.log(res); + return reject('chat error'); + } + const decoder = new TextDecoder(); let responseText = ''; + let systemPrompt = ''; const read = async () => { try { const { done, value } = await reader?.read(); if (done) { if (res.status === 200) { - resolve(responseText); + resolve({ responseText, systemPrompt }); } else { const parseError = JSON.parse(responseText); reject(parseError?.message || '请求异常'); @@ -36,12 +45,19 @@ export const streamFetch = ({ url, data, onMessage, abortSignal }: StreamFetchPr return; } const text = decoder.decode(value).replace(//g, '\n'); - res.status === 200 && onMessage(text); - responseText += text; + + // check system prompt + if (text.startsWith(SYSTEM_PROMPT_PREFIX)) { + systemPrompt = text.replace(SYSTEM_PROMPT_PREFIX, ''); + } else { + responseText += text; + onMessage(text); + } + read(); } catch (err: any) { if (err?.message === 'The user aborted a request.') { - return resolve(responseText); + return resolve({ responseText, systemPrompt }); } reject(typeof err === 'string' ? err : err?.message || '请求异常'); } diff --git a/src/constants/chat.ts b/src/constants/chat.ts new file mode 100644 index 000000000..9c7021ab5 --- /dev/null +++ b/src/constants/chat.ts @@ -0,0 +1 @@ +export const SYSTEM_PROMPT_PREFIX = 'SYSTEM_PROMPT:'; diff --git a/src/pages/api/chat/chat.ts b/src/pages/api/chat/chat.ts index 5bbfbbbc1..5cfb8e44d 100644 --- a/src/pages/api/chat/chat.ts +++ b/src/pages/api/chat/chat.ts @@ -41,7 +41,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) await connectToDatabase(); let startTime = Date.now(); - const { model, content, userApiKey, systemKey, userId } = await authChat({ + const { model, showModelDetail, content, userApiKey, systemKey, userId } = await authChat({ modelId, chatId, authorization @@ -120,7 +120,9 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) const { responseContent } = await gpt35StreamResponse({ res, stream, - chatResponse + chatResponse, + systemPrompt: + showModelDetail && filterPrompts[0].role === 'system' ? filterPrompts[0].content : '' }); // 只有使用平台的 key 才计费 diff --git a/src/pages/api/chat/init.ts b/src/pages/api/chat/init.ts index 2acd9e6a7..6ebe5b88e 100644 --- a/src/pages/api/chat/init.ts +++ b/src/pages/api/chat/init.ts @@ -48,7 +48,8 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) $project: { _id: '$content._id', obj: '$content.obj', - value: '$content.value' + value: '$content.value', + systemPrompt: '$content.systemPrompt' } } ]); diff --git a/src/pages/api/chat/saveChat.ts b/src/pages/api/chat/saveChat.ts index 583c42bfa..0494fea0e 100644 --- a/src/pages/api/chat/saveChat.ts +++ b/src/pages/api/chat/saveChat.ts @@ -26,7 +26,8 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) const content = prompts.map((item) => ({ _id: new mongoose.Types.ObjectId(item._id), obj: item.obj, - value: item.value + value: item.value, + systemPrompt: item.systemPrompt })); await authModel({ modelId, userId, authOwner: false }); diff --git a/src/pages/chat/index.tsx b/src/pages/chat/index.tsx index d1e21ef01..96937cbd1 100644 --- a/src/pages/chat/index.tsx +++ b/src/pages/chat/index.tsx @@ -16,7 +16,13 @@ import { MenuButton, MenuList, MenuItem, - Image + Image, + Button, + Modal, + ModalOverlay, + ModalContent, + ModalBody, + ModalCloseButton } from '@chakra-ui/react'; import { useToast } from '@/hooks/useToast'; import { useScreen } from '@/hooks/useScreen'; @@ -29,7 +35,7 @@ import { streamFetch } from '@/api/fetch'; import Icon from '@/components/Icon'; import MyIcon from '@/components/Icon'; import { throttle } from 'lodash'; -import mongoose from 'mongoose'; +import { Types } from 'mongoose'; const SlideBar = dynamic(() => import('./components/SlideBar')); const Empty = dynamic(() => import('./components/Empty')); @@ -67,7 +73,8 @@ const Chat = ({ modelId, chatId }: { modelId: string; chatId: string }) => { history: [] }); // 聊天框整体数据 - const [inputVal, setInputVal] = useState(''); // 输入的内容 + const [inputVal, setInputVal] = useState(''); // user input prompt + const [showSystemPrompt, setShowSystemPrompt] = useState(''); const isChatting = useMemo( () => chatData.history[chatData.history.length - 1]?.status === 'loading', @@ -199,7 +206,7 @@ const Chat = ({ modelId, chatId }: { modelId: string; chatId: string }) => { }; // 流请求,获取数据 - const responseText = await streamFetch({ + const { responseText, systemPrompt } = await streamFetch({ url: '/api/chat/chat', data: { prompt, @@ -228,7 +235,7 @@ const Chat = ({ modelId, chatId }: { modelId: string; chatId: string }) => { } let newChatId = ''; - // 保存对话信息 + // save chat record try { newChatId = await postSaveChat({ modelId, @@ -242,7 +249,8 @@ const Chat = ({ modelId, chatId }: { modelId: string; chatId: string }) => { { _id: prompts[1]._id, obj: 'AI', - value: responseText + value: responseText, + systemPrompt } ] }); @@ -266,7 +274,8 @@ const Chat = ({ modelId, chatId }: { modelId: string; chatId: string }) => { if (index !== state.history.length - 1) return item; return { ...item, - status: 'finish' + status: 'finish', + systemPrompt }; }) })); @@ -300,13 +309,13 @@ const Chat = ({ modelId, chatId }: { modelId: string; chatId: string }) => { const newChatList: ChatSiteItemType[] = [ ...chatData.history, { - _id: String(new mongoose.Types.ObjectId()), + _id: String(new Types.ObjectId()), obj: 'Human', value: val, status: 'finish' }, { - _id: String(new mongoose.Types.ObjectId()), + _id: String(new Types.ObjectId()), obj: 'AI', value: '', status: 'loading' @@ -492,10 +501,24 @@ const Chat = ({ modelId, chatId }: { modelId: string; chatId: string }) => { {item.obj === 'AI' ? ( - + <> + + {item.systemPrompt && ( + + )} + ) : ( {item.value} @@ -617,6 +640,19 @@ const Chat = ({ modelId, chatId }: { modelId: string; chatId: string }) => { + + {/* system prompt show modal */} + { + setShowSystemPrompt('')}> + + + + + {showSystemPrompt} + + + + } ); }; diff --git a/src/service/models/chat.ts b/src/service/models/chat.ts index 50eda2fa9..638ad7a63 100644 --- a/src/service/models/chat.ts +++ b/src/service/models/chat.ts @@ -41,6 +41,10 @@ const ChatSchema = new Schema({ value: { type: String, required: true + }, + systemPrompt: { + type: String, + default: '' } } ], diff --git a/src/service/utils/auth.ts b/src/service/utils/auth.ts index af90656a0..2ebb06cac 100644 --- a/src/service/utils/auth.ts +++ b/src/service/utils/auth.ts @@ -75,7 +75,7 @@ export const authModel = async ({ }; } - return { model }; + return { model, showModelDetail: model.share.isShareDetail || userId === String(model.userId) }; }; // 获取对话校验 @@ -91,7 +91,12 @@ export const authChat = async ({ const userId = await authToken(authorization); // 获取 model 数据 - const { model } = await authModel({ modelId, userId, authOwner: false, reserveDetail: true }); + const { model, showModelDetail } = await authModel({ + modelId, + userId, + authOwner: false, + reserveDetail: true + }); // 聊天内容 let content: ChatItemSimpleType[] = []; @@ -124,7 +129,8 @@ export const authChat = async ({ systemKey, content, userId, - model + model, + showModelDetail }; }; diff --git a/src/service/utils/openai.ts b/src/service/utils/openai.ts index 689d00faa..8cb0c539a 100644 --- a/src/service/utils/openai.ts +++ b/src/service/utils/openai.ts @@ -7,6 +7,7 @@ import { User } from '../models/user'; import { formatPrice } from '@/utils/user'; import { embeddingModel } from '@/constants/model'; import { pushGenerateVectorBill } from '../events/pushBill'; +import { SYSTEM_PROMPT_PREFIX } from '@/constants/chat'; /* 获取用户 api 的 openai 信息 */ export const getUserApiOpenai = async (userId: string) => { @@ -110,11 +111,13 @@ export const openaiCreateEmbedding = async ({ export const gpt35StreamResponse = ({ res, stream, - chatResponse + chatResponse, + systemPrompt = '' }: { res: NextApiResponse; stream: PassThrough; chatResponse: any; + systemPrompt?: string; }) => new Promise<{ responseContent: string }>(async (resolve, reject) => { try { @@ -144,8 +147,8 @@ export const gpt35StreamResponse = ({ } }; - const decoder = new TextDecoder(); try { + const decoder = new TextDecoder(); const parser = createParser(onParse); for await (const chunk of chatResponse.data as any) { if (stream.destroyed) { @@ -157,6 +160,12 @@ export const gpt35StreamResponse = ({ } catch (error) { console.log('pipe error', error); } + + // push system prompt + !stream.destroyed && + systemPrompt && + stream.push(`${SYSTEM_PROMPT_PREFIX}${systemPrompt.replace(/\n/g, '
')}`); + // close stream !stream.destroyed && stream.push(null); stream.destroy(); diff --git a/src/types/chat.d.ts b/src/types/chat.d.ts index 71f0808c0..641d2ede8 100644 --- a/src/types/chat.d.ts +++ b/src/types/chat.d.ts @@ -1,6 +1,7 @@ export type ChatItemSimpleType = { obj: 'Human' | 'AI' | 'SYSTEM'; value: string; + systemPrompt?: string; }; export type ChatItemType = { _id: string;