perf: token split text

This commit is contained in:
archer 2023-04-30 22:35:47 +08:00
parent 39869bc4ea
commit 89a67ca9c0
No known key found for this signature in database
GPG Key ID: 569A5660D2379E28
8 changed files with 96 additions and 85 deletions

View File

@ -1,11 +1,11 @@
import type { NextApiRequest, NextApiResponse } from 'next'; import type { NextApiRequest, NextApiResponse } from 'next';
import { connectToDatabase } from '@/service/mongo'; import { connectToDatabase } from '@/service/mongo';
import { getOpenAIApi, authChat } from '@/service/utils/auth'; import { getOpenAIApi, authChat } from '@/service/utils/auth';
import { axiosConfig, openaiChatFilter, systemPromptFilter } from '@/service/utils/tools'; import { axiosConfig, openaiChatFilter } from '@/service/utils/tools';
import { ChatItemSimpleType } from '@/types/chat'; import { ChatItemSimpleType } from '@/types/chat';
import { jsonRes } from '@/service/response'; import { jsonRes } from '@/service/response';
import { PassThrough } from 'stream'; import { PassThrough } from 'stream';
import { modelList, ModelVectorSearchModeMap, ModelVectorSearchModeEnum } from '@/constants/model'; import { modelList, ModelVectorSearchModeMap } from '@/constants/model';
import { pushChatBill } from '@/service/events/pushBill'; import { pushChatBill } from '@/service/events/pushBill';
import { gpt35StreamResponse } from '@/service/utils/openai'; import { gpt35StreamResponse } from '@/service/utils/openai';
import { searchKb_openai } from '@/service/tools/searchKb'; import { searchKb_openai } from '@/service/tools/searchKb';

View File

@ -118,11 +118,11 @@ const InputDataModal = ({
px={6} px={6}
pb={2} pb={2}
> >
<Box flex={2} mr={[0, 4]} mb={[4, 0]} h={['230px', '100%']}> <Box flex={1} mr={[0, 4]} mb={[4, 0]} h={['230px', '100%']}>
<Box h={'30px'}>{'匹配的知识点'}</Box> <Box h={'30px'}>{'匹配的知识点'}</Box>
<Textarea <Textarea
placeholder={'匹配的知识点。这部分内容会被搜索,请把控内容的质量。最多 1000 字。'} placeholder={'匹配的知识点。这部分内容会被搜索,请把控内容的质量。最多 1500 字。'}
maxLength={2000} maxLength={1500}
resize={'none'} resize={'none'}
h={'calc(100% - 30px)'} h={'calc(100% - 30px)'}
{...register(`q`, { {...register(`q`, {
@ -130,13 +130,13 @@ const InputDataModal = ({
})} })}
/> />
</Box> </Box>
<Box flex={3} h={['330px', '100%']}> <Box flex={1} h={['330px', '100%']}>
<Box h={'30px'}></Box> <Box h={'30px'}></Box>
<Textarea <Textarea
placeholder={ placeholder={
'补充知识。这部分内容不会被搜索,但会作为"匹配的知识点"的内容补充,你可以讲一些细节的内容填写在这里。最多 2000 字。' '补充知识。这部分内容不会被搜索,但会作为"匹配的知识点"的内容补充,你可以讲一些细节的内容填写在这里。最多 1500 字。'
} }
maxLength={2000} maxLength={1500}
resize={'none'} resize={'none'}
h={'calc(100% - 30px)'} h={'calc(100% - 30px)'}
{...register('a')} {...register('a')}

View File

@ -16,8 +16,10 @@ import {
MenuButton, MenuButton,
MenuList, MenuList,
MenuItem, MenuItem,
Input Input,
Tooltip
} from '@chakra-ui/react'; } from '@chakra-ui/react';
import { QuestionOutlineIcon } from '@chakra-ui/icons';
import type { BoxProps } from '@chakra-ui/react'; import type { BoxProps } from '@chakra-ui/react';
import type { ModelDataItemType } from '@/types/model'; import type { ModelDataItemType } from '@/types/model';
import { ModelDataStatusMap } from '@/constants/model'; import { ModelDataStatusMap } from '@/constants/model';
@ -208,7 +210,16 @@ const ModelDataCard = ({ modelId, isOwner }: { modelId: string; isOwner: boolean
<Table variant={'simple'} w={'100%'}> <Table variant={'simple'} w={'100%'}>
<Thead> <Thead>
<Tr> <Tr>
<Th>{'匹配的知识点'}</Th> <Th>
<Tooltip
label={
'对话时,会将用户的问题和知识库的 "匹配知识点" 进行比较,找到最相似的前 n 条记录,将这些记录的 "匹配知识点"+"补充知识点" 作为 chatgpt 的系统提示词。'
}
>
<QuestionOutlineIcon ml={1} />
</Tooltip>
</Th>
<Th></Th> <Th></Th>
<Th></Th> <Th></Th>
{isOwner && <Th></Th>} {isOwner && <Th></Th>}

View File

@ -20,8 +20,7 @@ import { useMutation } from '@tanstack/react-query';
import { postModelDataSplitData } from '@/api/model'; import { postModelDataSplitData } from '@/api/model';
import { formatPrice } from '@/utils/user'; import { formatPrice } from '@/utils/user';
import Radio from '@/components/Radio'; import Radio from '@/components/Radio';
import { splitText } from '@/utils/file'; import { splitText_token } from '@/utils/file';
import { countChatTokens } from '@/utils/tools';
const fileExtension = '.txt,.doc,.docx,.pdf,.md'; const fileExtension = '.txt,.doc,.docx,.pdf,.md';
@ -49,7 +48,7 @@ const SelectFileModal = ({
onSuccess: () => void; onSuccess: () => void;
modelId: string; modelId: string;
}) => { }) => {
const [selecting, setSelecting] = useState(false); const [btnLoading, setBtnLoading] = useState(false);
const { toast } = useToast(); const { toast } = useToast();
const [prompt, setPrompt] = useState(''); const [prompt, setPrompt] = useState('');
const { File, onOpen } = useSelectFile({ fileType: fileExtension, multiple: true }); const { File, onOpen } = useSelectFile({ fileType: fileExtension, multiple: true });
@ -62,17 +61,21 @@ const SelectFileModal = ({
const { openConfirm, ConfirmChild } = useConfirm({ const { openConfirm, ConfirmChild } = useConfirm({
content: `确认导入该文件,需要一定时间进行拆解,该任务无法终止!如果余额不足,未完成的任务会被直接清除。一共 ${ content: `确认导入该文件,需要一定时间进行拆解,该任务无法终止!如果余额不足,未完成的任务会被直接清除。一共 ${
splitRes.chunks.length splitRes.chunks.length
} ${splitRes.tokens || '数量太多,未计算'} tokens, ${formatPrice( } ${
splitRes.tokens * modeMap[mode].price splitRes.tokens
)} ` ? `大约 ${splitRes.tokens} 个tokens, 约 ${formatPrice(
splitRes.tokens * modeMap[mode].price
)} `
: ''
}`
}); });
const onSelectFile = useCallback( const onSelectFile = useCallback(
async (e: File[]) => { async (files: File[]) => {
setSelecting(true); setBtnLoading(true);
try { try {
let promise = Promise.resolve(); let promise = Promise.resolve();
e.map((file) => { files.forEach((file) => {
promise = promise.then(async () => { promise = promise.then(async () => {
const extension = file?.name?.split('.')?.pop()?.toLowerCase(); const extension = file?.name?.split('.')?.pop()?.toLowerCase();
let text = ''; let text = '';
@ -101,7 +104,7 @@ const SelectFileModal = ({
status: 'error' status: 'error'
}); });
} }
setSelecting(false); setBtnLoading(false);
}, },
[toast] [toast]
); );
@ -131,31 +134,27 @@ const SelectFileModal = ({
} }
}); });
const onclickImport = useCallback(() => { const onclickImport = useCallback(async () => {
const chunks = fileTextArr setBtnLoading(true);
let promise = Promise.resolve();
const splitRes = fileTextArr
.filter((item) => item) .filter((item) => item)
.map((item) => .map((item) =>
splitText({ splitText_token({
text: item, text: item,
...modeMap[mode] ...modeMap[mode]
}) })
)
.flat();
let tokens: number[] = [];
// just count 100 sets of tokens
if (chunks.length < 100) {
tokens = chunks.map((item) =>
countChatTokens({ messages: [{ role: 'system', content: item }] })
); );
}
setSplitRes({ setSplitRes({
tokens: tokens.reduce((sum, item) => sum + item, 0), tokens: splitRes.reduce((sum, item) => sum + item.tokens, 0),
chunks chunks: splitRes.map((item) => item.chunks).flat()
}); });
setBtnLoading(false);
await promise;
openConfirm(mutate)(); openConfirm(mutate)();
}, [fileTextArr, mode, mutate, openConfirm]); }, [fileTextArr, mode, mutate, openConfirm]);
@ -239,7 +238,7 @@ const SelectFileModal = ({
</ModalBody> </ModalBody>
<Flex px={6} pt={2} pb={4}> <Flex px={6} pt={2} pb={4}>
<Button isLoading={selecting} onClick={onOpen}> <Button isLoading={btnLoading} onClick={onOpen}>
</Button> </Button>
<Box flex={1}></Box> <Box flex={1}></Box>
@ -247,8 +246,8 @@ const SelectFileModal = ({
</Button> </Button>
<Button <Button
isLoading={isLoading} isLoading={isLoading || btnLoading}
isDisabled={selecting || fileTextArr[0] === ''} isDisabled={isLoading || btnLoading || fileTextArr[0] === ''}
onClick={onclickImport} onClick={onclickImport}
> >

View File

@ -106,7 +106,7 @@ A2:
) )
.then((res) => { .then((res) => {
const rawContent = res?.data.choices[0].message?.content || ''; // chatgpt 原本的回复 const rawContent = res?.data.choices[0].message?.content || ''; // chatgpt 原本的回复
const result = splitText(res?.data.choices[0].message?.content || ''); // 格式化后的QA对 const result = formatSplitText(res?.data.choices[0].message?.content || ''); // 格式化后的QA对
console.log(`split result length: `, result.length); console.log(`split result length: `, result.length);
// 计费 // 计费
pushSplitDataBill({ pushSplitDataBill({
@ -190,7 +190,7 @@ A2:
/** /**
* *
*/ */
function splitText(text: string) { function formatSplitText(text: string) {
const regex = /Q\d+:(\s*)(.*)(\s*)A\d+:(\s*)([\s\S]*?)(?=Q|$)/g; // 匹配Q和A的正则表达式 const regex = /Q\d+:(\s*)(.*)(\s*)A\d+:(\s*)([\s\S]*?)(?=Q|$)/g; // 匹配Q和A的正则表达式
const matches = text.matchAll(regex); // 获取所有匹配到的结果 const matches = text.matchAll(regex); // 获取所有匹配到的结果

View File

@ -1,7 +1,7 @@
import crypto from 'crypto'; import crypto from 'crypto';
import jwt from 'jsonwebtoken'; import jwt from 'jsonwebtoken';
import { ChatItemSimpleType } from '@/types/chat'; import { ChatItemSimpleType } from '@/types/chat';
import { countChatTokens } from '@/utils/tools'; import { countChatTokens, sliceTextByToken } from '@/utils/tools';
import { ChatCompletionRequestMessageRoleEnum, ChatCompletionRequestMessage } from 'openai'; import { ChatCompletionRequestMessageRoleEnum, ChatCompletionRequestMessage } from 'openai';
import { ChatModelEnum } from '@/constants/model'; import { ChatModelEnum } from '@/constants/model';
@ -111,18 +111,11 @@ export const systemPromptFilter = ({
prompts: string[]; prompts: string[];
maxTokens: number; maxTokens: number;
}) => { }) => {
let splitText = ''; const systemPrompt = prompts.join('\n');
// 从前往前截取 return sliceTextByToken({
for (let i = 0; i < prompts.length; i++) { model,
const prompt = simplifyStr(prompts[i]); text: systemPrompt,
length: maxTokens
splitText += `${prompt}\n`; });
const tokens = countChatTokens({ model, messages: [{ role: 'system', content: splitText }] });
if (tokens >= maxTokens) {
break;
}
}
return splitText.slice(0, splitText.length - 1);
}; };

View File

@ -1,6 +1,6 @@
import mammoth from 'mammoth'; import mammoth from 'mammoth';
import Papa from 'papaparse'; import Papa from 'papaparse';
import { countChatTokens } from './tools'; import { getEncMap } from './tools';
/** /**
* txt * txt
@ -145,7 +145,7 @@ export const fileDownload = ({
* slideLen - The size of the before and after Text * slideLen - The size of the before and after Text
* maxLen > slideLen * maxLen > slideLen
*/ */
export const splitText = ({ export const splitText_token = ({
text, text,
maxLen, maxLen,
slideLen slideLen
@ -154,39 +154,32 @@ export const splitText = ({
maxLen: number; maxLen: number;
slideLen: number; slideLen: number;
}) => { }) => {
const textArr = const enc = getEncMap()['gpt-3.5-turbo'];
text.split(/(?<=[。!?\.!\?\n])/g)?.filter((item) => { // filter empty text. encode sentence
const text = item.replace(/(\\n)/g, '\n').trim(); const encodeText = enc.encode(text);
if (text && text !== '\n') return true;
return false;
}) || [];
const chunks: { sum: number; arr: string[] }[] = [{ sum: 0, arr: [] }]; const chunks: string[] = [];
let tokens = 0;
for (let i = 0; i < textArr.length; i++) { let startIndex = 0;
const tokenLen = countChatTokens({ messages: [{ role: 'system', content: textArr[i] }] }); let endIndex = Math.min(startIndex + maxLen, encodeText.length);
chunks[chunks.length - 1].sum += tokenLen; let chunkEncodeArr = encodeText.slice(startIndex, endIndex);
chunks[chunks.length - 1].arr.push(textArr[i]);
// current length is over maxLen. create new chunk const decoder = new TextDecoder();
if (chunks[chunks.length - 1].sum + tokenLen >= maxLen) {
// get slide len text as the initial value
const chunk: { sum: number; arr: string[] } = { sum: 0, arr: [] };
for (let j = chunks[chunks.length - 1].arr.length - 1; j >= 0; j--) {
const chunkText = chunks[chunks.length - 1].arr[j];
const tokenLen = countChatTokens({ messages: [{ role: 'system', content: chunkText }] });
chunk.sum += tokenLen;
chunk.arr.unshift(chunkText);
if (chunk.sum >= slideLen) { while (startIndex < encodeText.length) {
break; tokens += chunkEncodeArr.length;
} chunks.push(decoder.decode(enc.decode(chunkEncodeArr)));
}
chunks.push(chunk); startIndex += maxLen - slideLen;
} endIndex = Math.min(startIndex + maxLen, encodeText.length);
chunkEncodeArr = encodeText.slice(Math.min(encodeText.length - slideLen, startIndex), endIndex);
} }
const result = chunks.map((item) => item.arr.join(''));
return result; return {
chunks,
tokens
};
}; };
export const fileToBase64 = (file: File) => { export const fileToBase64 = (file: File) => {

View File

@ -7,7 +7,7 @@ import { ChatModelEnum } from '@/constants/model';
const textDecoder = new TextDecoder(); const textDecoder = new TextDecoder();
const graphemer = new Graphemer(); const graphemer = new Graphemer();
let encMap: Record<string, Tiktoken>; let encMap: Record<string, Tiktoken>;
const getEncMap = () => { export const getEncMap = () => {
if (encMap) return encMap; if (encMap) return encMap;
encMap = { encMap = {
'gpt-3.5-turbo': encoding_for_model('gpt-3.5-turbo', { 'gpt-3.5-turbo': encoding_for_model('gpt-3.5-turbo', {
@ -136,3 +136,18 @@ export const countChatTokens = ({
const text = getChatGPTEncodingText(messages, model); const text = getChatGPTEncodingText(messages, model);
return text2TokensLen(getEncMap()[model], text); return text2TokensLen(getEncMap()[model], text);
}; };
export const sliceTextByToken = ({
model = 'gpt-3.5-turbo',
text,
length
}: {
model?: `${ChatModelEnum}`;
text: string;
length: number;
}) => {
const enc = getEncMap()[model];
const encodeText = enc.encode(text);
const decoder = new TextDecoder();
return decoder.decode(enc.decode(encodeText.slice(0, length)));
};