feat: 替换redis搜索
This commit is contained in:
parent
867d69659f
commit
1e5714da1b
@ -1,7 +1,7 @@
|
|||||||
import type { NextApiRequest, NextApiResponse } from 'next';
|
import type { NextApiRequest, NextApiResponse } from 'next';
|
||||||
import { connectToDatabase } from '@/service/mongo';
|
import { connectToDatabase } from '@/service/mongo';
|
||||||
import { authChat } from '@/service/utils/chat';
|
import { authChat } from '@/service/utils/chat';
|
||||||
import { httpsAgent, openaiChatFilter, systemPromptFilter } from '@/service/utils/tools';
|
import { httpsAgent, systemPromptFilter } from '@/service/utils/tools';
|
||||||
import { ChatCompletionRequestMessage, ChatCompletionRequestMessageRoleEnum } from 'openai';
|
import { ChatCompletionRequestMessage, ChatCompletionRequestMessageRoleEnum } from 'openai';
|
||||||
import { ChatItemType } from '@/types/chat';
|
import { ChatItemType } from '@/types/chat';
|
||||||
import { jsonRes } from '@/service/response';
|
import { jsonRes } from '@/service/response';
|
||||||
@ -9,11 +9,9 @@ import type { ModelSchema } from '@/types/mongoSchema';
|
|||||||
import { PassThrough } from 'stream';
|
import { PassThrough } from 'stream';
|
||||||
import { modelList, ModelVectorSearchModeMap, ModelVectorSearchModeEnum } from '@/constants/model';
|
import { modelList, ModelVectorSearchModeMap, ModelVectorSearchModeEnum } from '@/constants/model';
|
||||||
import { pushChatBill } from '@/service/events/pushBill';
|
import { pushChatBill } from '@/service/events/pushBill';
|
||||||
import { connectRedis } from '@/service/redis';
|
|
||||||
import { VecModelDataPrefix } from '@/constants/redis';
|
|
||||||
import { vectorToBuffer } from '@/utils/tools';
|
|
||||||
import { openaiCreateEmbedding, gpt35StreamResponse } from '@/service/utils/openai';
|
import { openaiCreateEmbedding, gpt35StreamResponse } from '@/service/utils/openai';
|
||||||
import dayjs from 'dayjs';
|
import dayjs from 'dayjs';
|
||||||
|
import { PgClient } from '@/service/pg';
|
||||||
|
|
||||||
/* 发送提示词 */
|
/* 发送提示词 */
|
||||||
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
|
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
|
||||||
@ -43,7 +41,6 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
|
|||||||
}
|
}
|
||||||
|
|
||||||
await connectToDatabase();
|
await connectToDatabase();
|
||||||
const redis = await connectRedis();
|
|
||||||
let startTime = Date.now();
|
let startTime = Date.now();
|
||||||
|
|
||||||
const { chat, userApiKey, systemKey, userId } = await authChat(chatId, authorization);
|
const { chat, userApiKey, systemKey, userId } = await authChat(chatId, authorization);
|
||||||
@ -65,38 +62,22 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
|
|||||||
text: prompt.value
|
text: prompt.value
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// 相似度搜素
|
||||||
const similarity = ModelVectorSearchModeMap[model.search.mode]?.similarity || 0.22;
|
const similarity = ModelVectorSearchModeMap[model.search.mode]?.similarity || 0.22;
|
||||||
// 搜索系统提示词, 按相似度从 redis 中搜出相关的 q 和 text
|
const vectorSearch = await PgClient.select<{ id: string; q: string; a: string }>('modelData', {
|
||||||
const redisData: any[] = await redis.sendCommand([
|
fields: ['id', 'q', 'a'],
|
||||||
'FT.SEARCH',
|
order: [{ field: 'vector', mode: `<=> '[${promptVector}]'` }],
|
||||||
`idx:${VecModelDataPrefix}:hash`,
|
where: [
|
||||||
`@modelId:{${String(
|
['model_id', model._id],
|
||||||
chat.modelId._id
|
'AND',
|
||||||
)}} @vector:[VECTOR_RANGE ${similarity} $blob]=>{$YIELD_DISTANCE_AS: score}`,
|
['user_id', userId],
|
||||||
'RETURN',
|
'AND',
|
||||||
'1',
|
`vector <=> '[${promptVector}]' < ${similarity}`
|
||||||
'text',
|
],
|
||||||
'SORTBY',
|
limit: 30
|
||||||
'score',
|
});
|
||||||
'PARAMS',
|
|
||||||
'2',
|
|
||||||
'blob',
|
|
||||||
vectorToBuffer(promptVector),
|
|
||||||
'LIMIT',
|
|
||||||
'0',
|
|
||||||
'30',
|
|
||||||
'DIALECT',
|
|
||||||
'2'
|
|
||||||
]);
|
|
||||||
|
|
||||||
const formatRedisPrompt: string[] = [];
|
const formatRedisPrompt: string[] = vectorSearch.rows.map((item) => `${item.q}\n${item.a}`);
|
||||||
// 格式化响应值,获取 qa
|
|
||||||
for (let i = 2; i < 61; i += 2) {
|
|
||||||
const text = redisData[i]?.[1];
|
|
||||||
if (text) {
|
|
||||||
formatRedisPrompt.push(text);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/* 高相似度+退出,无法匹配时直接退出 */
|
/* 高相似度+退出,无法匹配时直接退出 */
|
||||||
if (
|
if (
|
||||||
@ -121,9 +102,9 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
|
|||||||
|
|
||||||
prompts.unshift({
|
prompts.unshift({
|
||||||
obj: 'SYSTEM',
|
obj: 'SYSTEM',
|
||||||
value: `${model.systemPrompt} 用知识库内容回答,知识库内容为: "当前时间:${dayjs().format(
|
value: `${model.systemPrompt} 知识库是最新的,下面是知识库内容:当前时间为${dayjs().format(
|
||||||
'YYYY/MM/DD HH:mm:ss'
|
'YYYY/MM/DD HH:mm:ss'
|
||||||
)} ${systemPrompt}"`
|
)}\n${systemPrompt}`
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -2,8 +2,7 @@ import type { NextApiRequest, NextApiResponse } from 'next';
|
|||||||
import { jsonRes } from '@/service/response';
|
import { jsonRes } from '@/service/response';
|
||||||
import { connectToDatabase } from '@/service/mongo';
|
import { connectToDatabase } from '@/service/mongo';
|
||||||
import { authToken } from '@/service/utils/tools';
|
import { authToken } from '@/service/utils/tools';
|
||||||
import { connectRedis } from '@/service/redis';
|
import { PgClient } from '@/service/pg';
|
||||||
import { VecModelDataIdx } from '@/constants/redis';
|
|
||||||
|
|
||||||
export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
|
export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
|
||||||
try {
|
try {
|
||||||
@ -25,28 +24,23 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
|
|||||||
const userId = await authToken(authorization);
|
const userId = await authToken(authorization);
|
||||||
|
|
||||||
await connectToDatabase();
|
await connectToDatabase();
|
||||||
const redis = await connectRedis();
|
|
||||||
|
|
||||||
// 从 redis 中获取数据
|
// 统计数据
|
||||||
const searchRes = await redis.ft.search(
|
const count = await PgClient.count('modelData', {
|
||||||
VecModelDataIdx,
|
where: [['model_id', modelId], 'AND', ['user_id', userId]]
|
||||||
`@modelId:{${modelId}} @userId:{${userId}}`,
|
|
||||||
{
|
|
||||||
RETURN: ['q', 'text'],
|
|
||||||
LIMIT: {
|
|
||||||
from: 0,
|
|
||||||
size: 10000
|
|
||||||
}
|
|
||||||
}
|
|
||||||
);
|
|
||||||
|
|
||||||
const data: [string, string][] = [];
|
|
||||||
|
|
||||||
searchRes.documents.forEach((item: any) => {
|
|
||||||
if (item.value.q && item.value.text) {
|
|
||||||
data.push([item.value.q.replace(/\n/g, '\\n'), item.value.text.replace(/\n/g, '\\n')]);
|
|
||||||
}
|
|
||||||
});
|
});
|
||||||
|
// 从 pg 中获取所有数据
|
||||||
|
const pgData = await PgClient.select<{ q: string; a: string }>('modelData', {
|
||||||
|
where: [['model_id', modelId], 'AND', ['user_id', userId]],
|
||||||
|
fields: ['q', 'a'],
|
||||||
|
order: [{ field: 'id', mode: 'DESC' }],
|
||||||
|
limit: count
|
||||||
|
});
|
||||||
|
|
||||||
|
const data: [string, string][] = pgData.rows.map((item) => [
|
||||||
|
item.q.replace(/\n/g, '\\n'),
|
||||||
|
item.a.replace(/\n/g, '\\n')
|
||||||
|
]);
|
||||||
|
|
||||||
jsonRes(res, {
|
jsonRes(res, {
|
||||||
data
|
data
|
||||||
|
|||||||
@ -37,7 +37,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
|
|||||||
await connectToDatabase();
|
await connectToDatabase();
|
||||||
|
|
||||||
const searchRes = await PgClient.select<PgModelDataItemType>('modelData', {
|
const searchRes = await PgClient.select<PgModelDataItemType>('modelData', {
|
||||||
field: ['id', 'q', 'a', 'status'],
|
fields: ['id', 'q', 'a', 'status'],
|
||||||
where: [['user_id', userId], 'AND', ['model_id', modelId]],
|
where: [['user_id', userId], 'AND', ['model_id', modelId]],
|
||||||
order: [{ field: 'id', mode: 'DESC' }],
|
order: [{ field: 'id', mode: 'DESC' }],
|
||||||
limit: pageSize,
|
limit: pageSize,
|
||||||
|
|||||||
@ -3,11 +3,8 @@ import { jsonRes } from '@/service/response';
|
|||||||
import { connectToDatabase, Model } from '@/service/mongo';
|
import { connectToDatabase, Model } from '@/service/mongo';
|
||||||
import { authToken } from '@/service/utils/tools';
|
import { authToken } from '@/service/utils/tools';
|
||||||
import { generateVector } from '@/service/events/generateVector';
|
import { generateVector } from '@/service/events/generateVector';
|
||||||
import { connectRedis } from '@/service/redis';
|
import { ModelDataStatusEnum } from '@/constants/model';
|
||||||
import { VecModelDataPrefix, ModelDataStatusEnum } from '@/constants/redis';
|
import { PgClient } from '@/service/pg';
|
||||||
import { VecModelDataIdx } from '@/constants/redis';
|
|
||||||
import { customAlphabet } from 'nanoid';
|
|
||||||
const nanoid = customAlphabet('abcdefghijklmnopqrstuvwxyz1234567890', 12);
|
|
||||||
|
|
||||||
export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
|
export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
|
||||||
try {
|
try {
|
||||||
@ -29,7 +26,6 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
|
|||||||
const userId = await authToken(authorization);
|
const userId = await authToken(authorization);
|
||||||
|
|
||||||
await connectToDatabase();
|
await connectToDatabase();
|
||||||
const redis = await connectRedis();
|
|
||||||
|
|
||||||
// 验证是否是该用户的 model
|
// 验证是否是该用户的 model
|
||||||
const model = await Model.findOne({
|
const model = await Model.findOne({
|
||||||
@ -47,10 +43,18 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
|
|||||||
try {
|
try {
|
||||||
q = q.replace(/\\n/g, '\n');
|
q = q.replace(/\\n/g, '\n');
|
||||||
a = a.replace(/\\n/g, '\n');
|
a = a.replace(/\\n/g, '\n');
|
||||||
const redisSearch = await redis.ft.search(VecModelDataIdx, `@q:${q} @text:${a}`, {
|
const count = await PgClient.count('modelData', {
|
||||||
RETURN: ['q', 'text']
|
where: [
|
||||||
|
['user_id', userId],
|
||||||
|
'AND',
|
||||||
|
['model_id', modelId],
|
||||||
|
'AND',
|
||||||
|
['q', q],
|
||||||
|
'AND',
|
||||||
|
['a', a]
|
||||||
|
]
|
||||||
});
|
});
|
||||||
if (redisSearch.total > 0) {
|
if (count > 0) {
|
||||||
return Promise.reject('已经存在');
|
return Promise.reject('已经存在');
|
||||||
}
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
@ -62,35 +66,26 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
|
|||||||
});
|
});
|
||||||
})
|
})
|
||||||
);
|
);
|
||||||
|
// 过滤重复的内容
|
||||||
const filterData = searchRes
|
const filterData = searchRes
|
||||||
.filter((item) => item.status === 'fulfilled')
|
.filter((item) => item.status === 'fulfilled')
|
||||||
.map<{ q: string; a: string }>((item: any) => item.value);
|
.map<{ q: string; a: string }>((item: any) => item.value);
|
||||||
|
|
||||||
// 插入 redis
|
// 插入 pg
|
||||||
const insertRedisRes = await Promise.allSettled(
|
const insertRes = await PgClient.insert('modelData', {
|
||||||
filterData.map((item) => {
|
values: filterData.map((item) => [
|
||||||
return redis.sendCommand([
|
{ key: 'user_id', value: userId },
|
||||||
'HMSET',
|
{ key: 'model_id', value: modelId },
|
||||||
`${VecModelDataPrefix}:${nanoid()}`,
|
{ key: 'q', value: item.q },
|
||||||
'userId',
|
{ key: 'a', value: item.a },
|
||||||
userId,
|
{ key: 'status', value: ModelDataStatusEnum.waiting }
|
||||||
'modelId',
|
])
|
||||||
String(modelId),
|
});
|
||||||
'q',
|
|
||||||
item.q,
|
|
||||||
'text',
|
|
||||||
item.a,
|
|
||||||
'status',
|
|
||||||
ModelDataStatusEnum.waiting
|
|
||||||
]);
|
|
||||||
})
|
|
||||||
);
|
|
||||||
|
|
||||||
generateVector();
|
generateVector();
|
||||||
|
|
||||||
jsonRes(res, {
|
jsonRes(res, {
|
||||||
data: insertRedisRes.filter((item) => item.status === 'fulfilled').length
|
data: insertRes.rowCount
|
||||||
});
|
});
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
jsonRes(res, {
|
jsonRes(res, {
|
||||||
|
|||||||
@ -1,13 +1,13 @@
|
|||||||
import type { NextApiRequest, NextApiResponse } from 'next';
|
import type { NextApiRequest, NextApiResponse } from 'next';
|
||||||
import { jsonRes } from '@/service/response';
|
import { jsonRes } from '@/service/response';
|
||||||
import { authToken } from '@/service/utils/tools';
|
import { authToken } from '@/service/utils/tools';
|
||||||
import { connectRedis } from '@/service/redis';
|
|
||||||
import { ModelDataStatusEnum } from '@/constants/redis';
|
import { ModelDataStatusEnum } from '@/constants/redis';
|
||||||
import { generateVector } from '@/service/events/generateVector';
|
import { generateVector } from '@/service/events/generateVector';
|
||||||
|
import { PgClient } from '@/service/pg';
|
||||||
|
|
||||||
export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
|
export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
|
||||||
try {
|
try {
|
||||||
const { dataId, text, q } = req.body as { dataId: string; text: string; q?: string };
|
const { dataId, a, q } = req.body as { dataId: string; a: string; q?: string };
|
||||||
const { authorization } = req.headers;
|
const { authorization } = req.headers;
|
||||||
|
|
||||||
if (!authorization) {
|
if (!authorization) {
|
||||||
@ -21,26 +21,21 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
|
|||||||
// 凭证校验
|
// 凭证校验
|
||||||
const userId = await authToken(authorization);
|
const userId = await authToken(authorization);
|
||||||
|
|
||||||
const redis = await connectRedis();
|
// 更新 pg 内容
|
||||||
|
await PgClient.update('modelData', {
|
||||||
|
where: [['id', dataId], 'AND', ['user_id', userId]],
|
||||||
|
values: [
|
||||||
|
{ key: 'a', value: a },
|
||||||
|
...(q
|
||||||
|
? [
|
||||||
|
{ key: 'q', value: q },
|
||||||
|
{ key: 'status', value: ModelDataStatusEnum.waiting }
|
||||||
|
]
|
||||||
|
: [])
|
||||||
|
]
|
||||||
|
});
|
||||||
|
|
||||||
// 校验是否为该用户的数据
|
q && generateVector();
|
||||||
const dataItemUserId = await redis.hGet(dataId, 'userId');
|
|
||||||
if (dataItemUserId !== userId) {
|
|
||||||
throw new Error('无权操作');
|
|
||||||
}
|
|
||||||
|
|
||||||
// 更新
|
|
||||||
await redis.sendCommand([
|
|
||||||
'HMSET',
|
|
||||||
dataId,
|
|
||||||
...(q ? ['q', q, 'status', ModelDataStatusEnum.waiting] : []),
|
|
||||||
'text',
|
|
||||||
text
|
|
||||||
]);
|
|
||||||
|
|
||||||
if (q) {
|
|
||||||
generateVector();
|
|
||||||
}
|
|
||||||
|
|
||||||
jsonRes(res);
|
jsonRes(res);
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
|
|||||||
@ -6,13 +6,12 @@ import { getUserApiOpenai } from '@/service/utils/openai';
|
|||||||
import { TrainingStatusEnum } from '@/constants/model';
|
import { TrainingStatusEnum } from '@/constants/model';
|
||||||
import { TrainingItemType } from '@/types/training';
|
import { TrainingItemType } from '@/types/training';
|
||||||
import { httpsAgent } from '@/service/utils/tools';
|
import { httpsAgent } from '@/service/utils/tools';
|
||||||
import { connectRedis } from '@/service/redis';
|
import { PgClient } from '@/service/pg';
|
||||||
import { VecModelDataIdx } from '@/constants/redis';
|
|
||||||
|
|
||||||
/* 获取我的模型 */
|
/* 获取我的模型 */
|
||||||
export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
|
export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
|
||||||
try {
|
try {
|
||||||
const { modelId } = req.query;
|
const { modelId } = req.query as { modelId: string };
|
||||||
const { authorization } = req.headers;
|
const { authorization } = req.headers;
|
||||||
|
|
||||||
if (!authorization) {
|
if (!authorization) {
|
||||||
@ -37,21 +36,11 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
|
|||||||
}
|
}
|
||||||
|
|
||||||
await connectToDatabase();
|
await connectToDatabase();
|
||||||
const redis = await connectRedis();
|
|
||||||
|
|
||||||
// 获取 redis 中模型关联的所有数据
|
// 删除 pg 中所有该模型的数据
|
||||||
const searchRes = await redis.ft.search(
|
await PgClient.delete('modelData', {
|
||||||
VecModelDataIdx,
|
where: [['user_id', userId], 'AND', ['model_id', modelId]]
|
||||||
`@modelId:{${modelId}} @userId:{${userId}}`,
|
});
|
||||||
{
|
|
||||||
LIMIT: {
|
|
||||||
from: 0,
|
|
||||||
size: 10000
|
|
||||||
}
|
|
||||||
}
|
|
||||||
);
|
|
||||||
// 删除 redis 内容
|
|
||||||
await Promise.all(searchRes.documents.map((item) => redis.del(item.id)));
|
|
||||||
|
|
||||||
// 删除对应的聊天
|
// 删除对应的聊天
|
||||||
await Chat.deleteMany({
|
await Chat.deleteMany({
|
||||||
|
|||||||
@ -7,12 +7,15 @@ import { ChatCompletionRequestMessage, ChatCompletionRequestMessageRoleEnum } fr
|
|||||||
import { ChatItemType } from '@/types/chat';
|
import { ChatItemType } from '@/types/chat';
|
||||||
import { jsonRes } from '@/service/response';
|
import { jsonRes } from '@/service/response';
|
||||||
import { PassThrough } from 'stream';
|
import { PassThrough } from 'stream';
|
||||||
import { ChatModelNameEnum, modelList, ChatModelNameMap } from '@/constants/model';
|
import {
|
||||||
|
ChatModelNameEnum,
|
||||||
|
modelList,
|
||||||
|
ChatModelNameMap,
|
||||||
|
ModelVectorSearchModeMap
|
||||||
|
} from '@/constants/model';
|
||||||
import { pushChatBill } from '@/service/events/pushBill';
|
import { pushChatBill } from '@/service/events/pushBill';
|
||||||
import { connectRedis } from '@/service/redis';
|
|
||||||
import { VecModelDataPrefix } from '@/constants/redis';
|
|
||||||
import { vectorToBuffer } from '@/utils/tools';
|
|
||||||
import { openaiCreateEmbedding, gpt35StreamResponse } from '@/service/utils/openai';
|
import { openaiCreateEmbedding, gpt35StreamResponse } from '@/service/utils/openai';
|
||||||
|
import { PgClient } from '@/service/pg';
|
||||||
|
|
||||||
/* 发送提示词 */
|
/* 发送提示词 */
|
||||||
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
|
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
|
||||||
@ -46,7 +49,6 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
|
|||||||
}
|
}
|
||||||
|
|
||||||
await connectToDatabase();
|
await connectToDatabase();
|
||||||
const redis = await connectRedis();
|
|
||||||
let startTime = Date.now();
|
let startTime = Date.now();
|
||||||
|
|
||||||
/* 凭证校验 */
|
/* 凭证校验 */
|
||||||
@ -144,39 +146,29 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
|
|||||||
// 读取对话内容
|
// 读取对话内容
|
||||||
const prompts = [prompt];
|
const prompts = [prompt];
|
||||||
|
|
||||||
// 搜索系统提示词, 按相似度从 redis 中搜出相关的 q 和 text
|
// 相似度搜索
|
||||||
const redisData: any[] = await redis.sendCommand([
|
const similarity = ModelVectorSearchModeMap[model.search.mode]?.similarity || 0.22;
|
||||||
'FT.SEARCH',
|
const vectorSearch = await PgClient.select<{ id: string; q: string; a: string }>('modelData', {
|
||||||
`idx:${VecModelDataPrefix}:hash`,
|
fields: ['id', 'q', 'a'],
|
||||||
`@modelId:{${String(model._id)}}=>[KNN 20 @vector $blob AS score]`,
|
order: [{ field: 'vector', mode: `<=> '[${promptVector}]'` }],
|
||||||
'RETURN',
|
where: [
|
||||||
'1',
|
['model_id', model._id],
|
||||||
'text',
|
'AND',
|
||||||
'SORTBY',
|
['user_id', userId],
|
||||||
'score',
|
'AND',
|
||||||
'PARAMS',
|
`vector <=> '[${promptVector}]' < ${similarity}`
|
||||||
'2',
|
],
|
||||||
'blob',
|
limit: 30
|
||||||
vectorToBuffer(promptVector),
|
});
|
||||||
'DIALECT',
|
|
||||||
'2'
|
|
||||||
]);
|
|
||||||
|
|
||||||
// 格式化响应值,获取 qa
|
const formatRedisPrompt: string[] = vectorSearch.rows.map((item) => `${item.q}\n${item.a}`);
|
||||||
const formatRedisPrompt: string[] = [];
|
|
||||||
for (let i = 2; i < 42; i += 2) {
|
|
||||||
const text = redisData[i]?.[1];
|
|
||||||
if (text) {
|
|
||||||
formatRedisPrompt.push(text);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// textArr 筛选,最多 3000 tokens
|
// textArr 筛选,最多 3000 tokens
|
||||||
const systemPrompt = systemPromptFilter(formatRedisPrompt, 3000);
|
const systemPrompt = systemPromptFilter(formatRedisPrompt, 3000);
|
||||||
|
|
||||||
prompts.unshift({
|
prompts.unshift({
|
||||||
obj: 'SYSTEM',
|
obj: 'SYSTEM',
|
||||||
value: `${model.systemPrompt} 知识库内容是最新的,知识库内容为: "${systemPrompt}"`
|
value: `${model.systemPrompt} 知识库是最新的,下面是知识库内容:${systemPrompt}`
|
||||||
});
|
});
|
||||||
|
|
||||||
// 控制在 tokens 数量,防止超出
|
// 控制在 tokens 数量,防止超出
|
||||||
|
|||||||
@ -1,22 +1,15 @@
|
|||||||
import type { NextApiRequest, NextApiResponse } from 'next';
|
import type { NextApiRequest, NextApiResponse } from 'next';
|
||||||
import { connectToDatabase, Model } from '@/service/mongo';
|
import { connectToDatabase, Model } from '@/service/mongo';
|
||||||
import {
|
import { httpsAgent, systemPromptFilter, authOpenApiKey } from '@/service/utils/tools';
|
||||||
httpsAgent,
|
|
||||||
openaiChatFilter,
|
|
||||||
systemPromptFilter,
|
|
||||||
authOpenApiKey
|
|
||||||
} from '@/service/utils/tools';
|
|
||||||
import { ChatCompletionRequestMessage, ChatCompletionRequestMessageRoleEnum } from 'openai';
|
import { ChatCompletionRequestMessage, ChatCompletionRequestMessageRoleEnum } from 'openai';
|
||||||
import { ChatItemType } from '@/types/chat';
|
import { ChatItemType } from '@/types/chat';
|
||||||
import { jsonRes } from '@/service/response';
|
import { jsonRes } from '@/service/response';
|
||||||
import { PassThrough } from 'stream';
|
import { PassThrough } from 'stream';
|
||||||
import { modelList, ModelVectorSearchModeMap, ModelVectorSearchModeEnum } from '@/constants/model';
|
import { modelList, ModelVectorSearchModeMap, ModelVectorSearchModeEnum } from '@/constants/model';
|
||||||
import { pushChatBill } from '@/service/events/pushBill';
|
import { pushChatBill } from '@/service/events/pushBill';
|
||||||
import { connectRedis } from '@/service/redis';
|
|
||||||
import { VecModelDataPrefix } from '@/constants/redis';
|
|
||||||
import { vectorToBuffer } from '@/utils/tools';
|
|
||||||
import { openaiCreateEmbedding, gpt35StreamResponse } from '@/service/utils/openai';
|
import { openaiCreateEmbedding, gpt35StreamResponse } from '@/service/utils/openai';
|
||||||
import dayjs from 'dayjs';
|
import dayjs from 'dayjs';
|
||||||
|
import { PgClient } from '@/service/pg';
|
||||||
|
|
||||||
/* 发送提示词 */
|
/* 发送提示词 */
|
||||||
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
|
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
|
||||||
@ -56,7 +49,6 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
|
|||||||
}
|
}
|
||||||
|
|
||||||
await connectToDatabase();
|
await connectToDatabase();
|
||||||
const redis = await connectRedis();
|
|
||||||
let startTime = Date.now();
|
let startTime = Date.now();
|
||||||
|
|
||||||
/* 凭证校验 */
|
/* 凭证校验 */
|
||||||
@ -84,38 +76,22 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
|
|||||||
text: prompts[prompts.length - 1].value // 取最后一个
|
text: prompts[prompts.length - 1].value // 取最后一个
|
||||||
});
|
});
|
||||||
|
|
||||||
// 搜索系统提示词, 按相似度从 redis 中搜出相关的 q 和 text
|
// 相似度搜素
|
||||||
const similarity = ModelVectorSearchModeMap[model.search.mode]?.similarity || 0.22;
|
const similarity = ModelVectorSearchModeMap[model.search.mode]?.similarity || 0.22;
|
||||||
// 搜索系统提示词, 按相似度从 redis 中搜出相关的 q 和 text
|
const vectorSearch = await PgClient.select<{ id: string; q: string; a: string }>('modelData', {
|
||||||
const redisData: any[] = await redis.sendCommand([
|
fields: ['id', 'q', 'a'],
|
||||||
'FT.SEARCH',
|
order: [{ field: 'vector', mode: `<=> '[${promptVector}]'` }],
|
||||||
`idx:${VecModelDataPrefix}:hash`,
|
where: [
|
||||||
`@modelId:{${modelId}} @vector:[VECTOR_RANGE ${similarity} $blob]=>{$YIELD_DISTANCE_AS: score}`,
|
['model_id', model._id],
|
||||||
'RETURN',
|
'AND',
|
||||||
'1',
|
['user_id', userId],
|
||||||
'text',
|
'AND',
|
||||||
'SORTBY',
|
`vector <=> '[${promptVector}]' < ${similarity}`
|
||||||
'score',
|
],
|
||||||
'PARAMS',
|
limit: 30
|
||||||
'2',
|
});
|
||||||
'blob',
|
|
||||||
vectorToBuffer(promptVector),
|
|
||||||
'LIMIT',
|
|
||||||
'0',
|
|
||||||
'30',
|
|
||||||
'DIALECT',
|
|
||||||
'2'
|
|
||||||
]);
|
|
||||||
|
|
||||||
const formatRedisPrompt: string[] = [];
|
const formatRedisPrompt: string[] = vectorSearch.rows.map((item) => `${item.q}\n${item.a}`);
|
||||||
|
|
||||||
// 格式化响应值,获取 qa
|
|
||||||
for (let i = 2; i < 61; i += 2) {
|
|
||||||
const text = redisData[i]?.[1];
|
|
||||||
if (text) {
|
|
||||||
formatRedisPrompt.push(text);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// system 合并
|
// system 合并
|
||||||
if (prompts[0].obj === 'SYSTEM') {
|
if (prompts[0].obj === 'SYSTEM') {
|
||||||
@ -145,9 +121,9 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
|
|||||||
|
|
||||||
prompts.unshift({
|
prompts.unshift({
|
||||||
obj: 'SYSTEM',
|
obj: 'SYSTEM',
|
||||||
value: `${model.systemPrompt} 用知识库内容回答,知识库内容为: "当前时间:${dayjs().format(
|
value: `${model.systemPrompt} 知识库是最新的,下面是知识库内容:当前时间为${dayjs().format(
|
||||||
'YYYY/MM/DD HH:mm:ss'
|
'YYYY/MM/DD HH:mm:ss'
|
||||||
)} ${systemPrompt}"`
|
)}\n${systemPrompt}`
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -16,7 +16,7 @@ export async function generateVector(next = false): Promise<any> {
|
|||||||
try {
|
try {
|
||||||
// 从找出一个 status = waiting 的数据
|
// 从找出一个 status = waiting 的数据
|
||||||
const searchRes = await PgClient.select('modelData', {
|
const searchRes = await PgClient.select('modelData', {
|
||||||
field: ['id', 'q', 'user_id'],
|
fields: ['id', 'q', 'user_id'],
|
||||||
where: [['status', 'waiting']],
|
where: [['status', 'waiting']],
|
||||||
limit: 1
|
limit: 1
|
||||||
});
|
});
|
||||||
|
|||||||
@ -34,9 +34,9 @@ export const connectPg = async () => {
|
|||||||
|
|
||||||
type WhereProps = (string | [string, string | number])[];
|
type WhereProps = (string | [string, string | number])[];
|
||||||
type GetProps = {
|
type GetProps = {
|
||||||
field?: string[];
|
fields?: string[];
|
||||||
where?: WhereProps;
|
where?: WhereProps;
|
||||||
order?: { field: string; mode: 'DESC' | 'ASC' }[];
|
order?: { field: string; mode: 'DESC' | 'ASC' | string }[];
|
||||||
limit?: number;
|
limit?: number;
|
||||||
offset?: number;
|
offset?: number;
|
||||||
};
|
};
|
||||||
@ -62,7 +62,7 @@ class Pg {
|
|||||||
if (typeof item === 'string') {
|
if (typeof item === 'string') {
|
||||||
return item;
|
return item;
|
||||||
}
|
}
|
||||||
const val = typeof item[1] === 'string' ? `'${item[1]}'` : item[1];
|
const val = typeof item[1] === 'number' ? item[1] : `'${String(item[1])}'`;
|
||||||
return `${item[0]}=${val}`;
|
return `${item[0]}=${val}`;
|
||||||
})
|
})
|
||||||
.join(' ')}`
|
.join(' ')}`
|
||||||
@ -95,7 +95,9 @@ class Pg {
|
|||||||
.join(',');
|
.join(',');
|
||||||
}
|
}
|
||||||
async select<T extends QueryResultRow = any>(table: string, props: GetProps) {
|
async select<T extends QueryResultRow = any>(table: string, props: GetProps) {
|
||||||
const sql = `SELECT ${!props.field || props.field?.length === 0 ? '*' : props.field?.join(',')}
|
const sql = `SELECT ${
|
||||||
|
!props.fields || props.fields?.length === 0 ? '*' : props.fields?.join(',')
|
||||||
|
}
|
||||||
FROM ${table}
|
FROM ${table}
|
||||||
${this.getWhereStr(props.where)}
|
${this.getWhereStr(props.where)}
|
||||||
${
|
${
|
||||||
@ -123,19 +125,34 @@ class Pg {
|
|||||||
return pg.query(sql);
|
return pg.query(sql);
|
||||||
}
|
}
|
||||||
async update(table: string, props: UpdateProps) {
|
async update(table: string, props: UpdateProps) {
|
||||||
|
if (props.values.length === 0) {
|
||||||
|
return {
|
||||||
|
rowCount: 0
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
const sql = `UPDATE ${table} SET ${this.getUpdateValStr(props.values)} ${this.getWhereStr(
|
const sql = `UPDATE ${table} SET ${this.getUpdateValStr(props.values)} ${this.getWhereStr(
|
||||||
props.where
|
props.where
|
||||||
)}`;
|
)}`;
|
||||||
|
|
||||||
const pg = await connectPg();
|
const pg = await connectPg();
|
||||||
return pg.query(sql);
|
return pg.query(sql);
|
||||||
}
|
}
|
||||||
async insert(table: string, props: InsertProps) {
|
async insert(table: string, props: InsertProps) {
|
||||||
|
if (props.values.length === 0) {
|
||||||
|
return {
|
||||||
|
rowCount: 0
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
const fields = props.values[0].map((item) => item.key).join(',');
|
const fields = props.values[0].map((item) => item.key).join(',');
|
||||||
const sql = `INSERT INTO ${table} (${fields}) VALUES ${this.getInsertValStr(props.values)} `;
|
const sql = `INSERT INTO ${table} (${fields}) VALUES ${this.getInsertValStr(props.values)} `;
|
||||||
const pg = await connectPg();
|
const pg = await connectPg();
|
||||||
return pg.query(sql);
|
return pg.query(sql);
|
||||||
}
|
}
|
||||||
|
async query<T extends QueryResultRow = any>(sql: string) {
|
||||||
|
const pg = await connectPg();
|
||||||
|
return pg.query<T>(sql);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
export const PgClient = new Pg();
|
export const PgClient = new Pg();
|
||||||
|
|||||||
@ -137,5 +137,5 @@ export const systemPromptFilter = (prompts: string[], maxTokens: number) => {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return splitText.slice(0, splitText.length - 1);
|
return splitText.slice(0, splitText.length - 1).replace(/\n+/g, '\n');
|
||||||
};
|
};
|
||||||
|
|||||||
@ -51,23 +51,3 @@ export const Obj2Query = (obj: Record<string, string | number>) => {
|
|||||||
}
|
}
|
||||||
return queryParams.toString();
|
return queryParams.toString();
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
|
||||||
* 向量转成 float32 buffer 格式
|
|
||||||
*/
|
|
||||||
export const vectorToBuffer = (vector: number[]) => {
|
|
||||||
const npVector = new Float32Array(vector);
|
|
||||||
|
|
||||||
const buffer = Buffer.from(npVector.buffer);
|
|
||||||
|
|
||||||
return buffer;
|
|
||||||
};
|
|
||||||
|
|
||||||
export const formatVector = (vector: number[]) => {
|
|
||||||
let formattedVector = vector.slice(0, 1536); // 截取前1536个元素
|
|
||||||
if (vector.length > 1536) {
|
|
||||||
formattedVector = formattedVector.concat(Array(1536 - formattedVector.length).fill(0)); // 在后面添加0
|
|
||||||
}
|
|
||||||
|
|
||||||
return formattedVector;
|
|
||||||
};
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user