merge dev2.0

This commit is contained in:
archer 2023-03-23 23:07:24 +08:00
commit 8b72dca533
No known key found for this signature in database
GPG Key ID: 166CA6BF2383B2BB
4 changed files with 80 additions and 40 deletions

View File

@ -3,8 +3,9 @@ interface StreamFetchProps {
url: string; url: string;
data: any; data: any;
onMessage: (text: string) => void; onMessage: (text: string) => void;
abortSignal: AbortController;
} }
export const streamFetch = ({ url, data, onMessage }: StreamFetchProps) => export const streamFetch = ({ url, data, onMessage, abortSignal }: StreamFetchProps) =>
new Promise(async (resolve, reject) => { new Promise(async (resolve, reject) => {
try { try {
const res = await fetch(url, { const res = await fetch(url, {
@ -13,7 +14,8 @@ export const streamFetch = ({ url, data, onMessage }: StreamFetchProps) =>
'Content-Type': 'application/json', 'Content-Type': 'application/json',
Authorization: getToken() || '' Authorization: getToken() || ''
}, },
body: JSON.stringify(data) body: JSON.stringify(data),
signal: abortSignal.signal
}); });
const reader = res.body?.getReader(); const reader = res.body?.getReader();
if (!reader) return; if (!reader) return;

View File

@ -13,13 +13,26 @@ import { pushBill } from '@/service/events/bill';
/* 发送提示词 */ /* 发送提示词 */
export default async function handler(req: NextApiRequest, res: NextApiResponse) { export default async function handler(req: NextApiRequest, res: NextApiResponse) {
const { chatId, prompt } = req.body as { let step = 0; // step=1时表示开始了流响应
prompt: ChatItemType; const stream = new PassThrough();
chatId: string; stream.on('error', () => {
}; console.log('error: ', 'stream error');
const { authorization } = req.headers; stream.destroy();
});
res.on('close', () => {
stream.destroy();
});
res.on('error', () => {
console.log('error: ', 'request error');
stream.destroy();
});
try { try {
const { chatId, prompt } = req.body as {
prompt: ChatItemType;
chatId: string;
};
const { authorization } = req.headers;
if (!chatId || !prompt) { if (!chatId || !prompt) {
throw new Error('缺少参数'); throw new Error('缺少参数');
} }
@ -92,10 +105,10 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
res.setHeader('Access-Control-Allow-Origin', '*'); res.setHeader('Access-Control-Allow-Origin', '*');
res.setHeader('X-Accel-Buffering', 'no'); res.setHeader('X-Accel-Buffering', 'no');
res.setHeader('Cache-Control', 'no-cache, no-transform'); res.setHeader('Cache-Control', 'no-cache, no-transform');
step = 1;
let responseContent = ''; let responseContent = '';
const pass = new PassThrough(); stream.pipe(res);
pass.pipe(res);
const onParse = async (event: ParsedEvent | ReconnectInterval) => { const onParse = async (event: ParsedEvent | ReconnectInterval) => {
if (event.type !== 'event') return; if (event.type !== 'event') return;
@ -107,7 +120,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
if (!content) return; if (!content) return;
responseContent += content; responseContent += content;
// console.log('content:', content) // console.log('content:', content)
pass.push(content.replace(/\n/g, '<br/>')); stream.push(content.replace(/\n/g, '<br/>'));
} catch (error) { } catch (error) {
error; error;
} }
@ -116,13 +129,17 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
const decoder = new TextDecoder(); const decoder = new TextDecoder();
try { try {
for await (const chunk of chatResponse.data as any) { for await (const chunk of chatResponse.data as any) {
if (stream.destroyed) {
// 流被中断了,直接忽略后面的内容
break;
}
const parser = createParser(onParse); const parser = createParser(onParse);
parser.feed(decoder.decode(chunk)); parser.feed(decoder.decode(chunk));
} }
} catch (error) { } catch (error) {
console.log('pipe error', error); console.log('pipe error', error);
} }
pass.push(null); stream.push(null);
const promptsLen = formatPrompts.reduce((sum, item) => sum + item.content.length, 0); const promptsLen = formatPrompts.reduce((sum, item) => sum + item.content.length, 0);
console.log(`responseLen: ${responseContent.length}`, `promptLen: ${promptsLen}`); console.log(`responseLen: ${responseContent.length}`, `promptLen: ${promptsLen}`);
@ -135,10 +152,16 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
textLen: promptsLen + responseContent.length textLen: promptsLen + responseContent.length
}); });
} catch (err: any) { } catch (err: any) {
res.status(500); if (step === 1) {
jsonRes(res, { console.log('error结束');
code: 500, // 直接结束流
error: err stream.destroy();
}); } else {
res.status(500);
jsonRes(res, {
code: 500,
error: err
});
}
} }
} }

View File

@ -1,4 +1,4 @@
import React, { useCallback, useState, useRef, useMemo } from 'react'; import React, { useCallback, useState, useRef, useMemo, useEffect } from 'react';
import { useRouter } from 'next/router'; import { useRouter } from 'next/router';
import Image from 'next/image'; import Image from 'next/image';
import { import {
@ -88,6 +88,16 @@ const Chat = ({ chatId }: { chatId: string }) => {
}, [chatData]); }, [chatData]);
const { pushChatHistory } = useChatStore(); const { pushChatHistory } = useChatStore();
// 中断请求
const controller = useRef(new AbortController());
useEffect(() => {
controller.current = new AbortController();
return () => {
console.log('close========');
// eslint-disable-next-line react-hooks/exhaustive-deps
controller.current?.abort();
};
}, [chatId]);
// 滚动到底部 // 滚动到底部
const scrollToBottom = useCallback(() => { const scrollToBottom = useCallback(() => {
@ -212,7 +222,8 @@ const Chat = ({ chatId }: { chatId: string }) => {
}; };
}) })
})); }));
} },
abortSignal: controller.current
}); });
// 保存对话信息 // 保存对话信息

View File

@ -12,30 +12,34 @@ export const pushBill = async ({
chatId: string; chatId: string;
textLen: number; textLen: number;
}) => { }) => {
await connectToDatabase();
const modelItem = ModelList.find((item) => item.model === modelName);
if (!modelItem) return;
const price = modelItem.price * textLen;
let billId;
try { try {
// 插入 Bill 记录 await connectToDatabase();
const res = await Bill.create({
userId,
chatId,
textLen,
price
});
billId = res._id;
// 扣费 const modelItem = ModelList.find((item) => item.model === modelName);
await User.findByIdAndUpdate(userId, {
$inc: { balance: -price } if (!modelItem) return;
});
const price = modelItem.price * textLen;
let billId;
try {
// 插入 Bill 记录
const res = await Bill.create({
userId,
chatId,
textLen,
price
});
billId = res._id;
// 扣费
await User.findByIdAndUpdate(userId, {
$inc: { balance: -price }
});
} catch (error) {
billId && Bill.findByIdAndDelete(billId);
}
} catch (error) { } catch (error) {
billId && Bill.findByIdAndDelete(billId); console.log(error);
} }
}; };