diff --git a/src/pages/api/chat/vectorGpt.ts b/src/pages/api/chat/vectorGpt.ts index 704f2f9dc..2b7b38460 100644 --- a/src/pages/api/chat/vectorGpt.ts +++ b/src/pages/api/chat/vectorGpt.ts @@ -1,7 +1,7 @@ import type { NextApiRequest, NextApiResponse } from 'next'; import { connectToDatabase } from '@/service/mongo'; import { authChat } from '@/service/utils/chat'; -import { httpsAgent, openaiChatFilter, systemPromptFilter } from '@/service/utils/tools'; +import { httpsAgent, systemPromptFilter } from '@/service/utils/tools'; import { ChatCompletionRequestMessage, ChatCompletionRequestMessageRoleEnum } from 'openai'; import { ChatItemType } from '@/types/chat'; import { jsonRes } from '@/service/response'; @@ -9,11 +9,9 @@ import type { ModelSchema } from '@/types/mongoSchema'; import { PassThrough } from 'stream'; import { modelList, ModelVectorSearchModeMap, ModelVectorSearchModeEnum } from '@/constants/model'; import { pushChatBill } from '@/service/events/pushBill'; -import { connectRedis } from '@/service/redis'; -import { VecModelDataPrefix } from '@/constants/redis'; -import { vectorToBuffer } from '@/utils/tools'; import { openaiCreateEmbedding, gpt35StreamResponse } from '@/service/utils/openai'; import dayjs from 'dayjs'; +import { PgClient } from '@/service/pg'; /* 发送提示词 */ export default async function handler(req: NextApiRequest, res: NextApiResponse) { @@ -43,7 +41,6 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) } await connectToDatabase(); - const redis = await connectRedis(); let startTime = Date.now(); const { chat, userApiKey, systemKey, userId } = await authChat(chatId, authorization); @@ -65,38 +62,22 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) text: prompt.value }); + // 相似度搜素 const similarity = ModelVectorSearchModeMap[model.search.mode]?.similarity || 0.22; - // 搜索系统提示词, 按相似度从 redis 中搜出相关的 q 和 text - const redisData: any[] = await redis.sendCommand([ - 'FT.SEARCH', - `idx:${VecModelDataPrefix}:hash`, - `@modelId:{${String( - chat.modelId._id - )}} @vector:[VECTOR_RANGE ${similarity} $blob]=>{$YIELD_DISTANCE_AS: score}`, - 'RETURN', - '1', - 'text', - 'SORTBY', - 'score', - 'PARAMS', - '2', - 'blob', - vectorToBuffer(promptVector), - 'LIMIT', - '0', - '30', - 'DIALECT', - '2' - ]); + const vectorSearch = await PgClient.select<{ id: string; q: string; a: string }>('modelData', { + fields: ['id', 'q', 'a'], + order: [{ field: 'vector', mode: `<=> '[${promptVector}]'` }], + where: [ + ['model_id', model._id], + 'AND', + ['user_id', userId], + 'AND', + `vector <=> '[${promptVector}]' < ${similarity}` + ], + limit: 30 + }); - const formatRedisPrompt: string[] = []; - // 格式化响应值,获取 qa - for (let i = 2; i < 61; i += 2) { - const text = redisData[i]?.[1]; - if (text) { - formatRedisPrompt.push(text); - } - } + const formatRedisPrompt: string[] = vectorSearch.rows.map((item) => `${item.q}\n${item.a}`); /* 高相似度+退出,无法匹配时直接退出 */ if ( @@ -121,9 +102,9 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) prompts.unshift({ obj: 'SYSTEM', - value: `${model.systemPrompt} 用知识库内容回答,知识库内容为: "当前时间:${dayjs().format( + value: `${model.systemPrompt} 知识库是最新的,下面是知识库内容:当前时间为${dayjs().format( 'YYYY/MM/DD HH:mm:ss' - )} ${systemPrompt}"` + )}\n${systemPrompt}` }); } diff --git a/src/pages/api/model/data/exportModelData.ts b/src/pages/api/model/data/exportModelData.ts index 215a72615..99639760b 100644 --- a/src/pages/api/model/data/exportModelData.ts +++ b/src/pages/api/model/data/exportModelData.ts @@ -2,8 +2,7 @@ import type { NextApiRequest, NextApiResponse } from 'next'; import { jsonRes } from '@/service/response'; import { connectToDatabase } from '@/service/mongo'; import { authToken } from '@/service/utils/tools'; -import { connectRedis } from '@/service/redis'; -import { VecModelDataIdx } from '@/constants/redis'; +import { PgClient } from '@/service/pg'; export default async function handler(req: NextApiRequest, res: NextApiResponse) { try { @@ -25,28 +24,23 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse< const userId = await authToken(authorization); await connectToDatabase(); - const redis = await connectRedis(); - // 从 redis 中获取数据 - const searchRes = await redis.ft.search( - VecModelDataIdx, - `@modelId:{${modelId}} @userId:{${userId}}`, - { - RETURN: ['q', 'text'], - LIMIT: { - from: 0, - size: 10000 - } - } - ); - - const data: [string, string][] = []; - - searchRes.documents.forEach((item: any) => { - if (item.value.q && item.value.text) { - data.push([item.value.q.replace(/\n/g, '\\n'), item.value.text.replace(/\n/g, '\\n')]); - } + // 统计数据 + const count = await PgClient.count('modelData', { + where: [['model_id', modelId], 'AND', ['user_id', userId]] }); + // 从 pg 中获取所有数据 + const pgData = await PgClient.select<{ q: string; a: string }>('modelData', { + where: [['model_id', modelId], 'AND', ['user_id', userId]], + fields: ['q', 'a'], + order: [{ field: 'id', mode: 'DESC' }], + limit: count + }); + + const data: [string, string][] = pgData.rows.map((item) => [ + item.q.replace(/\n/g, '\\n'), + item.a.replace(/\n/g, '\\n') + ]); jsonRes(res, { data diff --git a/src/pages/api/model/data/getModelData.ts b/src/pages/api/model/data/getModelData.ts index 38b400fa7..64a29916c 100644 --- a/src/pages/api/model/data/getModelData.ts +++ b/src/pages/api/model/data/getModelData.ts @@ -37,7 +37,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse< await connectToDatabase(); const searchRes = await PgClient.select('modelData', { - field: ['id', 'q', 'a', 'status'], + fields: ['id', 'q', 'a', 'status'], where: [['user_id', userId], 'AND', ['model_id', modelId]], order: [{ field: 'id', mode: 'DESC' }], limit: pageSize, diff --git a/src/pages/api/model/data/pushModelDataCsv.ts b/src/pages/api/model/data/pushModelDataCsv.ts index 425c3ba99..99920cd77 100644 --- a/src/pages/api/model/data/pushModelDataCsv.ts +++ b/src/pages/api/model/data/pushModelDataCsv.ts @@ -3,11 +3,8 @@ import { jsonRes } from '@/service/response'; import { connectToDatabase, Model } from '@/service/mongo'; import { authToken } from '@/service/utils/tools'; import { generateVector } from '@/service/events/generateVector'; -import { connectRedis } from '@/service/redis'; -import { VecModelDataPrefix, ModelDataStatusEnum } from '@/constants/redis'; -import { VecModelDataIdx } from '@/constants/redis'; -import { customAlphabet } from 'nanoid'; -const nanoid = customAlphabet('abcdefghijklmnopqrstuvwxyz1234567890', 12); +import { ModelDataStatusEnum } from '@/constants/model'; +import { PgClient } from '@/service/pg'; export default async function handler(req: NextApiRequest, res: NextApiResponse) { try { @@ -29,7 +26,6 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse< const userId = await authToken(authorization); await connectToDatabase(); - const redis = await connectRedis(); // 验证是否是该用户的 model const model = await Model.findOne({ @@ -47,10 +43,18 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse< try { q = q.replace(/\\n/g, '\n'); a = a.replace(/\\n/g, '\n'); - const redisSearch = await redis.ft.search(VecModelDataIdx, `@q:${q} @text:${a}`, { - RETURN: ['q', 'text'] + const count = await PgClient.count('modelData', { + where: [ + ['user_id', userId], + 'AND', + ['model_id', modelId], + 'AND', + ['q', q], + 'AND', + ['a', a] + ] }); - if (redisSearch.total > 0) { + if (count > 0) { return Promise.reject('已经存在'); } } catch (error) { @@ -62,35 +66,26 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse< }); }) ); - + // 过滤重复的内容 const filterData = searchRes .filter((item) => item.status === 'fulfilled') .map<{ q: string; a: string }>((item: any) => item.value); - // 插入 redis - const insertRedisRes = await Promise.allSettled( - filterData.map((item) => { - return redis.sendCommand([ - 'HMSET', - `${VecModelDataPrefix}:${nanoid()}`, - 'userId', - userId, - 'modelId', - String(modelId), - 'q', - item.q, - 'text', - item.a, - 'status', - ModelDataStatusEnum.waiting - ]); - }) - ); + // 插入 pg + const insertRes = await PgClient.insert('modelData', { + values: filterData.map((item) => [ + { key: 'user_id', value: userId }, + { key: 'model_id', value: modelId }, + { key: 'q', value: item.q }, + { key: 'a', value: item.a }, + { key: 'status', value: ModelDataStatusEnum.waiting } + ]) + }); generateVector(); jsonRes(res, { - data: insertRedisRes.filter((item) => item.status === 'fulfilled').length + data: insertRes.rowCount }); } catch (err) { jsonRes(res, { diff --git a/src/pages/api/model/data/putModelData.ts b/src/pages/api/model/data/putModelData.ts index c146cd175..9e07e650e 100644 --- a/src/pages/api/model/data/putModelData.ts +++ b/src/pages/api/model/data/putModelData.ts @@ -1,13 +1,13 @@ import type { NextApiRequest, NextApiResponse } from 'next'; import { jsonRes } from '@/service/response'; import { authToken } from '@/service/utils/tools'; -import { connectRedis } from '@/service/redis'; import { ModelDataStatusEnum } from '@/constants/redis'; import { generateVector } from '@/service/events/generateVector'; +import { PgClient } from '@/service/pg'; export default async function handler(req: NextApiRequest, res: NextApiResponse) { try { - const { dataId, text, q } = req.body as { dataId: string; text: string; q?: string }; + const { dataId, a, q } = req.body as { dataId: string; a: string; q?: string }; const { authorization } = req.headers; if (!authorization) { @@ -21,26 +21,21 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse< // 凭证校验 const userId = await authToken(authorization); - const redis = await connectRedis(); + // 更新 pg 内容 + await PgClient.update('modelData', { + where: [['id', dataId], 'AND', ['user_id', userId]], + values: [ + { key: 'a', value: a }, + ...(q + ? [ + { key: 'q', value: q }, + { key: 'status', value: ModelDataStatusEnum.waiting } + ] + : []) + ] + }); - // 校验是否为该用户的数据 - const dataItemUserId = await redis.hGet(dataId, 'userId'); - if (dataItemUserId !== userId) { - throw new Error('无权操作'); - } - - // 更新 - await redis.sendCommand([ - 'HMSET', - dataId, - ...(q ? ['q', q, 'status', ModelDataStatusEnum.waiting] : []), - 'text', - text - ]); - - if (q) { - generateVector(); - } + q && generateVector(); jsonRes(res); } catch (err) { diff --git a/src/pages/api/model/del.ts b/src/pages/api/model/del.ts index be2780619..e82cade53 100644 --- a/src/pages/api/model/del.ts +++ b/src/pages/api/model/del.ts @@ -6,13 +6,12 @@ import { getUserApiOpenai } from '@/service/utils/openai'; import { TrainingStatusEnum } from '@/constants/model'; import { TrainingItemType } from '@/types/training'; import { httpsAgent } from '@/service/utils/tools'; -import { connectRedis } from '@/service/redis'; -import { VecModelDataIdx } from '@/constants/redis'; +import { PgClient } from '@/service/pg'; /* 获取我的模型 */ export default async function handler(req: NextApiRequest, res: NextApiResponse) { try { - const { modelId } = req.query; + const { modelId } = req.query as { modelId: string }; const { authorization } = req.headers; if (!authorization) { @@ -37,21 +36,11 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse< } await connectToDatabase(); - const redis = await connectRedis(); - // 获取 redis 中模型关联的所有数据 - const searchRes = await redis.ft.search( - VecModelDataIdx, - `@modelId:{${modelId}} @userId:{${userId}}`, - { - LIMIT: { - from: 0, - size: 10000 - } - } - ); - // 删除 redis 内容 - await Promise.all(searchRes.documents.map((item) => redis.del(item.id))); + // 删除 pg 中所有该模型的数据 + await PgClient.delete('modelData', { + where: [['user_id', userId], 'AND', ['model_id', modelId]] + }); // 删除对应的聊天 await Chat.deleteMany({ diff --git a/src/pages/api/openapi/chat/lafGpt.ts b/src/pages/api/openapi/chat/lafGpt.ts index 2019fca06..76fafe353 100644 --- a/src/pages/api/openapi/chat/lafGpt.ts +++ b/src/pages/api/openapi/chat/lafGpt.ts @@ -7,12 +7,15 @@ import { ChatCompletionRequestMessage, ChatCompletionRequestMessageRoleEnum } fr import { ChatItemType } from '@/types/chat'; import { jsonRes } from '@/service/response'; import { PassThrough } from 'stream'; -import { ChatModelNameEnum, modelList, ChatModelNameMap } from '@/constants/model'; +import { + ChatModelNameEnum, + modelList, + ChatModelNameMap, + ModelVectorSearchModeMap +} from '@/constants/model'; import { pushChatBill } from '@/service/events/pushBill'; -import { connectRedis } from '@/service/redis'; -import { VecModelDataPrefix } from '@/constants/redis'; -import { vectorToBuffer } from '@/utils/tools'; import { openaiCreateEmbedding, gpt35StreamResponse } from '@/service/utils/openai'; +import { PgClient } from '@/service/pg'; /* 发送提示词 */ export default async function handler(req: NextApiRequest, res: NextApiResponse) { @@ -46,7 +49,6 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) } await connectToDatabase(); - const redis = await connectRedis(); let startTime = Date.now(); /* 凭证校验 */ @@ -144,39 +146,29 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) // 读取对话内容 const prompts = [prompt]; - // 搜索系统提示词, 按相似度从 redis 中搜出相关的 q 和 text - const redisData: any[] = await redis.sendCommand([ - 'FT.SEARCH', - `idx:${VecModelDataPrefix}:hash`, - `@modelId:{${String(model._id)}}=>[KNN 20 @vector $blob AS score]`, - 'RETURN', - '1', - 'text', - 'SORTBY', - 'score', - 'PARAMS', - '2', - 'blob', - vectorToBuffer(promptVector), - 'DIALECT', - '2' - ]); + // 相似度搜索 + const similarity = ModelVectorSearchModeMap[model.search.mode]?.similarity || 0.22; + const vectorSearch = await PgClient.select<{ id: string; q: string; a: string }>('modelData', { + fields: ['id', 'q', 'a'], + order: [{ field: 'vector', mode: `<=> '[${promptVector}]'` }], + where: [ + ['model_id', model._id], + 'AND', + ['user_id', userId], + 'AND', + `vector <=> '[${promptVector}]' < ${similarity}` + ], + limit: 30 + }); - // 格式化响应值,获取 qa - const formatRedisPrompt: string[] = []; - for (let i = 2; i < 42; i += 2) { - const text = redisData[i]?.[1]; - if (text) { - formatRedisPrompt.push(text); - } - } + const formatRedisPrompt: string[] = vectorSearch.rows.map((item) => `${item.q}\n${item.a}`); // textArr 筛选,最多 3000 tokens const systemPrompt = systemPromptFilter(formatRedisPrompt, 3000); prompts.unshift({ obj: 'SYSTEM', - value: `${model.systemPrompt} 知识库内容是最新的,知识库内容为: "${systemPrompt}"` + value: `${model.systemPrompt} 知识库是最新的,下面是知识库内容:${systemPrompt}` }); // 控制在 tokens 数量,防止超出 diff --git a/src/pages/api/openapi/chat/vectorGpt.ts b/src/pages/api/openapi/chat/vectorGpt.ts index 84749e3f5..6e740fb91 100644 --- a/src/pages/api/openapi/chat/vectorGpt.ts +++ b/src/pages/api/openapi/chat/vectorGpt.ts @@ -1,22 +1,15 @@ import type { NextApiRequest, NextApiResponse } from 'next'; import { connectToDatabase, Model } from '@/service/mongo'; -import { - httpsAgent, - openaiChatFilter, - systemPromptFilter, - authOpenApiKey -} from '@/service/utils/tools'; +import { httpsAgent, systemPromptFilter, authOpenApiKey } from '@/service/utils/tools'; import { ChatCompletionRequestMessage, ChatCompletionRequestMessageRoleEnum } from 'openai'; import { ChatItemType } from '@/types/chat'; import { jsonRes } from '@/service/response'; import { PassThrough } from 'stream'; import { modelList, ModelVectorSearchModeMap, ModelVectorSearchModeEnum } from '@/constants/model'; import { pushChatBill } from '@/service/events/pushBill'; -import { connectRedis } from '@/service/redis'; -import { VecModelDataPrefix } from '@/constants/redis'; -import { vectorToBuffer } from '@/utils/tools'; import { openaiCreateEmbedding, gpt35StreamResponse } from '@/service/utils/openai'; import dayjs from 'dayjs'; +import { PgClient } from '@/service/pg'; /* 发送提示词 */ export default async function handler(req: NextApiRequest, res: NextApiResponse) { @@ -56,7 +49,6 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) } await connectToDatabase(); - const redis = await connectRedis(); let startTime = Date.now(); /* 凭证校验 */ @@ -84,38 +76,22 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) text: prompts[prompts.length - 1].value // 取最后一个 }); - // 搜索系统提示词, 按相似度从 redis 中搜出相关的 q 和 text + // 相似度搜素 const similarity = ModelVectorSearchModeMap[model.search.mode]?.similarity || 0.22; - // 搜索系统提示词, 按相似度从 redis 中搜出相关的 q 和 text - const redisData: any[] = await redis.sendCommand([ - 'FT.SEARCH', - `idx:${VecModelDataPrefix}:hash`, - `@modelId:{${modelId}} @vector:[VECTOR_RANGE ${similarity} $blob]=>{$YIELD_DISTANCE_AS: score}`, - 'RETURN', - '1', - 'text', - 'SORTBY', - 'score', - 'PARAMS', - '2', - 'blob', - vectorToBuffer(promptVector), - 'LIMIT', - '0', - '30', - 'DIALECT', - '2' - ]); + const vectorSearch = await PgClient.select<{ id: string; q: string; a: string }>('modelData', { + fields: ['id', 'q', 'a'], + order: [{ field: 'vector', mode: `<=> '[${promptVector}]'` }], + where: [ + ['model_id', model._id], + 'AND', + ['user_id', userId], + 'AND', + `vector <=> '[${promptVector}]' < ${similarity}` + ], + limit: 30 + }); - const formatRedisPrompt: string[] = []; - - // 格式化响应值,获取 qa - for (let i = 2; i < 61; i += 2) { - const text = redisData[i]?.[1]; - if (text) { - formatRedisPrompt.push(text); - } - } + const formatRedisPrompt: string[] = vectorSearch.rows.map((item) => `${item.q}\n${item.a}`); // system 合并 if (prompts[0].obj === 'SYSTEM') { @@ -145,9 +121,9 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) prompts.unshift({ obj: 'SYSTEM', - value: `${model.systemPrompt} 用知识库内容回答,知识库内容为: "当前时间:${dayjs().format( + value: `${model.systemPrompt} 知识库是最新的,下面是知识库内容:当前时间为${dayjs().format( 'YYYY/MM/DD HH:mm:ss' - )} ${systemPrompt}"` + )}\n${systemPrompt}` }); } diff --git a/src/service/events/generateVector.ts b/src/service/events/generateVector.ts index 76f462572..7d88d58bf 100644 --- a/src/service/events/generateVector.ts +++ b/src/service/events/generateVector.ts @@ -16,7 +16,7 @@ export async function generateVector(next = false): Promise { try { // 从找出一个 status = waiting 的数据 const searchRes = await PgClient.select('modelData', { - field: ['id', 'q', 'user_id'], + fields: ['id', 'q', 'user_id'], where: [['status', 'waiting']], limit: 1 }); diff --git a/src/service/pg.ts b/src/service/pg.ts index 53c51e937..342a09d3f 100644 --- a/src/service/pg.ts +++ b/src/service/pg.ts @@ -34,9 +34,9 @@ export const connectPg = async () => { type WhereProps = (string | [string, string | number])[]; type GetProps = { - field?: string[]; + fields?: string[]; where?: WhereProps; - order?: { field: string; mode: 'DESC' | 'ASC' }[]; + order?: { field: string; mode: 'DESC' | 'ASC' | string }[]; limit?: number; offset?: number; }; @@ -62,7 +62,7 @@ class Pg { if (typeof item === 'string') { return item; } - const val = typeof item[1] === 'string' ? `'${item[1]}'` : item[1]; + const val = typeof item[1] === 'number' ? item[1] : `'${String(item[1])}'`; return `${item[0]}=${val}`; }) .join(' ')}` @@ -95,7 +95,9 @@ class Pg { .join(','); } async select(table: string, props: GetProps) { - const sql = `SELECT ${!props.field || props.field?.length === 0 ? '*' : props.field?.join(',')} + const sql = `SELECT ${ + !props.fields || props.fields?.length === 0 ? '*' : props.fields?.join(',') + } FROM ${table} ${this.getWhereStr(props.where)} ${ @@ -123,19 +125,34 @@ class Pg { return pg.query(sql); } async update(table: string, props: UpdateProps) { + if (props.values.length === 0) { + return { + rowCount: 0 + }; + } + const sql = `UPDATE ${table} SET ${this.getUpdateValStr(props.values)} ${this.getWhereStr( props.where )}`; - const pg = await connectPg(); return pg.query(sql); } async insert(table: string, props: InsertProps) { + if (props.values.length === 0) { + return { + rowCount: 0 + }; + } + const fields = props.values[0].map((item) => item.key).join(','); const sql = `INSERT INTO ${table} (${fields}) VALUES ${this.getInsertValStr(props.values)} `; const pg = await connectPg(); return pg.query(sql); } + async query(sql: string) { + const pg = await connectPg(); + return pg.query(sql); + } } export const PgClient = new Pg(); diff --git a/src/service/utils/tools.ts b/src/service/utils/tools.ts index 38da94be9..b2cab6f91 100644 --- a/src/service/utils/tools.ts +++ b/src/service/utils/tools.ts @@ -137,5 +137,5 @@ export const systemPromptFilter = (prompts: string[], maxTokens: number) => { } } - return splitText.slice(0, splitText.length - 1); + return splitText.slice(0, splitText.length - 1).replace(/\n+/g, '\n'); }; diff --git a/src/utils/tools.ts b/src/utils/tools.ts index 71632d233..64d935402 100644 --- a/src/utils/tools.ts +++ b/src/utils/tools.ts @@ -51,23 +51,3 @@ export const Obj2Query = (obj: Record) => { } return queryParams.toString(); }; - -/** - * 向量转成 float32 buffer 格式 - */ -export const vectorToBuffer = (vector: number[]) => { - const npVector = new Float32Array(vector); - - const buffer = Buffer.from(npVector.buffer); - - return buffer; -}; - -export const formatVector = (vector: number[]) => { - let formattedVector = vector.slice(0, 1536); // 截取前1536个元素 - if (vector.length > 1536) { - formattedVector = formattedVector.concat(Array(1536 - formattedVector.length).fill(0)); // 在后面添加0 - } - - return formattedVector; -};