perf: 生成对话框时机

This commit is contained in:
archer 2023-04-23 14:07:17 +08:00
parent 9682c82713
commit c2c73ed23c
No known key found for this signature in database
GPG Key ID: 569A5660D2379E28
25 changed files with 299 additions and 327 deletions

View File

@ -2,16 +2,11 @@ import { GET, POST, DELETE } from './request';
import type { ChatItemType, ChatSiteItemType } from '@/types/chat'; import type { ChatItemType, ChatSiteItemType } from '@/types/chat';
import type { InitChatResponse } from './response/chat'; import type { InitChatResponse } from './response/chat';
/**
* ID
*/
export const getChatSiteId = (modelId: string) => GET<string>(`/chat/generate?modelId=${modelId}`);
/** /**
* *
*/ */
export const getInitChatSiteInfo = (chatId: string) => export const getInitChatSiteInfo = (modelId: string, chatId: '' | string) =>
GET<InitChatResponse>(`/chat/init?chatId=${chatId}`); GET<InitChatResponse>(`/chat/init?modelId=${modelId}&chatId=${chatId}`);
/** /**
* GPT3 prompt * GPT3 prompt
@ -34,8 +29,11 @@ export const postGPT3SendPrompt = ({
/** /**
* *
*/ */
export const postSaveChat = (data: { chatId: string; prompts: ChatItemType[] }) => export const postSaveChat = (data: {
POST('/chat/saveChat', data); modelId: string;
chatId: '' | string;
prompts: ChatItemType[];
}) => POST<string>('/chat/saveChat', data);
/** /**
* *

View File

@ -1,4 +1,4 @@
import type { ServiceName, ModelDataType, ModelSchema } from '@/types/mongoSchema'; import type { ModelDataType, ModelSchema } from '@/types/mongoSchema';
export enum ModelDataStatusEnum { export enum ModelDataStatusEnum {
ready = 'ready', ready = 'ready',
@ -18,7 +18,6 @@ export const ChatModelNameMap = {
}; };
export type ModelConstantsData = { export type ModelConstantsData = {
serviceCompany: `${ServiceName}`;
name: string; name: string;
model: `${ChatModelNameEnum}`; model: `${ChatModelNameEnum}`;
trainName: string; // 空字符串代表不能训练 trainName: string; // 空字符串代表不能训练
@ -30,7 +29,6 @@ export type ModelConstantsData = {
export const modelList: ModelConstantsData[] = [ export const modelList: ModelConstantsData[] = [
{ {
serviceCompany: 'openai',
name: 'chatGPT', name: 'chatGPT',
model: ChatModelNameEnum.GPT35, model: ChatModelNameEnum.GPT35,
trainName: '', trainName: '',
@ -40,7 +38,6 @@ export const modelList: ModelConstantsData[] = [
price: 3 price: 3
}, },
{ {
serviceCompany: 'openai',
name: '知识库', name: '知识库',
model: ChatModelNameEnum.VECTOR_GPT, model: ChatModelNameEnum.VECTOR_GPT,
trainName: 'vector', trainName: 'vector',
@ -132,7 +129,6 @@ export const defaultModel: ModelSchema = {
mode: ModelVectorSearchModeEnum.hightSimilarity mode: ModelVectorSearchModeEnum.hightSimilarity
}, },
service: { service: {
company: 'openai',
trainId: '', trainId: '',
chatModel: ChatModelNameEnum.GPT35, chatModel: ChatModelNameEnum.GPT35,
modelName: ChatModelNameEnum.GPT35 modelName: ChatModelNameEnum.GPT35

View File

@ -1,11 +1,10 @@
import type { NextApiRequest, NextApiResponse } from 'next'; import type { NextApiRequest, NextApiResponse } from 'next';
import { connectToDatabase } from '@/service/mongo'; import { connectToDatabase } from '@/service/mongo';
import { getOpenAIApi, authChat } from '@/service/utils/chat'; import { getOpenAIApi, authChat } from '@/service/utils/auth';
import { httpsAgent, openaiChatFilter } from '@/service/utils/tools'; import { httpsAgent, openaiChatFilter } from '@/service/utils/tools';
import { ChatCompletionRequestMessage, ChatCompletionRequestMessageRoleEnum } from 'openai'; import { ChatCompletionRequestMessage, ChatCompletionRequestMessageRoleEnum } from 'openai';
import { ChatItemType } from '@/types/chat'; import { ChatItemType } from '@/types/chat';
import { jsonRes } from '@/service/response'; import { jsonRes } from '@/service/response';
import type { ModelSchema } from '@/types/mongoSchema';
import { PassThrough } from 'stream'; import { PassThrough } from 'stream';
import { modelList } from '@/constants/model'; import { modelList } from '@/constants/model';
import { pushChatBill } from '@/service/events/pushBill'; import { pushChatBill } from '@/service/events/pushBill';
@ -28,29 +27,33 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
}); });
try { try {
const { chatId, prompt } = req.body as { const { chatId, prompt, modelId } = req.body as {
prompt: ChatItemType; prompt: ChatItemType;
chatId: string; modelId: string;
chatId: '' | string;
}; };
const { authorization } = req.headers; const { authorization } = req.headers;
if (!chatId || !prompt) { if (!modelId || !prompt) {
throw new Error('缺少参数'); throw new Error('缺少参数');
} }
await connectToDatabase(); await connectToDatabase();
let startTime = Date.now(); let startTime = Date.now();
const { chat, userApiKey, systemKey, userId } = await authChat(chatId, authorization); const { model, content, userApiKey, systemKey, userId } = await authChat({
modelId,
chatId,
authorization
});
const model: ModelSchema = chat.modelId;
const modelConstantsData = modelList.find((item) => item.model === model.service.modelName); const modelConstantsData = modelList.find((item) => item.model === model.service.modelName);
if (!modelConstantsData) { if (!modelConstantsData) {
throw new Error('模型加载异常'); throw new Error('模型加载异常');
} }
// 读取对话内容 // 读取对话内容
const prompts = [...chat.content, prompt]; const prompts = [...content, prompt];
// 如果有系统提示词,自动插入 // 如果有系统提示词,自动插入
if (model.systemPrompt) { if (model.systemPrompt) {

View File

@ -1,54 +0,0 @@
import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response';
import { connectToDatabase, Model, Chat } from '@/service/mongo';
import { authToken } from '@/service/utils/tools';
import type { ModelSchema } from '@/types/mongoSchema';
/* 获取我的模型 */
export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
try {
const { modelId } = req.query as {
modelId: string;
};
const { authorization } = req.headers;
if (!authorization) {
throw new Error('无权生成对话');
}
if (!modelId) {
throw new Error('缺少参数');
}
// 凭证校验
const userId = await authToken(authorization);
await connectToDatabase();
// 校验是否为用户的模型
const model = await Model.findOne<ModelSchema>({
_id: modelId,
userId
});
if (!model) {
throw new Error('无权使用该模型');
}
// 创建 chat 数据
const response = await Chat.create({
userId,
modelId,
content: []
});
jsonRes(res, {
data: response._id // 即聊天框的 ID
});
} catch (err) {
jsonRes(res, {
code: 500,
error: err
});
}
}

View File

@ -1,9 +1,10 @@
import type { NextApiRequest, NextApiResponse } from 'next'; import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response'; import { jsonRes } from '@/service/response';
import { connectToDatabase, Chat } from '@/service/mongo'; import { connectToDatabase, Chat, Model } from '@/service/mongo';
import type { ChatPopulate } from '@/types/mongoSchema';
import type { InitChatResponse } from '@/api/response/chat'; import type { InitChatResponse } from '@/api/response/chat';
import { authToken } from '@/service/utils/tools'; import { authToken } from '@/service/utils/tools';
import { ChatItemType } from '@/types/chat';
import { authModel } from '@/service/utils/auth';
/* 初始化我的聊天框,需要身份验证 */ /* 初始化我的聊天框,需要身份验证 */
export default async function handler(req: NextApiRequest, res: NextApiResponse) { export default async function handler(req: NextApiRequest, res: NextApiResponse) {
@ -11,43 +12,46 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
const { authorization } = req.headers; const { authorization } = req.headers;
const userId = await authToken(authorization); const userId = await authToken(authorization);
const { chatId } = req.query as { chatId: string }; const { modelId, chatId } = req.query as { modelId: string; chatId: '' | string };
if (!chatId) { if (!modelId) {
throw new Error('缺少参数'); throw new Error('缺少参数');
} }
await connectToDatabase(); await connectToDatabase();
// 获取 chat 数据 // 获取 model 数据
const chat = await Chat.findOne<ChatPopulate>({ const { model } = await authModel(modelId, userId);
_id: chatId,
userId
}).populate({
path: 'modelId',
options: {
strictPopulate: false
}
});
if (!chat) { // 历史记录
throw new Error('聊天框不存在'); let history: ChatItemType[] = [];
if (chatId) {
// 获取 chat 数据
const chat = await Chat.findOne({
_id: chatId,
userId,
modelId
});
if (!chat) {
throw new Error('聊天框不存在');
}
// filter 被 deleted 的内容
history = chat.content.filter((item) => item.deleted !== true);
} }
// filter 掉被 deleted 的内容
chat.content = chat.content.filter((item) => item.deleted !== true);
const model = chat.modelId;
jsonRes<InitChatResponse>(res, { jsonRes<InitChatResponse>(res, {
data: { data: {
chatId: chat._id, chatId: chatId || '',
modelId: model._id, modelId: modelId,
name: model.name, name: model.name,
avatar: model.avatar, avatar: model.avatar,
intro: model.intro, intro: model.intro,
modelName: model.service.modelName, modelName: model.service.modelName,
chatModel: model.service.chatModel, chatModel: model.service.chatModel,
history: chat.content history
} }
}); });
} catch (err) { } catch (err) {

View File

@ -2,34 +2,53 @@ import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response'; import { jsonRes } from '@/service/response';
import { ChatItemType } from '@/types/chat'; import { ChatItemType } from '@/types/chat';
import { connectToDatabase, Chat } from '@/service/mongo'; import { connectToDatabase, Chat } from '@/service/mongo';
import { authModel } from '@/service/utils/auth';
import { authToken } from '@/service/utils/tools';
/* 聊天内容存存储 */ /* 聊天内容存存储 */
export default async function handler(req: NextApiRequest, res: NextApiResponse) { export default async function handler(req: NextApiRequest, res: NextApiResponse) {
try { try {
const { chatId, prompts } = req.body as { const { chatId, modelId, prompts } = req.body as {
chatId: string; chatId: '' | string;
modelId: string;
prompts: ChatItemType[]; prompts: ChatItemType[];
}; };
if (!chatId || !prompts) { if (!prompts) {
throw new Error('缺少参数'); throw new Error('缺少参数');
} }
const userId = await authToken(req.headers.authorization);
await connectToDatabase(); await connectToDatabase();
// 存入库 const content = prompts.map((item) => ({
await Chat.findByIdAndUpdate(chatId, { obj: item.obj,
$push: { value: item.value
content: { }));
$each: prompts.map((item) => ({
obj: item.obj,
value: item.value
}))
}
},
updateTime: new Date()
});
// 没有 chatId, 创建一个对话
if (!chatId) {
await authModel(modelId, userId);
const { _id } = await Chat.create({
userId,
modelId,
content
});
return jsonRes(res, {
data: _id
});
} else {
// 已经有记录,追加入库
await Chat.findByIdAndUpdate(chatId, {
$push: {
content: {
$each: content
}
},
updateTime: new Date()
});
}
jsonRes(res); jsonRes(res);
} catch (err) { } catch (err) {
jsonRes(res, { jsonRes(res, {

View File

@ -1,6 +1,6 @@
import type { NextApiRequest, NextApiResponse } from 'next'; import type { NextApiRequest, NextApiResponse } from 'next';
import { connectToDatabase } from '@/service/mongo'; import { connectToDatabase } from '@/service/mongo';
import { authChat } from '@/service/utils/chat'; import { authChat } from '@/service/utils/auth';
import { httpsAgent, systemPromptFilter, openaiChatFilter } from '@/service/utils/tools'; import { httpsAgent, systemPromptFilter, openaiChatFilter } from '@/service/utils/tools';
import { ChatCompletionRequestMessage, ChatCompletionRequestMessageRoleEnum } from 'openai'; import { ChatCompletionRequestMessage, ChatCompletionRequestMessageRoleEnum } from 'openai';
import { ChatItemType } from '@/types/chat'; import { ChatItemType } from '@/types/chat';
@ -35,29 +35,33 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
}); });
try { try {
const { chatId, prompt } = req.body as { const { modelId, chatId, prompt } = req.body as {
modelId: string;
chatId: '' | string;
prompt: ChatItemType; prompt: ChatItemType;
chatId: string;
}; };
const { authorization } = req.headers; const { authorization } = req.headers;
if (!chatId || !prompt) { if (!modelId || !prompt) {
throw new Error('缺少参数'); throw new Error('缺少参数');
} }
await connectToDatabase(); await connectToDatabase();
let startTime = Date.now(); let startTime = Date.now();
const { chat, userApiKey, systemKey, userId } = await authChat(chatId, authorization); const { model, content, userApiKey, systemKey, userId } = await authChat({
modelId,
chatId,
authorization
});
const model: ModelSchema = chat.modelId;
const modelConstantsData = modelList.find((item) => item.model === model.service.modelName); const modelConstantsData = modelList.find((item) => item.model === model.service.modelName);
if (!modelConstantsData) { if (!modelConstantsData) {
throw new Error('模型加载异常'); throw new Error('模型加载异常');
} }
// 读取对话内容 // 读取对话内容
const prompts = [...chat.content, prompt]; const prompts = [...content, prompt];
// 获取提示词的向量 // 获取提示词的向量
const { vector: promptVector, chatAPI } = await openaiCreateEmbedding({ const { vector: promptVector, chatAPI } = await openaiCreateEmbedding({

View File

@ -47,7 +47,6 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
userId, userId,
status: ModelStatusEnum.running, status: ModelStatusEnum.running,
service: { service: {
company: modelItem.serviceCompany,
trainId: '', trainId: '',
chatModel: ChatModelNameMap[modelItem.model], // 聊天时用的模型 chatModel: ChatModelNameMap[modelItem.model], // 聊天时用的模型
modelName: modelItem.model // 最底层的模型,不会变,用于计费等核心操作 modelName: modelItem.model // 最底层的模型,不会变,用于计费等核心操作

View File

@ -36,14 +36,14 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
const textList: string[] = []; const textList: string[] = [];
let splitText = ''; let splitText = '';
/* 取 3k ~ 4K tokens 内容 */ /* 取 2.5k ~ 3.5K tokens 内容 */
chunks.forEach((chunk) => { chunks.forEach((chunk) => {
const tokens = encode(splitText + chunk).length; const tokens = encode(splitText + chunk).length;
if (tokens >= 4000) { if (tokens >= 3500) {
// 超过 4000不要这块内容 // 超过 3500不要这块内容
splitText && textList.push(splitText); splitText && textList.push(splitText);
splitText = chunk; splitText = chunk;
} else if (tokens >= 3000) { } else if (tokens >= 2500) {
// 超过 3000取内容 // 超过 3000取内容
splitText && textList.push(splitText + chunk); splitText && textList.push(splitText + chunk);
splitText = ''; splitText = '';

View File

@ -1,6 +1,6 @@
import type { NextApiRequest, NextApiResponse } from 'next'; import type { NextApiRequest, NextApiResponse } from 'next';
import { connectToDatabase, Model } from '@/service/mongo'; import { connectToDatabase, Model } from '@/service/mongo';
import { getOpenAIApi } from '@/service/utils/chat'; import { getOpenAIApi } from '@/service/utils/auth';
import { httpsAgent, openaiChatFilter, authOpenApiKey } from '@/service/utils/tools'; import { httpsAgent, openaiChatFilter, authOpenApiKey } from '@/service/utils/tools';
import { ChatCompletionRequestMessage, ChatCompletionRequestMessageRoleEnum } from 'openai'; import { ChatCompletionRequestMessage, ChatCompletionRequestMessageRoleEnum } from 'openai';
import { ChatItemType } from '@/types/chat'; import { ChatItemType } from '@/types/chat';

View File

@ -1,6 +1,6 @@
import type { NextApiRequest, NextApiResponse } from 'next'; import type { NextApiRequest, NextApiResponse } from 'next';
import { connectToDatabase, Model } from '@/service/mongo'; import { connectToDatabase, Model } from '@/service/mongo';
import { getOpenAIApi } from '@/service/utils/chat'; import { getOpenAIApi } from '@/service/utils/auth';
import { authOpenApiKey } from '@/service/utils/tools'; import { authOpenApiKey } from '@/service/utils/tools';
import { httpsAgent, openaiChatFilter, systemPromptFilter } from '@/service/utils/tools'; import { httpsAgent, openaiChatFilter, systemPromptFilter } from '@/service/utils/tools';
import { ChatCompletionRequestMessage, ChatCompletionRequestMessageRoleEnum } from 'openai'; import { ChatCompletionRequestMessage, ChatCompletionRequestMessageRoleEnum } from 'openai';

View File

@ -11,13 +11,6 @@ import {
Flex, Flex,
Divider, Divider,
IconButton, IconButton,
Modal,
ModalOverlay,
ModalContent,
ModalHeader,
ModalFooter,
ModalBody,
ModalCloseButton,
useDisclosure, useDisclosure,
useColorMode, useColorMode,
useColorModeValue useColorModeValue
@ -29,8 +22,6 @@ import { useRouter } from 'next/router';
import { getToken } from '@/utils/user'; import { getToken } from '@/utils/user';
import MyIcon from '@/components/Icon'; import MyIcon from '@/components/Icon';
import { useCopyData } from '@/utils/tools'; import { useCopyData } from '@/utils/tools';
import Markdown from '@/components/Markdown';
import { getChatSiteId } from '@/api/chat';
import WxConcat from '@/components/WxConcat'; import WxConcat from '@/components/WxConcat';
import { useMarkdown } from '@/hooks/useMarkdown'; import { useMarkdown } from '@/hooks/useMarkdown';
@ -42,7 +33,7 @@ const SlideBar = ({
}: { }: {
chatId: string; chatId: string;
modelId: string; modelId: string;
resetChat: () => void; resetChat: (modelId?: string, chatId?: string) => void;
onClose: () => void; onClose: () => void;
}) => { }) => {
const router = useRouter(); const router = useRouter();
@ -86,7 +77,7 @@ const SlideBar = ({
: {})} : {})}
onClick={() => { onClick={() => {
if (item.chatId === chatId) return; if (item.chatId === chatId) return;
router.replace(`/chat?chatId=${item.chatId}`); resetChat(modelId, item.chatId);
onClose(); onClose();
}} }}
> >
@ -155,7 +146,7 @@ const SlideBar = ({
mb={4} mb={4}
mx={'auto'} mx={'auto'}
leftIcon={<AddIcon />} leftIcon={<AddIcon />}
onClick={resetChat} onClick={() => resetChat()}
> >
</Button> </Button>
@ -194,7 +185,7 @@ const SlideBar = ({
: {})} : {})}
onClick={async () => { onClick={async () => {
if (item._id === modelId) return; if (item._id === modelId) return;
router.replace(`/chat?chatId=${await getChatSiteId(item._id)}`); resetChat(item._id);
onClose(); onClose();
}} }}
> >
@ -260,49 +251,6 @@ const SlideBar = ({
/> />
</Flex> </Flex>
{/* 分享提示modal */}
<Modal isOpen={isOpenShare} onClose={onCloseShare}>
<ModalOverlay />
<ModalContent color={useColorModeValue('blackAlpha.700', 'white')}>
<ModalHeader></ModalHeader>
<ModalCloseButton />
<ModalBody>
<Markdown source={shareHint} />
</ModalBody>
<ModalFooter>
<Button colorScheme="gray" variant={'outline'} mr={3} onClick={onCloseShare}>
</Button>
{getToken() && (
<Button
variant="outline"
mr={3}
onClick={async () => {
copyData(
`${location.origin}/chat?chatId=${await getChatSiteId(modelId)}`,
'已复制分享链接'
);
onCloseShare();
onClose();
}}
>
</Button>
)}
<Button
onClick={() => {
copyData(`${location.origin}/chat?chatId=${chatId}`, '已复制分享链接');
onCloseShare();
onClose();
}}
>
</Button>
</ModalFooter>
</ModalContent>
</Modal>
{/* wx 联系 */} {/* wx 联系 */}
{isOpenWx && <WxConcat onClose={onCloseWx} />} {isOpenWx && <WxConcat onClose={onCloseWx} />}
</Flex> </Flex>

View File

@ -1,7 +1,7 @@
import React, { useCallback, useState, useRef, useMemo, useEffect } from 'react'; import React, { useCallback, useState, useRef, useMemo, useEffect } from 'react';
import { useRouter } from 'next/router'; import { useRouter } from 'next/router';
import Image from 'next/image'; import Image from 'next/image';
import { getInitChatSiteInfo, getChatSiteId, delChatRecordByIndex, postSaveChat } from '@/api/chat'; import { getInitChatSiteInfo, delChatRecordByIndex, postSaveChat } from '@/api/chat';
import type { InitChatResponse } from '@/api/response/chat'; import type { InitChatResponse } from '@/api/response/chat';
import { ChatSiteItemType } from '@/types/chat'; import { ChatSiteItemType } from '@/types/chat';
import { import {
@ -41,18 +41,17 @@ interface ChatType extends InitChatResponse {
history: ChatSiteItemType[]; history: ChatSiteItemType[];
} }
const Chat = ({ chatId }: { chatId: string }) => { const Chat = ({ modelId, chatId }: { modelId: string; chatId: string }) => {
const router = useRouter();
const ChatBox = useRef<HTMLDivElement>(null); const ChatBox = useRef<HTMLDivElement>(null);
const TextareaDom = useRef<HTMLTextAreaElement>(null); const TextareaDom = useRef<HTMLTextAreaElement>(null);
const { toast } = useToast();
const router = useRouter();
// 中断请求 // 中断请求
const controller = useRef(new AbortController()); const controller = useRef(new AbortController());
const [chatData, setChatData] = useState<ChatType>({ const [chatData, setChatData] = useState<ChatType>({
chatId: '', chatId,
modelId: '', modelId,
name: '', name: '',
avatar: '', avatar: '',
intro: '', intro: '',
@ -60,6 +59,7 @@ const Chat = ({ chatId }: { chatId: string }) => {
modelName: '', modelName: '',
history: [] history: []
}); // 聊天框整体数据 }); // 聊天框整体数据
const [inputVal, setInputVal] = useState(''); // 输入的内容 const [inputVal, setInputVal] = useState(''); // 输入的内容
const isChatting = useMemo( const isChatting = useMemo(
@ -68,6 +68,7 @@ const Chat = ({ chatId }: { chatId: string }) => {
); );
const { isOpen: isOpenSlider, onClose: onCloseSlider, onOpen: onOpenSlider } = useDisclosure(); const { isOpen: isOpenSlider, onClose: onCloseSlider, onOpen: onOpenSlider } = useDisclosure();
const { toast } = useToast();
const { copyData } = useCopyData(); const { copyData } = useCopyData();
const { isPc, media } = useScreen(); const { isPc, media } = useScreen();
const { setLoading } = useGlobalStore(); const { setLoading } = useGlobalStore();
@ -108,19 +109,72 @@ const Chat = ({ chatId }: { chatId: string }) => {
}, 100); }, 100);
}, []); }, []);
// 重载对话 // 获取对话信息
const resetChat = useCallback(async () => { const loadChatInfo = useCallback(
if (!chatData) return; async ({
try { modelId,
router.replace(`/chat?chatId=${await getChatSiteId(chatData.modelId)}`); chatId,
} catch (error: any) { isLoading = false,
toast({ isScroll = false
title: error?.message || '生成新对话失败', }: {
status: 'warning' modelId: string;
}); chatId: string;
} isLoading?: boolean;
onCloseSlider(); isScroll?: boolean;
}, [chatData, onCloseSlider, router, toast]); }) => {
isLoading && setLoading(true);
try {
const res = await getInitChatSiteInfo(modelId, chatId);
setChatData({
...res,
history: res.history.map((item) => ({
...item,
status: 'finish'
}))
});
if (isScroll && res.history.length > 0) {
setTimeout(() => {
scrollToBottom('auto');
}, 2000);
}
} catch (e: any) {
toast({
title: e?.message || '获取对话信息异常,请检查地址',
status: 'error',
isClosable: true,
duration: 5000
});
router.replace('/model/list');
}
setLoading(false);
return null;
},
[router, scrollToBottom, setLoading, toast]
);
// 重载新的对话
const resetChat = useCallback(
async (modelId = chatData.modelId, chatId = '') => {
// 强制中断流
controller.current?.abort();
try {
router.replace(`/chat?modelId=${modelId}&chatId=${chatId}`);
loadChatInfo({
modelId,
chatId,
isLoading: true,
isScroll: true
});
} catch (error: any) {
toast({
title: error?.message || '生成新对话失败',
status: 'warning'
});
}
onCloseSlider();
},
[chatData.modelId, loadChatInfo, onCloseSlider, router, toast]
);
// gpt 对话 // gpt 对话
const gptChatPrompt = useCallback( const gptChatPrompt = useCallback(
@ -132,6 +186,10 @@ const Chat = ({ chatId }: { chatId: string }) => {
if (!urlMap[chatData.modelName]) return Promise.reject('找不到模型'); if (!urlMap[chatData.modelName]) return Promise.reject('找不到模型');
// create abort obj
const abortSignal = new AbortController();
controller.current = abortSignal;
const prompt = { const prompt = {
obj: prompts.obj, obj: prompts.obj,
value: prompts.value value: prompts.value
@ -141,7 +199,8 @@ const Chat = ({ chatId }: { chatId: string }) => {
url: urlMap[chatData.modelName], url: urlMap[chatData.modelName],
data: { data: {
prompt, prompt,
chatId chatId,
modelId
}, },
onMessage: (text: string) => { onMessage: (text: string) => {
setChatData((state) => ({ setChatData((state) => ({
@ -156,12 +215,14 @@ const Chat = ({ chatId }: { chatId: string }) => {
})); }));
generatingMessage(); generatingMessage();
}, },
abortSignal: controller.current abortSignal
}); });
let id = '';
// 保存对话信息 // 保存对话信息
try { try {
await postSaveChat({ id = await postSaveChat({
modelId,
chatId, chatId,
prompts: [ prompts: [
prompt, prompt,
@ -171,6 +232,9 @@ const Chat = ({ chatId }: { chatId: string }) => {
} }
] ]
}); });
if (id) {
router.replace(`/chat?modelId=${modelId}&chatId=${id}`);
}
} catch (err) { } catch (err) {
toast({ toast({
title: '对话出现异常, 继续对话会导致上下文丢失,请刷新页面', title: '对话出现异常, 继续对话会导致上下文丢失,请刷新页面',
@ -183,6 +247,7 @@ const Chat = ({ chatId }: { chatId: string }) => {
// 设置完成状态 // 设置完成状态
setChatData((state) => ({ setChatData((state) => ({
...state, ...state,
chatId: id || state.chatId, // 如果有 Id说明是新创建的对话
history: state.history.map((item, index) => { history: state.history.map((item, index) => {
if (index !== state.history.length - 1) return item; if (index !== state.history.length - 1) return item;
return { return {
@ -192,7 +257,7 @@ const Chat = ({ chatId }: { chatId: string }) => {
}) })
})); }));
}, },
[chatData.modelName, chatId, generatingMessage, toast] [chatData.modelName, chatId, generatingMessage, modelId, router, toast]
); );
/** /**
@ -210,7 +275,7 @@ const Chat = ({ chatId }: { chatId: string }) => {
// 去除空行 // 去除空行
const val = inputVal.trim().replace(/\n\s*/g, '\n'); const val = inputVal.trim().replace(/\n\s*/g, '\n');
if (!chatData?.modelId || !val) { if (!val) {
toast({ toast({
title: '内容为空', title: '内容为空',
status: 'warning' status: 'warning'
@ -271,12 +336,12 @@ const Chat = ({ chatId }: { chatId: string }) => {
})); }));
} }
}, [ }, [
inputVal,
chatData,
isChatting, isChatting,
inputVal,
chatData.history,
resetInputVal, resetInputVal,
scrollToBottom,
toast, toast,
scrollToBottom,
gptChatPrompt, gptChatPrompt,
pushChatHistory, pushChatHistory,
chatId chatId
@ -312,50 +377,22 @@ const Chat = ({ chatId }: { chatId: string }) => {
); );
// 初始化聊天框 // 初始化聊天框
useQuery( useQuery(['init'], () =>
['init', chatId], loadChatInfo({
() => { modelId,
setLoading(true); chatId,
return getInitChatSiteInfo(chatId); isLoading: true,
}, isScroll: true
{ })
onSuccess(res) {
setChatData({
...res,
history: res.history.map((item) => ({
...item,
status: 'finish'
}))
});
if (res.history.length > 0) {
setTimeout(() => {
scrollToBottom('auto');
}, 2000);
}
},
onError(e: any) {
toast({
title: e?.message || '初始化异常,请检查地址',
status: 'error',
isClosable: true,
duration: 5000
});
router.push('/model/list');
},
onSettled() {
setLoading(false);
}
}
); );
// 更新流中断对象 // 更新流中断对象
useEffect(() => { useEffect(() => {
controller.current = new AbortController();
return () => { return () => {
// eslint-disable-next-line react-hooks/exhaustive-deps // eslint-disable-next-line react-hooks/exhaustive-deps
controller.current?.abort(); controller.current?.abort();
}; };
}, [chatId]); }, []);
return ( return (
<Flex <Flex
@ -368,7 +405,7 @@ const Chat = ({ chatId }: { chatId: string }) => {
<SlideBar <SlideBar
resetChat={resetChat} resetChat={resetChat}
chatId={chatId} chatId={chatId}
modelId={chatData.modelId} modelId={modelId}
onClose={onCloseSlider} onClose={onCloseSlider}
/> />
</Box> </Box>
@ -399,7 +436,7 @@ const Chat = ({ chatId }: { chatId: string }) => {
<SlideBar <SlideBar
resetChat={resetChat} resetChat={resetChat}
chatId={chatId} chatId={chatId}
modelId={chatData.modelId} modelId={modelId}
onClose={onCloseSlider} onClose={onCloseSlider}
/> />
</DrawerContent> </DrawerContent>
@ -565,9 +602,10 @@ const Chat = ({ chatId }: { chatId: string }) => {
export default Chat; export default Chat;
export async function getServerSideProps(context: any) { export async function getServerSideProps(context: any) {
const chatId = context?.query?.chatId || 'noid'; const modelId = context?.query?.modelId || '';
const chatId = context?.query?.chatId || '';
return { return {
props: { chatId } props: { modelId, chatId }
}; };
} }

View File

@ -1,5 +1,5 @@
import React, { useEffect } from 'react'; import React, { useEffect } from 'react';
import { Card } from '@chakra-ui/react'; import { Card, Box, Link } from '@chakra-ui/react';
import Markdown from '@/components/Markdown'; import Markdown from '@/components/Markdown';
import { useMarkdown } from '@/hooks/useMarkdown'; import { useMarkdown } from '@/hooks/useMarkdown';
import { useRouter } from 'next/router'; import { useRouter } from 'next/router';
@ -15,9 +15,20 @@ const Home = () => {
}, [inviterId]); }, [inviterId]);
return ( return (
<Card p={5} lineHeight={2}> <>
<Markdown source={data} isChatting={false} /> <Card p={5} lineHeight={2}>
</Card> <Markdown source={data} isChatting={false} />
</Card>
<Card p={5} mt={4} textAlign={'center'}>
<Box>
{/* <Link href="https://beian.miit.gov.cn/" target="_blank">
B2-20080101
</Link> */}
</Box>
<Box>Made by FastGpt Team.</Box>
</Card>
</>
); );
}; };

View File

@ -1,7 +1,6 @@
import React, { useCallback, useState, useRef, useMemo, useEffect } from 'react'; import React, { useCallback, useState, useRef, useMemo, useEffect } from 'react';
import { useRouter } from 'next/router'; import { useRouter } from 'next/router';
import { getModelById, delModelById, putModelTrainingStatus, putModelById } from '@/api/model'; import { getModelById, delModelById, putModelTrainingStatus, putModelById } from '@/api/model';
import { getChatSiteId } from '@/api/chat';
import type { ModelSchema } from '@/types/mongoSchema'; import type { ModelSchema } from '@/types/mongoSchema';
import { Card, Box, Flex, Button, Tag, Grid } from '@chakra-ui/react'; import { Card, Box, Flex, Button, Tag, Grid } from '@chakra-ui/react';
import { useToast } from '@/hooks/useToast'; import { useToast } from '@/hooks/useToast';
@ -70,14 +69,12 @@ const ModelDetail = ({ modelId }: { modelId: string }) => {
const handlePreviewChat = useCallback(async () => { const handlePreviewChat = useCallback(async () => {
setLoading(true); setLoading(true);
try { try {
const chatId = await getChatSiteId(model._id); router.push(`/chat?modelId=${modelId}`);
router.push(`/chat?chatId=${chatId}`);
} catch (err) { } catch (err) {
console.log('error->', err); console.log('error->', err);
} }
setLoading(false); setLoading(false);
}, [setLoading, model, router]); }, [setLoading, router, modelId]);
/* 上传数据集,触发微调 */ /* 上传数据集,触发微调 */
// const startTraining = useCallback( // const startTraining = useCallback(

View File

@ -1,6 +1,5 @@
import React, { useState, useCallback } from 'react'; import React, { useState, useCallback } from 'react';
import { Box, Button, Flex, Card } from '@chakra-ui/react'; import { Box, Button, Flex, Card } from '@chakra-ui/react';
import { getChatSiteId } from '@/api/chat';
import type { ModelSchema } from '@/types/mongoSchema'; import type { ModelSchema } from '@/types/mongoSchema';
import { useRouter } from 'next/router'; import { useRouter } from 'next/router';
import ModelTable from './components/ModelTable'; import ModelTable from './components/ModelTable';

View File

@ -1,5 +1,5 @@
import { DataItem } from '@/service/mongo'; import { DataItem } from '@/service/mongo';
import { getOpenAIApi } from '@/service/utils/chat'; import { getOpenAIApi } from '@/service/utils/auth';
import { httpsAgent } from '@/service/utils/tools'; import { httpsAgent } from '@/service/utils/tools';
import { getOpenApiKey } from '../utils/openai'; import { getOpenApiKey } from '../utils/openai';
import type { ChatCompletionRequestMessage } from 'openai'; import type { ChatCompletionRequestMessage } from 'openai';

View File

@ -1,5 +1,5 @@
import { SplitData } from '@/service/mongo'; import { SplitData } from '@/service/mongo';
import { getOpenAIApi } from '@/service/utils/chat'; import { getOpenAIApi } from '@/service/utils/auth';
import { httpsAgent } from '@/service/utils/tools'; import { httpsAgent } from '@/service/utils/tools';
import { getOpenApiKey } from '../utils/openai'; import { getOpenApiKey } from '../utils/openai';
import type { ChatCompletionRequestMessage } from 'openai'; import type { ChatCompletionRequestMessage } from 'openai';

View File

@ -14,7 +14,7 @@ export const pushChatBill = async ({
isPay: boolean; isPay: boolean;
modelName: string; modelName: string;
userId: string; userId: string;
chatId?: string; chatId?: '' | string;
text: string; text: string;
}) => { }) => {
let billId; let billId;
@ -42,7 +42,7 @@ export const pushChatBill = async ({
userId, userId,
type: 'chat', type: 'chat',
modelName, modelName,
chatId, chatId: chatId ? chatId : undefined,
textLen: text.length, textLen: text.length,
tokenLen: tokens, tokenLen: tokens,
price price

View File

@ -53,11 +53,6 @@ const ModelSchema = new Schema({
} }
}, },
service: { service: {
company: {
type: String,
required: true,
enum: ['openai']
},
trainId: { trainId: {
// 训练时需要的 ID 不能训练的模型没有这个值。 // 训练时需要的 ID 不能训练的模型没有这个值。
type: String, type: String,

70
src/service/utils/auth.ts Normal file
View File

@ -0,0 +1,70 @@
import { Configuration, OpenAIApi } from 'openai';
import { Chat, Model } from '../mongo';
import type { ModelSchema } from '@/types/mongoSchema';
import { authToken } from './tools';
import { getOpenApiKey } from './openai';
import type { ChatItemType } from '@/types/chat';
export const getOpenAIApi = (apiKey: string) => {
const configuration = new Configuration({
apiKey
});
return new OpenAIApi(configuration, undefined);
};
// 模型使用权校验
export const authModel = async (modelId: string, userId: string) => {
// 获取 model 数据
const model = await Model.findById<ModelSchema>(modelId);
if (!model) {
return Promise.reject('模型不存在');
}
// 凭证校验
if (userId !== String(model.userId)) {
return Promise.reject('无权使用该模型');
}
return { model };
};
// 获取对话校验
export const authChat = async ({
modelId,
chatId,
authorization
}: {
modelId: string;
chatId: '' | string;
authorization?: string;
}) => {
const userId = await authToken(authorization);
// 获取 model 数据
const { model } = await authModel(modelId, userId);
// 聊天内容
let content: ChatItemType[] = [];
if (chatId) {
// 获取 chat 数据
const chat = await Chat.findById(chatId);
if (!chat) {
return Promise.reject('对话不存在');
}
// filter 掉被 deleted 的内容
content = chat.content.filter((item) => item.deleted !== true);
}
// 获取 user 的 apiKey
const { userApiKey, systemKey } = await getOpenApiKey(userId);
return {
userApiKey,
systemKey,
content,
userId,
model
};
};

View File

@ -1,46 +0,0 @@
import { Configuration, OpenAIApi } from 'openai';
import { Chat } from '../mongo';
import type { ChatPopulate } from '@/types/mongoSchema';
import { authToken } from './tools';
import { getOpenApiKey } from './openai';
export const getOpenAIApi = (apiKey: string) => {
const configuration = new Configuration({
apiKey
});
return new OpenAIApi(configuration, undefined);
};
export const authChat = async (chatId: string, authorization?: string) => {
// 获取 chat 数据
const chat = await Chat.findById<ChatPopulate>(chatId).populate({
path: 'modelId',
options: {
strictPopulate: false
}
});
if (!chat || !chat.modelId || !chat.userId) {
return Promise.reject('模型不存在');
}
// 凭证校验
const userId = await authToken(authorization);
if (userId !== String(chat.userId._id)) {
return Promise.reject('无权使用该对话');
}
// 获取 user 的 apiKey
const { user, userApiKey, systemKey } = await getOpenApiKey(chat.userId as unknown as string);
// filter 掉被 deleted 的内容
chat.content = chat.content.filter((item) => item.deleted !== true);
return {
userApiKey,
systemKey,
chat,
userId: user._id
};
};

View File

@ -1,7 +1,7 @@
import type { NextApiResponse } from 'next'; import type { NextApiResponse } from 'next';
import type { PassThrough } from 'stream'; import type { PassThrough } from 'stream';
import { createParser, ParsedEvent, ReconnectInterval } from 'eventsource-parser'; import { createParser, ParsedEvent, ReconnectInterval } from 'eventsource-parser';
import { getOpenAIApi } from '@/service/utils/chat'; import { getOpenAIApi } from '@/service/utils/auth';
import { httpsAgent } from './tools'; import { httpsAgent } from './tools';
import { User } from '../models/user'; import { User } from '../models/user';
import { formatPrice } from '@/utils/user'; import { formatPrice } from '@/utils/user';

View File

@ -2,7 +2,6 @@ import { create } from 'zustand';
import { devtools, persist } from 'zustand/middleware'; import { devtools, persist } from 'zustand/middleware';
import { immer } from 'zustand/middleware/immer'; import { immer } from 'zustand/middleware/immer';
import type { HistoryItem } from '@/types/chat'; import type { HistoryItem } from '@/types/chat';
import { getChatSiteId } from '@/api/chat';
type Props = { type Props = {
chatHistory: HistoryItem[]; chatHistory: HistoryItem[];
@ -10,7 +9,6 @@ type Props = {
updateChatHistory: (chatId: string, title: string) => void; updateChatHistory: (chatId: string, title: string) => void;
removeChatHistoryByWindowId: (chatId: string) => void; removeChatHistoryByWindowId: (chatId: string) => void;
clearHistory: () => void; clearHistory: () => void;
generateChatWindow: (modelId: string) => Promise<string>;
}; };
export const useChatStore = create<Props>()( export const useChatStore = create<Props>()(
devtools( devtools(
@ -40,9 +38,6 @@ export const useChatStore = create<Props>()(
set((state) => { set((state) => {
state.chatHistory = []; state.chatHistory = [];
}); });
},
generateChatWindow(modelId: string) {
return getChatSiteId(modelId);
} }
})), })),
{ {

View File

@ -7,8 +7,6 @@ import {
} from '@/constants/model'; } from '@/constants/model';
import type { DataType } from './data'; import type { DataType } from './data';
export type ServiceName = 'openai';
export interface UserModelSchema { export interface UserModelSchema {
_id: string; _id: string;
username: string; username: string;
@ -46,7 +44,6 @@ export interface ModelSchema {
mode: `${ModelVectorSearchModeEnum}`; mode: `${ModelVectorSearchModeEnum}`;
}; };
service: { service: {
company: ServiceName;
trainId: string; // 训练的模型训练后就是训练的模型id trainId: string; // 训练的模型训练后就是训练的模型id
chatModel: string; // 聊天时用的模型,训练后就是训练的模型 chatModel: string; // 聊天时用的模型,训练后就是训练的模型
modelName: `${ChatModelNameEnum}`; // 底层模型名称,不会变 modelName: `${ChatModelNameEnum}`; // 底层模型名称,不会变
@ -86,7 +83,6 @@ export interface ModelSplitDataSchema {
export interface TrainingSchema { export interface TrainingSchema {
_id: string; _id: string;
serviceName: ServiceName;
tuneId: string; tuneId: string;
modelId: string; modelId: string;
status: `${TrainingStatusEnum}`; status: `${TrainingStatusEnum}`;