From 741381ecb011e589089d7619c6f37d3b040253f3 Mon Sep 17 00:00:00 2001 From: archer <545436317@qq.com> Date: Sat, 27 May 2023 04:38:00 +0800 Subject: [PATCH] perf: generate queue --- src/api/plugins/kb.ts | 9 +- src/constants/plugin.ts | 6 +- src/pages/api/admin/countTraining.ts | 37 ++++ src/pages/api/openapi/kb/pushData.ts | 51 +++--- src/pages/api/openapi/kb/updateData.ts | 10 +- src/pages/api/openapi/text/splitText.ts | 69 ------- .../api/plugins/kb/data/getTrainingData.ts | 25 ++- src/pages/kb/components/InputDataModal.tsx | 4 +- src/pages/kb/components/SelectCsvModal.tsx | 4 +- src/pages/kb/components/SelectFileModal.tsx | 14 +- src/service/events/generateQA.ts | 78 ++++---- src/service/events/generateVector.ts | 170 +++++++++--------- src/service/models/trainingData.ts | 23 ++- src/service/mongo.ts | 24 +-- src/service/pg.ts | 4 +- src/service/response.ts | 2 +- src/service/utils/tools.ts | 16 +- src/types/index.d.ts | 2 + src/types/mongoSchema.d.ts | 5 +- 19 files changed, 288 insertions(+), 265 deletions(-) create mode 100644 src/pages/api/admin/countTraining.ts delete mode 100644 src/pages/api/openapi/text/splitText.ts diff --git a/src/api/plugins/kb.ts b/src/api/plugins/kb.ts index 00681fb88..09d890e23 100644 --- a/src/api/plugins/kb.ts +++ b/src/api/plugins/kb.ts @@ -2,7 +2,7 @@ import { GET, POST, PUT, DELETE } from '../request'; import type { KbItemType } from '@/types/plugin'; import { RequestPaging } from '@/types/index'; import { TrainingTypeEnum } from '@/constants/plugin'; -import { KbDataItemType } from '@/types/plugin'; +import { Props as PushDataProps } from '@/pages/api/openapi/kb/pushData'; export type KbUpdateParams = { id: string; name: string; tags: string; avatar: string }; @@ -46,10 +46,7 @@ export const getKbDataItemById = (dataId: string) => /** * 直接push数据 */ -export const postKbDataFromList = (data: { - kbId: string; - data: { a: KbDataItemType['a']; q: KbDataItemType['q'] }[]; -}) => POST(`/openapi/kb/pushData`, data); +export const postKbDataFromList = (data: PushDataProps) => POST(`/openapi/kb/pushData`, data); /** * 更新一条数据 @@ -70,4 +67,4 @@ export const postSplitData = (data: { chunks: string[]; prompt: string; mode: `${TrainingTypeEnum}`; -}) => POST(`/openapi/text/splitText`, data); +}) => POST(`/openapi/text/pushData`, data); diff --git a/src/constants/plugin.ts b/src/constants/plugin.ts index 0368090b4..a464e448b 100644 --- a/src/constants/plugin.ts +++ b/src/constants/plugin.ts @@ -1,4 +1,8 @@ export enum TrainingTypeEnum { 'qa' = 'qa', - 'subsection' = 'subsection' + 'index' = 'index' } +export const TrainingTypeMap = { + [TrainingTypeEnum.qa]: 'qa', + [TrainingTypeEnum.index]: 'index' +}; diff --git a/src/pages/api/admin/countTraining.ts b/src/pages/api/admin/countTraining.ts new file mode 100644 index 000000000..48de12e3d --- /dev/null +++ b/src/pages/api/admin/countTraining.ts @@ -0,0 +1,37 @@ +// Next.js API route support: https://nextjs.org/docs/api-routes/introduction +import type { NextApiRequest, NextApiResponse } from 'next'; +import { jsonRes } from '@/service/response'; +import { authUser } from '@/service/utils/auth'; +import { connectToDatabase, TrainingData } from '@/service/mongo'; +import { TrainingTypeEnum } from '@/constants/plugin'; + +export default async function handler(req: NextApiRequest, res: NextApiResponse) { + try { + await authUser({ req, authRoot: true }); + + await connectToDatabase(); + + // split queue data + const result = await TrainingData.aggregate([ + { + $group: { + _id: '$mode', + count: { $sum: 1 } + } + } + ]); + + jsonRes(res, { + data: { + qaListLen: result.find((item) => item._id === TrainingTypeEnum.qa)?.count || 0, + vectorListLen: result.find((item) => item._id === TrainingTypeEnum.index)?.count || 0 + } + }); + } catch (error) { + console.log(error); + jsonRes(res, { + code: 500, + error + }); + } +} diff --git a/src/pages/api/openapi/kb/pushData.ts b/src/pages/api/openapi/kb/pushData.ts index 0d34af332..ecf9c560b 100644 --- a/src/pages/api/openapi/kb/pushData.ts +++ b/src/pages/api/openapi/kb/pushData.ts @@ -3,19 +3,21 @@ import type { KbDataItemType } from '@/types/plugin'; import { jsonRes } from '@/service/response'; import { connectToDatabase, TrainingData } from '@/service/mongo'; import { authUser } from '@/service/utils/auth'; -import { generateVector } from '@/service/events/generateVector'; -import { PgClient } from '@/service/pg'; import { authKb } from '@/service/utils/auth'; import { withNextCors } from '@/service/utils/tools'; +import { TrainingTypeEnum } from '@/constants/plugin'; +import { startQueue } from '@/service/utils/tools'; -interface Props { +export type Props = { kbId: string; data: { a: KbDataItemType['a']; q: KbDataItemType['q'] }[]; -} + mode: `${TrainingTypeEnum}`; + prompt?: string; +}; export default withNextCors(async function handler(req: NextApiRequest, res: NextApiResponse) { try { - const { kbId, data } = req.body as Props; + const { kbId, data, mode, prompt } = req.body as Props; if (!kbId || !Array.isArray(data)) { throw new Error('缺少参数'); @@ -29,7 +31,9 @@ export default withNextCors(async function handler(req: NextApiRequest, res: Nex data: await pushDataToKb({ kbId, data, - userId + userId, + mode, + prompt }) }); } catch (err) { @@ -40,36 +44,43 @@ export default withNextCors(async function handler(req: NextApiRequest, res: Nex } }); -export async function pushDataToKb({ userId, kbId, data }: { userId: string } & Props) { +export async function pushDataToKb({ + userId, + kbId, + data, + mode, + prompt +}: { userId: string } & Props) { await authKb({ userId, kbId }); if (data.length === 0) { - return { - trainingId: '' - }; + return {}; } // 插入记录 - const { _id } = await TrainingData.create({ - userId, - kbId, - vectorList: data - }); + await TrainingData.insertMany( + data.map((item) => ({ + q: item.q, + a: item.a, + userId, + kbId, + mode, + prompt + })) + ); - generateVector(_id); + startQueue(); - return { - trainingId: _id - }; + return {}; } export const config = { api: { bodyParser: { - sizeLimit: '100mb' + sizeLimit: '20mb' } } }; diff --git a/src/pages/api/openapi/kb/updateData.ts b/src/pages/api/openapi/kb/updateData.ts index 5a7402846..59154999a 100644 --- a/src/pages/api/openapi/kb/updateData.ts +++ b/src/pages/api/openapi/kb/updateData.ts @@ -33,7 +33,15 @@ export default withNextCors(async function handler(req: NextApiRequest, res: Nex // 更新 pg 内容.仅修改a,不需要更新向量。 await PgClient.update('modelData', { where: [['id', dataId], 'AND', ['user_id', userId]], - values: [{ key: 'a', value: a }, ...(q ? [{ key: 'q', value: `${vector[0]}` }] : [])] + values: [ + { key: 'a', value: a }, + ...(q + ? [ + { key: 'q', value: q }, + { key: 'vector', value: `[${vector[0]}]` } + ] + : []) + ] }); jsonRes(res); diff --git a/src/pages/api/openapi/text/splitText.ts b/src/pages/api/openapi/text/splitText.ts deleted file mode 100644 index 642730d28..000000000 --- a/src/pages/api/openapi/text/splitText.ts +++ /dev/null @@ -1,69 +0,0 @@ -import type { NextApiRequest, NextApiResponse } from 'next'; -import { jsonRes } from '@/service/response'; -import { connectToDatabase, TrainingData } from '@/service/mongo'; -import { authKb, authUser } from '@/service/utils/auth'; -import { generateQA } from '@/service/events/generateQA'; -import { TrainingTypeEnum } from '@/constants/plugin'; -import { withNextCors } from '@/service/utils/tools'; -import { pushDataToKb } from '../kb/pushData'; - -/* split text */ -export default withNextCors(async function handler(req: NextApiRequest, res: NextApiResponse) { - try { - const { chunks, kbId, prompt, mode } = req.body as { - kbId: string; - chunks: string[]; - prompt: string; - mode: `${TrainingTypeEnum}`; - }; - if (!chunks || !kbId || !prompt) { - throw new Error('参数错误'); - } - await connectToDatabase(); - - const { userId } = await authUser({ req }); - - // 验证是否是该用户的 model - await authKb({ - kbId, - userId - }); - - if (mode === TrainingTypeEnum.qa) { - // 批量QA拆分插入数据 - const { _id } = await TrainingData.create({ - userId, - kbId, - qaList: chunks, - prompt - }); - generateQA(_id); - } else if (mode === TrainingTypeEnum.subsection) { - // 分段导入,直接插入向量队列 - const response = await pushDataToKb({ - kbId, - data: chunks.map((item) => ({ q: item, a: '' })), - userId - }); - - return jsonRes(res, { - data: response - }); - } - - jsonRes(res); - } catch (err) { - jsonRes(res, { - code: 500, - error: err - }); - } -}); - -export const config = { - api: { - bodyParser: { - sizeLimit: '100mb' - } - } -}; diff --git a/src/pages/api/plugins/kb/data/getTrainingData.ts b/src/pages/api/plugins/kb/data/getTrainingData.ts index d7369f90b..e6bb8ec24 100644 --- a/src/pages/api/plugins/kb/data/getTrainingData.ts +++ b/src/pages/api/plugins/kb/data/getTrainingData.ts @@ -2,9 +2,10 @@ import type { NextApiRequest, NextApiResponse } from 'next'; import { jsonRes } from '@/service/response'; import { connectToDatabase, TrainingData } from '@/service/mongo'; import { authUser } from '@/service/utils/auth'; -import { Types } from 'mongoose'; import { generateQA } from '@/service/events/generateQA'; import { generateVector } from '@/service/events/generateVector'; +import { TrainingTypeEnum } from '@/constants/plugin'; +import { Types } from 'mongoose'; /* 拆分数据成QA */ export default async function handler(req: NextApiRequest, res: NextApiResponse) { @@ -19,26 +20,24 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) // split queue data const result = await TrainingData.aggregate([ - { $match: { userId: new Types.ObjectId(userId), kbId: new Types.ObjectId(kbId) } }, { - $project: { - qaListLength: { $size: { $ifNull: ['$qaList', []] } }, - vectorListLength: { $size: { $ifNull: ['$vectorList', []] } } + $match: { + userId: new Types.ObjectId(userId), + kbId: new Types.ObjectId(kbId) } }, { $group: { - _id: null, - totalQaListLength: { $sum: '$qaListLength' }, - totalVectorListLength: { $sum: '$vectorListLength' } + _id: '$mode', + count: { $sum: 1 } } } ]); jsonRes(res, { data: { - qaListLen: result[0]?.totalQaListLength || 0, - vectorListLen: result[0]?.totalVectorListLength || 0 + qaListLen: result.find((item) => item._id === TrainingTypeEnum.qa)?.count || 0, + vectorListLen: result.find((item) => item._id === TrainingTypeEnum.index)?.count || 0 } }); @@ -49,10 +48,10 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) kbId }, '_id' - ); + ).limit(10); list.forEach((item) => { - generateQA(item._id); - generateVector(item._id); + generateQA(); + generateVector(); }); } } catch (err) { diff --git a/src/pages/kb/components/InputDataModal.tsx b/src/pages/kb/components/InputDataModal.tsx index 839b97552..b1c442748 100644 --- a/src/pages/kb/components/InputDataModal.tsx +++ b/src/pages/kb/components/InputDataModal.tsx @@ -13,6 +13,7 @@ import { import { useForm } from 'react-hook-form'; import { postKbDataFromList, putKbDataById } from '@/api/plugins/kb'; import { useToast } from '@/hooks/useToast'; +import { TrainingTypeEnum } from '@/constants/plugin'; export type FormData = { dataId?: string; a: string; q: string }; @@ -59,7 +60,8 @@ const InputDataModal = ({ a: e.a, q: e.q } - ] + ], + mode: TrainingTypeEnum.index }); toast({ diff --git a/src/pages/kb/components/SelectCsvModal.tsx b/src/pages/kb/components/SelectCsvModal.tsx index 54f7f8376..3faac308c 100644 --- a/src/pages/kb/components/SelectCsvModal.tsx +++ b/src/pages/kb/components/SelectCsvModal.tsx @@ -19,6 +19,7 @@ import { postKbDataFromList } from '@/api/plugins/kb'; import Markdown from '@/components/Markdown'; import { useMarkdown } from '@/hooks/useMarkdown'; import { fileDownload } from '@/utils/file'; +import { TrainingTypeEnum } from '@/constants/plugin'; const csvTemplate = `question,answer\n"什么是 laf","laf 是一个云函数开发平台……"\n"什么是 sealos","Sealos 是以 kubernetes 为内核的云操作系统发行版,可以……"`; @@ -72,7 +73,8 @@ const SelectJsonModal = ({ const res = await postKbDataFromList({ kbId, - data: fileData + data: fileData, + mode: TrainingTypeEnum.index }); toast({ diff --git a/src/pages/kb/components/SelectFileModal.tsx b/src/pages/kb/components/SelectFileModal.tsx index 08a122124..c26f204c7 100644 --- a/src/pages/kb/components/SelectFileModal.tsx +++ b/src/pages/kb/components/SelectFileModal.tsx @@ -17,7 +17,7 @@ import { useSelectFile } from '@/hooks/useSelectFile'; import { useConfirm } from '@/hooks/useConfirm'; import { readTxtContent, readPdfContent, readDocContent } from '@/utils/file'; import { useMutation } from '@tanstack/react-query'; -import { postSplitData } from '@/api/plugins/kb'; +import { postKbDataFromList } from '@/api/plugins/kb'; import Radio from '@/components/Radio'; import { splitText_token } from '@/utils/file'; import { TrainingTypeEnum } from '@/constants/plugin'; @@ -32,7 +32,7 @@ const modeMap = { price: 4, isPrompt: true }, - subsection: { + index: { maxLen: 800, slideLen: 300, price: 0.4, @@ -53,7 +53,7 @@ const SelectFileModal = ({ const { toast } = useToast(); const [prompt, setPrompt] = useState(''); const { File, onOpen } = useSelectFile({ fileType: fileExtension, multiple: true }); - const [mode, setMode] = useState<`${TrainingTypeEnum}`>(TrainingTypeEnum.subsection); + const [mode, setMode] = useState<`${TrainingTypeEnum}`>(TrainingTypeEnum.index); const [fileTextArr, setFileTextArr] = useState(['']); const [splitRes, setSplitRes] = useState<{ tokens: number; chunks: string[] }>({ tokens: 0, @@ -108,9 +108,9 @@ const SelectFileModal = ({ mutationFn: async () => { if (splitRes.chunks.length === 0) return; - await postSplitData({ + await postKbDataFromList({ kbId, - chunks: splitRes.chunks, + data: splitRes.chunks.map((text) => ({ q: text, a: '' })), prompt: `下面是"${prompt || '一段长文本'}"`, mode }); @@ -195,11 +195,11 @@ const SelectFileModal = ({ setMode(e as 'subsection' | 'qa')} + onChange={(e) => setMode(e as 'index' | 'qa')} /> {/* 内容介绍 */} diff --git a/src/service/events/generateQA.ts b/src/service/events/generateQA.ts index a815f30b8..6227f3d0e 100644 --- a/src/service/events/generateQA.ts +++ b/src/service/events/generateQA.ts @@ -7,49 +7,61 @@ import { modelServiceToolMap } from '../utils/chat'; import { ChatRoleEnum } from '@/constants/chat'; import { BillTypeEnum } from '@/constants/user'; import { pushDataToKb } from '@/pages/api/openapi/kb/pushData'; +import { TrainingTypeEnum } from '@/constants/plugin'; import { ERROR_ENUM } from '../errorCode'; -// 每次最多选 1 组 -const listLen = 1; +export async function generateQA(): Promise { + const maxProcess = Number(process.env.QA_MAX_PROCESS || 10); + + if (global.qaQueueLen >= maxProcess) return; + global.qaQueueLen++; + + let trainingId = ''; + let userId = ''; -export async function generateQA(trainingId: string): Promise { try { // 找出一个需要生成的 dataItem (4分钟锁) const data = await TrainingData.findOneAndUpdate( { - _id: trainingId, - lockTime: { $lte: Date.now() - 4 * 60 * 1000 } + mode: TrainingTypeEnum.qa, + lockTime: { $lte: new Date(Date.now() - 2 * 60 * 1000) } }, { lockTime: new Date() } - ); + ).select({ + _id: 1, + userId: 1, + kbId: 1, + prompt: 1, + q: 1 + }); - if (!data || data.qaList.length === 0) { - await TrainingData.findOneAndDelete({ - _id: trainingId, - qaList: [], - vectorList: [] - }); + /* 无待生成的任务 */ + if (!data) { + global.qaQueueLen--; + !global.qaQueueLen && console.log(`没有需要【QA】的数据`); return; } - const qaList: string[] = data.qaList.slice(-listLen); + trainingId = data._id; + userId = String(data.userId); + const kbId = String(data.kbId); // 余额校验并获取 openapi Key const { userOpenAiKey, systemAuthKey } = await getApiKey({ model: OpenAiChatEnum.GPT35, - userId: data.userId, + userId, type: 'training' }); - console.log(`正在生成一组QA, 包含 ${qaList.length} 组文本。ID: ${data._id}`); + console.log(`正在生成一组QA。ID: ${trainingId}`); const startTime = Date.now(); // 请求 chatgpt 获取回答 const response = await Promise.all( - qaList.map((text) => + [data.q].map((text) => modelServiceToolMap[OpenAiChatEnum.GPT35] .chatCompletion({ apiKey: userOpenAiKey || systemAuthKey, @@ -100,24 +112,19 @@ A2: // 创建 向量生成 队列 pushDataToKb({ - kbId: data.kbId, + kbId, data: responseList, - userId: data.userId + userId, + mode: TrainingTypeEnum.index }); - // 删除 QA 队列。如果小于 n 条,整个数据删掉。 如果大于 n 条,仅删数组后 n 个 - if (data.vectorList.length <= listLen) { - await TrainingData.findByIdAndDelete(data._id); - } else { - await TrainingData.findByIdAndUpdate(data._id, { - qaList: data.qaList.slice(0, -listLen), - lockTime: new Date('2000/1/1') - }); - } + // delete data from training + await TrainingData.findByIdAndDelete(data._id); console.log('生成QA成功,time:', `${(Date.now() - startTime) / 1000}s`); - generateQA(trainingId); + global.qaQueueLen--; + generateQA(); } catch (err: any) { // log if (err?.response) { @@ -130,25 +137,28 @@ A2: // openai 账号异常或者账号余额不足,删除任务 if (openaiError2[err?.response?.data?.error?.type] || err === ERROR_ENUM.insufficientQuota) { console.log('余额不足,删除向量生成任务'); - await TrainingData.findByIdAndDelete(trainingId); - return; + await TrainingData.deleteMany({ + userId + }); + return generateQA(); } // unlock + global.qaQueueLen--; await TrainingData.findByIdAndUpdate(trainingId, { lockTime: new Date('2000/1/1') }); // 频率限制 if (err?.response?.statusText === 'Too Many Requests') { - console.log('生成向量次数限制,30s后尝试'); + console.log('生成向量次数限制,20s后尝试'); return setTimeout(() => { - generateQA(trainingId); - }, 30000); + generateQA(); + }, 20000); } setTimeout(() => { - generateQA(trainingId); + generateQA(); }, 1000); } } diff --git a/src/service/events/generateVector.ts b/src/service/events/generateVector.ts index a8c32a137..1861c8191 100644 --- a/src/service/events/generateVector.ts +++ b/src/service/events/generateVector.ts @@ -3,104 +3,109 @@ import { insertKbItem, PgClient } from '@/service/pg'; import { openaiEmbedding } from '@/pages/api/openapi/plugin/openaiEmbedding'; import { TrainingData } from '../models/trainingData'; import { ERROR_ENUM } from '../errorCode'; - -// 每次最多选 5 组 -const listLen = 5; +import { TrainingTypeEnum } from '@/constants/plugin'; /* 索引生成队列。每导入一次,就是一个单独的线程 */ -export async function generateVector(trainingId: string): Promise { +export async function generateVector(): Promise { + const maxProcess = Number(process.env.VECTOR_MAX_PROCESS || 10); + + if (global.vectorQueueLen >= maxProcess) return; + global.vectorQueueLen++; + + let trainingId = ''; + let userId = ''; + try { - // 找出一个需要生成的 dataItem (2分钟锁) const data = await TrainingData.findOneAndUpdate( { - _id: trainingId, - lockTime: { $lte: Date.now() - 2 * 60 * 1000 } + mode: TrainingTypeEnum.index, + lockTime: { $lte: new Date(Date.now() - 2 * 60 * 1000) } }, { lockTime: new Date() } - ); + ).select({ + _id: 1, + userId: 1, + kbId: 1, + q: 1, + a: 1 + }); + /* 无待生成的任务 */ if (!data) { - await TrainingData.findOneAndDelete({ - _id: trainingId, - qaList: [], - vectorList: [] - }); + global.vectorQueueLen--; + !global.vectorQueueLen && console.log(`没有需要【索引】的数据`); return; } - const userId = String(data.userId); + trainingId = data._id; + userId = String(data.userId); const kbId = String(data.kbId); - const dataItems: { q: string; a: string }[] = data.vectorList.slice(-listLen).map((item) => ({ - q: item.q, - a: item.a - })); + const dataItems = [ + { + q: data.q, + a: data.a + } + ]; // 过滤重复的 qa 内容 - const searchRes = await Promise.allSettled( - dataItems.map(async ({ q, a = '' }) => { - if (!q) { - return Promise.reject('q为空'); - } + // const searchRes = await Promise.allSettled( + // dataItems.map(async ({ q, a = '' }) => { + // if (!q) { + // return Promise.reject('q为空'); + // } - q = q.replace(/\\n/g, '\n'); - a = a.replace(/\\n/g, '\n'); + // q = q.replace(/\\n/g, '\n'); + // a = a.replace(/\\n/g, '\n'); - // Exactly the same data, not push - try { - const count = await PgClient.count('modelData', { - where: [['user_id', userId], 'AND', ['kb_id', kbId], 'AND', ['q', q], 'AND', ['a', a]] - }); - if (count > 0) { - return Promise.reject('已经存在'); - } - } catch (error) { - error; - } - return Promise.resolve({ - q, - a - }); - }) - ); - const filterData = searchRes - .filter((item) => item.status === 'fulfilled') - .map<{ q: string; a: string }>((item: any) => item.value); + // // Exactly the same data, not push + // try { + // const count = await PgClient.count('modelData', { + // where: [['user_id', userId], 'AND', ['kb_id', kbId], 'AND', ['q', q], 'AND', ['a', a]] + // }); - if (filterData.length > 0) { - // 生成词向量 - const vectors = await openaiEmbedding({ - input: filterData.map((item) => item.q), - userId, - type: 'training' - }); + // if (count > 0) { + // return Promise.reject('已经存在'); + // } + // } catch (error) { + // error; + // } + // return Promise.resolve({ + // q, + // a + // }); + // }) + // ); + // const filterData = searchRes + // .filter((item) => item.status === 'fulfilled') + // .map<{ q: string; a: string }>((item: any) => item.value); - // 生成结果插入到 pg - await insertKbItem({ - userId, - kbId, - data: vectors.map((vector, i) => ({ - q: filterData[i].q, - a: filterData[i].a, - vector - })) - }); - } + // 生成词向量 + const vectors = await openaiEmbedding({ + input: dataItems.map((item) => item.q), + userId, + type: 'training' + }); - // 删除 mongo 训练队列. 如果小于 n 条,整个数据删掉。 如果大于 n 条,仅删数组后 n 个 - if (data.vectorList.length <= listLen) { - await TrainingData.findByIdAndDelete(trainingId); - console.log(`全部向量生成完毕: ${trainingId}`); - } else { - await TrainingData.findByIdAndUpdate(trainingId, { - vectorList: data.vectorList.slice(0, -listLen), - lockTime: new Date('2000/1/1') - }); - console.log(`生成向量成功: ${trainingId}`); - generateVector(trainingId); - } + // 生成结果插入到 pg + await insertKbItem({ + userId, + kbId, + data: vectors.map((vector, i) => ({ + q: dataItems[i].q, + a: dataItems[i].a, + vector + })) + }); + + // delete data from training + await TrainingData.findByIdAndDelete(data._id); + console.log(`生成向量成功: ${data._id}`); + + global.vectorQueueLen--; + generateVector(); } catch (err: any) { // log if (err?.response) { @@ -113,25 +118,28 @@ export async function generateVector(trainingId: string): Promise { // openai 账号异常或者账号余额不足,删除任务 if (openaiError2[err?.response?.data?.error?.type] || err === ERROR_ENUM.insufficientQuota) { console.log('余额不足,删除向量生成任务'); - await TrainingData.findByIdAndDelete(trainingId); - return; + await TrainingData.deleteMany({ + userId + }); + return generateVector(); } // unlock + global.vectorQueueLen--; await TrainingData.findByIdAndUpdate(trainingId, { lockTime: new Date('2000/1/1') }); // 频率限制 if (err?.response?.statusText === 'Too Many Requests') { - console.log('生成向量次数限制,30s后尝试'); + console.log('生成向量次数限制,20s后尝试'); return setTimeout(() => { - generateVector(trainingId); - }, 30000); + generateVector(); + }, 20000); } setTimeout(() => { - generateVector(trainingId); + generateVector(); }, 1000); } } diff --git a/src/service/models/trainingData.ts b/src/service/models/trainingData.ts index 30f3ee5ae..a631e8ff5 100644 --- a/src/service/models/trainingData.ts +++ b/src/service/models/trainingData.ts @@ -1,9 +1,9 @@ /* 模型的知识库 */ import { Schema, model, models, Model as MongoModel } from 'mongoose'; import { TrainingDataSchema as TrainingDateType } from '@/types/mongoSchema'; +import { TrainingTypeMap } from '@/constants/plugin'; // pgList and vectorList, Only one of them will work - const TrainingDataSchema = new Schema({ userId: { type: Schema.Types.ObjectId, @@ -19,18 +19,27 @@ const TrainingDataSchema = new Schema({ type: Date, default: () => new Date('2000/1/1') }, - vectorList: { - type: [{ q: String, a: String }], - default: [] + mode: { + type: String, + enum: Object.keys(TrainingTypeMap), + required: true }, prompt: { // 拆分时的提示词 type: String, default: '' }, - qaList: { - type: [String], - default: [] + q: { + // 如果是 + type: String, + default: '' + }, + a: { + type: String, + default: '' + }, + vectorList: { + type: Object } }); diff --git a/src/service/mongo.ts b/src/service/mongo.ts index 9436ba270..cebc69feb 100644 --- a/src/service/mongo.ts +++ b/src/service/mongo.ts @@ -1,8 +1,7 @@ import mongoose from 'mongoose'; -import { generateQA } from './events/generateQA'; -import { generateVector } from './events/generateVector'; import tunnel from 'tunnel'; import { TrainingData } from './mongo'; +import { startQueue } from './utils/tools'; /** * 连接 MongoDB 数据库 @@ -38,7 +37,10 @@ export async function connectToDatabase(): Promise { }); } - startTrain(); + global.qaQueueLen = 0; + global.vectorQueueLen = 0; + + startQueue(); // 5 分钟后解锁不正常的数据,并触发开始训练 setTimeout(async () => { await TrainingData.updateMany( @@ -49,24 +51,10 @@ export async function connectToDatabase(): Promise { lockTime: new Date('2000/1/1') } ); - startTrain(); + startQueue(); }, 5 * 60 * 1000); } -async function startTrain() { - const qa = await TrainingData.find({ - qaList: { $exists: true, $ne: [] } - }); - - qa.map((item) => generateQA(String(item._id))); - - const vector = await TrainingData.find({ - vectorList: { $exists: true, $ne: [] } - }); - - vector.map((item) => generateVector(String(item._id))); -} - export * from './models/authCode'; export * from './models/chat'; export * from './models/model'; diff --git a/src/service/pg.ts b/src/service/pg.ts index 588fc7287..a5740d58b 100644 --- a/src/service/pg.ts +++ b/src/service/pg.ts @@ -14,8 +14,8 @@ export const connectPg = async () => { password: process.env.PG_PASSWORD, database: process.env.PG_DB_NAME, max: 20, - idleTimeoutMillis: 30000, - connectionTimeoutMillis: 2000 + idleTimeoutMillis: 60000, + connectionTimeoutMillis: 20000 }); global.pgClient.on('error', (err) => { diff --git a/src/service/response.ts b/src/service/response.ts index 2dce57220..1230d66f6 100644 --- a/src/service/response.ts +++ b/src/service/response.ts @@ -45,7 +45,7 @@ export const jsonRes = ( } else if (openaiError[error?.response?.statusText]) { msg = openaiError[error.response.statusText]; } - console.log(error); + console.log(error?.message || error); } res.json({ diff --git a/src/service/utils/tools.ts b/src/service/utils/tools.ts index 8615020f5..b33f2047f 100644 --- a/src/service/utils/tools.ts +++ b/src/service/utils/tools.ts @@ -2,6 +2,8 @@ import type { NextApiResponse, NextApiHandler, NextApiRequest } from 'next'; import NextCors from 'nextjs-cors'; import crypto from 'crypto'; import jwt from 'jsonwebtoken'; +import { generateQA } from '../events/generateQA'; +import { generateVector } from '../events/generateVector'; /* 密码加密 */ export const hashPassword = (psw: string) => { @@ -45,7 +47,7 @@ export function withNextCors(handler: NextApiHandler): NextApiHandler { req: NextApiRequest, res: NextApiResponse ) { - const methods = ['GET', 'HEAD', 'PUT', 'PATCH', 'POST', 'DELETE']; + const methods = ['GET', 'eHEAD', 'PUT', 'PATCH', 'POST', 'DELETE']; const origin = req.headers.origin; await NextCors(req, res, { methods, @@ -56,3 +58,15 @@ export function withNextCors(handler: NextApiHandler): NextApiHandler { return handler(req, res); }; } + +export const startQueue = () => { + const qaMax = Number(process.env.QA_MAX_PROCESS || 10); + const vectorMax = Number(process.env.VECTOR_MAX_PROCESS || 10); + + for (let i = 0; i < qaMax; i++) { + generateQA(); + } + for (let i = 0; i < vectorMax; i++) { + generateVector(); + } +}; diff --git a/src/types/index.d.ts b/src/types/index.d.ts index 316862907..05c687782 100644 --- a/src/types/index.d.ts +++ b/src/types/index.d.ts @@ -9,6 +9,8 @@ declare global { var particlesJS: any; var grecaptcha: any; var QRCode: any; + var qaQueueLen: number; + var vectorQueueLen: number; interface Window { ['pdfjs-dist/build/pdf']: any; diff --git a/src/types/mongoSchema.d.ts b/src/types/mongoSchema.d.ts index 6cb947e34..7299d3986 100644 --- a/src/types/mongoSchema.d.ts +++ b/src/types/mongoSchema.d.ts @@ -74,9 +74,10 @@ export interface TrainingDataSchema { userId: string; kbId: string; lockTime: Date; - vectorList: { q: string; a: string }[]; + mode: `${TrainingTypeEnum}`; prompt: string; - qaList: string[]; + q: string; + a: string; } export interface ChatSchema {