diff --git a/client/src/pages/api/openapi/kb/pushData.ts b/client/src/pages/api/openapi/kb/pushData.ts index 50341b2af..cf1e97d3c 100644 --- a/client/src/pages/api/openapi/kb/pushData.ts +++ b/client/src/pages/api/openapi/kb/pushData.ts @@ -8,6 +8,7 @@ import { PgTrainingTableName, TrainingModeEnum } from '@/constants/plugin'; import { startQueue } from '@/service/utils/tools'; import { PgClient } from '@/service/pg'; import { modelToolMap } from '@/utils/plugin'; +import { getVectorModel } from '@/service/utils/data'; export type DateItemType = { a: string; q: string; source?: string }; @@ -22,17 +23,25 @@ export type Response = { insertLen: number; }; -const modeMaxToken = { - [TrainingModeEnum.index]: 6000, - [TrainingModeEnum.qa]: 12000 +const modeMap = { + [TrainingModeEnum.index]: true, + [TrainingModeEnum.qa]: true }; export default withNextCors(async function handler(req: NextApiRequest, res: NextApiResponse) { try { - const { kbId, data, mode, prompt } = req.body as Props; + const { kbId, data, mode = TrainingModeEnum.index, prompt } = req.body as Props; if (!kbId || !Array.isArray(data)) { - throw new Error('缺少参数'); + throw new Error('KbId or data is empty'); + } + + if (modeMap[mode] === undefined) { + throw new Error('Mode is error'); + } + + if (data.length > 500) { + throw new Error('Data is too long, max 500'); } await connectToDatabase(); @@ -64,25 +73,42 @@ export async function pushDataToKb({ mode, prompt }: { userId: string } & Props): Promise { - await authKb({ - userId, - kbId - }); + const [kb, vectorModel] = await Promise.all([ + authKb({ + userId, + kbId + }), + (async () => { + if (mode === TrainingModeEnum.index) { + const vectorModel = (await KB.findById(kbId, 'vectorModel'))?.vectorModel; + + return getVectorModel(vectorModel || global.vectorModels[0].model); + } + return global.vectorModels[0]; + })() + ]); + + const modeMaxToken = { + [TrainingModeEnum.index]: vectorModel.maxToken, + [TrainingModeEnum.qa]: global.qaModel.maxToken * 0.8 + }; // 过滤重复的 qa 内容 const set = new Set(); const filterData: DateItemType[] = []; data.forEach((item) => { + if (!item.q) return; + const text = item.q + item.a; - // count token + // count q token const token = modelToolMap.countTokens({ model: 'gpt-3.5-turbo', messages: [{ obj: 'System', value: item.q }] }); - if (token > modeMaxToken[TrainingModeEnum.qa]) { + if (token > modeMaxToken[mode]) { return; } @@ -138,15 +164,8 @@ export async function pushDataToKb({ .filter((item) => item.status === 'fulfilled') .map((item: any) => item.value); - const vectorModel = await (async () => { - if (mode === TrainingModeEnum.index) { - return (await KB.findById(kbId, 'vectorModel'))?.vectorModel || global.vectorModels[0].model; - } - return global.vectorModels[0].model; - })(); - // 插入记录 - await TrainingData.insertMany( + const insertRes = await TrainingData.insertMany( insertData.map((item) => ({ q: item.q, a: item.a, @@ -155,21 +174,21 @@ export async function pushDataToKb({ kbId, mode, prompt, - vectorModel + vectorModel: vectorModel.model })) ); - insertData.length > 0 && startQueue(); + insertRes.length > 0 && startQueue(); return { - insertLen: insertData.length + insertLen: insertRes.length }; } export const config = { api: { bodyParser: { - sizeLimit: '20mb' + sizeLimit: '12mb' } } }; diff --git a/client/src/pages/api/openapi/text/sensitiveCheck.ts b/client/src/pages/api/openapi/text/sensitiveCheck.ts deleted file mode 100644 index 521cc567d..000000000 --- a/client/src/pages/api/openapi/text/sensitiveCheck.ts +++ /dev/null @@ -1,51 +0,0 @@ -// Next.js API route support: https://nextjs.org/docs/api-routes/introduction -import type { NextApiRequest, NextApiResponse } from 'next'; -import { jsonRes } from '@/service/response'; -import { authUser } from '@/service/utils/auth'; -import axios from 'axios'; -import { axiosConfig } from '@/service/ai/openai'; - -export type Props = { - input: string; -}; - -export default async function handler(req: NextApiRequest, res: NextApiResponse) { - try { - await authUser({ req }); - - const result = await sensitiveCheck(req.body); - - jsonRes(res, { - data: result, - message: result - }); - } catch (err) { - jsonRes(res, { - code: 500, - error: err - }); - } -} - -export async function sensitiveCheck({ input }: Props) { - const response = await axios({ - ...axiosConfig(), - method: 'POST', - url: `/moderations`, - data: { - input - } - }); - - const data = (response.data.results?.[0]?.category_scores as Record) || {}; - - const values = Object.values(data); - - for (const val of values) { - if (val > 0.2) { - return Promise.reject('您的内容不合规'); - } - } - - return ''; -} diff --git a/client/src/pages/kb/detail/components/Import/Chunk.tsx b/client/src/pages/kb/detail/components/Import/Chunk.tsx index b93f6a45f..871541ac6 100644 --- a/client/src/pages/kb/detail/components/Import/Chunk.tsx +++ b/client/src/pages/kb/detail/components/Import/Chunk.tsx @@ -66,7 +66,7 @@ const ChunkImport = ({ kbId }: { kbId: string }) => { // subsection import let success = 0; - const step = 500; + const step = 300; for (let i = 0; i < chunks.length; i += step) { const { insertLen } = await postKbDataFromList({ kbId, diff --git a/client/src/pages/kb/detail/components/Import/Csv.tsx b/client/src/pages/kb/detail/components/Import/Csv.tsx index f3f743f56..68c99214b 100644 --- a/client/src/pages/kb/detail/components/Import/Csv.tsx +++ b/client/src/pages/kb/detail/components/Import/Csv.tsx @@ -54,7 +54,7 @@ const CsvImport = ({ kbId }: { kbId: string }) => { // subsection import let success = 0; - const step = 500; + const step = 300; for (let i = 0; i < filterChunks.length; i += step) { const { insertLen } = await postKbDataFromList({ kbId, diff --git a/client/src/pages/kb/detail/components/Import/QA.tsx b/client/src/pages/kb/detail/components/Import/QA.tsx index eb3cc0b7e..dfb40231d 100644 --- a/client/src/pages/kb/detail/components/Import/QA.tsx +++ b/client/src/pages/kb/detail/components/Import/QA.tsx @@ -53,7 +53,7 @@ const QAImport = ({ kbId }: { kbId: string }) => { // subsection import let success = 0; - const step = 300; + const step = 200; for (let i = 0; i < chunks.length; i += step) { const { insertLen } = await postKbDataFromList({ kbId, diff --git a/client/src/pages/kb/detail/components/Info.tsx b/client/src/pages/kb/detail/components/Info.tsx index 9ad46b18e..9c9947f60 100644 --- a/client/src/pages/kb/detail/components/Info.tsx +++ b/client/src/pages/kb/detail/components/Info.tsx @@ -156,6 +156,12 @@ const Info = ( {getValues('vectorModel').name} + + + MaxTokens + + {getValues('vectorModel').maxToken} + 知识库头像