feat: save system prompt
This commit is contained in:
parent
b0d414ac12
commit
90456301d2
@ -1,4 +1,6 @@
|
|||||||
import { getToken } from '../utils/user';
|
import { getToken } from '../utils/user';
|
||||||
|
import { SYSTEM_PROMPT_PREFIX } from '@/constants/chat';
|
||||||
|
|
||||||
interface StreamFetchProps {
|
interface StreamFetchProps {
|
||||||
url: string;
|
url: string;
|
||||||
data: any;
|
data: any;
|
||||||
@ -6,7 +8,7 @@ interface StreamFetchProps {
|
|||||||
abortSignal: AbortController;
|
abortSignal: AbortController;
|
||||||
}
|
}
|
||||||
export const streamFetch = ({ url, data, onMessage, abortSignal }: StreamFetchProps) =>
|
export const streamFetch = ({ url, data, onMessage, abortSignal }: StreamFetchProps) =>
|
||||||
new Promise<string>(async (resolve, reject) => {
|
new Promise<{ responseText: string; systemPrompt: string }>(async (resolve, reject) => {
|
||||||
try {
|
try {
|
||||||
const res = await fetch(url, {
|
const res = await fetch(url, {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
@ -19,15 +21,22 @@ export const streamFetch = ({ url, data, onMessage, abortSignal }: StreamFetchPr
|
|||||||
});
|
});
|
||||||
const reader = res.body?.getReader();
|
const reader = res.body?.getReader();
|
||||||
if (!reader) return;
|
if (!reader) return;
|
||||||
|
|
||||||
|
if (res.status !== 200) {
|
||||||
|
console.log(res);
|
||||||
|
return reject('chat error');
|
||||||
|
}
|
||||||
|
|
||||||
const decoder = new TextDecoder();
|
const decoder = new TextDecoder();
|
||||||
let responseText = '';
|
let responseText = '';
|
||||||
|
let systemPrompt = '';
|
||||||
|
|
||||||
const read = async () => {
|
const read = async () => {
|
||||||
try {
|
try {
|
||||||
const { done, value } = await reader?.read();
|
const { done, value } = await reader?.read();
|
||||||
if (done) {
|
if (done) {
|
||||||
if (res.status === 200) {
|
if (res.status === 200) {
|
||||||
resolve(responseText);
|
resolve({ responseText, systemPrompt });
|
||||||
} else {
|
} else {
|
||||||
const parseError = JSON.parse(responseText);
|
const parseError = JSON.parse(responseText);
|
||||||
reject(parseError?.message || '请求异常');
|
reject(parseError?.message || '请求异常');
|
||||||
@ -36,12 +45,19 @@ export const streamFetch = ({ url, data, onMessage, abortSignal }: StreamFetchPr
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
const text = decoder.decode(value).replace(/<br\/>/g, '\n');
|
const text = decoder.decode(value).replace(/<br\/>/g, '\n');
|
||||||
res.status === 200 && onMessage(text);
|
|
||||||
responseText += text;
|
// check system prompt
|
||||||
|
if (text.startsWith(SYSTEM_PROMPT_PREFIX)) {
|
||||||
|
systemPrompt = text.replace(SYSTEM_PROMPT_PREFIX, '');
|
||||||
|
} else {
|
||||||
|
responseText += text;
|
||||||
|
onMessage(text);
|
||||||
|
}
|
||||||
|
|
||||||
read();
|
read();
|
||||||
} catch (err: any) {
|
} catch (err: any) {
|
||||||
if (err?.message === 'The user aborted a request.') {
|
if (err?.message === 'The user aborted a request.') {
|
||||||
return resolve(responseText);
|
return resolve({ responseText, systemPrompt });
|
||||||
}
|
}
|
||||||
reject(typeof err === 'string' ? err : err?.message || '请求异常');
|
reject(typeof err === 'string' ? err : err?.message || '请求异常');
|
||||||
}
|
}
|
||||||
|
|||||||
1
src/constants/chat.ts
Normal file
1
src/constants/chat.ts
Normal file
@ -0,0 +1 @@
|
|||||||
|
export const SYSTEM_PROMPT_PREFIX = 'SYSTEM_PROMPT:';
|
||||||
@ -41,7 +41,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
|
|||||||
await connectToDatabase();
|
await connectToDatabase();
|
||||||
let startTime = Date.now();
|
let startTime = Date.now();
|
||||||
|
|
||||||
const { model, content, userApiKey, systemKey, userId } = await authChat({
|
const { model, showModelDetail, content, userApiKey, systemKey, userId } = await authChat({
|
||||||
modelId,
|
modelId,
|
||||||
chatId,
|
chatId,
|
||||||
authorization
|
authorization
|
||||||
@ -120,7 +120,9 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
|
|||||||
const { responseContent } = await gpt35StreamResponse({
|
const { responseContent } = await gpt35StreamResponse({
|
||||||
res,
|
res,
|
||||||
stream,
|
stream,
|
||||||
chatResponse
|
chatResponse,
|
||||||
|
systemPrompt:
|
||||||
|
showModelDetail && filterPrompts[0].role === 'system' ? filterPrompts[0].content : ''
|
||||||
});
|
});
|
||||||
|
|
||||||
// 只有使用平台的 key 才计费
|
// 只有使用平台的 key 才计费
|
||||||
|
|||||||
@ -48,7 +48,8 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
|
|||||||
$project: {
|
$project: {
|
||||||
_id: '$content._id',
|
_id: '$content._id',
|
||||||
obj: '$content.obj',
|
obj: '$content.obj',
|
||||||
value: '$content.value'
|
value: '$content.value',
|
||||||
|
systemPrompt: '$content.systemPrompt'
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
]);
|
]);
|
||||||
|
|||||||
@ -26,7 +26,8 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
|
|||||||
const content = prompts.map((item) => ({
|
const content = prompts.map((item) => ({
|
||||||
_id: new mongoose.Types.ObjectId(item._id),
|
_id: new mongoose.Types.ObjectId(item._id),
|
||||||
obj: item.obj,
|
obj: item.obj,
|
||||||
value: item.value
|
value: item.value,
|
||||||
|
systemPrompt: item.systemPrompt
|
||||||
}));
|
}));
|
||||||
|
|
||||||
await authModel({ modelId, userId, authOwner: false });
|
await authModel({ modelId, userId, authOwner: false });
|
||||||
|
|||||||
@ -16,7 +16,13 @@ import {
|
|||||||
MenuButton,
|
MenuButton,
|
||||||
MenuList,
|
MenuList,
|
||||||
MenuItem,
|
MenuItem,
|
||||||
Image
|
Image,
|
||||||
|
Button,
|
||||||
|
Modal,
|
||||||
|
ModalOverlay,
|
||||||
|
ModalContent,
|
||||||
|
ModalBody,
|
||||||
|
ModalCloseButton
|
||||||
} from '@chakra-ui/react';
|
} from '@chakra-ui/react';
|
||||||
import { useToast } from '@/hooks/useToast';
|
import { useToast } from '@/hooks/useToast';
|
||||||
import { useScreen } from '@/hooks/useScreen';
|
import { useScreen } from '@/hooks/useScreen';
|
||||||
@ -29,7 +35,7 @@ import { streamFetch } from '@/api/fetch';
|
|||||||
import Icon from '@/components/Icon';
|
import Icon from '@/components/Icon';
|
||||||
import MyIcon from '@/components/Icon';
|
import MyIcon from '@/components/Icon';
|
||||||
import { throttle } from 'lodash';
|
import { throttle } from 'lodash';
|
||||||
import mongoose from 'mongoose';
|
import { Types } from 'mongoose';
|
||||||
|
|
||||||
const SlideBar = dynamic(() => import('./components/SlideBar'));
|
const SlideBar = dynamic(() => import('./components/SlideBar'));
|
||||||
const Empty = dynamic(() => import('./components/Empty'));
|
const Empty = dynamic(() => import('./components/Empty'));
|
||||||
@ -67,7 +73,8 @@ const Chat = ({ modelId, chatId }: { modelId: string; chatId: string }) => {
|
|||||||
history: []
|
history: []
|
||||||
}); // 聊天框整体数据
|
}); // 聊天框整体数据
|
||||||
|
|
||||||
const [inputVal, setInputVal] = useState(''); // 输入的内容
|
const [inputVal, setInputVal] = useState(''); // user input prompt
|
||||||
|
const [showSystemPrompt, setShowSystemPrompt] = useState('');
|
||||||
|
|
||||||
const isChatting = useMemo(
|
const isChatting = useMemo(
|
||||||
() => chatData.history[chatData.history.length - 1]?.status === 'loading',
|
() => chatData.history[chatData.history.length - 1]?.status === 'loading',
|
||||||
@ -199,7 +206,7 @@ const Chat = ({ modelId, chatId }: { modelId: string; chatId: string }) => {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// 流请求,获取数据
|
// 流请求,获取数据
|
||||||
const responseText = await streamFetch({
|
const { responseText, systemPrompt } = await streamFetch({
|
||||||
url: '/api/chat/chat',
|
url: '/api/chat/chat',
|
||||||
data: {
|
data: {
|
||||||
prompt,
|
prompt,
|
||||||
@ -228,7 +235,7 @@ const Chat = ({ modelId, chatId }: { modelId: string; chatId: string }) => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
let newChatId = '';
|
let newChatId = '';
|
||||||
// 保存对话信息
|
// save chat record
|
||||||
try {
|
try {
|
||||||
newChatId = await postSaveChat({
|
newChatId = await postSaveChat({
|
||||||
modelId,
|
modelId,
|
||||||
@ -242,7 +249,8 @@ const Chat = ({ modelId, chatId }: { modelId: string; chatId: string }) => {
|
|||||||
{
|
{
|
||||||
_id: prompts[1]._id,
|
_id: prompts[1]._id,
|
||||||
obj: 'AI',
|
obj: 'AI',
|
||||||
value: responseText
|
value: responseText,
|
||||||
|
systemPrompt
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
});
|
});
|
||||||
@ -266,7 +274,8 @@ const Chat = ({ modelId, chatId }: { modelId: string; chatId: string }) => {
|
|||||||
if (index !== state.history.length - 1) return item;
|
if (index !== state.history.length - 1) return item;
|
||||||
return {
|
return {
|
||||||
...item,
|
...item,
|
||||||
status: 'finish'
|
status: 'finish',
|
||||||
|
systemPrompt
|
||||||
};
|
};
|
||||||
})
|
})
|
||||||
}));
|
}));
|
||||||
@ -300,13 +309,13 @@ const Chat = ({ modelId, chatId }: { modelId: string; chatId: string }) => {
|
|||||||
const newChatList: ChatSiteItemType[] = [
|
const newChatList: ChatSiteItemType[] = [
|
||||||
...chatData.history,
|
...chatData.history,
|
||||||
{
|
{
|
||||||
_id: String(new mongoose.Types.ObjectId()),
|
_id: String(new Types.ObjectId()),
|
||||||
obj: 'Human',
|
obj: 'Human',
|
||||||
value: val,
|
value: val,
|
||||||
status: 'finish'
|
status: 'finish'
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
_id: String(new mongoose.Types.ObjectId()),
|
_id: String(new Types.ObjectId()),
|
||||||
obj: 'AI',
|
obj: 'AI',
|
||||||
value: '',
|
value: '',
|
||||||
status: 'loading'
|
status: 'loading'
|
||||||
@ -492,10 +501,24 @@ const Chat = ({ modelId, chatId }: { modelId: string; chatId: string }) => {
|
|||||||
</Menu>
|
</Menu>
|
||||||
<Box flex={'1 0 0'} w={0} overflow={'hidden'}>
|
<Box flex={'1 0 0'} w={0} overflow={'hidden'}>
|
||||||
{item.obj === 'AI' ? (
|
{item.obj === 'AI' ? (
|
||||||
<Markdown
|
<>
|
||||||
source={item.value}
|
<Markdown
|
||||||
isChatting={isChatting && index === chatData.history.length - 1}
|
source={item.value}
|
||||||
/>
|
isChatting={isChatting && index === chatData.history.length - 1}
|
||||||
|
/>
|
||||||
|
{item.systemPrompt && (
|
||||||
|
<Button
|
||||||
|
size={'xs'}
|
||||||
|
mt={2}
|
||||||
|
fontWeight={'normal'}
|
||||||
|
colorScheme={'gray'}
|
||||||
|
variant={'outline'}
|
||||||
|
onClick={() => setShowSystemPrompt(item.systemPrompt || '')}
|
||||||
|
>
|
||||||
|
查看提示词
|
||||||
|
</Button>
|
||||||
|
)}
|
||||||
|
</>
|
||||||
) : (
|
) : (
|
||||||
<Box className="markdown" whiteSpace={'pre-wrap'}>
|
<Box className="markdown" whiteSpace={'pre-wrap'}>
|
||||||
<Box as={'p'}>{item.value}</Box>
|
<Box as={'p'}>{item.value}</Box>
|
||||||
@ -617,6 +640,19 @@ const Chat = ({ modelId, chatId }: { modelId: string; chatId: string }) => {
|
|||||||
</Box>
|
</Box>
|
||||||
</Box>
|
</Box>
|
||||||
</Flex>
|
</Flex>
|
||||||
|
|
||||||
|
{/* system prompt show modal */}
|
||||||
|
{
|
||||||
|
<Modal isOpen={!!showSystemPrompt} onClose={() => setShowSystemPrompt('')}>
|
||||||
|
<ModalOverlay />
|
||||||
|
<ModalContent maxW={'min(90vw, 600px)'} pr={2} maxH={'80vh'} overflowY={'auto'}>
|
||||||
|
<ModalCloseButton />
|
||||||
|
<ModalBody pt={10} fontSize={'sm'} whiteSpace={'pre-wrap'} textAlign={'justify'}>
|
||||||
|
{showSystemPrompt}
|
||||||
|
</ModalBody>
|
||||||
|
</ModalContent>
|
||||||
|
</Modal>
|
||||||
|
}
|
||||||
</Flex>
|
</Flex>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|||||||
@ -41,6 +41,10 @@ const ChatSchema = new Schema({
|
|||||||
value: {
|
value: {
|
||||||
type: String,
|
type: String,
|
||||||
required: true
|
required: true
|
||||||
|
},
|
||||||
|
systemPrompt: {
|
||||||
|
type: String,
|
||||||
|
default: ''
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
|||||||
@ -75,7 +75,7 @@ export const authModel = async ({
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
return { model };
|
return { model, showModelDetail: model.share.isShareDetail || userId === String(model.userId) };
|
||||||
};
|
};
|
||||||
|
|
||||||
// 获取对话校验
|
// 获取对话校验
|
||||||
@ -91,7 +91,12 @@ export const authChat = async ({
|
|||||||
const userId = await authToken(authorization);
|
const userId = await authToken(authorization);
|
||||||
|
|
||||||
// 获取 model 数据
|
// 获取 model 数据
|
||||||
const { model } = await authModel({ modelId, userId, authOwner: false, reserveDetail: true });
|
const { model, showModelDetail } = await authModel({
|
||||||
|
modelId,
|
||||||
|
userId,
|
||||||
|
authOwner: false,
|
||||||
|
reserveDetail: true
|
||||||
|
});
|
||||||
|
|
||||||
// 聊天内容
|
// 聊天内容
|
||||||
let content: ChatItemSimpleType[] = [];
|
let content: ChatItemSimpleType[] = [];
|
||||||
@ -124,7 +129,8 @@ export const authChat = async ({
|
|||||||
systemKey,
|
systemKey,
|
||||||
content,
|
content,
|
||||||
userId,
|
userId,
|
||||||
model
|
model,
|
||||||
|
showModelDetail
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@ -7,6 +7,7 @@ import { User } from '../models/user';
|
|||||||
import { formatPrice } from '@/utils/user';
|
import { formatPrice } from '@/utils/user';
|
||||||
import { embeddingModel } from '@/constants/model';
|
import { embeddingModel } from '@/constants/model';
|
||||||
import { pushGenerateVectorBill } from '../events/pushBill';
|
import { pushGenerateVectorBill } from '../events/pushBill';
|
||||||
|
import { SYSTEM_PROMPT_PREFIX } from '@/constants/chat';
|
||||||
|
|
||||||
/* 获取用户 api 的 openai 信息 */
|
/* 获取用户 api 的 openai 信息 */
|
||||||
export const getUserApiOpenai = async (userId: string) => {
|
export const getUserApiOpenai = async (userId: string) => {
|
||||||
@ -110,11 +111,13 @@ export const openaiCreateEmbedding = async ({
|
|||||||
export const gpt35StreamResponse = ({
|
export const gpt35StreamResponse = ({
|
||||||
res,
|
res,
|
||||||
stream,
|
stream,
|
||||||
chatResponse
|
chatResponse,
|
||||||
|
systemPrompt = ''
|
||||||
}: {
|
}: {
|
||||||
res: NextApiResponse;
|
res: NextApiResponse;
|
||||||
stream: PassThrough;
|
stream: PassThrough;
|
||||||
chatResponse: any;
|
chatResponse: any;
|
||||||
|
systemPrompt?: string;
|
||||||
}) =>
|
}) =>
|
||||||
new Promise<{ responseContent: string }>(async (resolve, reject) => {
|
new Promise<{ responseContent: string }>(async (resolve, reject) => {
|
||||||
try {
|
try {
|
||||||
@ -144,8 +147,8 @@ export const gpt35StreamResponse = ({
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
const decoder = new TextDecoder();
|
|
||||||
try {
|
try {
|
||||||
|
const decoder = new TextDecoder();
|
||||||
const parser = createParser(onParse);
|
const parser = createParser(onParse);
|
||||||
for await (const chunk of chatResponse.data as any) {
|
for await (const chunk of chatResponse.data as any) {
|
||||||
if (stream.destroyed) {
|
if (stream.destroyed) {
|
||||||
@ -157,6 +160,12 @@ export const gpt35StreamResponse = ({
|
|||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.log('pipe error', error);
|
console.log('pipe error', error);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// push system prompt
|
||||||
|
!stream.destroyed &&
|
||||||
|
systemPrompt &&
|
||||||
|
stream.push(`${SYSTEM_PROMPT_PREFIX}${systemPrompt.replace(/\n/g, '<br/>')}`);
|
||||||
|
|
||||||
// close stream
|
// close stream
|
||||||
!stream.destroyed && stream.push(null);
|
!stream.destroyed && stream.push(null);
|
||||||
stream.destroy();
|
stream.destroy();
|
||||||
|
|||||||
1
src/types/chat.d.ts
vendored
1
src/types/chat.d.ts
vendored
@ -1,6 +1,7 @@
|
|||||||
export type ChatItemSimpleType = {
|
export type ChatItemSimpleType = {
|
||||||
obj: 'Human' | 'AI' | 'SYSTEM';
|
obj: 'Human' | 'AI' | 'SYSTEM';
|
||||||
value: string;
|
value: string;
|
||||||
|
systemPrompt?: string;
|
||||||
};
|
};
|
||||||
export type ChatItemType = {
|
export type ChatItemType = {
|
||||||
_id: string;
|
_id: string;
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user