perf: 知识库数据结构

This commit is contained in:
archer 2023-04-01 22:31:56 +08:00
parent 5759cbeae0
commit ae4243b522
No known key found for this signature in database
GPG Key ID: 166CA6BF2383B2BB
26 changed files with 611 additions and 518 deletions

View File

@ -107,5 +107,6 @@ echo "Restart clash"
```bash ```bash
# 索引 # 索引
# FT.CREATE idx:model:data ON JSON PREFIX 1 model:data: SCHEMA $.modelId AS modelId TAG $.dataId AS dataId TAG $.vector AS vector VECTOR FLAT 6 DIM 1536 DISTANCE_METRIC COSINE TYPE FLOAT32 # FT.CREATE idx:model:data ON JSON PREFIX 1 model:data: SCHEMA $.modelId AS modelId TAG $.dataId AS dataId TAG $.vector AS vector VECTOR FLAT 6 DIM 1536 DISTANCE_METRIC COSINE TYPE FLOAT32
FT.CREATE idx:model:data:hash ON HASH PREFIX 1 model:data: SCHEMA modelId TAG dataId TAG vector VECTOR FLAT 6 DIM 1536 DISTANCE_METRIC COSINE TYPE FLOAT32 # FT.CREATE idx:model:data:hash ON HASH PREFIX 1 model:data: SCHEMA modelId TAG dataId TAG vector VECTOR FLAT 6 DIM 1536 DISTANCE_METRIC COSINE TYPE FLOAT32
``` FT.CREATE idx:model:data ON HASH PREFIX 1 model:data: SCHEMA modelId TAG userId TAG q TEXT text TEXT vector VECTOR FLAT 6 DIM 1536 DISTANCE_METRIC COSINE TYPE FLOAT32
```

View File

@ -44,11 +44,16 @@ export const getModelSplitDataList = (modelId: string) =>
export const postModelDataInput = (data: { export const postModelDataInput = (data: {
modelId: string; modelId: string;
data: { text: ModelDataSchema['text']; q: ModelDataSchema['q'] }[]; data: { text: ModelDataSchema['text']; q: ModelDataSchema['q'] }[];
}) => POST(`/model/data/pushModelDataInput`, data); }) => POST<number>(`/model/data/pushModelDataInput`, data);
export const postModelDataFileText = (modelId: string, text: string) => export const postModelDataFileText = (modelId: string, text: string) =>
POST(`/model/data/splitData`, { modelId, text }); POST(`/model/data/splitData`, { modelId, text });
export const postModelDataJsonData = (
modelId: string,
jsonData: { prompt: string; completion: string; vector?: number[] }[]
) => POST(`/model/data/pushModelDataJson`, { modelId, data: jsonData });
export const putModelDataById = (data: { dataId: string; text: string }) => export const putModelDataById = (data: { dataId: string; text: string }) =>
PUT('/model/data/putModelData', data); PUT('/model/data/putModelData', data);
export const delOneModelData = (dataId: string) => export const delOneModelData = (dataId: string) =>

View File

@ -1,4 +1,5 @@
import type { ServiceName, ModelDataType, ModelSchema } from '@/types/mongoSchema'; import type { ServiceName, ModelDataType, ModelSchema } from '@/types/mongoSchema';
import type { RedisModelDataItemType } from '@/types/redis';
export enum ChatModelNameEnum { export enum ChatModelNameEnum {
GPT35 = 'gpt-3.5-turbo', GPT35 = 'gpt-3.5-turbo',
@ -93,9 +94,9 @@ export const formatModelStatus = {
} }
}; };
export const ModelDataStatusMap = { export const ModelDataStatusMap: Record<RedisModelDataItemType['status'], string> = {
0: '训练完成', ready: '训练完成',
1: '训练中' waiting: '训练中'
}; };
export const defaultModel: ModelSchema = { export const defaultModel: ModelSchema = {

View File

@ -1 +1,6 @@
export const VecModelDataIndex = 'model:data'; export const VecModelDataPrefix = 'model:data';
export const VecModelDataIdx = `idx:${VecModelDataPrefix}:hash`;
export enum ModelDataStatusEnum {
ready = 'ready',
waiting = 'waiting'
}

View File

@ -1,6 +1,6 @@
import type { NextApiRequest, NextApiResponse } from 'next'; import type { NextApiRequest, NextApiResponse } from 'next';
import { createParser, ParsedEvent, ReconnectInterval } from 'eventsource-parser'; import { createParser, ParsedEvent, ReconnectInterval } from 'eventsource-parser';
import { connectToDatabase, ModelData } from '@/service/mongo'; import { connectToDatabase } from '@/service/mongo';
import { getOpenAIApi, authChat } from '@/service/utils/chat'; import { getOpenAIApi, authChat } from '@/service/utils/chat';
import { httpsAgent, openaiChatFilter, systemPromptFilter } from '@/service/utils/tools'; import { httpsAgent, openaiChatFilter, systemPromptFilter } from '@/service/utils/tools';
import { ChatCompletionRequestMessage, ChatCompletionRequestMessageRoleEnum } from 'openai'; import { ChatCompletionRequestMessage, ChatCompletionRequestMessageRoleEnum } from 'openai';
@ -11,7 +11,7 @@ import { PassThrough } from 'stream';
import { modelList } from '@/constants/model'; import { modelList } from '@/constants/model';
import { pushChatBill } from '@/service/events/pushBill'; import { pushChatBill } from '@/service/events/pushBill';
import { connectRedis } from '@/service/redis'; import { connectRedis } from '@/service/redis';
import { VecModelDataIndex } from '@/constants/redis'; import { VecModelDataPrefix } from '@/constants/redis';
import { vectorToBuffer } from '@/utils/tools'; import { vectorToBuffer } from '@/utils/tools';
/* 发送提示词 */ /* 发送提示词 */
@ -73,17 +73,17 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
) )
.then((res) => res?.data?.data?.[0]?.embedding || []); .then((res) => res?.data?.data?.[0]?.embedding || []);
// 搜索系统提示词, 按相似度从 redis 中搜出前3条不同 dataId 的数据 // 搜索系统提示词, 按相似度从 redis 中搜出相关的 q 和 text
const redisData: any[] = await redis.sendCommand([ const redisData: any[] = await redis.sendCommand([
'FT.SEARCH', 'FT.SEARCH',
`idx:${VecModelDataIndex}:hash`, `idx:${VecModelDataPrefix}:hash`,
`@modelId:{${String( `@modelId:{${String(
chat.modelId._id chat.modelId._id
)}} @vector:[VECTOR_RANGE 0.15 $blob]=>{$YIELD_DISTANCE_AS: score}`, )}} @vector:[VECTOR_RANGE 0.15 $blob]=>{$YIELD_DISTANCE_AS: score}`,
// `@modelId:{${String(chat.modelId._id)}}=>[KNN 10 @vector $blob AS score]`, // `@modelId:{${String(chat.modelId._id)}}=>[KNN 10 @vector $blob AS score]`,
'RETURN', 'RETURN',
'1', '1',
'dataId', 'text',
'SORTBY', 'SORTBY',
'score', 'score',
'PARAMS', 'PARAMS',
@ -97,42 +97,28 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
'2' '2'
]); ]);
// 格式化响应值,获取去重后的id // 格式化响应值,获取 qa
let formatIds = [2, 4, 6, 8, 10, 12, 14, 16, 18, 20] const formatRedisPrompt = [2, 4, 6, 8, 10, 12, 14, 16, 18, 20]
.map((i) => { .map((i) => {
if (!redisData[i] || !redisData[i][1]) return ''; if (!redisData[i]) return '';
return redisData[i][1]; const text = (redisData[i][1] as string) || '';
if (!text) return '';
return text;
}) })
.filter((item) => item); .filter((item) => item);
formatIds = Array.from(new Set(formatIds));
if (formatIds.length === 0) { if (formatRedisPrompt.length === 0) {
throw new Error('对不起,我没有找到你的问题'); throw new Error('对不起,我没有找到你的问题');
} }
// 从 mongo 中取出原文作为提示词
const textArr = (
await Promise.all(
[2, 4, 6, 8, 10, 12, 14, 16, 18, 20].map((i) => {
if (!redisData[i] || !redisData[i][1]) return '';
return ModelData.findById(redisData[i][1])
.select('text q')
.then((res) => {
if (!res) return '';
// const questions = res.q.map((item) => item.text).join(' ');
const answer = res.text;
return `${answer}`;
});
})
)
).filter((item) => item);
// textArr 筛选,最多 3000 tokens // textArr 筛选,最多 3000 tokens
const systemPrompt = systemPromptFilter(textArr, 3400); const systemPrompt = systemPromptFilter(formatRedisPrompt, 3400);
prompts.unshift({ prompts.unshift({
obj: 'SYSTEM', obj: 'SYSTEM',
value: `${model.systemPrompt} 我的知识库: "${systemPrompt}"` value: `${model.systemPrompt} 我的知识库: "${systemPrompt}"`
}); });
// 控制在 tokens 数量,防止超出 // 控制在 tokens 数量,防止超出

View File

@ -1,9 +1,7 @@
import type { NextApiRequest, NextApiResponse } from 'next'; import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response'; import { jsonRes } from '@/service/response';
import { connectToDatabase, ModelData } from '@/service/mongo';
import { authToken } from '@/service/utils/tools'; import { authToken } from '@/service/utils/tools';
import { connectRedis } from '@/service/redis'; import { connectRedis } from '@/service/redis';
import { VecModelDataIndex } from '@/constants/redis';
export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) { export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
try { try {
@ -23,25 +21,15 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
// 凭证校验 // 凭证校验
const userId = await authToken(authorization); const userId = await authToken(authorization);
await connectToDatabase();
const redis = await connectRedis(); const redis = await connectRedis();
const data = await ModelData.findById(dataId); // 校验是否为该用户的数据
const dataItemUserId = await redis.hGet(dataId, 'userId');
await ModelData.deleteOne({ if (dataItemUserId !== userId) {
_id: dataId, throw new Error('无权操作');
userId }
}); // 删除
await redis.del(dataId);
// 删除 redis 数据
data?.q.forEach(async (item) => {
try {
await redis.json.del(`${VecModelDataIndex}:${item.id}`);
} catch (error) {
console.log(error);
}
});
jsonRes(res); jsonRes(res);
} catch (err) { } catch (err) {
console.log(err); console.log(err);

View File

@ -1,7 +1,10 @@
import type { NextApiRequest, NextApiResponse } from 'next'; import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response'; import { jsonRes } from '@/service/response';
import { connectToDatabase, ModelData } from '@/service/mongo'; import { connectToDatabase } from '@/service/mongo';
import { authToken } from '@/service/utils/tools'; import { authToken } from '@/service/utils/tools';
import { connectRedis } from '@/service/redis';
import { VecModelDataIdx } from '@/constants/redis';
import { SearchOptions } from 'redis';
export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) { export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
try { try {
@ -32,24 +35,34 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
const userId = await authToken(authorization); const userId = await authToken(authorization);
await connectToDatabase(); await connectToDatabase();
const redis = await connectRedis();
const data = await ModelData.find({ // 从 redis 中获取数据
modelId, const searchRes = await redis.ft.search(
userId VecModelDataIdx,
}) `@modelId:{${modelId}} @userId:{${userId}}`,
.sort({ _id: -1 }) // 按照创建时间倒序排列 {
.skip((pageNum - 1) * pageSize) RETURN: ['q', 'text', 'status'],
.limit(pageSize); LIMIT: {
from: (pageNum - 1) * pageSize,
size: pageSize
},
SORTBY: {
BY: 'modelId',
DIRECTION: 'DESC'
}
}
);
jsonRes(res, { jsonRes(res, {
data: { data: {
pageNum, pageNum,
pageSize, pageSize,
data, data: searchRes.documents.map((item) => ({
total: await ModelData.countDocuments({ id: item.id,
modelId, ...item.value
userId })),
}) total: searchRes.total
} }
}); });
} catch (err) { } catch (err) {

View File

@ -1,9 +1,11 @@
import type { NextApiRequest, NextApiResponse } from 'next'; import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response'; import { jsonRes } from '@/service/response';
import { connectToDatabase, ModelData, Model } from '@/service/mongo'; import { connectToDatabase, Model } from '@/service/mongo';
import { authToken } from '@/service/utils/tools'; import { authToken } from '@/service/utils/tools';
import { ModelDataSchema } from '@/types/mongoSchema'; import { ModelDataSchema } from '@/types/mongoSchema';
import { generateVector } from '@/service/events/generateVector'; import { generateVector } from '@/service/events/generateVector';
import { connectRedis } from '@/service/redis';
import { VecModelDataPrefix, ModelDataStatusEnum } from '@/constants/redis';
export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) { export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
try { try {
@ -25,6 +27,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
const userId = await authToken(authorization); const userId = await authToken(authorization);
await connectToDatabase(); await connectToDatabase();
const redis = await connectRedis();
// 验证是否是该用户的 model // 验证是否是该用户的 model
const model = await Model.findOne({ const model = await Model.findOne({
@ -36,19 +39,29 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
throw new Error('无权操作该模型'); throw new Error('无权操作该模型');
} }
// push data const insertRes = await Promise.allSettled(
await ModelData.insertMany( data.map((item) => {
data.map((item) => ({ return redis.sendCommand([
...item, 'HMSET',
modelId, `${VecModelDataPrefix}:${item.q.id}`,
userId 'userId',
})) userId,
'modelId',
modelId,
'q',
item.q.text,
'text',
item.text,
'status',
ModelDataStatusEnum.waiting
]);
})
); );
generateVector(true); generateVector(true);
jsonRes(res, { jsonRes(res, {
data: model data: insertRes.filter((item) => item.status === 'rejected').length
}); });
} catch (err) { } catch (err) {
jsonRes(res, { jsonRes(res, {

View File

@ -0,0 +1,78 @@
import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response';
import { connectToDatabase, Model } from '@/service/mongo';
import { authToken } from '@/service/utils/tools';
import { generateVector } from '@/service/events/generateVector';
import { vectorToBuffer, formatVector } from '@/utils/tools';
import { connectRedis } from '@/service/redis';
import { VecModelDataPrefix, ModelDataStatusEnum } from '@/constants/redis';
import { customAlphabet } from 'nanoid';
const nanoid = customAlphabet('abcdefghijklmnopqrstuvwxyz1234567890', 12);
export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
try {
const { modelId, data } = req.body as {
modelId: string;
data: { prompt: string; completion: string; vector?: number[] }[];
};
const { authorization } = req.headers;
if (!authorization) {
throw new Error('无权操作');
}
if (!modelId || !Array.isArray(data)) {
throw new Error('缺少参数');
}
// 凭证校验
const userId = await authToken(authorization);
await connectToDatabase();
const redis = await connectRedis();
// 验证是否是该用户的 model
const model = await Model.findOne({
_id: modelId,
userId
});
if (!model) {
throw new Error('无权操作该模型');
}
// 插入 redis
const insertRedisRes = await Promise.allSettled(
data.map((item) => {
const vector = item.vector;
return redis.sendCommand([
'HMSET',
`${VecModelDataPrefix}:${nanoid()}`,
'userId',
userId,
'modelId',
String(modelId),
...(vector ? ['vector', vectorToBuffer(formatVector(vector))] : []),
'q',
item.prompt,
'text',
item.completion,
'status',
vector ? ModelDataStatusEnum.ready : ModelDataStatusEnum.waiting
]);
})
);
generateVector(true);
jsonRes(res, {
data: insertRedisRes.filter((item) => item.status === 'rejected').length
});
} catch (err) {
jsonRes(res, {
code: 500,
error: err
});
}
}

View File

@ -1,57 +0,0 @@
import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response';
import { connectToDatabase, DataItem, ModelData } from '@/service/mongo';
import { authToken } from '@/service/utils/tools';
import { customAlphabet } from 'nanoid';
const nanoid = customAlphabet('abcdefghijklmnopqrstuvwxyz1234567890', 12);
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
try {
let { dataIds, modelId } = req.body as { dataIds: string[]; modelId: string };
if (!dataIds) {
throw new Error('参数错误');
}
await connectToDatabase();
const { authorization } = req.headers;
const userId = await authToken(authorization);
const dataItems = (
await Promise.all(
dataIds.map((dataId) =>
DataItem.find<{ _id: string; result: { q: string }[]; text: string }>(
{
userId,
dataId
},
'result text'
)
)
)
).flat();
// push data
await ModelData.insertMany(
dataItems.map((item) => ({
modelId: modelId,
userId,
text: item.text,
q: item.result.map((item) => ({
id: nanoid(),
text: item.q
}))
}))
);
jsonRes(res, {
data: dataItems
});
} catch (err) {
jsonRes(res, {
code: 500,
error: err
});
}
}

View File

@ -1,7 +1,7 @@
import type { NextApiRequest, NextApiResponse } from 'next'; import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response'; import { jsonRes } from '@/service/response';
import { connectToDatabase, ModelData } from '@/service/mongo';
import { authToken } from '@/service/utils/tools'; import { authToken } from '@/service/utils/tools';
import { connectRedis } from '@/service/redis';
export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) { export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
try { try {
@ -22,17 +22,16 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
// 凭证校验 // 凭证校验
const userId = await authToken(authorization); const userId = await authToken(authorization);
await connectToDatabase(); const redis = await connectRedis();
await ModelData.updateOne( // 校验是否为该用户的数据
{ const dataItemUserId = await redis.hGet(dataId, 'userId');
_id: dataId, if (dataItemUserId !== userId) {
userId throw new Error('无权操作');
}, }
{
text // 更新
} await redis.hSet(dataId, 'text', text);
);
jsonRes(res); jsonRes(res);
} catch (err) { } catch (err) {

View File

@ -1,13 +1,12 @@
import type { NextApiRequest, NextApiResponse } from 'next'; import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response'; import { jsonRes } from '@/service/response';
import { Chat, Model, Training, connectToDatabase, ModelData } from '@/service/mongo'; import { Chat, Model, Training, connectToDatabase } from '@/service/mongo';
import { authToken, getUserApiOpenai } from '@/service/utils/tools'; import { authToken, getUserApiOpenai } from '@/service/utils/tools';
import { TrainingStatusEnum } from '@/constants/model'; import { TrainingStatusEnum } from '@/constants/model';
import { getOpenAIApi } from '@/service/utils/chat';
import { TrainingItemType } from '@/types/training'; import { TrainingItemType } from '@/types/training';
import { httpsAgent } from '@/service/utils/tools'; import { httpsAgent } from '@/service/utils/tools';
import { connectRedis } from '@/service/redis'; import { connectRedis } from '@/service/redis';
import { VecModelDataIndex } from '@/constants/redis'; import { VecModelDataIdx } from '@/constants/redis';
/* 获取我的模型 */ /* 获取我的模型 */
export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) { export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
@ -26,39 +25,38 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
// 凭证校验 // 凭证校验
const userId = await authToken(authorization); const userId = await authToken(authorization);
// 验证是否是该用户的 model
const model = await Model.findOne({
_id: modelId,
userId
});
if (!model) {
throw new Error('无权操作该模型');
}
await connectToDatabase(); await connectToDatabase();
const redis = await connectRedis(); const redis = await connectRedis();
const modelDataList = await ModelData.find({ // 获取 redis 中模型关联的所有数据
const searchRes = await redis.ft.search(
VecModelDataIdx,
`@modelId:{${modelId}} @userId:{${userId}}`,
{
LIMIT: {
from: 0,
size: 10000
}
}
);
// 删除 redis 内容
await Promise.all(searchRes.documents.map((item) => redis.del(item.id)));
// 删除对应的聊天
await Chat.deleteMany({
modelId modelId
}); });
// 删除 redis
modelDataList?.forEach((modelData) =>
modelData.q.forEach(async (item) => {
try {
await redis.json.del(`${VecModelDataIndex}:${item.id}`);
} catch (error) {
console.log(error);
}
})
);
let requestQueue: any[] = [];
// 删除对应的聊天
requestQueue.push(
Chat.deleteMany({
modelId
})
);
// 删除数据集
requestQueue.push(
ModelData.deleteMany({
modelId
})
);
// 查看是否正在训练 // 查看是否正在训练
const training: TrainingItemType | null = await Training.findOne({ const training: TrainingItemType | null = await Training.findOne({
modelId, modelId,
@ -78,21 +76,15 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
} }
// 删除对应训练记录 // 删除对应训练记录
requestQueue.push( await Training.deleteMany({
Training.deleteMany({ modelId
modelId });
})
);
// 删除模型 // 删除模型
requestQueue.push( await Model.deleteOne({
Model.deleteOne({ _id: modelId,
_id: modelId, userId
userId });
})
);
await Promise.all(requestQueue);
jsonRes(res); jsonRes(res);
} catch (err) { } catch (err) {

View File

@ -1,68 +0,0 @@
// Next.js API route support: https://nextjs.org/docs/api-routes/introduction
import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response';
import { connectToDatabase, Bill } from '@/service/mongo';
import { authToken } from '@/service/utils/tools';
import type { BillSchema } from '@/types/mongoSchema';
import { VecModelDataIndex } from '@/constants/redis';
import { connectRedis } from '@/service/redis';
import { vectorToBuffer } from '@/utils/tools';
let vectorData = [
-0.025028639, -0.010407282, 0.026523087, -0.0107438695, -0.006967359, 0.010043768, -0.012043097,
0.008724345, -0.028919589, -0.0117738275, 0.0050690062, 0.02961969
].concat(new Array(1524).fill(0));
let vectorData2 = [
0.025028639, 0.010407282, 0.026523087, 0.0107438695, -0.006967359, 0.010043768, -0.012043097,
0.008724345, 0.028919589, 0.0117738275, 0.0050690062, 0.02961969
].concat(new Array(1524).fill(0));
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
try {
if (process.env.NODE_ENV !== 'development') {
throw new Error('不是开发环境');
}
await connectToDatabase();
const redis = await connectRedis();
await redis.sendCommand([
'HMSET',
'model:data:333',
'vector',
vectorToBuffer(vectorData2),
'modelId',
'1133',
'dataId',
'safadfa'
]);
// search
const response = await redis.sendCommand([
'FT.SEARCH',
'idx:model:data:hash',
'@modelId:{1133} @vector:[VECTOR_RANGE 0.15 $blob]=>{$YIELD_DISTANCE_AS: score}',
'RETURN',
'2',
'modelId',
'dataId',
'PARAMS',
'2',
'blob',
vectorToBuffer(vectorData2),
'SORTBY',
'score',
'DIALECT',
'2'
]);
jsonRes(res, {
data: response
});
} catch (err) {
jsonRes(res, {
code: 500,
error: err
});
}
}

View File

@ -190,97 +190,91 @@ const Chat = ({ chatId }: { chatId: string }) => {
/** /**
* *
*/ */
const sendPrompt = useCallback( const sendPrompt = useCallback(async () => {
async (e?: React.MouseEvent<HTMLDivElement>) => { const storeInput = inputVal;
e?.stopPropagation(); // 去除空行
e?.preventDefault(); const val = inputVal
.trim()
.split('\n')
.filter((val) => val)
.join('\n');
if (!chatData?.modelId || !val || !ChatBox.current || isChatting) {
return;
}
const storeInput = inputVal; // 长度校验
// 去除空行 const tokens = encode(val).length;
const val = inputVal const model = modelList.find((item) => item.model === chatData.modelName);
.trim()
.split('\n') if (model && tokens >= model.maxToken) {
.filter((val) => val) toast({
.join('\n'); title: '单次输入超出 4000 tokens',
if (!chatData?.modelId || !val || !ChatBox.current || isChatting) { status: 'warning'
return; });
return;
}
const newChatList: ChatSiteItemType[] = [
...chatData.history,
{
obj: 'Human',
value: val,
status: 'finish'
},
{
obj: 'AI',
value: '',
status: 'loading'
} }
];
// 长度校验 // 插入内容
const tokens = encode(val).length; setChatData((state) => ({
const model = modelList.find((item) => item.model === chatData.modelName); ...state,
history: newChatList
}));
if (model && tokens >= model.maxToken) { // 清空输入内容
toast({ resetInputVal('');
title: '单次输入超出 4000 tokens', scrollToBottom();
status: 'warning'
try {
await gptChatPrompt(newChatList[newChatList.length - 2]);
// 如果是 Human 第一次发送,插入历史记录
const humanChat = newChatList.filter((item) => item.obj === 'Human');
if (humanChat.length === 1) {
pushChatHistory({
chatId,
title: humanChat[0].value
}); });
return;
} }
} catch (err: any) {
toast({
title: typeof err === 'string' ? err : err?.message || '聊天出错了~',
status: 'warning',
duration: 5000,
isClosable: true
});
const newChatList: ChatSiteItemType[] = [ resetInputVal(storeInput);
...chatData.history,
{
obj: 'Human',
value: val,
status: 'finish'
},
{
obj: 'AI',
value: '',
status: 'loading'
}
];
// 插入内容
setChatData((state) => ({ setChatData((state) => ({
...state, ...state,
history: newChatList history: newChatList.slice(0, newChatList.length - 2)
})); }));
}
// 清空输入内容 }, [
resetInputVal(''); inputVal,
scrollToBottom(); chatData,
isChatting,
try { resetInputVal,
await gptChatPrompt(newChatList[newChatList.length - 2]); scrollToBottom,
toast,
// 如果是 Human 第一次发送,插入历史记录 gptChatPrompt,
const humanChat = newChatList.filter((item) => item.obj === 'Human'); pushChatHistory,
if (humanChat.length === 1) { chatId
pushChatHistory({ ]);
chatId,
title: humanChat[0].value
});
}
} catch (err: any) {
toast({
title: typeof err === 'string' ? err : err?.message || '聊天出错了~',
status: 'warning',
duration: 5000,
isClosable: true
});
resetInputVal(storeInput);
setChatData((state) => ({
...state,
history: newChatList.slice(0, newChatList.length - 2)
}));
}
},
[
inputVal,
chatData,
isChatting,
resetInputVal,
scrollToBottom,
toast,
gptChatPrompt,
pushChatHistory,
chatId
]
);
// 删除一句话 // 删除一句话
const delChatRecord = useCallback( const delChatRecord = useCallback(
@ -474,6 +468,7 @@ const Chat = ({ chatId }: { chatId: string }) => {
flex={1} flex={1}
w={0} w={0}
py={0} py={0}
pr={0}
border={'none'} border={'none'}
_focusVisible={{ _focusVisible={{
border: 'none' border: 'none'

View File

@ -45,24 +45,22 @@ const InputDataModal = ({
setImporting(true); setImporting(true);
try { try {
await postModelDataInput({ const res = await postModelDataInput({
modelId: modelId, modelId: modelId,
data: [ data: [
{ {
text: e.text, text: e.text,
q: [ q: {
{ id: nanoid(),
id: nanoid(), text: e.q
text: e.q }
}
]
} }
] ]
}); });
toast({ toast({
title: '导入数据成功,需要一段时间训练', title: res === 0 ? '导入数据成功,需要一段时间训练' : '数据导入异常',
status: 'success' status: res === 0 ? 'success' : 'warning'
}); });
onClose(); onClose();
onSuccess(); onSuccess();
@ -88,8 +86,15 @@ const InputDataModal = ({
<ModalHeader></ModalHeader> <ModalHeader></ModalHeader>
<ModalCloseButton /> <ModalCloseButton />
<Flex flex={'1 0 0'} h={0} px={6} pb={2}> <Box
<Box flex={2} mr={4} h={'100%'}> display={['block', 'flex']}
flex={'1 0 0'}
h={['100%', 0]}
overflowY={'auto'}
px={6}
pb={2}
>
<Box flex={2} mr={[0, 4]} mb={[4, 0]} h={['230px', '100%']}>
<Box h={'30px'}></Box> <Box h={'30px'}></Box>
<Textarea <Textarea
placeholder="相关问题,可以回车输入多个问法, 最多500字" placeholder="相关问题,可以回车输入多个问法, 最多500字"
@ -101,10 +106,11 @@ const InputDataModal = ({
})} })}
/> />
</Box> </Box>
<Box flex={3} h={'100%'}> <Box flex={3} h={['330px', '100%']}>
<Box h={'30px'}></Box> <Box h={'30px'}></Box>
<Textarea <Textarea
placeholder="知识点" placeholder="知识点,最多1000字"
maxLength={1000}
resize={'none'} resize={'none'}
h={'calc(100% - 30px)'} h={'calc(100% - 30px)'}
{...register(`text`, { {...register(`text`, {
@ -112,7 +118,7 @@ const InputDataModal = ({
})} })}
/> />
</Box> </Box>
</Flex> </Box>
<Flex px={6} pt={2} pb={4}> <Flex px={6} pt={2} pb={4}>
<Box flex={1}></Box> <Box flex={1}></Box>

View File

@ -19,7 +19,7 @@ import {
MenuItem MenuItem
} from '@chakra-ui/react'; } from '@chakra-ui/react';
import type { ModelSchema } from '@/types/mongoSchema'; import type { ModelSchema } from '@/types/mongoSchema';
import { ModelDataSchema } from '@/types/mongoSchema'; import type { RedisModelDataItemType } from '@/types/redis';
import { ModelDataStatusMap } from '@/constants/model'; import { ModelDataStatusMap } from '@/constants/model';
import { usePagination } from '@/hooks/usePagination'; import { usePagination } from '@/hooks/usePagination';
import { import {
@ -35,7 +35,8 @@ import dynamic from 'next/dynamic';
import { useQuery } from '@tanstack/react-query'; import { useQuery } from '@tanstack/react-query';
const InputModel = dynamic(() => import('./InputDataModal')); const InputModel = dynamic(() => import('./InputDataModal'));
const SelectModel = dynamic(() => import('./SelectFileModal')); const SelectFileModel = dynamic(() => import('./SelectFileModal'));
const SelectJsonModel = dynamic(() => import('./SelectJsonModal'));
const ModelDataCard = ({ model }: { model: ModelSchema }) => { const ModelDataCard = ({ model }: { model: ModelSchema }) => {
const { toast } = useToast(); const { toast } = useToast();
@ -48,7 +49,7 @@ const ModelDataCard = ({ model }: { model: ModelSchema }) => {
total, total,
getData, getData,
pageNum pageNum
} = usePagination<ModelDataSchema>({ } = usePagination<RedisModelDataItemType>({
api: getModelDataList, api: getModelDataList,
pageSize: 8, pageSize: 8,
params: { params: {
@ -76,12 +77,17 @@ const ModelDataCard = ({ model }: { model: ModelSchema }) => {
onClose: onCloseInputModal onClose: onCloseInputModal
} = useDisclosure(); } = useDisclosure();
const { const {
isOpen: isOpenSelectModal, isOpen: isOpenSelectFileModal,
onOpen: onOpenSelectModal, onOpen: onOpenSelectFileModal,
onClose: onCloseSelectModal onClose: onCloseSelectFileModal
} = useDisclosure();
const {
isOpen: isOpenSelectJsonModal,
onOpen: onOpenSelectJsonModal,
onClose: onCloseSelectJsonModal
} = useDisclosure(); } = useDisclosure();
const { data, refetch } = useQuery(['getModelSplitDataList'], () => const { data: splitDataList, refetch } = useQuery(['getModelSplitDataList'], () =>
getModelSplitDataList(model._id) getModelSplitDataList(model._id)
); );
@ -113,13 +119,18 @@ const ModelDataCard = ({ model }: { model: ModelSchema }) => {
<MenuButton as={Button}></MenuButton> <MenuButton as={Button}></MenuButton>
<MenuList> <MenuList>
<MenuItem onClick={onOpenInputModal}></MenuItem> <MenuItem onClick={onOpenInputModal}></MenuItem>
<MenuItem onClick={onOpenSelectModal}></MenuItem> <MenuItem onClick={onOpenSelectFileModal}></MenuItem>
<MenuItem onClick={onOpenSelectJsonModal}>JSON导入</MenuItem>
</MenuList> </MenuList>
</Menu> </Menu>
</Flex> </Flex>
{data && data.length > 0 && <Box fontSize={'xs'}>{data.length}...</Box>} {splitDataList && splitDataList.length > 0 && (
<Box fontSize={'xs'}>
{splitDataList.map((item) => item.textList).flat().length}...
</Box>
)}
<Box mt={4}> <Box mt={4}>
<TableContainer> <TableContainer minH={'500px'}>
<Table variant={'simple'}> <Table variant={'simple'}>
<Thead> <Thead>
<Tr> <Tr>
@ -131,19 +142,11 @@ const ModelDataCard = ({ model }: { model: ModelSchema }) => {
</Thead> </Thead>
<Tbody> <Tbody>
{modelDataList.map((item) => ( {modelDataList.map((item) => (
<Tr key={item._id}> <Tr key={item.id}>
<Td w={'350px'}> <Td w={'350px'}>
{item.q.map((item, i) => ( <Box fontSize={'xs'} w={'100%'} whiteSpace={'pre-wrap'} _notLast={{ mb: 1 }}>
<Box {item.q}
key={item.id} </Box>
fontSize={'xs'}
w={'100%'}
whiteSpace={'pre-wrap'}
_notLast={{ mb: 1 }}
>
{item.text}
</Box>
))}
</Td> </Td>
<Td minW={'200px'}> <Td minW={'200px'}>
<Textarea <Textarea
@ -153,9 +156,9 @@ const ModelDataCard = ({ model }: { model: ModelSchema }) => {
fontSize={'xs'} fontSize={'xs'}
resize={'both'} resize={'both'}
onBlur={(e) => { onBlur={(e) => {
const oldVal = modelDataList.find((data) => item._id === data._id)?.text; const oldVal = modelDataList.find((data) => item.id === data.id)?.text;
if (oldVal !== e.target.value) { if (oldVal !== e.target.value) {
updateAnswer(item._id, e.target.value); updateAnswer(item.id, e.target.value);
} }
}} }}
></Textarea> ></Textarea>
@ -169,7 +172,7 @@ const ModelDataCard = ({ model }: { model: ModelSchema }) => {
aria-label={'delete'} aria-label={'delete'}
size={'sm'} size={'sm'}
onClick={async () => { onClick={async () => {
await delOneModelData(item._id); await delOneModelData(item.id);
refetchData(pageNum); refetchData(pageNum);
}} }}
/> />
@ -188,8 +191,19 @@ const ModelDataCard = ({ model }: { model: ModelSchema }) => {
{isOpenInputModal && ( {isOpenInputModal && (
<InputModel modelId={model._id} onClose={onCloseInputModal} onSuccess={refetchData} /> <InputModel modelId={model._id} onClose={onCloseInputModal} onSuccess={refetchData} />
)} )}
{isOpenSelectModal && ( {isOpenSelectFileModal && (
<SelectModel modelId={model._id} onClose={onCloseSelectModal} onSuccess={refetchData} /> <SelectFileModel
modelId={model._id}
onClose={onCloseSelectFileModal}
onSuccess={refetchData}
/>
)}
{isOpenSelectJsonModal && (
<SelectJsonModel
modelId={model._id}
onClose={onCloseSelectJsonModal}
onSuccess={refetchData}
/>
)} )}
</> </>
); );

View File

@ -100,40 +100,43 @@ const SelectFileModal = ({
}); });
return ( return (
<Modal isOpen={true} onClose={onClose}> <Modal isOpen={true} onClose={onClose} isCentered>
<ModalOverlay /> <ModalOverlay />
<ModalContent maxW={'min(900px, 90vw)'} position={'relative'}> <ModalContent maxW={'min(900px, 90vw)'} m={0} position={'relative'} h={['90vh', '70vh']}>
<ModalHeader></ModalHeader> <ModalHeader></ModalHeader>
<ModalCloseButton /> <ModalCloseButton />
<ModalBody> <ModalBody
<Flex display={'flex'}
flexDirection={'column'} flexDirection={'column'}
p={4}
h={'100%'}
alignItems={'center'}
justifyContent={'center'}
fontSize={'sm'}
>
<Button isLoading={selecting} onClick={onOpen}>
</Button>
<Box mt={2} maxW={['100%', '70%']}>
{fileExtension} QA
tokens0.04/1k tokens
</Box>
<Box mt={2}>
{fileText.length} {encode(fileText).length} tokens
</Box>
<Box
flex={'1 0 0'}
h={0}
w={'100%'}
overflowY={'auto'}
p={2} p={2}
h={'100%'} backgroundColor={'blackAlpha.50'}
alignItems={'center'} whiteSpace={'pre-wrap'}
justifyContent={'center'} fontSize={'xs'}
fontSize={'sm'}
> >
<Button isLoading={selecting} onClick={onOpen}> {fileText}
</Box>
</Button>
<Box mt={2}> {fileExtension} . </Box>
<Box mt={2}>
{fileText.length} {encode(fileText).length} tokens
</Box>
<Box
h={'300px'}
w={'100%'}
overflow={'auto'}
p={2}
backgroundColor={'blackAlpha.50'}
whiteSpace={'pre'}
fontSize={'xs'}
>
{fileText}
</Box>
</Flex>
</ModalBody> </ModalBody>
<Flex px={6} pt={2} pb={4}> <Flex px={6} pt={2} pb={4}>

View File

@ -0,0 +1,145 @@
import React, { useState, useCallback } from 'react';
import {
Box,
Flex,
Button,
Modal,
ModalOverlay,
ModalContent,
ModalHeader,
ModalCloseButton,
ModalBody
} from '@chakra-ui/react';
import { useToast } from '@/hooks/useToast';
import { useSelectFile } from '@/hooks/useSelectFile';
import { useConfirm } from '@/hooks/useConfirm';
import { readTxtContent } from '@/utils/tools';
import { useMutation } from '@tanstack/react-query';
import { postModelDataJsonData } from '@/api/model';
import Markdown from '@/components/Markdown';
const SelectJsonModal = ({
onClose,
onSuccess,
modelId
}: {
onClose: () => void;
onSuccess: () => void;
modelId: string;
}) => {
const [selecting, setSelecting] = useState(false);
const { toast } = useToast();
const { File, onOpen } = useSelectFile({ fileType: '.json', multiple: true });
const [fileData, setFileData] = useState<
{ prompt: string; completion: string; vector?: number[] }[]
>([]);
const { openConfirm, ConfirmChild } = useConfirm({
content: '确认导入该数据集?'
});
const onSelectFile = useCallback(
async (e: File[]) => {
setSelecting(true);
try {
const jsonData = (
await Promise.all(e.map((item) => readTxtContent(item).then((text) => JSON.parse(text))))
).flat();
// check 文件类型
for (let i = 0; i < jsonData.length; i++) {
if (!jsonData[i]?.prompt || !jsonData[i]?.completion) {
throw new Error('缺少 prompt 或 completion');
}
}
setFileData(jsonData);
} catch (error: any) {
console.log(error);
toast({
title: error?.message || 'JSON文件格式有误',
status: 'error'
});
}
setSelecting(false);
},
[setSelecting, toast]
);
const { mutate, isLoading } = useMutation({
mutationFn: async () => {
if (!fileData) return;
const res = await postModelDataJsonData(modelId, fileData);
console.log(res);
toast({
title: '导入数据成功,需要一段拆解和训练',
status: 'success'
});
onClose();
onSuccess();
},
onError() {
toast({
title: '导入文件失败',
status: 'error'
});
}
});
return (
<Modal isOpen={true} onClose={onClose} isCentered>
<ModalOverlay />
<ModalContent maxW={'90vw'} position={'relative'} m={0} h={'90vh'}>
<ModalHeader>JSON数据集</ModalHeader>
<ModalCloseButton />
<ModalBody h={'100%'} display={['block', 'flex']} fontSize={'sm'} overflowY={'auto'}>
<Box flex={'2 0 0'} w={['100%', 0]} mr={[0, 4]} mb={[4, 0]}>
<Markdown
source={`接受一个对象数组,每个对象必须包含 prompt 和 completion 格式可以包含vector。prompt 代表问题completion 代表回答的内容可以多个问题对应一个回答vector 为 prompt 的向量,如果没有讲有系统生成。例如:
~~~json
[
{
"prompt":"sealos是什么?\\n介绍下sealos\\nsealos有什么用",
"completion":"sealos是xxxxxx"
},
{
"prompt":"laf是什么?",
"completion":"laf是xxxxxx",
"vector":[-0.42,-0.4314314,0.43143]
}
]
~~~`}
/>
<Flex alignItems={'center'}>
<Button isLoading={selecting} onClick={onOpen}>
JSON
</Button>
<Box ml={4}> {fileData.length} </Box>
</Flex>
</Box>
<Box flex={'2 0 0'} h={'100%'} overflow={'auto'} p={2} backgroundColor={'blackAlpha.50'}>
{JSON.stringify(fileData)}
</Box>
</ModalBody>
<Flex px={6} pt={2} pb={4}>
<Box flex={1}></Box>
<Button variant={'outline'} mr={3} onClick={onClose}>
</Button>
<Button
isLoading={isLoading}
isDisabled={fileData.length === 0}
onClick={openConfirm(mutate)}
>
</Button>
</Flex>
</ModalContent>
<ConfirmChild />
<File onSelect={onSelectFile} />
</Modal>
);
};
export default SelectJsonModal;

View File

@ -1,10 +1,12 @@
import { SplitData, ModelData } from '@/service/mongo'; import { SplitData } from '@/service/mongo';
import { getOpenAIApi } from '@/service/utils/chat'; import { getOpenAIApi } from '@/service/utils/chat';
import { httpsAgent, getOpenApiKey } from '@/service/utils/tools'; import { httpsAgent, getOpenApiKey } from '@/service/utils/tools';
import type { ChatCompletionRequestMessage } from 'openai'; import type { ChatCompletionRequestMessage } from 'openai';
import { ChatModelNameEnum } from '@/constants/model'; import { ChatModelNameEnum } from '@/constants/model';
import { pushSplitDataBill } from '@/service/events/pushBill'; import { pushSplitDataBill } from '@/service/events/pushBill';
import { generateVector } from './generateVector'; import { generateVector } from './generateVector';
import { connectRedis } from '../redis';
import { VecModelDataPrefix } from '@/constants/redis';
import { customAlphabet } from 'nanoid'; import { customAlphabet } from 'nanoid';
const nanoid = customAlphabet('abcdefghijklmnopqrstuvwxyz1234567890', 12); const nanoid = customAlphabet('abcdefghijklmnopqrstuvwxyz1234567890', 12);
@ -18,6 +20,7 @@ export async function generateQA(next = false): Promise<any> {
}; };
try { try {
const redis = await connectRedis();
// 找出一个需要生成的 dataItem // 找出一个需要生成的 dataItem
const dataItem = await SplitData.findOne({ const dataItem = await SplitData.findOne({
textList: { $exists: true, $ne: [] } textList: { $exists: true, $ne: [] }
@ -29,8 +32,10 @@ export async function generateQA(next = false): Promise<any> {
return; return;
} }
// 源文本
const text = dataItem.textList[dataItem.textList.length - 1]; const text = dataItem.textList[dataItem.textList.length - 1];
if (!text) { if (!text) {
await SplitData.findByIdAndUpdate(dataItem._id, { $pop: { textList: 1 } }); // 弹出无效文本
throw new Error('无文本'); throw new Error('无文本');
} }
@ -63,7 +68,7 @@ export async function generateQA(next = false): Promise<any> {
.createChatCompletion( .createChatCompletion(
{ {
model: ChatModelNameEnum.GPT35, model: ChatModelNameEnum.GPT35,
temperature: 0.2, temperature: 0.4,
n: 1, n: 1,
messages: [ messages: [
systemPrompt, systemPrompt,
@ -79,26 +84,29 @@ export async function generateQA(next = false): Promise<any> {
} }
) )
.then((res) => ({ .then((res) => ({
rawContent: res?.data.choices[0].message?.content || '', rawContent: res?.data.choices[0].message?.content || '', // chatgpt原本的回复
result: splitText(res?.data.choices[0].message?.content || '') result: splitText(res?.data.choices[0].message?.content || '') // 格式化后的QA对
})); // 从 content 中提取 QA }));
await Promise.allSettled([ await Promise.allSettled([
SplitData.findByIdAndUpdate(dataItem._id, { $pop: { textList: 1 } }), SplitData.findByIdAndUpdate(dataItem._id, { $pop: { textList: 1 } }), // 弹出已经拆分的文本
ModelData.insertMany( ...response.result.map((item) => {
response.result.map((item) => ({ // 插入 redis
modelId: dataItem.modelId, return redis.sendCommand([
userId: dataItem.userId, 'HMSET',
text: item.a, `${VecModelDataPrefix}:${nanoid()}`,
q: [ 'userId',
{ String(dataItem.userId),
id: nanoid(), 'modelId',
text: item.q String(dataItem.modelId),
} 'q',
], item.q,
status: 1 'text',
})) item.a,
) 'status',
'waiting'
]);
})
]); ]);
console.log( console.log(

View File

@ -1,9 +1,9 @@
import { getOpenAIApi } from '@/service/utils/chat'; import { getOpenAIApi } from '@/service/utils/chat';
import { httpsAgent } from '@/service/utils/tools'; import { httpsAgent } from '@/service/utils/tools';
import { ModelData } from '../models/modelData';
import { connectRedis } from '../redis'; import { connectRedis } from '../redis';
import { VecModelDataIndex } from '@/constants/redis'; import { VecModelDataIdx } from '@/constants/redis';
import { vectorToBuffer } from '@/utils/tools'; import { vectorToBuffer } from '@/utils/tools';
import { ModelDataStatusEnum } from '@/constants/redis';
export async function generateVector(next = false): Promise<any> { export async function generateVector(next = false): Promise<any> {
if (global.generatingVector && !next) return; if (global.generatingVector && !next) return;
@ -12,74 +12,71 @@ export async function generateVector(next = false): Promise<any> {
try { try {
const redis = await connectRedis(); const redis = await connectRedis();
// 找出一个需要生成的 dataItem // 从找出一个 status = waiting 的数据
const dataItem = await ModelData.findOne({ const searchRes = await redis.ft.search(
status: { $ne: 0 } VecModelDataIdx,
}); `@status:{${ModelDataStatusEnum.waiting}}`,
{
RETURN: ['q'],
LIMIT: {
from: 0,
size: 1
}
}
);
if (!dataItem) { if (searchRes.total === 0) {
console.log('没有需要生成 【向量】 的数据'); console.log('没有需要生成 【向量】 的数据');
global.generatingVector = false; global.generatingVector = false;
return; return;
} }
const dataItem: { id: string; q: string } = {
id: searchRes.documents[0].id,
q: String(searchRes.documents[0]?.value?.q || '')
};
// 获取 openapi Key // 获取 openapi Key
const openAiKey = process.env.OPENAIKEY as string; const openAiKey = process.env.OPENAIKEY as string;
// 获取 openai 请求实例 // 获取 openai 请求实例
const chatAPI = getOpenAIApi(openAiKey); const chatAPI = getOpenAIApi(openAiKey);
const dataId = String(dataItem._id);
// 生成词向量 // 生成词向量
const response = await Promise.allSettled( const vector = await chatAPI
dataItem.q.map((item, i) => .createEmbedding(
chatAPI {
.createEmbedding( model: 'text-embedding-ada-002',
{ input: dataItem.q
model: 'text-embedding-ada-002', },
input: item.text {
}, timeout: 120000,
{ httpsAgent
timeout: 120000, }
httpsAgent
}
)
.then((res) => res?.data?.data?.[0]?.embedding || [])
.then((vector) =>
redis.sendCommand([
'HMSET',
`${VecModelDataIndex}:${item.id}`,
'vector',
vectorToBuffer(vector),
'modelId',
String(dataItem.modelId),
'dataId',
String(dataId)
])
)
) )
); .then((res) => res?.data?.data?.[0]?.embedding || []);
if (response.filter((item) => item.status === 'fulfilled').length === 0) { // 更新 redis 向量和状态数据
throw new Error(JSON.stringify(response)); await redis.sendCommand([
} 'HMSET',
// 修改该数据状态 dataItem.id,
await ModelData.findByIdAndUpdate(dataItem._id, { 'vector',
status: 0 vectorToBuffer(vector),
}); 'status',
ModelDataStatusEnum.ready
]);
console.log(`生成向量成功: ${dataItem._id}`); console.log(`生成向量成功: ${dataItem.id}`);
setTimeout(() => { setTimeout(() => {
generateVector(true); generateVector(true);
}, 3000); }, 2000);
} catch (error: any) { } catch (error: any) {
console.log(error); console.log('error: 生成向量错误', error?.response?.statusText);
console.log('error: 生成向量错误', error?.response?.data); !error?.response && console.log(error);
if (error?.response?.statusText === 'Too Many Requests') { if (error?.response?.statusText === 'Too Many Requests') {
console.log('次数限制1分钟后尝试'); console.log('生成向量次数限制1分钟后尝试');
// 限制次数1分钟后再试 // 限制次数1分钟后再试
setTimeout(() => { setTimeout(() => {
generateVector(true); generateVector(true);

View File

@ -1,37 +0,0 @@
/* 模型的知识库 */
import { Schema, model, models, Model as MongoModel } from 'mongoose';
import { ModelDataSchema as ModelDataType } from '@/types/mongoSchema';
const ModelDataSchema = new Schema({
modelId: {
type: Schema.Types.ObjectId,
ref: 'model',
required: true
},
userId: {
type: Schema.Types.ObjectId,
ref: 'user',
required: true
},
text: {
type: String,
required: true
},
q: {
type: [
{
id: String, // 对应redis的key
text: String
}
],
default: []
},
status: {
type: Number,
enum: [0, 1], // 1 训练ing
default: 1
}
});
export const ModelData: MongoModel<ModelDataType> =
models['modelData'] || model('modelData', ModelDataSchema);

View File

@ -35,7 +35,6 @@ export async function connectToDatabase(): Promise<void> {
export * from './models/authCode'; export * from './models/authCode';
export * from './models/chat'; export * from './models/chat';
export * from './models/model'; export * from './models/model';
export * from './models/modelData';
export * from './models/user'; export * from './models/user';
export * from './models/training'; export * from './models/training';
export * from './models/bill'; export * from './models/bill';

View File

@ -29,8 +29,8 @@ export const connectRedis = async () => {
await global.redisClient.connect(); await global.redisClient.connect();
// 0 - 测试库1 - 正式 // 1 - 测试库0 - 正式
await global.redisClient.select(0); await global.redisClient.select(process.env.NODE_ENV === 'development' ? 0 : 0);
return global.redisClient; return global.redisClient;
} catch (error) { } catch (error) {

View File

@ -60,7 +60,7 @@ export interface ModelDataSchema {
q: { q: {
id: string; id: string;
text: string; text: string;
}[]; };
status: ModelDataType; status: ModelDataType;
} }

View File

@ -1,6 +1,7 @@
import { ModelDataStatusEnum } from '@/constants/redis';
export interface RedisModelDataItemType { export interface RedisModelDataItemType {
id: string; id: string;
vector: number[]; q: string;
dataId: string; text: string;
modelId: string; status: `${ModelDataStatusEnum}`;
} }

View File

@ -127,3 +127,9 @@ export const vectorToBuffer = (vector: number[]) => {
return Buffer.from(npVector.buffer); return Buffer.from(npVector.buffer);
}; };
export function formatVector(vector: number[]) {
let formattedVector = vector.slice(0, 1536); // 截取前1536个元素
formattedVector = formattedVector.concat(Array(1536 - formattedVector.length).fill(0)); // 在后面添加0
return formattedVector;
}