* move file * perf: dataset file manage * v441 description * fix: qa csv update file * feat: rename file * frontend show system-version
183 lines
4.2 KiB
TypeScript
183 lines
4.2 KiB
TypeScript
import type { NextApiRequest, NextApiResponse } from 'next';
|
|
import { jsonRes } from '@/service/response';
|
|
import { connectToDatabase, TrainingData, KB } from '@/service/mongo';
|
|
import { authUser } from '@/service/utils/auth';
|
|
import { authKb } from '@/service/utils/auth';
|
|
import { withNextCors } from '@/service/utils/tools';
|
|
import { PgDatasetTableName, TrainingModeEnum } from '@/constants/plugin';
|
|
import { startQueue } from '@/service/utils/tools';
|
|
import { PgClient } from '@/service/pg';
|
|
import { modelToolMap } from '@/utils/plugin';
|
|
import { getVectorModel } from '@/service/utils/data';
|
|
import { DatasetItemType } from '@/types/plugin';
|
|
|
|
export type Props = {
|
|
kbId: string;
|
|
data: DatasetItemType[];
|
|
mode: `${TrainingModeEnum}`;
|
|
prompt?: string;
|
|
};
|
|
|
|
export type Response = {
|
|
insertLen: number;
|
|
};
|
|
|
|
const modeMap = {
|
|
[TrainingModeEnum.index]: true,
|
|
[TrainingModeEnum.qa]: true
|
|
};
|
|
|
|
export default withNextCors(async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
|
|
try {
|
|
const { kbId, data, mode = TrainingModeEnum.index, prompt } = req.body as Props;
|
|
|
|
if (!kbId || !Array.isArray(data)) {
|
|
throw new Error('KbId or data is empty');
|
|
}
|
|
|
|
if (modeMap[mode] === undefined) {
|
|
throw new Error('Mode is error');
|
|
}
|
|
|
|
if (data.length > 500) {
|
|
throw new Error('Data is too long, max 500');
|
|
}
|
|
|
|
await connectToDatabase();
|
|
|
|
// 凭证校验
|
|
const { userId } = await authUser({ req });
|
|
|
|
jsonRes<Response>(res, {
|
|
data: await pushDataToKb({
|
|
kbId,
|
|
data,
|
|
userId,
|
|
mode,
|
|
prompt
|
|
})
|
|
});
|
|
} catch (err) {
|
|
jsonRes(res, {
|
|
code: 500,
|
|
error: err
|
|
});
|
|
}
|
|
});
|
|
|
|
export async function pushDataToKb({
|
|
userId,
|
|
kbId,
|
|
data,
|
|
mode,
|
|
prompt
|
|
}: { userId: string } & Props): Promise<Response> {
|
|
const [kb, vectorModel] = await Promise.all([
|
|
authKb({
|
|
userId,
|
|
kbId
|
|
}),
|
|
(async () => {
|
|
if (mode === TrainingModeEnum.index) {
|
|
const vectorModel = (await KB.findById(kbId, 'vectorModel'))?.vectorModel;
|
|
|
|
return getVectorModel(vectorModel || global.vectorModels[0].model);
|
|
}
|
|
return global.vectorModels[0];
|
|
})()
|
|
]);
|
|
|
|
const modeMaxToken = {
|
|
[TrainingModeEnum.index]: vectorModel.maxToken * 1.5,
|
|
[TrainingModeEnum.qa]: global.qaModel.maxToken * 0.8
|
|
};
|
|
|
|
// 过滤重复的 qa 内容
|
|
const set = new Set();
|
|
const filterData: DatasetItemType[] = [];
|
|
|
|
data.forEach((item) => {
|
|
if (!item.q) return;
|
|
|
|
const text = item.q + item.a;
|
|
|
|
// count q token
|
|
const token = modelToolMap.countTokens({
|
|
messages: [{ obj: 'System', value: item.q }]
|
|
});
|
|
|
|
if (token > modeMaxToken[mode]) {
|
|
return;
|
|
}
|
|
|
|
if (!set.has(text)) {
|
|
filterData.push(item);
|
|
set.add(text);
|
|
}
|
|
});
|
|
|
|
// 数据库去重
|
|
const insertData = (
|
|
await Promise.allSettled(
|
|
filterData.map(async (data) => {
|
|
let { q, a } = data;
|
|
if (mode !== TrainingModeEnum.index) {
|
|
return Promise.resolve(data);
|
|
}
|
|
|
|
if (!q) {
|
|
return Promise.reject('q为空');
|
|
}
|
|
|
|
q = q.replace(/\\n/g, '\n').trim().replace(/'/g, '"');
|
|
a = a.replace(/\\n/g, '\n').trim().replace(/'/g, '"');
|
|
|
|
// Exactly the same data, not push
|
|
try {
|
|
const { rows } = await PgClient.query(`
|
|
SELECT COUNT(*) > 0 AS exists
|
|
FROM ${PgDatasetTableName}
|
|
WHERE md5(q)=md5('${q}') AND md5(a)=md5('${a}') AND user_id='${userId}' AND kb_id='${kbId}'
|
|
`);
|
|
const exists = rows[0]?.exists || false;
|
|
|
|
if (exists) {
|
|
return Promise.reject('已经存在');
|
|
}
|
|
} catch (error) {
|
|
console.log(error);
|
|
}
|
|
return Promise.resolve(data);
|
|
})
|
|
)
|
|
)
|
|
.filter((item) => item.status === 'fulfilled')
|
|
.map<DatasetItemType>((item: any) => item.value);
|
|
|
|
// 插入记录
|
|
const insertRes = await TrainingData.insertMany(
|
|
insertData.map((item) => ({
|
|
...item,
|
|
userId,
|
|
kbId,
|
|
mode,
|
|
prompt,
|
|
vectorModel: vectorModel.model
|
|
}))
|
|
);
|
|
|
|
insertRes.length > 0 && startQueue();
|
|
|
|
return {
|
|
insertLen: insertRes.length
|
|
};
|
|
}
|
|
|
|
export const config = {
|
|
api: {
|
|
bodyParser: {
|
|
sizeLimit: '12mb'
|
|
}
|
|
}
|
|
};
|