feat: save system prompt

This commit is contained in:
archer 2023-05-02 14:06:10 +08:00
parent b0d414ac12
commit 90456301d2
No known key found for this signature in database
GPG Key ID: 569A5660D2379E28
10 changed files with 104 additions and 27 deletions

View File

@ -1,4 +1,6 @@
import { getToken } from '../utils/user'; import { getToken } from '../utils/user';
import { SYSTEM_PROMPT_PREFIX } from '@/constants/chat';
interface StreamFetchProps { interface StreamFetchProps {
url: string; url: string;
data: any; data: any;
@ -6,7 +8,7 @@ interface StreamFetchProps {
abortSignal: AbortController; abortSignal: AbortController;
} }
export const streamFetch = ({ url, data, onMessage, abortSignal }: StreamFetchProps) => export const streamFetch = ({ url, data, onMessage, abortSignal }: StreamFetchProps) =>
new Promise<string>(async (resolve, reject) => { new Promise<{ responseText: string; systemPrompt: string }>(async (resolve, reject) => {
try { try {
const res = await fetch(url, { const res = await fetch(url, {
method: 'POST', method: 'POST',
@ -19,15 +21,22 @@ export const streamFetch = ({ url, data, onMessage, abortSignal }: StreamFetchPr
}); });
const reader = res.body?.getReader(); const reader = res.body?.getReader();
if (!reader) return; if (!reader) return;
if (res.status !== 200) {
console.log(res);
return reject('chat error');
}
const decoder = new TextDecoder(); const decoder = new TextDecoder();
let responseText = ''; let responseText = '';
let systemPrompt = '';
const read = async () => { const read = async () => {
try { try {
const { done, value } = await reader?.read(); const { done, value } = await reader?.read();
if (done) { if (done) {
if (res.status === 200) { if (res.status === 200) {
resolve(responseText); resolve({ responseText, systemPrompt });
} else { } else {
const parseError = JSON.parse(responseText); const parseError = JSON.parse(responseText);
reject(parseError?.message || '请求异常'); reject(parseError?.message || '请求异常');
@ -36,12 +45,19 @@ export const streamFetch = ({ url, data, onMessage, abortSignal }: StreamFetchPr
return; return;
} }
const text = decoder.decode(value).replace(/<br\/>/g, '\n'); const text = decoder.decode(value).replace(/<br\/>/g, '\n');
res.status === 200 && onMessage(text);
responseText += text; // check system prompt
if (text.startsWith(SYSTEM_PROMPT_PREFIX)) {
systemPrompt = text.replace(SYSTEM_PROMPT_PREFIX, '');
} else {
responseText += text;
onMessage(text);
}
read(); read();
} catch (err: any) { } catch (err: any) {
if (err?.message === 'The user aborted a request.') { if (err?.message === 'The user aborted a request.') {
return resolve(responseText); return resolve({ responseText, systemPrompt });
} }
reject(typeof err === 'string' ? err : err?.message || '请求异常'); reject(typeof err === 'string' ? err : err?.message || '请求异常');
} }

1
src/constants/chat.ts Normal file
View File

@ -0,0 +1 @@
export const SYSTEM_PROMPT_PREFIX = 'SYSTEM_PROMPT:';

View File

@ -41,7 +41,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
await connectToDatabase(); await connectToDatabase();
let startTime = Date.now(); let startTime = Date.now();
const { model, content, userApiKey, systemKey, userId } = await authChat({ const { model, showModelDetail, content, userApiKey, systemKey, userId } = await authChat({
modelId, modelId,
chatId, chatId,
authorization authorization
@ -120,7 +120,9 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
const { responseContent } = await gpt35StreamResponse({ const { responseContent } = await gpt35StreamResponse({
res, res,
stream, stream,
chatResponse chatResponse,
systemPrompt:
showModelDetail && filterPrompts[0].role === 'system' ? filterPrompts[0].content : ''
}); });
// 只有使用平台的 key 才计费 // 只有使用平台的 key 才计费

View File

@ -48,7 +48,8 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
$project: { $project: {
_id: '$content._id', _id: '$content._id',
obj: '$content.obj', obj: '$content.obj',
value: '$content.value' value: '$content.value',
systemPrompt: '$content.systemPrompt'
} }
} }
]); ]);

View File

@ -26,7 +26,8 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
const content = prompts.map((item) => ({ const content = prompts.map((item) => ({
_id: new mongoose.Types.ObjectId(item._id), _id: new mongoose.Types.ObjectId(item._id),
obj: item.obj, obj: item.obj,
value: item.value value: item.value,
systemPrompt: item.systemPrompt
})); }));
await authModel({ modelId, userId, authOwner: false }); await authModel({ modelId, userId, authOwner: false });

View File

@ -16,7 +16,13 @@ import {
MenuButton, MenuButton,
MenuList, MenuList,
MenuItem, MenuItem,
Image Image,
Button,
Modal,
ModalOverlay,
ModalContent,
ModalBody,
ModalCloseButton
} from '@chakra-ui/react'; } from '@chakra-ui/react';
import { useToast } from '@/hooks/useToast'; import { useToast } from '@/hooks/useToast';
import { useScreen } from '@/hooks/useScreen'; import { useScreen } from '@/hooks/useScreen';
@ -29,7 +35,7 @@ import { streamFetch } from '@/api/fetch';
import Icon from '@/components/Icon'; import Icon from '@/components/Icon';
import MyIcon from '@/components/Icon'; import MyIcon from '@/components/Icon';
import { throttle } from 'lodash'; import { throttle } from 'lodash';
import mongoose from 'mongoose'; import { Types } from 'mongoose';
const SlideBar = dynamic(() => import('./components/SlideBar')); const SlideBar = dynamic(() => import('./components/SlideBar'));
const Empty = dynamic(() => import('./components/Empty')); const Empty = dynamic(() => import('./components/Empty'));
@ -67,7 +73,8 @@ const Chat = ({ modelId, chatId }: { modelId: string; chatId: string }) => {
history: [] history: []
}); // 聊天框整体数据 }); // 聊天框整体数据
const [inputVal, setInputVal] = useState(''); // 输入的内容 const [inputVal, setInputVal] = useState(''); // user input prompt
const [showSystemPrompt, setShowSystemPrompt] = useState('');
const isChatting = useMemo( const isChatting = useMemo(
() => chatData.history[chatData.history.length - 1]?.status === 'loading', () => chatData.history[chatData.history.length - 1]?.status === 'loading',
@ -199,7 +206,7 @@ const Chat = ({ modelId, chatId }: { modelId: string; chatId: string }) => {
}; };
// 流请求,获取数据 // 流请求,获取数据
const responseText = await streamFetch({ const { responseText, systemPrompt } = await streamFetch({
url: '/api/chat/chat', url: '/api/chat/chat',
data: { data: {
prompt, prompt,
@ -228,7 +235,7 @@ const Chat = ({ modelId, chatId }: { modelId: string; chatId: string }) => {
} }
let newChatId = ''; let newChatId = '';
// 保存对话信息 // save chat record
try { try {
newChatId = await postSaveChat({ newChatId = await postSaveChat({
modelId, modelId,
@ -242,7 +249,8 @@ const Chat = ({ modelId, chatId }: { modelId: string; chatId: string }) => {
{ {
_id: prompts[1]._id, _id: prompts[1]._id,
obj: 'AI', obj: 'AI',
value: responseText value: responseText,
systemPrompt
} }
] ]
}); });
@ -266,7 +274,8 @@ const Chat = ({ modelId, chatId }: { modelId: string; chatId: string }) => {
if (index !== state.history.length - 1) return item; if (index !== state.history.length - 1) return item;
return { return {
...item, ...item,
status: 'finish' status: 'finish',
systemPrompt
}; };
}) })
})); }));
@ -300,13 +309,13 @@ const Chat = ({ modelId, chatId }: { modelId: string; chatId: string }) => {
const newChatList: ChatSiteItemType[] = [ const newChatList: ChatSiteItemType[] = [
...chatData.history, ...chatData.history,
{ {
_id: String(new mongoose.Types.ObjectId()), _id: String(new Types.ObjectId()),
obj: 'Human', obj: 'Human',
value: val, value: val,
status: 'finish' status: 'finish'
}, },
{ {
_id: String(new mongoose.Types.ObjectId()), _id: String(new Types.ObjectId()),
obj: 'AI', obj: 'AI',
value: '', value: '',
status: 'loading' status: 'loading'
@ -492,10 +501,24 @@ const Chat = ({ modelId, chatId }: { modelId: string; chatId: string }) => {
</Menu> </Menu>
<Box flex={'1 0 0'} w={0} overflow={'hidden'}> <Box flex={'1 0 0'} w={0} overflow={'hidden'}>
{item.obj === 'AI' ? ( {item.obj === 'AI' ? (
<Markdown <>
source={item.value} <Markdown
isChatting={isChatting && index === chatData.history.length - 1} source={item.value}
/> isChatting={isChatting && index === chatData.history.length - 1}
/>
{item.systemPrompt && (
<Button
size={'xs'}
mt={2}
fontWeight={'normal'}
colorScheme={'gray'}
variant={'outline'}
onClick={() => setShowSystemPrompt(item.systemPrompt || '')}
>
</Button>
)}
</>
) : ( ) : (
<Box className="markdown" whiteSpace={'pre-wrap'}> <Box className="markdown" whiteSpace={'pre-wrap'}>
<Box as={'p'}>{item.value}</Box> <Box as={'p'}>{item.value}</Box>
@ -617,6 +640,19 @@ const Chat = ({ modelId, chatId }: { modelId: string; chatId: string }) => {
</Box> </Box>
</Box> </Box>
</Flex> </Flex>
{/* system prompt show modal */}
{
<Modal isOpen={!!showSystemPrompt} onClose={() => setShowSystemPrompt('')}>
<ModalOverlay />
<ModalContent maxW={'min(90vw, 600px)'} pr={2} maxH={'80vh'} overflowY={'auto'}>
<ModalCloseButton />
<ModalBody pt={10} fontSize={'sm'} whiteSpace={'pre-wrap'} textAlign={'justify'}>
{showSystemPrompt}
</ModalBody>
</ModalContent>
</Modal>
}
</Flex> </Flex>
); );
}; };

View File

@ -41,6 +41,10 @@ const ChatSchema = new Schema({
value: { value: {
type: String, type: String,
required: true required: true
},
systemPrompt: {
type: String,
default: ''
} }
} }
], ],

View File

@ -75,7 +75,7 @@ export const authModel = async ({
}; };
} }
return { model }; return { model, showModelDetail: model.share.isShareDetail || userId === String(model.userId) };
}; };
// 获取对话校验 // 获取对话校验
@ -91,7 +91,12 @@ export const authChat = async ({
const userId = await authToken(authorization); const userId = await authToken(authorization);
// 获取 model 数据 // 获取 model 数据
const { model } = await authModel({ modelId, userId, authOwner: false, reserveDetail: true }); const { model, showModelDetail } = await authModel({
modelId,
userId,
authOwner: false,
reserveDetail: true
});
// 聊天内容 // 聊天内容
let content: ChatItemSimpleType[] = []; let content: ChatItemSimpleType[] = [];
@ -124,7 +129,8 @@ export const authChat = async ({
systemKey, systemKey,
content, content,
userId, userId,
model model,
showModelDetail
}; };
}; };

View File

@ -7,6 +7,7 @@ import { User } from '../models/user';
import { formatPrice } from '@/utils/user'; import { formatPrice } from '@/utils/user';
import { embeddingModel } from '@/constants/model'; import { embeddingModel } from '@/constants/model';
import { pushGenerateVectorBill } from '../events/pushBill'; import { pushGenerateVectorBill } from '../events/pushBill';
import { SYSTEM_PROMPT_PREFIX } from '@/constants/chat';
/* 获取用户 api 的 openai 信息 */ /* 获取用户 api 的 openai 信息 */
export const getUserApiOpenai = async (userId: string) => { export const getUserApiOpenai = async (userId: string) => {
@ -110,11 +111,13 @@ export const openaiCreateEmbedding = async ({
export const gpt35StreamResponse = ({ export const gpt35StreamResponse = ({
res, res,
stream, stream,
chatResponse chatResponse,
systemPrompt = ''
}: { }: {
res: NextApiResponse; res: NextApiResponse;
stream: PassThrough; stream: PassThrough;
chatResponse: any; chatResponse: any;
systemPrompt?: string;
}) => }) =>
new Promise<{ responseContent: string }>(async (resolve, reject) => { new Promise<{ responseContent: string }>(async (resolve, reject) => {
try { try {
@ -144,8 +147,8 @@ export const gpt35StreamResponse = ({
} }
}; };
const decoder = new TextDecoder();
try { try {
const decoder = new TextDecoder();
const parser = createParser(onParse); const parser = createParser(onParse);
for await (const chunk of chatResponse.data as any) { for await (const chunk of chatResponse.data as any) {
if (stream.destroyed) { if (stream.destroyed) {
@ -157,6 +160,12 @@ export const gpt35StreamResponse = ({
} catch (error) { } catch (error) {
console.log('pipe error', error); console.log('pipe error', error);
} }
// push system prompt
!stream.destroyed &&
systemPrompt &&
stream.push(`${SYSTEM_PROMPT_PREFIX}${systemPrompt.replace(/\n/g, '<br/>')}`);
// close stream // close stream
!stream.destroyed && stream.push(null); !stream.destroyed && stream.push(null);
stream.destroy(); stream.destroy();

1
src/types/chat.d.ts vendored
View File

@ -1,6 +1,7 @@
export type ChatItemSimpleType = { export type ChatItemSimpleType = {
obj: 'Human' | 'AI' | 'SYSTEM'; obj: 'Human' | 'AI' | 'SYSTEM';
value: string; value: string;
systemPrompt?: string;
}; };
export type ChatItemType = { export type ChatItemType = {
_id: string; _id: string;