feat: 替换redis搜索

This commit is contained in:
archer 2023-04-19 12:00:28 +08:00
parent 867d69659f
commit 1e5714da1b
No known key found for this signature in database
GPG Key ID: 569A5660D2379E28
12 changed files with 147 additions and 228 deletions

View File

@ -1,7 +1,7 @@
import type { NextApiRequest, NextApiResponse } from 'next'; import type { NextApiRequest, NextApiResponse } from 'next';
import { connectToDatabase } from '@/service/mongo'; import { connectToDatabase } from '@/service/mongo';
import { authChat } from '@/service/utils/chat'; import { authChat } from '@/service/utils/chat';
import { httpsAgent, openaiChatFilter, systemPromptFilter } from '@/service/utils/tools'; import { httpsAgent, systemPromptFilter } from '@/service/utils/tools';
import { ChatCompletionRequestMessage, ChatCompletionRequestMessageRoleEnum } from 'openai'; import { ChatCompletionRequestMessage, ChatCompletionRequestMessageRoleEnum } from 'openai';
import { ChatItemType } from '@/types/chat'; import { ChatItemType } from '@/types/chat';
import { jsonRes } from '@/service/response'; import { jsonRes } from '@/service/response';
@ -9,11 +9,9 @@ import type { ModelSchema } from '@/types/mongoSchema';
import { PassThrough } from 'stream'; import { PassThrough } from 'stream';
import { modelList, ModelVectorSearchModeMap, ModelVectorSearchModeEnum } from '@/constants/model'; import { modelList, ModelVectorSearchModeMap, ModelVectorSearchModeEnum } from '@/constants/model';
import { pushChatBill } from '@/service/events/pushBill'; import { pushChatBill } from '@/service/events/pushBill';
import { connectRedis } from '@/service/redis';
import { VecModelDataPrefix } from '@/constants/redis';
import { vectorToBuffer } from '@/utils/tools';
import { openaiCreateEmbedding, gpt35StreamResponse } from '@/service/utils/openai'; import { openaiCreateEmbedding, gpt35StreamResponse } from '@/service/utils/openai';
import dayjs from 'dayjs'; import dayjs from 'dayjs';
import { PgClient } from '@/service/pg';
/* 发送提示词 */ /* 发送提示词 */
export default async function handler(req: NextApiRequest, res: NextApiResponse) { export default async function handler(req: NextApiRequest, res: NextApiResponse) {
@ -43,7 +41,6 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
} }
await connectToDatabase(); await connectToDatabase();
const redis = await connectRedis();
let startTime = Date.now(); let startTime = Date.now();
const { chat, userApiKey, systemKey, userId } = await authChat(chatId, authorization); const { chat, userApiKey, systemKey, userId } = await authChat(chatId, authorization);
@ -65,38 +62,22 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
text: prompt.value text: prompt.value
}); });
// 相似度搜素
const similarity = ModelVectorSearchModeMap[model.search.mode]?.similarity || 0.22; const similarity = ModelVectorSearchModeMap[model.search.mode]?.similarity || 0.22;
// 搜索系统提示词, 按相似度从 redis 中搜出相关的 q 和 text const vectorSearch = await PgClient.select<{ id: string; q: string; a: string }>('modelData', {
const redisData: any[] = await redis.sendCommand([ fields: ['id', 'q', 'a'],
'FT.SEARCH', order: [{ field: 'vector', mode: `<=> '[${promptVector}]'` }],
`idx:${VecModelDataPrefix}:hash`, where: [
`@modelId:{${String( ['model_id', model._id],
chat.modelId._id 'AND',
)}} @vector:[VECTOR_RANGE ${similarity} $blob]=>{$YIELD_DISTANCE_AS: score}`, ['user_id', userId],
'RETURN', 'AND',
'1', `vector <=> '[${promptVector}]' < ${similarity}`
'text', ],
'SORTBY', limit: 30
'score', });
'PARAMS',
'2',
'blob',
vectorToBuffer(promptVector),
'LIMIT',
'0',
'30',
'DIALECT',
'2'
]);
const formatRedisPrompt: string[] = []; const formatRedisPrompt: string[] = vectorSearch.rows.map((item) => `${item.q}\n${item.a}`);
// 格式化响应值,获取 qa
for (let i = 2; i < 61; i += 2) {
const text = redisData[i]?.[1];
if (text) {
formatRedisPrompt.push(text);
}
}
/* 高相似度+退出,无法匹配时直接退出 */ /* 高相似度+退出,无法匹配时直接退出 */
if ( if (
@ -121,9 +102,9 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
prompts.unshift({ prompts.unshift({
obj: 'SYSTEM', obj: 'SYSTEM',
value: `${model.systemPrompt} 用知识库内容回答,知识库内容为: "当前时间:${dayjs().format( value: `${model.systemPrompt} 知识库是最新的,下面是知识库内容:当前时间为${dayjs().format(
'YYYY/MM/DD HH:mm:ss' 'YYYY/MM/DD HH:mm:ss'
)} ${systemPrompt}"` )}\n${systemPrompt}`
}); });
} }

View File

@ -2,8 +2,7 @@ import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response'; import { jsonRes } from '@/service/response';
import { connectToDatabase } from '@/service/mongo'; import { connectToDatabase } from '@/service/mongo';
import { authToken } from '@/service/utils/tools'; import { authToken } from '@/service/utils/tools';
import { connectRedis } from '@/service/redis'; import { PgClient } from '@/service/pg';
import { VecModelDataIdx } from '@/constants/redis';
export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) { export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
try { try {
@ -25,28 +24,23 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
const userId = await authToken(authorization); const userId = await authToken(authorization);
await connectToDatabase(); await connectToDatabase();
const redis = await connectRedis();
// 从 redis 中获取数据 // 统计数据
const searchRes = await redis.ft.search( const count = await PgClient.count('modelData', {
VecModelDataIdx, where: [['model_id', modelId], 'AND', ['user_id', userId]]
`@modelId:{${modelId}} @userId:{${userId}}`,
{
RETURN: ['q', 'text'],
LIMIT: {
from: 0,
size: 10000
}
}
);
const data: [string, string][] = [];
searchRes.documents.forEach((item: any) => {
if (item.value.q && item.value.text) {
data.push([item.value.q.replace(/\n/g, '\\n'), item.value.text.replace(/\n/g, '\\n')]);
}
}); });
// 从 pg 中获取所有数据
const pgData = await PgClient.select<{ q: string; a: string }>('modelData', {
where: [['model_id', modelId], 'AND', ['user_id', userId]],
fields: ['q', 'a'],
order: [{ field: 'id', mode: 'DESC' }],
limit: count
});
const data: [string, string][] = pgData.rows.map((item) => [
item.q.replace(/\n/g, '\\n'),
item.a.replace(/\n/g, '\\n')
]);
jsonRes(res, { jsonRes(res, {
data data

View File

@ -37,7 +37,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
await connectToDatabase(); await connectToDatabase();
const searchRes = await PgClient.select<PgModelDataItemType>('modelData', { const searchRes = await PgClient.select<PgModelDataItemType>('modelData', {
field: ['id', 'q', 'a', 'status'], fields: ['id', 'q', 'a', 'status'],
where: [['user_id', userId], 'AND', ['model_id', modelId]], where: [['user_id', userId], 'AND', ['model_id', modelId]],
order: [{ field: 'id', mode: 'DESC' }], order: [{ field: 'id', mode: 'DESC' }],
limit: pageSize, limit: pageSize,

View File

@ -3,11 +3,8 @@ import { jsonRes } from '@/service/response';
import { connectToDatabase, Model } from '@/service/mongo'; import { connectToDatabase, Model } from '@/service/mongo';
import { authToken } from '@/service/utils/tools'; import { authToken } from '@/service/utils/tools';
import { generateVector } from '@/service/events/generateVector'; import { generateVector } from '@/service/events/generateVector';
import { connectRedis } from '@/service/redis'; import { ModelDataStatusEnum } from '@/constants/model';
import { VecModelDataPrefix, ModelDataStatusEnum } from '@/constants/redis'; import { PgClient } from '@/service/pg';
import { VecModelDataIdx } from '@/constants/redis';
import { customAlphabet } from 'nanoid';
const nanoid = customAlphabet('abcdefghijklmnopqrstuvwxyz1234567890', 12);
export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) { export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
try { try {
@ -29,7 +26,6 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
const userId = await authToken(authorization); const userId = await authToken(authorization);
await connectToDatabase(); await connectToDatabase();
const redis = await connectRedis();
// 验证是否是该用户的 model // 验证是否是该用户的 model
const model = await Model.findOne({ const model = await Model.findOne({
@ -47,10 +43,18 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
try { try {
q = q.replace(/\\n/g, '\n'); q = q.replace(/\\n/g, '\n');
a = a.replace(/\\n/g, '\n'); a = a.replace(/\\n/g, '\n');
const redisSearch = await redis.ft.search(VecModelDataIdx, `@q:${q} @text:${a}`, { const count = await PgClient.count('modelData', {
RETURN: ['q', 'text'] where: [
['user_id', userId],
'AND',
['model_id', modelId],
'AND',
['q', q],
'AND',
['a', a]
]
}); });
if (redisSearch.total > 0) { if (count > 0) {
return Promise.reject('已经存在'); return Promise.reject('已经存在');
} }
} catch (error) { } catch (error) {
@ -62,35 +66,26 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
}); });
}) })
); );
// 过滤重复的内容
const filterData = searchRes const filterData = searchRes
.filter((item) => item.status === 'fulfilled') .filter((item) => item.status === 'fulfilled')
.map<{ q: string; a: string }>((item: any) => item.value); .map<{ q: string; a: string }>((item: any) => item.value);
// 插入 redis // 插入 pg
const insertRedisRes = await Promise.allSettled( const insertRes = await PgClient.insert('modelData', {
filterData.map((item) => { values: filterData.map((item) => [
return redis.sendCommand([ { key: 'user_id', value: userId },
'HMSET', { key: 'model_id', value: modelId },
`${VecModelDataPrefix}:${nanoid()}`, { key: 'q', value: item.q },
'userId', { key: 'a', value: item.a },
userId, { key: 'status', value: ModelDataStatusEnum.waiting }
'modelId', ])
String(modelId), });
'q',
item.q,
'text',
item.a,
'status',
ModelDataStatusEnum.waiting
]);
})
);
generateVector(); generateVector();
jsonRes(res, { jsonRes(res, {
data: insertRedisRes.filter((item) => item.status === 'fulfilled').length data: insertRes.rowCount
}); });
} catch (err) { } catch (err) {
jsonRes(res, { jsonRes(res, {

View File

@ -1,13 +1,13 @@
import type { NextApiRequest, NextApiResponse } from 'next'; import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response'; import { jsonRes } from '@/service/response';
import { authToken } from '@/service/utils/tools'; import { authToken } from '@/service/utils/tools';
import { connectRedis } from '@/service/redis';
import { ModelDataStatusEnum } from '@/constants/redis'; import { ModelDataStatusEnum } from '@/constants/redis';
import { generateVector } from '@/service/events/generateVector'; import { generateVector } from '@/service/events/generateVector';
import { PgClient } from '@/service/pg';
export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) { export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
try { try {
const { dataId, text, q } = req.body as { dataId: string; text: string; q?: string }; const { dataId, a, q } = req.body as { dataId: string; a: string; q?: string };
const { authorization } = req.headers; const { authorization } = req.headers;
if (!authorization) { if (!authorization) {
@ -21,26 +21,21 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
// 凭证校验 // 凭证校验
const userId = await authToken(authorization); const userId = await authToken(authorization);
const redis = await connectRedis(); // 更新 pg 内容
await PgClient.update('modelData', {
where: [['id', dataId], 'AND', ['user_id', userId]],
values: [
{ key: 'a', value: a },
...(q
? [
{ key: 'q', value: q },
{ key: 'status', value: ModelDataStatusEnum.waiting }
]
: [])
]
});
// 校验是否为该用户的数据 q && generateVector();
const dataItemUserId = await redis.hGet(dataId, 'userId');
if (dataItemUserId !== userId) {
throw new Error('无权操作');
}
// 更新
await redis.sendCommand([
'HMSET',
dataId,
...(q ? ['q', q, 'status', ModelDataStatusEnum.waiting] : []),
'text',
text
]);
if (q) {
generateVector();
}
jsonRes(res); jsonRes(res);
} catch (err) { } catch (err) {

View File

@ -6,13 +6,12 @@ import { getUserApiOpenai } from '@/service/utils/openai';
import { TrainingStatusEnum } from '@/constants/model'; import { TrainingStatusEnum } from '@/constants/model';
import { TrainingItemType } from '@/types/training'; import { TrainingItemType } from '@/types/training';
import { httpsAgent } from '@/service/utils/tools'; import { httpsAgent } from '@/service/utils/tools';
import { connectRedis } from '@/service/redis'; import { PgClient } from '@/service/pg';
import { VecModelDataIdx } from '@/constants/redis';
/* 获取我的模型 */ /* 获取我的模型 */
export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) { export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
try { try {
const { modelId } = req.query; const { modelId } = req.query as { modelId: string };
const { authorization } = req.headers; const { authorization } = req.headers;
if (!authorization) { if (!authorization) {
@ -37,21 +36,11 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
} }
await connectToDatabase(); await connectToDatabase();
const redis = await connectRedis();
// 获取 redis 中模型关联的所有数据 // 删除 pg 中所有该模型的数据
const searchRes = await redis.ft.search( await PgClient.delete('modelData', {
VecModelDataIdx, where: [['user_id', userId], 'AND', ['model_id', modelId]]
`@modelId:{${modelId}} @userId:{${userId}}`, });
{
LIMIT: {
from: 0,
size: 10000
}
}
);
// 删除 redis 内容
await Promise.all(searchRes.documents.map((item) => redis.del(item.id)));
// 删除对应的聊天 // 删除对应的聊天
await Chat.deleteMany({ await Chat.deleteMany({

View File

@ -7,12 +7,15 @@ import { ChatCompletionRequestMessage, ChatCompletionRequestMessageRoleEnum } fr
import { ChatItemType } from '@/types/chat'; import { ChatItemType } from '@/types/chat';
import { jsonRes } from '@/service/response'; import { jsonRes } from '@/service/response';
import { PassThrough } from 'stream'; import { PassThrough } from 'stream';
import { ChatModelNameEnum, modelList, ChatModelNameMap } from '@/constants/model'; import {
ChatModelNameEnum,
modelList,
ChatModelNameMap,
ModelVectorSearchModeMap
} from '@/constants/model';
import { pushChatBill } from '@/service/events/pushBill'; import { pushChatBill } from '@/service/events/pushBill';
import { connectRedis } from '@/service/redis';
import { VecModelDataPrefix } from '@/constants/redis';
import { vectorToBuffer } from '@/utils/tools';
import { openaiCreateEmbedding, gpt35StreamResponse } from '@/service/utils/openai'; import { openaiCreateEmbedding, gpt35StreamResponse } from '@/service/utils/openai';
import { PgClient } from '@/service/pg';
/* 发送提示词 */ /* 发送提示词 */
export default async function handler(req: NextApiRequest, res: NextApiResponse) { export default async function handler(req: NextApiRequest, res: NextApiResponse) {
@ -46,7 +49,6 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
} }
await connectToDatabase(); await connectToDatabase();
const redis = await connectRedis();
let startTime = Date.now(); let startTime = Date.now();
/* 凭证校验 */ /* 凭证校验 */
@ -144,39 +146,29 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
// 读取对话内容 // 读取对话内容
const prompts = [prompt]; const prompts = [prompt];
// 搜索系统提示词, 按相似度从 redis 中搜出相关的 q 和 text // 相似度搜索
const redisData: any[] = await redis.sendCommand([ const similarity = ModelVectorSearchModeMap[model.search.mode]?.similarity || 0.22;
'FT.SEARCH', const vectorSearch = await PgClient.select<{ id: string; q: string; a: string }>('modelData', {
`idx:${VecModelDataPrefix}:hash`, fields: ['id', 'q', 'a'],
`@modelId:{${String(model._id)}}=>[KNN 20 @vector $blob AS score]`, order: [{ field: 'vector', mode: `<=> '[${promptVector}]'` }],
'RETURN', where: [
'1', ['model_id', model._id],
'text', 'AND',
'SORTBY', ['user_id', userId],
'score', 'AND',
'PARAMS', `vector <=> '[${promptVector}]' < ${similarity}`
'2', ],
'blob', limit: 30
vectorToBuffer(promptVector), });
'DIALECT',
'2'
]);
// 格式化响应值,获取 qa const formatRedisPrompt: string[] = vectorSearch.rows.map((item) => `${item.q}\n${item.a}`);
const formatRedisPrompt: string[] = [];
for (let i = 2; i < 42; i += 2) {
const text = redisData[i]?.[1];
if (text) {
formatRedisPrompt.push(text);
}
}
// textArr 筛选,最多 3000 tokens // textArr 筛选,最多 3000 tokens
const systemPrompt = systemPromptFilter(formatRedisPrompt, 3000); const systemPrompt = systemPromptFilter(formatRedisPrompt, 3000);
prompts.unshift({ prompts.unshift({
obj: 'SYSTEM', obj: 'SYSTEM',
value: `${model.systemPrompt} 知识库内容是最新的,知识库内容为: "${systemPrompt}"` value: `${model.systemPrompt} 知识库是最新的,下面是知识库内容:${systemPrompt}`
}); });
// 控制在 tokens 数量,防止超出 // 控制在 tokens 数量,防止超出

View File

@ -1,22 +1,15 @@
import type { NextApiRequest, NextApiResponse } from 'next'; import type { NextApiRequest, NextApiResponse } from 'next';
import { connectToDatabase, Model } from '@/service/mongo'; import { connectToDatabase, Model } from '@/service/mongo';
import { import { httpsAgent, systemPromptFilter, authOpenApiKey } from '@/service/utils/tools';
httpsAgent,
openaiChatFilter,
systemPromptFilter,
authOpenApiKey
} from '@/service/utils/tools';
import { ChatCompletionRequestMessage, ChatCompletionRequestMessageRoleEnum } from 'openai'; import { ChatCompletionRequestMessage, ChatCompletionRequestMessageRoleEnum } from 'openai';
import { ChatItemType } from '@/types/chat'; import { ChatItemType } from '@/types/chat';
import { jsonRes } from '@/service/response'; import { jsonRes } from '@/service/response';
import { PassThrough } from 'stream'; import { PassThrough } from 'stream';
import { modelList, ModelVectorSearchModeMap, ModelVectorSearchModeEnum } from '@/constants/model'; import { modelList, ModelVectorSearchModeMap, ModelVectorSearchModeEnum } from '@/constants/model';
import { pushChatBill } from '@/service/events/pushBill'; import { pushChatBill } from '@/service/events/pushBill';
import { connectRedis } from '@/service/redis';
import { VecModelDataPrefix } from '@/constants/redis';
import { vectorToBuffer } from '@/utils/tools';
import { openaiCreateEmbedding, gpt35StreamResponse } from '@/service/utils/openai'; import { openaiCreateEmbedding, gpt35StreamResponse } from '@/service/utils/openai';
import dayjs from 'dayjs'; import dayjs from 'dayjs';
import { PgClient } from '@/service/pg';
/* 发送提示词 */ /* 发送提示词 */
export default async function handler(req: NextApiRequest, res: NextApiResponse) { export default async function handler(req: NextApiRequest, res: NextApiResponse) {
@ -56,7 +49,6 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
} }
await connectToDatabase(); await connectToDatabase();
const redis = await connectRedis();
let startTime = Date.now(); let startTime = Date.now();
/* 凭证校验 */ /* 凭证校验 */
@ -84,38 +76,22 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
text: prompts[prompts.length - 1].value // 取最后一个 text: prompts[prompts.length - 1].value // 取最后一个
}); });
// 搜索系统提示词, 按相似度从 redis 中搜出相关的 q 和 text // 相似度搜素
const similarity = ModelVectorSearchModeMap[model.search.mode]?.similarity || 0.22; const similarity = ModelVectorSearchModeMap[model.search.mode]?.similarity || 0.22;
// 搜索系统提示词, 按相似度从 redis 中搜出相关的 q 和 text const vectorSearch = await PgClient.select<{ id: string; q: string; a: string }>('modelData', {
const redisData: any[] = await redis.sendCommand([ fields: ['id', 'q', 'a'],
'FT.SEARCH', order: [{ field: 'vector', mode: `<=> '[${promptVector}]'` }],
`idx:${VecModelDataPrefix}:hash`, where: [
`@modelId:{${modelId}} @vector:[VECTOR_RANGE ${similarity} $blob]=>{$YIELD_DISTANCE_AS: score}`, ['model_id', model._id],
'RETURN', 'AND',
'1', ['user_id', userId],
'text', 'AND',
'SORTBY', `vector <=> '[${promptVector}]' < ${similarity}`
'score', ],
'PARAMS', limit: 30
'2', });
'blob',
vectorToBuffer(promptVector),
'LIMIT',
'0',
'30',
'DIALECT',
'2'
]);
const formatRedisPrompt: string[] = []; const formatRedisPrompt: string[] = vectorSearch.rows.map((item) => `${item.q}\n${item.a}`);
// 格式化响应值,获取 qa
for (let i = 2; i < 61; i += 2) {
const text = redisData[i]?.[1];
if (text) {
formatRedisPrompt.push(text);
}
}
// system 合并 // system 合并
if (prompts[0].obj === 'SYSTEM') { if (prompts[0].obj === 'SYSTEM') {
@ -145,9 +121,9 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
prompts.unshift({ prompts.unshift({
obj: 'SYSTEM', obj: 'SYSTEM',
value: `${model.systemPrompt} 用知识库内容回答,知识库内容为: "当前时间:${dayjs().format( value: `${model.systemPrompt} 知识库是最新的,下面是知识库内容:当前时间为${dayjs().format(
'YYYY/MM/DD HH:mm:ss' 'YYYY/MM/DD HH:mm:ss'
)} ${systemPrompt}"` )}\n${systemPrompt}`
}); });
} }

View File

@ -16,7 +16,7 @@ export async function generateVector(next = false): Promise<any> {
try { try {
// 从找出一个 status = waiting 的数据 // 从找出一个 status = waiting 的数据
const searchRes = await PgClient.select('modelData', { const searchRes = await PgClient.select('modelData', {
field: ['id', 'q', 'user_id'], fields: ['id', 'q', 'user_id'],
where: [['status', 'waiting']], where: [['status', 'waiting']],
limit: 1 limit: 1
}); });

View File

@ -34,9 +34,9 @@ export const connectPg = async () => {
type WhereProps = (string | [string, string | number])[]; type WhereProps = (string | [string, string | number])[];
type GetProps = { type GetProps = {
field?: string[]; fields?: string[];
where?: WhereProps; where?: WhereProps;
order?: { field: string; mode: 'DESC' | 'ASC' }[]; order?: { field: string; mode: 'DESC' | 'ASC' | string }[];
limit?: number; limit?: number;
offset?: number; offset?: number;
}; };
@ -62,7 +62,7 @@ class Pg {
if (typeof item === 'string') { if (typeof item === 'string') {
return item; return item;
} }
const val = typeof item[1] === 'string' ? `'${item[1]}'` : item[1]; const val = typeof item[1] === 'number' ? item[1] : `'${String(item[1])}'`;
return `${item[0]}=${val}`; return `${item[0]}=${val}`;
}) })
.join(' ')}` .join(' ')}`
@ -95,7 +95,9 @@ class Pg {
.join(','); .join(',');
} }
async select<T extends QueryResultRow = any>(table: string, props: GetProps) { async select<T extends QueryResultRow = any>(table: string, props: GetProps) {
const sql = `SELECT ${!props.field || props.field?.length === 0 ? '*' : props.field?.join(',')} const sql = `SELECT ${
!props.fields || props.fields?.length === 0 ? '*' : props.fields?.join(',')
}
FROM ${table} FROM ${table}
${this.getWhereStr(props.where)} ${this.getWhereStr(props.where)}
${ ${
@ -123,19 +125,34 @@ class Pg {
return pg.query(sql); return pg.query(sql);
} }
async update(table: string, props: UpdateProps) { async update(table: string, props: UpdateProps) {
if (props.values.length === 0) {
return {
rowCount: 0
};
}
const sql = `UPDATE ${table} SET ${this.getUpdateValStr(props.values)} ${this.getWhereStr( const sql = `UPDATE ${table} SET ${this.getUpdateValStr(props.values)} ${this.getWhereStr(
props.where props.where
)}`; )}`;
const pg = await connectPg(); const pg = await connectPg();
return pg.query(sql); return pg.query(sql);
} }
async insert(table: string, props: InsertProps) { async insert(table: string, props: InsertProps) {
if (props.values.length === 0) {
return {
rowCount: 0
};
}
const fields = props.values[0].map((item) => item.key).join(','); const fields = props.values[0].map((item) => item.key).join(',');
const sql = `INSERT INTO ${table} (${fields}) VALUES ${this.getInsertValStr(props.values)} `; const sql = `INSERT INTO ${table} (${fields}) VALUES ${this.getInsertValStr(props.values)} `;
const pg = await connectPg(); const pg = await connectPg();
return pg.query(sql); return pg.query(sql);
} }
async query<T extends QueryResultRow = any>(sql: string) {
const pg = await connectPg();
return pg.query<T>(sql);
}
} }
export const PgClient = new Pg(); export const PgClient = new Pg();

View File

@ -137,5 +137,5 @@ export const systemPromptFilter = (prompts: string[], maxTokens: number) => {
} }
} }
return splitText.slice(0, splitText.length - 1); return splitText.slice(0, splitText.length - 1).replace(/\n+/g, '\n');
}; };

View File

@ -51,23 +51,3 @@ export const Obj2Query = (obj: Record<string, string | number>) => {
} }
return queryParams.toString(); return queryParams.toString();
}; };
/**
* float32 buffer
*/
export const vectorToBuffer = (vector: number[]) => {
const npVector = new Float32Array(vector);
const buffer = Buffer.from(npVector.buffer);
return buffer;
};
export const formatVector = (vector: number[]) => {
let formattedVector = vector.slice(0, 1536); // 截取前1536个元素
if (vector.length > 1536) {
formattedVector = formattedVector.concat(Array(1536 - formattedVector.length).fill(0)); // 在后面添加0
}
return formattedVector;
};