diff --git a/src/pages/api/openapi/kb/pushData.ts b/src/pages/api/openapi/kb/pushData.ts index fff18ad17..9dd950a7d 100644 --- a/src/pages/api/openapi/kb/pushData.ts +++ b/src/pages/api/openapi/kb/pushData.ts @@ -7,6 +7,7 @@ import { authKb } from '@/service/utils/auth'; import { withNextCors } from '@/service/utils/tools'; import { TrainingModeEnum } from '@/constants/plugin'; import { startQueue } from '@/service/utils/tools'; +import { PgClient } from '@/service/pg'; export type Props = { kbId: string; @@ -60,10 +61,23 @@ export async function pushDataToKb({ return {}; } - // 去重 // 过滤重复的 qa 内容 + const set = new Set(); + const filterData: { + a: string; + q: string; + }[] = []; + + data.forEach((item) => { + const text = item.q + item.a; + if (!set.has(text)) { + filterData.push(item); + set.add(text); + } + }); + // 数据库去重 // const searchRes = await Promise.allSettled( - // dataItems.map(async ({ q, a = '' }) => { + // data.map(async ({ q, a = '' }) => { // if (!q) { // return Promise.reject('q为空'); // } diff --git a/src/service/events/generateQA.ts b/src/service/events/generateQA.ts index d6e8909c2..c239403bb 100644 --- a/src/service/events/generateQA.ts +++ b/src/service/events/generateQA.ts @@ -10,6 +10,10 @@ import { pushDataToKb } from '@/pages/api/openapi/kb/pushData'; import { TrainingModeEnum } from '@/constants/plugin'; import { ERROR_ENUM } from '../errorCode'; +const reduceQueue = () => { + global.qaQueueLen = global.qaQueueLen > 0 ? global.qaQueueLen - 1 : 0; +}; + export async function generateQA(): Promise { const maxProcess = Number(process.env.QA_MAX_PROCESS || 10); @@ -20,11 +24,34 @@ export async function generateQA(): Promise { let userId = ''; try { - // 找出一个需要生成的 dataItem (4分钟锁) + const match = { + mode: TrainingModeEnum.qa, + lockTime: { $lte: new Date(Date.now() - 4 * 60 * 1000) } + }; + // random get task + const agree = await TrainingData.aggregate([ + { + $match: match + }, + { $sample: { size: 1 } }, + { + $project: { + _id: 1 + } + } + ]); + + // no task + if (agree.length === 0) { + reduceQueue(); + global.qaQueueLen <= 0 && console.log(`没有需要【QA】的数据, ${global.qaQueueLen}`); + return; + } + const data = await TrainingData.findOneAndUpdate( { - mode: TrainingModeEnum.qa, - lockTime: { $lte: new Date(Date.now() - 2 * 60 * 1000) } + _id: agree[0]._id, + ...match }, { lockTime: new Date() @@ -37,11 +64,10 @@ export async function generateQA(): Promise { q: 1 }); - /* 无待生成的任务 */ + // task preemption if (!data) { - global.qaQueueLen--; - !global.qaQueueLen && console.log(`没有需要【QA】的数据`); - return; + reduceQueue(); + return generateQA(); } trainingId = data._id; @@ -123,10 +149,10 @@ A2: console.log('生成QA成功,time:', `${(Date.now() - startTime) / 1000}s`); - global.qaQueueLen--; + reduceQueue(); generateQA(); } catch (err: any) { - global.qaQueueLen--; + reduceQueue(); // log if (err?.response) { console.log('openai error: 生成QA错误'); diff --git a/src/service/events/generateVector.ts b/src/service/events/generateVector.ts index 8ce7fb559..9c77581c0 100644 --- a/src/service/events/generateVector.ts +++ b/src/service/events/generateVector.ts @@ -1,10 +1,14 @@ import { openaiError2 } from '../errorCode'; -import { insertKbItem, PgClient } from '@/service/pg'; +import { insertKbItem } from '@/service/pg'; import { openaiEmbedding } from '@/pages/api/openapi/plugin/openaiEmbedding'; import { TrainingData } from '../models/trainingData'; import { ERROR_ENUM } from '../errorCode'; import { TrainingModeEnum } from '@/constants/plugin'; +const reduceQueue = () => { + global.vectorQueueLen = global.vectorQueueLen > 0 ? global.vectorQueueLen - 1 : 0; +}; + /* 索引生成队列。每导入一次,就是一个单独的线程 */ export async function generateVector(): Promise { const maxProcess = Number(process.env.VECTOR_MAX_PROCESS || 10); @@ -16,10 +20,34 @@ export async function generateVector(): Promise { let userId = ''; try { + const match = { + mode: TrainingModeEnum.index, + lockTime: { $lte: new Date(Date.now() - 2 * 60 * 1000) } + }; + // random get task + const agree = await TrainingData.aggregate([ + { + $match: match + }, + { $sample: { size: 1 } }, + { + $project: { + _id: 1 + } + } + ]); + + // no task + if (agree.length === 0) { + reduceQueue(); + global.vectorQueueLen <= 0 && console.log(`没有需要【索引】的数据, ${global.vectorQueueLen}`); + return; + } + const data = await TrainingData.findOneAndUpdate( { - mode: TrainingModeEnum.index, - lockTime: { $lte: new Date(Date.now() - 2 * 60 * 1000) } + _id: agree[0]._id, + ...match }, { lockTime: new Date() @@ -32,11 +60,10 @@ export async function generateVector(): Promise { a: 1 }); - /* 无待生成的任务 */ + // task preemption if (!data) { - global.vectorQueueLen--; - !global.vectorQueueLen && console.log(`没有需要【索引】的数据`); - return; + reduceQueue(); + return generateVector(); } trainingId = data._id; @@ -72,10 +99,10 @@ export async function generateVector(): Promise { await TrainingData.findByIdAndDelete(data._id); console.log(`生成向量成功: ${data._id}`); - global.vectorQueueLen--; + reduceQueue(); generateVector(); } catch (err: any) { - global.vectorQueueLen--; + reduceQueue(); // log if (err?.response) { console.log('openai error: 生成向量错误');