feat: 模型数据管理

feat: 模型数据导入

feat: redis 向量入库

feat: 向量索引

feat: 文件导入模型

perf: 交互

perf: prompt
This commit is contained in:
archer 2023-03-29 00:22:48 +08:00
parent 713332522f
commit 2099a87908
No known key found for this signature in database
GPG Key ID: 569A5660D2379E28
45 changed files with 1522 additions and 284 deletions

View File

@ -3,4 +3,6 @@ AXIOS_PROXY_PORT=33210
MONGODB_URI= MONGODB_URI=
MY_MAIL= MY_MAIL=
MAILE_CODE= MAILE_CODE=
TOKEN_KEY= TOKEN_KEY=
OPENAIKEY=
REDIS_URL=

View File

@ -41,6 +41,7 @@
"react-hook-form": "^7.43.1", "react-hook-form": "^7.43.1",
"react-markdown": "^8.0.5", "react-markdown": "^8.0.5",
"react-syntax-highlighter": "^15.5.0", "react-syntax-highlighter": "^15.5.0",
"redis": "^4.6.5",
"rehype-katex": "^6.0.2", "rehype-katex": "^6.0.2",
"remark-gfm": "^3.0.1", "remark-gfm": "^3.0.1",
"remark-math": "^5.1.1", "remark-math": "^5.1.1",

95
pnpm-lock.yaml generated
View File

@ -47,6 +47,7 @@ specifiers:
react-hook-form: ^7.43.1 react-hook-form: ^7.43.1
react-markdown: ^8.0.5 react-markdown: ^8.0.5
react-syntax-highlighter: ^15.5.0 react-syntax-highlighter: ^15.5.0
redis: ^4.6.5
rehype-katex: ^6.0.2 rehype-katex: ^6.0.2
remark-gfm: ^3.0.1 remark-gfm: ^3.0.1
remark-math: ^5.1.1 remark-math: ^5.1.1
@ -87,6 +88,7 @@ dependencies:
react-hook-form: registry.npmmirror.com/react-hook-form/7.43.1_react@18.2.0 react-hook-form: registry.npmmirror.com/react-hook-form/7.43.1_react@18.2.0
react-markdown: registry.npmmirror.com/react-markdown/8.0.5_pmekkgnqduwlme35zpnqhenc34 react-markdown: registry.npmmirror.com/react-markdown/8.0.5_pmekkgnqduwlme35zpnqhenc34
react-syntax-highlighter: registry.npmmirror.com/react-syntax-highlighter/15.5.0_react@18.2.0 react-syntax-highlighter: registry.npmmirror.com/react-syntax-highlighter/15.5.0_react@18.2.0
redis: registry.npmmirror.com/redis/4.6.5
rehype-katex: registry.npmmirror.com/rehype-katex/6.0.2 rehype-katex: registry.npmmirror.com/rehype-katex/6.0.2
remark-gfm: registry.npmmirror.com/remark-gfm/3.0.1 remark-gfm: registry.npmmirror.com/remark-gfm/3.0.1
remark-math: registry.npmmirror.com/remark-math/5.1.1 remark-math: registry.npmmirror.com/remark-math/5.1.1
@ -4504,6 +4506,72 @@ packages:
version: 2.11.6 version: 2.11.6
dev: false dev: false
registry.npmmirror.com/@redis/bloom/1.2.0_@redis+client@1.5.6:
resolution: {integrity: sha512-HG2DFjYKbpNmVXsa0keLHp/3leGJz1mjh09f2RLGGLQZzSHpkmZWuwJbAvo3QcRY8p80m5+ZdXZdYOSBLlp7Cg==, registry: https://registry.npm.taobao.org/, tarball: https://registry.npmmirror.com/@redis/bloom/-/bloom-1.2.0.tgz}
id: registry.npmmirror.com/@redis/bloom/1.2.0
name: '@redis/bloom'
version: 1.2.0
peerDependencies:
'@redis/client': ^1.0.0
dependencies:
'@redis/client': registry.npmmirror.com/@redis/client/1.5.6
dev: false
registry.npmmirror.com/@redis/client/1.5.6:
resolution: {integrity: sha512-dFD1S6je+A47Lj22jN/upVU2fj4huR7S9APd7/ziUXsIXDL+11GPYti4Suv5y8FuXaN+0ZG4JF+y1houEJ7ToA==, registry: https://registry.npm.taobao.org/, tarball: https://registry.npmmirror.com/@redis/client/-/client-1.5.6.tgz}
name: '@redis/client'
version: 1.5.6
engines: {node: '>=14'}
dependencies:
cluster-key-slot: registry.npmmirror.com/cluster-key-slot/1.1.2
generic-pool: registry.npmmirror.com/generic-pool/3.9.0
yallist: registry.npmmirror.com/yallist/4.0.0
dev: false
registry.npmmirror.com/@redis/graph/1.1.0_@redis+client@1.5.6:
resolution: {integrity: sha512-16yZWngxyXPd+MJxeSr0dqh2AIOi8j9yXKcKCwVaKDbH3HTuETpDVPcLujhFYVPtYrngSco31BUcSa9TH31Gqg==, registry: https://registry.npm.taobao.org/, tarball: https://registry.npmmirror.com/@redis/graph/-/graph-1.1.0.tgz}
id: registry.npmmirror.com/@redis/graph/1.1.0
name: '@redis/graph'
version: 1.1.0
peerDependencies:
'@redis/client': ^1.0.0
dependencies:
'@redis/client': registry.npmmirror.com/@redis/client/1.5.6
dev: false
registry.npmmirror.com/@redis/json/1.0.4_@redis+client@1.5.6:
resolution: {integrity: sha512-LUZE2Gdrhg0Rx7AN+cZkb1e6HjoSKaeeW8rYnt89Tly13GBI5eP4CwDVr+MY8BAYfCg4/N15OUrtLoona9uSgw==, registry: https://registry.npm.taobao.org/, tarball: https://registry.npmmirror.com/@redis/json/-/json-1.0.4.tgz}
id: registry.npmmirror.com/@redis/json/1.0.4
name: '@redis/json'
version: 1.0.4
peerDependencies:
'@redis/client': ^1.0.0
dependencies:
'@redis/client': registry.npmmirror.com/@redis/client/1.5.6
dev: false
registry.npmmirror.com/@redis/search/1.1.2_@redis+client@1.5.6:
resolution: {integrity: sha512-/cMfstG/fOh/SsE+4/BQGeuH/JJloeWuH+qJzM8dbxuWvdWibWAOAHHCZTMPhV3xIlH4/cUEIA8OV5QnYpaVoA==, registry: https://registry.npm.taobao.org/, tarball: https://registry.npmmirror.com/@redis/search/-/search-1.1.2.tgz}
id: registry.npmmirror.com/@redis/search/1.1.2
name: '@redis/search'
version: 1.1.2
peerDependencies:
'@redis/client': ^1.0.0
dependencies:
'@redis/client': registry.npmmirror.com/@redis/client/1.5.6
dev: false
registry.npmmirror.com/@redis/time-series/1.0.4_@redis+client@1.5.6:
resolution: {integrity: sha512-ThUIgo2U/g7cCuZavucQTQzA9g9JbDDY2f64u3AbAoz/8vE2lt2U37LamDUVChhaDA3IRT9R6VvJwqnUfTJzng==, registry: https://registry.npm.taobao.org/, tarball: https://registry.npmmirror.com/@redis/time-series/-/time-series-1.0.4.tgz}
id: registry.npmmirror.com/@redis/time-series/1.0.4
name: '@redis/time-series'
version: 1.0.4
peerDependencies:
'@redis/client': ^1.0.0
dependencies:
'@redis/client': registry.npmmirror.com/@redis/client/1.5.6
dev: false
registry.npmmirror.com/@rushstack/eslint-patch/1.2.0: registry.npmmirror.com/@rushstack/eslint-patch/1.2.0:
resolution: {integrity: sha512-sXo/qW2/pAcmT43VoRKOJbDOfV3cYpq3szSVfIThQXNt+E4DfKj361vaAt3c88U5tPUxzEswam7GW48PJqtKAg==, registry: https://registry.npm.taobao.org/, tarball: https://registry.npmmirror.com/@rushstack/eslint-patch/-/eslint-patch-1.2.0.tgz} resolution: {integrity: sha512-sXo/qW2/pAcmT43VoRKOJbDOfV3cYpq3szSVfIThQXNt+E4DfKj361vaAt3c88U5tPUxzEswam7GW48PJqtKAg==, registry: https://registry.npm.taobao.org/, tarball: https://registry.npmmirror.com/@rushstack/eslint-patch/-/eslint-patch-1.2.0.tgz}
name: '@rushstack/eslint-patch' name: '@rushstack/eslint-patch'
@ -5562,6 +5630,13 @@ packages:
version: 0.0.1 version: 0.0.1
dev: false dev: false
registry.npmmirror.com/cluster-key-slot/1.1.2:
resolution: {integrity: sha512-RMr0FhtfXemyinomL4hrWcYJxmX6deFdCxpJzhDttxgO1+bcCnkk+9drydLVDmAMG7NE6aN/fl4F7ucU/90gAA==, registry: https://registry.npm.taobao.org/, tarball: https://registry.npmmirror.com/cluster-key-slot/-/cluster-key-slot-1.1.2.tgz}
name: cluster-key-slot
version: 1.1.2
engines: {node: '>=0.10.0'}
dev: false
registry.npmmirror.com/color-convert/1.9.3: registry.npmmirror.com/color-convert/1.9.3:
resolution: {integrity: sha512-QfAUtd+vFdAtFQcC8CCyYt1fYWxSqAiK2cSD6zDB8N3cpsEBAvRxp9zOGg6G/SHHJYAT88/az/IuDGALsNVbGg==, registry: https://registry.npm.taobao.org/, tarball: https://registry.npmmirror.com/color-convert/-/color-convert-1.9.3.tgz} resolution: {integrity: sha512-QfAUtd+vFdAtFQcC8CCyYt1fYWxSqAiK2cSD6zDB8N3cpsEBAvRxp9zOGg6G/SHHJYAT88/az/IuDGALsNVbGg==, registry: https://registry.npm.taobao.org/, tarball: https://registry.npmmirror.com/color-convert/-/color-convert-1.9.3.tgz}
name: color-convert name: color-convert
@ -6799,6 +6874,13 @@ packages:
version: 1.2.3 version: 1.2.3
dev: true dev: true
registry.npmmirror.com/generic-pool/3.9.0:
resolution: {integrity: sha512-hymDOu5B53XvN4QT9dBmZxPX4CWhBPPLguTZ9MMFeFa/Kg0xWVfylOVNlJji/E7yTZWFd/q9GO5TxDLq156D7g==, registry: https://registry.npm.taobao.org/, tarball: https://registry.npmmirror.com/generic-pool/-/generic-pool-3.9.0.tgz}
name: generic-pool
version: 3.9.0
engines: {node: '>= 4'}
dev: false
registry.npmmirror.com/gensync/1.0.0-beta.2: registry.npmmirror.com/gensync/1.0.0-beta.2:
resolution: {integrity: sha512-3hN7NaskYvMDLQY55gnW3NQ+mesEAepTqlg+VEbj7zzqEMBVNhzcGYYeqFo/TlYz6eQiFcp1HcsCZO+nGgS8zg==, registry: https://registry.npm.taobao.org/, tarball: https://registry.npmmirror.com/gensync/-/gensync-1.0.0-beta.2.tgz} resolution: {integrity: sha512-3hN7NaskYvMDLQY55gnW3NQ+mesEAepTqlg+VEbj7zzqEMBVNhzcGYYeqFo/TlYz6eQiFcp1HcsCZO+nGgS8zg==, registry: https://registry.npm.taobao.org/, tarball: https://registry.npmmirror.com/gensync/-/gensync-1.0.0-beta.2.tgz}
name: gensync name: gensync
@ -9367,6 +9449,19 @@ packages:
picomatch: registry.npmmirror.com/picomatch/2.3.1 picomatch: registry.npmmirror.com/picomatch/2.3.1
dev: false dev: false
registry.npmmirror.com/redis/4.6.5:
resolution: {integrity: sha512-O0OWA36gDQbswOdUuAhRL6mTZpHFN525HlgZgDaVNgCJIAZR3ya06NTESb0R+TUZ+BFaDpz6NnnVvoMx9meUFg==, registry: https://registry.npm.taobao.org/, tarball: https://registry.npmmirror.com/redis/-/redis-4.6.5.tgz}
name: redis
version: 4.6.5
dependencies:
'@redis/bloom': registry.npmmirror.com/@redis/bloom/1.2.0_@redis+client@1.5.6
'@redis/client': registry.npmmirror.com/@redis/client/1.5.6
'@redis/graph': registry.npmmirror.com/@redis/graph/1.1.0_@redis+client@1.5.6
'@redis/json': registry.npmmirror.com/@redis/json/1.0.4_@redis+client@1.5.6
'@redis/search': registry.npmmirror.com/@redis/search/1.1.2_@redis+client@1.5.6
'@redis/time-series': registry.npmmirror.com/@redis/time-series/1.0.4_@redis+client@1.5.6
dev: false
registry.npmmirror.com/refractor/3.6.0: registry.npmmirror.com/refractor/3.6.0:
resolution: {integrity: sha512-MY9W41IOWxxk31o+YvFCNyNzdkc9M20NoZK5vq6jkv4I/uh2zkWcfudj0Q1fovjUQJrNewS9NMzeTtqPf+n5EA==, registry: https://registry.npm.taobao.org/, tarball: https://registry.npmmirror.com/refractor/-/refractor-3.6.0.tgz} resolution: {integrity: sha512-MY9W41IOWxxk31o+YvFCNyNzdkc9M20NoZK5vq6jkv4I/uh2zkWcfudj0Q1fovjUQJrNewS9NMzeTtqPf+n5EA==, registry: https://registry.npm.taobao.org/, tarball: https://registry.npmmirror.com/refractor/-/refractor-3.6.0.tgz}
name: refractor name: refractor

View File

@ -1,7 +1,10 @@
import { GET, POST, DELETE, PUT } from './request'; import { GET, POST, DELETE, PUT } from './request';
import type { ModelSchema } from '@/types/mongoSchema'; import type { ModelSchema, ModelDataSchema } from '@/types/mongoSchema';
import { ModelUpdateParams } from '@/types/model'; import { ModelUpdateParams } from '@/types/model';
import { TrainingItemType } from '../types/training'; import { TrainingItemType } from '../types/training';
import { PagingData } from '@/types';
import { RequestPaging } from '../types/index';
import { Obj2Query } from '@/utils/tools';
export const getMyModels = () => GET<ModelSchema[]>('/model/list'); export const getMyModels = () => GET<ModelSchema[]>('/model/list');
@ -16,13 +19,35 @@ export const putModelById = (id: string, data: ModelUpdateParams) =>
PUT(`/model/update?modelId=${id}`, data); PUT(`/model/update?modelId=${id}`, data);
export const postTrainModel = (id: string, form: FormData) => export const postTrainModel = (id: string, form: FormData) =>
POST(`/model/train?modelId=${id}`, form, { POST(`/model/train/train?modelId=${id}`, form, {
headers: { headers: {
'content-type': 'multipart/form-data' 'content-type': 'multipart/form-data'
} }
}); });
export const putModelTrainingStatus = (id: string) => PUT(`/model/putTrainStatus?modelId=${id}`); export const putModelTrainingStatus = (id: string) =>
PUT(`/model/train/putTrainStatus?modelId=${id}`);
export const getModelTrainings = (id: string) => export const getModelTrainings = (id: string) =>
GET<TrainingItemType[]>(`/model/getTrainings?modelId=${id}`); GET<TrainingItemType[]>(`/model/train/getTrainings?modelId=${id}`);
/* 模型 data */
type GetModelDataListProps = RequestPaging & {
modelId: string;
};
export const getModelDataList = (props: GetModelDataListProps) =>
GET(`/model/data/getModelData?${Obj2Query(props)}`);
export const postModelDataInput = (data: {
modelId: string;
data: { text: ModelDataSchema['text']; q: ModelDataSchema['q'] }[];
}) => POST(`/model/data/pushModelDataInput`, data);
export const postModelDataFileText = (modelId: string, text: string) =>
POST(`/model/data/splitData`, { modelId, text });
export const putModelDataById = (data: { dataId: string; text: string }) =>
PUT('/model/data/putModelData', data);
export const delOneModelData = (dataId: string) =>
DELETE(`/model/data/delModelDataById?dataId=${dataId}`);

View File

@ -26,12 +26,12 @@ const navbarList = [
link: '/model/list', link: '/model/list',
activeLink: ['/model/list', '/model/detail'] activeLink: ['/model/list', '/model/detail']
}, },
{ // {
label: '数据', // label: '数据',
icon: 'icon-datafull', // icon: 'icon-datafull',
link: '/data/list', // link: '/data/list',
activeLink: ['/data/list', '/data/detail'] // activeLink: ['/data/list', '/data/detail']
}, // },
{ {
label: '账号', label: '账号',
icon: 'icon-yonghu-yuan', icon: 'icon-yonghu-yuan',

View File

@ -1,11 +1,17 @@
import type { ServiceName } from '@/types/mongoSchema'; import type { ServiceName, ModelDataType, ModelSchema } from '@/types/mongoSchema';
import { ModelSchema } from '../types/mongoSchema';
export enum ChatModelNameEnum { export enum ChatModelNameEnum {
GPT35 = 'gpt-3.5-turbo', GPT35 = 'gpt-3.5-turbo',
VECTOR_GPT = 'VECTOR_GPT',
GPT3 = 'text-davinci-003' GPT3 = 'text-davinci-003'
} }
export const ChatModelNameMap = {
[ChatModelNameEnum.GPT35]: 'gpt-3.5-turbo',
[ChatModelNameEnum.VECTOR_GPT]: 'gpt-3.5-turbo',
[ChatModelNameEnum.GPT3]: 'text-davinci-003'
};
export type ModelConstantsData = { export type ModelConstantsData = {
serviceCompany: `${ServiceName}`; serviceCompany: `${ServiceName}`;
name: string; name: string;
@ -29,6 +35,17 @@ export const modelList: ModelConstantsData[] = [
trainedMaxToken: 2000, trainedMaxToken: 2000,
maxTemperature: 2, maxTemperature: 2,
price: 3 price: 3
},
{
serviceCompany: 'openai',
name: '知识库',
model: ChatModelNameEnum.VECTOR_GPT,
trainName: 'vector',
maxToken: 4000,
contextMaxToken: 7500,
trainedMaxToken: 2000,
maxTemperature: 1,
price: 3
} }
// { // {
// serviceCompany: 'openai', // serviceCompany: 'openai',
@ -76,6 +93,11 @@ export const formatModelStatus = {
} }
}; };
export const ModelDataStatusMap = {
0: '训练完成',
1: '训练中'
};
export const defaultModel: ModelSchema = { export const defaultModel: ModelSchema = {
_id: '', _id: '',
userId: '', userId: '',

1
src/constants/redis.ts Normal file
View File

@ -0,0 +1 @@
export const VecModelDataIndex = 'model:data';

View File

@ -8,7 +8,7 @@ export const usePaging = <T = any>({
pageSize = 10, pageSize = 10,
params = {} params = {}
}: { }: {
api: (data: any) => Promise<PagingData<T>>; api: (data: any) => any;
pageSize?: number; pageSize?: number;
params?: Record<string, any>; params?: Record<string, any>;
}) => { }) => {
@ -30,7 +30,7 @@ export const usePaging = <T = any>({
setRequesting(true); setRequesting(true);
try { try {
const res = await api({ const res: PagingData<T> = await api({
pageNum: num, pageNum: num,
pageSize, pageSize,
...params ...params
@ -75,6 +75,7 @@ export const usePaging = <T = any>({
requesting, requesting,
isLoadAll, isLoadAll,
nextPage, nextPage,
initRequesting initRequesting,
setData
}; };
}; };

View File

@ -46,7 +46,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
const model: ModelSchema = chat.modelId; const model: ModelSchema = chat.modelId;
const modelConstantsData = modelList.find((item) => item.model === model.service.modelName); const modelConstantsData = modelList.find((item) => item.model === model.service.modelName);
if (!modelConstantsData) { if (!modelConstantsData) {
throw new Error('模型异常,请用 chatgpt 模型'); throw new Error('模型加载异常');
} }
// 读取对话内容 // 读取对话内容

View File

@ -0,0 +1,241 @@
import type { NextApiRequest, NextApiResponse } from 'next';
import { createParser, ParsedEvent, ReconnectInterval } from 'eventsource-parser';
import { connectToDatabase, ModelData } from '@/service/mongo';
import { getOpenAIApi, authChat } from '@/service/utils/chat';
import { httpsAgent, openaiChatFilter, systemPromptFilter } from '@/service/utils/tools';
import { ChatCompletionRequestMessage, ChatCompletionRequestMessageRoleEnum } from 'openai';
import { ChatItemType } from '@/types/chat';
import { jsonRes } from '@/service/response';
import type { ModelSchema } from '@/types/mongoSchema';
import { PassThrough } from 'stream';
import { modelList } from '@/constants/model';
import { pushChatBill } from '@/service/events/pushBill';
import { connectRedis } from '@/service/redis';
import { VecModelDataIndex } from '@/constants/redis';
import { vectorToBuffer } from '@/utils/tools';
let vectorData = [
-0.025028639, -0.010407282, 0.026523087, -0.0107438695, -0.006967359, 0.010043768, -0.012043097,
0.008724345, -0.028919589, -0.0117738275, 0.0050690062, 0.02961969
].concat(new Array(1524).fill(0));
/* 发送提示词 */
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
let step = 0; // step=1时表示开始了流响应
const stream = new PassThrough();
stream.on('error', () => {
console.log('error: ', 'stream error');
stream.destroy();
});
res.on('close', () => {
stream.destroy();
});
res.on('error', () => {
console.log('error: ', 'request error');
stream.destroy();
});
try {
const { chatId, prompt } = req.body as {
prompt: ChatItemType;
chatId: string;
};
const { authorization } = req.headers;
if (!chatId || !prompt) {
throw new Error('缺少参数');
}
await connectToDatabase();
const redis = await connectRedis();
const { chat, userApiKey, systemKey, userId } = await authChat(chatId, authorization);
const model: ModelSchema = chat.modelId;
const modelConstantsData = modelList.find((item) => item.model === model.service.modelName);
if (!modelConstantsData) {
throw new Error('模型加载异常');
}
// 读取对话内容
const prompts = [...chat.content, prompt];
// 获取 chatAPI
const chatAPI = getOpenAIApi(userApiKey || systemKey);
// 把输入的内容转成向量
const promptVector = await chatAPI
.createEmbedding(
{
model: 'text-embedding-ada-002',
input: prompt.value
},
{
timeout: 120000,
httpsAgent
}
)
.then((res) => res?.data?.data?.[0]?.embedding || []);
const binary = vectorToBuffer(promptVector);
// 搜索系统提示词, 按相似度从 redis 中搜出前3条不同 dataId 的数据
const redisData: any[] = await redis.sendCommand([
'FT.SEARCH',
`idx:${VecModelDataIndex}`,
`@modelId:{${String(chat.modelId._id)}} @vector:[VECTOR_RANGE 0.2 $blob]`,
// `@modelId:{${String(chat.modelId._id)}}=>[KNN 10 @vector $blob AS score]`,
'RETURN',
'1',
'dataId',
// 'SORTBY',
// 'score',
'PARAMS',
'2',
'blob',
binary,
'DIALECT',
'2'
]);
// 格式化响应值获取去重后的id
let formatIds = [2, 4, 6, 8, 10, 12, 14, 16, 18, 20]
.map((i) => {
if (!redisData[i] || !redisData[i][1]) return '';
return redisData[i][1];
})
.filter((item) => item);
formatIds = Array.from(new Set(formatIds));
if (formatIds.length === 0) {
throw new Error('对不起,我没有找到你的问题');
}
// 从 mongo 中取出原文作为提示词
const textArr = (
await Promise.all(
[2, 4, 6, 8, 10, 12, 14, 16, 18, 20].map((i) => {
if (!redisData[i] || !redisData[i][1]) return '';
return ModelData.findById(redisData[i][1])
.select('text')
.then((res) => res?.text || '');
})
)
).filter((item) => item);
// textArr 筛选,最多 3000 tokens
const systemPrompt = systemPromptFilter(textArr, 2800);
prompts.unshift({
obj: 'SYSTEM',
value: `请根据下面的知识回答问题: ${systemPrompt}`
});
// 控制在 tokens 数量,防止超出
const filterPrompts = openaiChatFilter(prompts, modelConstantsData.contextMaxToken);
// 格式化文本内容成 chatgpt 格式
const map = {
Human: ChatCompletionRequestMessageRoleEnum.User,
AI: ChatCompletionRequestMessageRoleEnum.Assistant,
SYSTEM: ChatCompletionRequestMessageRoleEnum.System
};
const formatPrompts: ChatCompletionRequestMessage[] = filterPrompts.map(
(item: ChatItemType) => ({
role: map[item.obj],
content: item.value
})
);
// console.log(formatPrompts);
// 计算温度
const temperature = modelConstantsData.maxTemperature * (model.temperature / 10);
let startTime = Date.now();
// 发出请求
const chatResponse = await chatAPI.createChatCompletion(
{
model: model.service.chatModel,
temperature: temperature,
// max_tokens: modelConstantsData.maxToken,
messages: formatPrompts,
frequency_penalty: 0.5, // 越大,重复内容越少
presence_penalty: -0.5, // 越大,越容易出现新内容
stream: true
},
{
timeout: 40000,
responseType: 'stream',
httpsAgent
}
);
console.log('api response time:', `${(Date.now() - startTime) / 1000}s`);
// 创建响应流
res.setHeader('Content-Type', 'text/event-stream;charset-utf-8');
res.setHeader('Access-Control-Allow-Origin', '*');
res.setHeader('X-Accel-Buffering', 'no');
res.setHeader('Cache-Control', 'no-cache, no-transform');
step = 1;
let responseContent = '';
stream.pipe(res);
const onParse = async (event: ParsedEvent | ReconnectInterval) => {
if (event.type !== 'event') return;
const data = event.data;
if (data === '[DONE]') return;
try {
const json = JSON.parse(data);
const content: string = json?.choices?.[0].delta.content || '';
if (!content || (responseContent === '' && content === '\n')) return;
responseContent += content;
// console.log('content:', content)
!stream.destroyed && stream.push(content.replace(/\n/g, '<br/>'));
} catch (error) {
error;
}
};
const decoder = new TextDecoder();
try {
for await (const chunk of chatResponse.data as any) {
if (stream.destroyed) {
// 流被中断了,直接忽略后面的内容
break;
}
const parser = createParser(onParse);
parser.feed(decoder.decode(chunk));
}
} catch (error) {
console.log('pipe error', error);
}
// close stream
!stream.destroyed && stream.push(null);
stream.destroy();
const promptsContent = formatPrompts.map((item) => item.content).join('');
// 只有使用平台的 key 才计费
pushChatBill({
isPay: !userApiKey,
modelName: model.service.modelName,
userId,
chatId,
text: promptsContent + responseContent
});
// jsonRes(res);
} catch (err: any) {
if (step === 1) {
// 直接结束流
console.log('error结束');
stream.destroy();
} else {
res.status(500);
jsonRes(res, {
code: 500,
error: err
});
}
}
}

View File

@ -24,7 +24,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
if (!DataRecord) { if (!DataRecord) {
throw new Error('找不到数据集'); throw new Error('找不到数据集');
} }
const replaceText = text.replace(/[\r\n\\n]+/g, ' '); const replaceText = text.replace(/[\\n]+/g, ' ');
// 文本拆分成 chunk // 文本拆分成 chunk
let chunks = replaceText.match(/[^!?.。]+[!?.。]/g) || []; let chunks = replaceText.match(/[^!?.。]+[!?.。]/g) || [];
@ -35,7 +35,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
chunks.forEach((chunk) => { chunks.forEach((chunk) => {
splitText += chunk; splitText += chunk;
const tokens = encode(splitText).length; const tokens = encode(splitText).length;
if (tokens >= 980) { if (tokens >= 780) {
dataItems.push({ dataItems.push({
userId, userId,
dataId, dataId,

View File

@ -3,7 +3,7 @@ import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response'; import { jsonRes } from '@/service/response';
import { connectToDatabase } from '@/service/mongo'; import { connectToDatabase } from '@/service/mongo';
import { authToken } from '@/service/utils/tools'; import { authToken } from '@/service/utils/tools';
import { ModelStatusEnum, modelList, ChatModelNameEnum } from '@/constants/model'; import { ModelStatusEnum, modelList, ChatModelNameEnum, ChatModelNameMap } from '@/constants/model';
import { Model } from '@/service/models/model'; import { Model } from '@/service/models/model';
export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) { export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
@ -33,15 +33,6 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
await connectToDatabase(); await connectToDatabase();
// 重名校验
const authRepeatName = await Model.findOne({
name,
userId
});
if (authRepeatName) {
throw new Error('模型名重复');
}
// 上限校验 // 上限校验
const authCount = await Model.countDocuments({ const authCount = await Model.countDocuments({
userId userId
@ -57,9 +48,9 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
status: ModelStatusEnum.running, status: ModelStatusEnum.running,
service: { service: {
company: modelItem.serviceCompany, company: modelItem.serviceCompany,
trainId: modelItem.trainName, trainId: '',
chatModel: modelItem.model, chatModel: ChatModelNameMap[modelItem.model], // 聊天时用的模型
modelName: modelItem.model modelName: modelItem.model // 最底层的模型,不会变,用于计费等核心操作
} }
}); });

View File

@ -5,8 +5,8 @@ import { authToken } from '@/service/utils/tools';
export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) { export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
try { try {
let { modelId } = req.query as { let { dataId } = req.query as {
modelId: string; dataId: string;
}; };
const { authorization } = req.headers; const { authorization } = req.headers;
@ -14,7 +14,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
throw new Error('无权操作'); throw new Error('无权操作');
} }
if (!modelId) { if (!dataId) {
throw new Error('缺少参数'); throw new Error('缺少参数');
} }
@ -24,7 +24,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
await connectToDatabase(); await connectToDatabase();
await ModelData.deleteOne({ await ModelData.deleteOne({
modelId, _id: dataId,
userId userId
}); });

View File

@ -14,6 +14,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
pageNum: string; pageNum: string;
pageSize: string; pageSize: string;
}; };
const { authorization } = req.headers; const { authorization } = req.headers;
pageNum = +pageNum; pageNum = +pageNum;
@ -41,7 +42,15 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
.limit(pageSize); .limit(pageSize);
jsonRes(res, { jsonRes(res, {
data data: {
pageNum,
pageSize,
data,
total: await ModelData.countDocuments({
modelId,
userId
})
}
}); });
} catch (err) { } catch (err) {
jsonRes(res, { jsonRes(res, {

View File

@ -2,12 +2,14 @@ import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response'; import { jsonRes } from '@/service/response';
import { connectToDatabase, ModelData, Model } from '@/service/mongo'; import { connectToDatabase, ModelData, Model } from '@/service/mongo';
import { authToken } from '@/service/utils/tools'; import { authToken } from '@/service/utils/tools';
import { ModelDataSchema } from '@/types/mongoSchema';
import { generateVector } from '@/service/events/generateVector';
export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) { export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
try { try {
const { modelId, data } = req.body as { const { modelId, data } = req.body as {
modelId: string; modelId: string;
data: { q: string; a: string }[]; data: { text: ModelDataSchema['text']; q: ModelDataSchema['q'] }[];
}; };
const { authorization } = req.headers; const { authorization } = req.headers;
@ -43,6 +45,8 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
})) }))
); );
generateVector(true);
jsonRes(res, { jsonRes(res, {
data: model data: model
}); });

View File

@ -0,0 +1,57 @@
import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response';
import { connectToDatabase, DataItem, ModelData } from '@/service/mongo';
import { authToken } from '@/service/utils/tools';
import { customAlphabet } from 'nanoid';
const nanoid = customAlphabet('abcdefghijklmnopqrstuvwxyz1234567890', 12);
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
try {
let { dataIds, modelId } = req.body as { dataIds: string[]; modelId: string };
if (!dataIds) {
throw new Error('参数错误');
}
await connectToDatabase();
const { authorization } = req.headers;
const userId = await authToken(authorization);
const dataItems = (
await Promise.all(
dataIds.map((dataId) =>
DataItem.find<{ _id: string; result: { q: string }[]; text: string }>(
{
userId,
dataId
},
'result text'
)
)
)
).flat();
// push data
await ModelData.insertMany(
dataItems.map((item) => ({
modelId: modelId,
userId,
text: item.text,
q: item.result.map((item) => ({
id: nanoid(),
text: item.q
}))
}))
);
jsonRes(res, {
data: dataItems
});
} catch (err) {
jsonRes(res, {
code: 500,
error: err
});
}
}

View File

@ -5,9 +5,9 @@ import { authToken } from '@/service/utils/tools';
export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) { export default async function handler(req: NextApiRequest, res: NextApiResponse<any>) {
try { try {
let { modelId, answer } = req.body as { let { dataId, text } = req.body as {
modelId: string; dataId: string;
answer: string; text: string;
}; };
const { authorization } = req.headers; const { authorization } = req.headers;
@ -15,7 +15,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
throw new Error('无权操作'); throw new Error('无权操作');
} }
if (!modelId) { if (!dataId) {
throw new Error('缺少参数'); throw new Error('缺少参数');
} }
@ -26,11 +26,11 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
await ModelData.updateOne( await ModelData.updateOne(
{ {
modelId, _id: dataId,
userId userId
}, },
{ {
a: answer text
} }
); );

View File

@ -0,0 +1,67 @@
import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response';
import { connectToDatabase, SplitData, Model } from '@/service/mongo';
import { authToken } from '@/service/utils/tools';
import { generateQA } from '@/service/events/generateQA';
import { encode } from 'gpt-token-utils';
/* 拆分数据成QA */
export default async function handler(req: NextApiRequest, res: NextApiResponse) {
try {
const { text, modelId } = req.body as { text: string; modelId: string };
if (!text || !modelId) {
throw new Error('参数错误');
}
await connectToDatabase();
const { authorization } = req.headers;
const userId = await authToken(authorization);
// 验证是否是该用户的 model
const model = await Model.findOne({
_id: modelId,
userId
});
if (!model) {
throw new Error('无权操作该模型');
}
const replaceText = text.replace(/(\\n|\n)+/g, ' ');
// 文本拆分成 chunk
let chunks = replaceText.match(/[^!?.。]+[!?.。]/g) || [];
const textList: string[] = [];
let splitText = '';
chunks.forEach((chunk) => {
splitText += chunk;
const tokens = encode(splitText).length;
if (tokens >= 980) {
textList.push(splitText);
splitText = '';
}
});
// 批量插入数据
await SplitData.create({
userId,
modelId,
rawText: text,
textList
});
// generateQA();
jsonRes(res, {
data: { chunks, replaceText }
});
} catch (err) {
jsonRes(res, {
code: 500,
error: err
});
}
}

View File

@ -1,6 +1,6 @@
import type { NextApiRequest, NextApiResponse } from 'next'; import type { NextApiRequest, NextApiResponse } from 'next';
import { jsonRes } from '@/service/response'; import { jsonRes } from '@/service/response';
import { Chat, Model, Training, connectToDatabase } from '@/service/mongo'; import { Chat, Model, Training, connectToDatabase, ModelData } from '@/service/mongo';
import { authToken, getUserOpenaiKey } from '@/service/utils/tools'; import { authToken, getUserOpenaiKey } from '@/service/utils/tools';
import { TrainingStatusEnum } from '@/constants/model'; import { TrainingStatusEnum } from '@/constants/model';
import { getOpenAIApi } from '@/service/utils/chat'; import { getOpenAIApi } from '@/service/utils/chat';
@ -26,16 +26,20 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
await connectToDatabase(); await connectToDatabase();
// 删除模型 let requestQueue: any[] = [];
await Model.deleteOne({
_id: modelId,
userId
});
// 删除对应的聊天 // 删除对应的聊天
await Chat.deleteMany({ requestQueue.push(
modelId Chat.deleteMany({
}); modelId
})
);
// 删除数据集
requestQueue.push(
ModelData.deleteMany({
modelId
})
);
// 查看是否正在训练 // 查看是否正在训练
const training: TrainingItemType | null = await Training.findOne({ const training: TrainingItemType | null = await Training.findOne({
@ -56,9 +60,20 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
} }
// 删除对应训练记录 // 删除对应训练记录
await Training.deleteMany({ requestQueue.push(
modelId Training.deleteMany({
}); modelId
})
);
// 删除模型
requestQueue.push(
Model.deleteOne({
_id: modelId,
userId
})
);
await requestQueue;
jsonRes(res); jsonRes(res);
} catch (err) { } catch (err) {

View File

@ -37,7 +37,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse<
systemPrompt, systemPrompt,
intro, intro,
temperature, temperature,
service, // service,
security security
} }
); );

View File

@ -119,6 +119,7 @@ const Chat = ({ chatId }: { chatId: string }) => {
async (prompts: ChatSiteItemType) => { async (prompts: ChatSiteItemType) => {
const urlMap: Record<string, string> = { const urlMap: Record<string, string> = {
[ChatModelNameEnum.GPT35]: '/api/chat/chatGpt', [ChatModelNameEnum.GPT35]: '/api/chat/chatGpt',
[ChatModelNameEnum.VECTOR_GPT]: '/api/chat/vectorGpt',
[ChatModelNameEnum.GPT3]: '/api/chat/gpt3' [ChatModelNameEnum.GPT3]: '/api/chat/gpt3'
}; };

View File

@ -184,7 +184,7 @@ const DataList = () => {
> >
</Button> </Button>
<Menu> {/* <Menu>
<MenuButton as={Button} mr={2} size={'sm'} isLoading={isExporting}> <MenuButton as={Button} mr={2} size={'sm'} isLoading={isExporting}>
</MenuButton> </MenuButton>
@ -200,7 +200,7 @@ const DataList = () => {
</MenuItem> </MenuItem>
)} )}
</MenuList> </MenuList>
</Menu> </Menu> */}
<Button <Button
size={'sm'} size={'sm'}

View File

@ -0,0 +1,141 @@
import React, { useState, useCallback } from 'react';
import {
Box,
IconButton,
Flex,
Button,
Modal,
ModalOverlay,
ModalContent,
ModalHeader,
ModalCloseButton,
Input,
Textarea
} from '@chakra-ui/react';
import { useForm, useFieldArray } from 'react-hook-form';
import { postModelDataInput } from '@/api/model';
import { useToast } from '@/hooks/useToast';
import { DeleteIcon } from '@chakra-ui/icons';
import { customAlphabet } from 'nanoid';
const nanoid = customAlphabet('abcdefghijklmnopqrstuvwxyz1234567890', 12);
type FormData = { text: string; q: { val: string }[] };
const InputDataModal = ({
onClose,
onSuccess,
modelId
}: {
onClose: () => void;
onSuccess: () => void;
modelId: string;
}) => {
const [importing, setImporting] = useState(false);
const { toast } = useToast();
const { register, handleSubmit, control } = useForm<FormData>({
defaultValues: {
text: '',
q: [{ val: '' }]
}
});
const {
fields: inputQ,
append: appendQ,
remove: removeQ
} = useFieldArray({
control,
name: 'q'
});
const sureImportData = useCallback(
async (e: FormData) => {
setImporting(true);
try {
await postModelDataInput({
modelId: modelId,
data: [
{
text: e.text,
q: e.q.map((item) => ({
id: nanoid(),
text: item.val
}))
}
]
});
toast({
title: '导入数据成功,需要一段时间训练',
status: 'success'
});
onClose();
onSuccess();
} catch (err) {
console.log(err);
}
setImporting(false);
},
[modelId, onClose, onSuccess, toast]
);
return (
<Modal isOpen={true} onClose={onClose}>
<ModalOverlay />
<ModalContent maxW={'min(900px, 90vw)'} maxH={'80vh'} position={'relative'}>
<ModalHeader></ModalHeader>
<ModalCloseButton />
<Box px={6} pb={2} overflowY={'auto'}>
<Box mb={2}>:</Box>
<Textarea
mb={4}
placeholder="知识点"
rows={3}
maxH={'200px'}
{...register(`text`, {
required: '知识点'
})}
/>
{inputQ.map((item, index) => (
<Box key={item.id} mb={5}>
<Box mb={2}>{index + 1}:</Box>
<Flex>
<Input
placeholder="问法"
{...register(`q.${index}.val`, {
required: '问法不能为空'
})}
></Input>
{inputQ.length > 1 && (
<IconButton
icon={<DeleteIcon />}
aria-label={'delete'}
colorScheme={'gray'}
variant={'unstyled'}
onClick={() => removeQ(index)}
/>
)}
</Flex>
</Box>
))}
</Box>
<Flex px={6} pt={2} pb={4}>
<Button alignSelf={'flex-start'} variant={'outline'} onClick={() => appendQ({ val: '' })}>
</Button>
<Box flex={1}></Box>
<Button variant={'outline'} mr={3} onClick={onClose}>
</Button>
<Button isLoading={importing} onClick={handleSubmit(sureImportData)}>
</Button>
</Flex>
</ModalContent>
</Modal>
);
};
export default InputDataModal;

View File

@ -0,0 +1,202 @@
import React, { useCallback } from 'react';
import {
Box,
TableContainer,
Table,
Thead,
Tbody,
Tr,
Th,
Td,
IconButton,
Flex,
Button,
useDisclosure,
Textarea,
Menu,
MenuButton,
MenuList,
MenuItem
} from '@chakra-ui/react';
import type { ModelSchema } from '@/types/mongoSchema';
import { ModelDataSchema } from '@/types/mongoSchema';
import { ModelDataStatusMap } from '@/constants/model';
import { usePaging } from '@/hooks/usePaging';
import ScrollData from '@/components/ScrollData';
import { getModelDataList, delOneModelData, putModelDataById } from '@/api/model';
import { DeleteIcon, RepeatIcon } from '@chakra-ui/icons';
import { useToast } from '@/hooks/useToast';
import { useLoading } from '@/hooks/useLoading';
import dynamic from 'next/dynamic';
const InputModel = dynamic(() => import('./InputDataModal'));
const SelectModel = dynamic(() => import('./SelectFileModal'));
const ModelDataCard = ({ model }: { model: ModelSchema }) => {
const { toast } = useToast();
const { Loading } = useLoading();
const {
nextPage,
isLoadAll,
requesting,
data: modelDataList,
total,
setData,
getData
} = usePaging<ModelDataSchema>({
api: getModelDataList,
pageSize: 20,
params: {
modelId: model._id
}
});
const updateAnswer = useCallback(
async (dataId: string, text: string) => {
await putModelDataById({
dataId,
text
});
toast({
title: '修改回答成功',
status: 'success'
});
},
[toast]
);
const {
isOpen: isOpenInputModal,
onOpen: onOpenInputModal,
onClose: onCloseInputModal
} = useDisclosure();
const {
isOpen: isOpenSelectModal,
onOpen: onOpenSelectModal,
onClose: onCloseSelectModal
} = useDisclosure();
return (
<>
<Flex>
<Box fontWeight={'bold'} fontSize={'lg'} flex={1}>
: {total}{' '}
<Box as={'span'} fontSize={'sm'}>
</Box>
</Box>
<IconButton
icon={<RepeatIcon />}
aria-label={'refresh'}
variant={'outline'}
mr={4}
onClick={() => getData(1, true)}
/>
<Menu>
<MenuButton as={Button}></MenuButton>
<MenuList>
<MenuItem onClick={onOpenInputModal}></MenuItem>
<MenuItem onClick={onOpenSelectModal}></MenuItem>
</MenuList>
</Menu>
</Flex>
<ScrollData
h={'100%'}
px={6}
mt={3}
isLoadAll={isLoadAll}
requesting={requesting}
nextPage={nextPage}
position={'relative'}
>
<TableContainer mt={4}>
<Table variant={'simple'}>
<Thead>
<Tr>
<Th>Question</Th>
<Th>Text</Th>
<Th>Status</Th>
<Th></Th>
</Tr>
</Thead>
<Tbody>
{modelDataList.map((item) => (
<Tr key={item._id}>
<Td w={'350px'}>
{item.q.map((item, i) => (
<Box
key={item.id}
fontSize={'xs'}
w={'100%'}
whiteSpace={'pre-wrap'}
_notLast={{ mb: 1 }}
>
Q{i + 1}:{' '}
<Box as={'span'} userSelect={'all'}>
{item.text}
</Box>
</Box>
))}
</Td>
<Td minW={'200px'}>
<Textarea
w={'100%'}
h={'100%'}
defaultValue={item.text}
fontSize={'xs'}
resize={'both'}
onBlur={(e) => {
const oldVal = modelDataList.find((data) => item._id === data._id)?.text;
if (oldVal !== e.target.value) {
updateAnswer(item._id, e.target.value);
setData((state) =>
state.map((data) => ({
...data,
text: data._id === item._id ? e.target.value : data.text
}))
);
}
}}
></Textarea>
</Td>
<Td w={'100px'}>{ModelDataStatusMap[item.status]}</Td>
<Td>
<IconButton
icon={<DeleteIcon />}
variant={'outline'}
colorScheme={'gray'}
aria-label={'delete'}
size={'sm'}
onClick={async () => {
delOneModelData(item._id);
setData((state) => state.filter((data) => data._id !== item._id));
}}
/>
</Td>
</Tr>
))}
</Tbody>
</Table>
</TableContainer>
<Loading loading={requesting} fixed={false} />
</ScrollData>
{isOpenInputModal && (
<InputModel
modelId={model._id}
onClose={onCloseInputModal}
onSuccess={() => getData(1, true)}
/>
)}
{isOpenSelectModal && (
<SelectModel
modelId={model._id}
onClose={onCloseSelectModal}
onSuccess={() => getData(1, true)}
/>
)}
</>
);
};
export default ModelDataCard;

View File

@ -11,13 +11,28 @@ import {
SliderFilledTrack, SliderFilledTrack,
SliderThumb, SliderThumb,
SliderMark, SliderMark,
Tooltip Tooltip,
Button
} from '@chakra-ui/react'; } from '@chakra-ui/react';
import { QuestionOutlineIcon } from '@chakra-ui/icons'; import { QuestionOutlineIcon } from '@chakra-ui/icons';
import type { ModelSchema } from '@/types/mongoSchema'; import type { ModelSchema } from '@/types/mongoSchema';
import { UseFormReturn } from 'react-hook-form'; import { UseFormReturn } from 'react-hook-form';
import { modelList } from '@/constants/model';
import { formatPrice } from '@/utils/user';
import { useConfirm } from '@/hooks/useConfirm';
const ModelEditForm = ({ formHooks }: { formHooks: UseFormReturn<ModelSchema> }) => { const ModelEditForm = ({
formHooks,
canTrain,
handleDelModel
}: {
formHooks: UseFormReturn<ModelSchema>;
canTrain: boolean;
handleDelModel: () => void;
}) => {
const { openConfirm, ConfirmChild } = useConfirm({
content: '确认删除该模型?'
});
const { register, setValue, getValues } = formHooks; const { register, setValue, getValues } = formHooks;
const [refresh, setRefresh] = useState(false); const [refresh, setRefresh] = useState(false);
@ -29,7 +44,7 @@ const ModelEditForm = ({ formHooks }: { formHooks: UseFormReturn<ModelSchema> })
</Flex> </Flex>
<FormControl mt={4}> <FormControl mt={4}>
<Flex alignItems={'center'}> <Flex alignItems={'center'}>
<Box flex={'0 0 50px'} w={0}> <Box flex={'0 0 80px'} w={0}>
: :
</Box> </Box>
<Input <Input
@ -39,7 +54,36 @@ const ModelEditForm = ({ formHooks }: { formHooks: UseFormReturn<ModelSchema> })
></Input> ></Input>
</Flex> </Flex>
</FormControl> </FormControl>
<FormControl mt={4}> <Flex alignItems={'center'} mt={4}>
<Box flex={'0 0 80px'} w={0}>
:
</Box>
<Box>{getValues('service.modelName')}</Box>
</Flex>
<Flex alignItems={'center'} mt={4}>
<Box flex={'0 0 80px'} w={0}>
:
</Box>
<Box>
{formatPrice(
modelList.find((item) => item.model === getValues('service.modelName'))?.price || 0,
1000
)}
/1K tokens()
</Box>
</Flex>
<Flex mt={5} alignItems={'center'}>
<Box flex={'0 0 80px'}>:</Box>
<Button
colorScheme={'gray'}
variant={'outline'}
size={'sm'}
onClick={openConfirm(handleDelModel)}
>
</Button>
</Flex>
{/* <FormControl mt={4}>
<Box mb={1}>:</Box> <Box mb={1}>:</Box>
<Textarea <Textarea
rows={5} rows={5}
@ -47,7 +91,7 @@ const ModelEditForm = ({ formHooks }: { formHooks: UseFormReturn<ModelSchema> })
{...register('intro')} {...register('intro')}
placeholder={'模型的介绍,仅做展示,不影响模型的效果'} placeholder={'模型的介绍,仅做展示,不影响模型的效果'}
/> />
</FormControl> </FormControl> */}
</Card> </Card>
<Card p={4}> <Card p={4}>
<Box fontWeight={'bold'}></Box> <Box fontWeight={'bold'}></Box>
@ -94,15 +138,24 @@ const ModelEditForm = ({ formHooks }: { formHooks: UseFormReturn<ModelSchema> })
</Flex> </Flex>
</FormControl> </FormControl>
<Box mt={4}> <Box mt={4}>
<Box mb={1}></Box> {canTrain ? (
<Textarea <Box fontWeight={'bold'}>
rows={6} prompt
maxLength={-1} 使 tokens
{...register('systemPrompt')} </Box>
placeholder={ ) : (
'模型默认的 prompt 词,通过调整该内容,可以生成一个限定范围的模型。\n\n注意改功能会影响对话的整体朝向' <>
} <Box mb={1}></Box>
/> <Textarea
rows={6}
maxLength={-1}
{...register('systemPrompt')}
placeholder={
'模型默认的 prompt 词,通过调整该内容,可以生成一个限定范围的模型。\n\n注意改功能会影响对话的整体朝向'
}
/>
</>
)}
</Box> </Box>
</Card> </Card>
{/* <Card p={4}> {/* <Card p={4}>
@ -202,6 +255,7 @@ const ModelEditForm = ({ formHooks }: { formHooks: UseFormReturn<ModelSchema> })
</Flex> </Flex>
</FormControl> </FormControl>
</Card> */} </Card> */}
<ConfirmChild />
</> </>
); );
}; };

View File

@ -0,0 +1,155 @@
import React, { useState, useCallback } from 'react';
import {
Box,
Flex,
Button,
Modal,
ModalOverlay,
ModalContent,
ModalHeader,
ModalCloseButton,
ModalBody
} from '@chakra-ui/react';
import { useToast } from '@/hooks/useToast';
import { useSelectFile } from '@/hooks/useSelectFile';
import { customAlphabet } from 'nanoid';
import { encode } from 'gpt-token-utils';
import { useConfirm } from '@/hooks/useConfirm';
import { readTxtContent, readPdfContent, readDocContent } from '@/utils/tools';
import { useMutation } from '@tanstack/react-query';
import { postModelDataFileText } from '@/api/model';
const nanoid = customAlphabet('abcdefghijklmnopqrstuvwxyz1234567890', 12);
const fileExtension = '.txt,.doc,.docx,.pdf,.md';
const SelectFileModal = ({
onClose,
onSuccess,
modelId
}: {
onClose: () => void;
onSuccess: () => void;
modelId: string;
}) => {
const [selecting, setSelecting] = useState(false);
const { toast } = useToast();
const { File, onOpen } = useSelectFile({ fileType: fileExtension, multiple: true });
const [fileText, setFileText] = useState('');
const { openConfirm, ConfirmChild } = useConfirm({
content: '确认导入该文件,需要一定时间进行拆解,该任务无法终止!'
});
const onSelectFile = useCallback(
async (e: File[]) => {
setSelecting(true);
try {
const fileTexts = (
await Promise.all(
e.map((file) => {
// @ts-ignore
const extension = file?.name?.split('.').pop().toLowerCase();
switch (extension) {
case 'txt':
case 'md':
return readTxtContent(file);
case 'pdf':
return readPdfContent(file);
case 'doc':
case 'docx':
return readDocContent(file);
default:
return '';
}
})
)
)
.join('\n')
.replace(/\n+/g, '\n');
setFileText(fileTexts);
console.log(encode(fileTexts));
} catch (error: any) {
console.log(error);
toast({
title: typeof error === 'string' ? error : '解析文件失败',
status: 'error'
});
}
setSelecting(false);
},
[setSelecting, toast]
);
const { mutate, isLoading } = useMutation({
mutationFn: async () => {
if (!fileText) return;
await postModelDataFileText(modelId, fileText);
toast({
title: '导入数据成功,需要一段拆解和训练',
status: 'success'
});
onClose();
onSuccess();
},
onError() {
toast({
title: '导入文件失败',
status: 'error'
});
}
});
return (
<Modal isOpen={true} onClose={onClose}>
<ModalOverlay />
<ModalContent maxW={'min(900px, 90vw)'} position={'relative'}>
<ModalHeader></ModalHeader>
<ModalCloseButton />
<ModalBody>
<Flex
flexDirection={'column'}
p={2}
h={'100%'}
alignItems={'center'}
justifyContent={'center'}
fontSize={'sm'}
>
<Button isLoading={selecting} onClick={onOpen}>
</Button>
<Box mt={2}> {fileExtension} . </Box>
<Box mt={2}>
{fileText.length} {encode(fileText).length} tokens
</Box>
<Box
h={'300px'}
w={'100%'}
overflow={'auto'}
p={2}
backgroundColor={'blackAlpha.50'}
whiteSpace={'pre'}
fontSize={'xs'}
>
{fileText}
</Box>
</Flex>
</ModalBody>
<Flex px={6} pt={2} pb={4}>
<Box flex={1}></Box>
<Button variant={'outline'} mr={3} onClick={onClose}>
</Button>
<Button isLoading={isLoading} isDisabled={fileText === ''} onClick={openConfirm(mutate)}>
</Button>
</Flex>
</ModalContent>
<ConfirmChild />
<File onSelect={onSelectFile} />
</Modal>
);
};
export default SelectFileModal;

View File

@ -1,37 +1,27 @@
import React, { useCallback, useState, useRef, useMemo, useEffect } from 'react'; import React, { useCallback, useState, useRef, useMemo, useEffect } from 'react';
import { useRouter } from 'next/router'; import { useRouter } from 'next/router';
import { import { getModelById, delModelById, putModelTrainingStatus, putModelById } from '@/api/model';
getModelById,
delModelById,
postTrainModel,
putModelTrainingStatus,
putModelById
} from '@/api/model';
import { getChatSiteId } from '@/api/chat'; import { getChatSiteId } from '@/api/chat';
import type { ModelSchema } from '@/types/mongoSchema'; import type { ModelSchema } from '@/types/mongoSchema';
import { Card, Box, Flex, Button, Tag, Grid } from '@chakra-ui/react'; import { Card, Box, Flex, Button, Tag, Grid } from '@chakra-ui/react';
import { useToast } from '@/hooks/useToast'; import { useToast } from '@/hooks/useToast';
import { useConfirm } from '@/hooks/useConfirm';
import { useForm } from 'react-hook-form'; import { useForm } from 'react-hook-form';
import { formatModelStatus, ModelStatusEnum, modelList, defaultModel } from '@/constants/model'; import { formatModelStatus, ModelStatusEnum, modelList, defaultModel } from '@/constants/model';
import { useGlobalStore } from '@/store/global'; import { useGlobalStore } from '@/store/global';
import { useScreen } from '@/hooks/useScreen'; import { useScreen } from '@/hooks/useScreen';
import ModelEditForm from './components/ModelEditForm'; import ModelEditForm from './components/ModelEditForm';
import Icon from '@/components/Iconfont';
import { useQuery } from '@tanstack/react-query'; import { useQuery } from '@tanstack/react-query';
import dynamic from 'next/dynamic'; import dynamic from 'next/dynamic';
const Training = dynamic(() => import('./components/Training')); const ModelDataCard = dynamic(() => import('./components/ModelDataCard'));
const ModelDetail = ({ modelId }: { modelId: string }) => { const ModelDetail = ({ modelId }: { modelId: string }) => {
const { toast } = useToast(); const { toast } = useToast();
const router = useRouter(); const router = useRouter();
const { isPc, media } = useScreen(); const { isPc, media } = useScreen();
const { setLoading } = useGlobalStore(); const { setLoading } = useGlobalStore();
const { openConfirm, ConfirmChild } = useConfirm({
content: '确认删除该模型?' // const SelectFileDom = useRef<HTMLInputElement>(null);
});
const SelectFileDom = useRef<HTMLInputElement>(null);
const [model, setModel] = useState<ModelSchema>(defaultModel); const [model, setModel] = useState<ModelSchema>(defaultModel);
const formHooks = useForm<ModelSchema>({ const formHooks = useForm<ModelSchema>({
defaultValues: model defaultValues: model
@ -39,7 +29,7 @@ const ModelDetail = ({ modelId }: { modelId: string }) => {
const canTrain = useMemo(() => { const canTrain = useMemo(() => {
const openai = modelList.find((item) => item.model === model?.service.modelName); const openai = modelList.find((item) => item.model === model?.service.modelName);
return openai && openai.trainName; return !!(openai && openai.trainName);
}, [model]); }, [model]);
/* 加载模型数据 */ /* 加载模型数据 */
@ -91,34 +81,34 @@ const ModelDetail = ({ modelId }: { modelId: string }) => {
}, [setLoading, model, router]); }, [setLoading, model, router]);
/* 上传数据集,触发微调 */ /* 上传数据集,触发微调 */
const startTraining = useCallback( // const startTraining = useCallback(
async (e: React.ChangeEvent<HTMLInputElement>) => { // async (e: React.ChangeEvent<HTMLInputElement>) => {
if (!modelId || !e.target.files || e.target.files?.length === 0) return; // if (!modelId || !e.target.files || e.target.files?.length === 0) return;
setLoading(true); // setLoading(true);
try { // try {
const file = e.target.files[0]; // const file = e.target.files[0];
const formData = new FormData(); // const formData = new FormData();
formData.append('file', file); // formData.append('file', file);
await postTrainModel(modelId, formData); // await postTrainModel(modelId, formData);
toast({ // toast({
title: '开始训练...', // title: '开始训练...',
status: 'success' // status: 'success'
}); // });
// 重新获取模型 // // 重新获取模型
loadModel(); // loadModel();
} catch (err: any) { // } catch (err: any) {
toast({ // toast({
title: err?.message || '上传文件失败', // title: err?.message || '上传文件失败',
status: 'error' // status: 'error'
}); // });
console.log('error->', err); // console.log('error->', err);
} // }
setLoading(false); // setLoading(false);
}, // },
[setLoading, loadModel, modelId, toast] // [setLoading, loadModel, modelId, toast]
); // );
/* 点击更新模型状态 */ /* 点击更新模型状态 */
const handleClickUpdateStatus = useCallback(async () => { const handleClickUpdateStatus = useCallback(async () => {
@ -250,87 +240,34 @@ const ModelDetail = ({ modelId }: { modelId: string }) => {
)} )}
</Card> </Card>
<Grid mt={5} gridTemplateColumns={media('1fr 1fr', '1fr')} gridGap={5}> <Grid mt={5} gridTemplateColumns={media('1fr 1fr', '1fr')} gridGap={5}>
<ModelEditForm formHooks={formHooks} /> <ModelEditForm formHooks={formHooks} handleDelModel={handleDelModel} canTrain={canTrain} />
{canTrain && ( {/* {canTrain && (
<Card p={4}> <Card p={4}>
<Training model={model} /> <Training model={model} />
</Card> </Card>
)} */}
{canTrain && model._id && (
<Card
p={4}
height={'700px'}
{...media(
{
gridColumnStart: 1,
gridColumnEnd: 3
},
{}
)}
>
<ModelDataCard model={model} />
</Card>
)} )}
<Card p={4}>
<Box fontWeight={'bold'} fontSize={'lg'}>
</Box>
<Flex mt={5} alignItems={'center'}>
<Box flex={'0 0 80px'}>:</Box>
<Button
size={'sm'}
onClick={() => {
SelectFileDom.current?.click();
}}
title={!canTrain ? '模型不支持微调' : ''}
isDisabled={!canTrain}
>
</Button>
<Flex
as={'a'}
href="/TrainingTemplate.jsonl"
download
ml={5}
cursor={'pointer'}
alignItems={'center'}
color={'blue.500'}
>
<Icon name={'icon-yunxiazai'} color={'#3182ce'} />
</Flex>
</Flex>
{/* 提示 */}
<Box mt={3} py={3} color={'blackAlpha.600'}>
<Box as={'li'} lineHeight={1.9}>
使openai key
</Box>
<Box as={'li'} lineHeight={1.9}>
使
<Box
as={'span'}
fontWeight={'bold'}
textDecoration={'underline'}
color={'blackAlpha.800'}
mx={2}
cursor={'pointer'}
onClick={() => router.push('/data/list')}
>
</Box>
</Box>
<Box as={'li'} lineHeight={1.9}>
prompt completion
</Box>
<Box as={'li'} lineHeight={1.9}>
prompt {'</s>'}
</Box>
<Box as={'li'} lineHeight={1.9}>
completion {'</s>'}
</Box>
</Box>
<Flex mt={5} alignItems={'center'}>
<Box flex={'0 0 80px'}>:</Box>
<Button colorScheme={'red'} size={'sm'} onClick={openConfirm(handleDelModel)}>
</Button>
</Flex>
</Card>
</Grid> </Grid>
{/* 文件选择 */} {/* 文件选择 */}
<Box position={'absolute'} w={0} h={0} overflow={'hidden'}> {/* <Box position={'absolute'} w={0} h={0} overflow={'hidden'}>
<input ref={SelectFileDom} type="file" accept=".jsonl" onChange={startTraining} /> <input ref={SelectFileDom} type="file" accept=".jsonl" onChange={startTraining} />
</Box> </Box> */}
<ConfirmChild />
</> </>
); );
}; };

View File

@ -1,29 +1,26 @@
import { DataItem } from '@/service/mongo'; import { SplitData, ModelData } from '@/service/mongo';
import { getOpenAIApi } from '@/service/utils/chat'; import { getOpenAIApi } from '@/service/utils/chat';
import { httpsAgent, getOpenApiKey } from '@/service/utils/tools'; import { httpsAgent, getOpenApiKey } from '@/service/utils/tools';
import type { ChatCompletionRequestMessage } from 'openai'; import type { ChatCompletionRequestMessage } from 'openai';
import { DataItemSchema } from '@/types/mongoSchema';
import { ChatModelNameEnum } from '@/constants/model'; import { ChatModelNameEnum } from '@/constants/model';
import { pushSplitDataBill } from '@/service/events/pushBill'; import { pushSplitDataBill } from '@/service/events/pushBill';
import { generateVector } from './generateVector';
import { customAlphabet } from 'nanoid';
const nanoid = customAlphabet('abcdefghijklmnopqrstuvwxyz1234567890', 12);
export async function generateQA(next = false): Promise<any> { export async function generateQA(next = false): Promise<any> {
if (process.env.NODE_ENV === 'development') return;
if (global.generatingQA && !next) return; if (global.generatingQA && !next) return;
global.generatingQA = true; global.generatingQA = true;
const systemPrompt: ChatCompletionRequestMessage = { const systemPrompt: ChatCompletionRequestMessage = {
role: 'system', role: 'system',
content: `总结助手。我会向你发送一段长文本,请从中总结出5至15个问题和答案,答案请尽量详细,按以下格式返回: Q1:\nA1:\nQ2:\nA2:\n` content: `总结助手。我会向你发送一段长文本,请从中总结出5至15个问题和答案,答案请尽量详细,按以下格式返回: Q1:\nA1:\nQ2:\nA2:\n`
}; };
let dataItem: DataItemSchema | null = null;
try { try {
// 找出一个需要生成的 dataItem // 找出一个需要生成的 dataItem
dataItem = await DataItem.findOne({ const dataItem = await SplitData.findOne({
status: { $ne: 0 }, textList: { $exists: true, $ne: [] }
times: { $gt: 0 },
type: 'QA'
}); });
if (!dataItem) { if (!dataItem) {
@ -32,10 +29,13 @@ export async function generateQA(next = false): Promise<any> {
return; return;
} }
// 更新状态为生成中 // 弹出文本
await DataItem.findByIdAndUpdate(dataItem._id, { await SplitData.findByIdAndUpdate(dataItem._id, { $pop: { textList: 1 } });
status: 2
}); const text = dataItem.textList[dataItem.textList.length - 1];
if (!text) {
throw new Error('无文本');
}
// 获取 openapi Key // 获取 openapi Key
let userApiKey, systemKey; let userApiKey, systemKey;
@ -44,10 +44,10 @@ export async function generateQA(next = false): Promise<any> {
userApiKey = key.userApiKey; userApiKey = key.userApiKey;
systemKey = key.systemKey; systemKey = key.systemKey;
} catch (error) { } catch (error) {
// 余额不够了, 把用户所有记录改成闲置 // 余额不够了, 清空该记录
await DataItem.updateMany({ await SplitData.findByIdAndUpdate(dataItem._id, {
userId: dataItem.userId, textList: [],
status: 0 errorText: '余额不足,生成数据集任务终止'
}); });
throw new Error('获取 openai key 失败'); throw new Error('获取 openai key 失败');
} }
@ -59,84 +59,71 @@ export async function generateQA(next = false): Promise<any> {
// 获取 openai 请求实例 // 获取 openai 请求实例
const chatAPI = getOpenAIApi(userApiKey || systemKey); const chatAPI = getOpenAIApi(userApiKey || systemKey);
// 请求 chatgpt 获取回答 // 请求 chatgpt 获取回答
const response = await Promise.allSettled( const response = await chatAPI
[0.2, 0.8].map( .createChatCompletion(
(temperature) => {
chatAPI model: ChatModelNameEnum.GPT35,
.createChatCompletion( temperature: 0.2,
{ n: 1,
model: ChatModelNameEnum.GPT35, messages: [
temperature: temperature, systemPrompt,
n: 1, {
messages: [ role: 'user',
systemPrompt, content: text
{ }
role: 'user', ]
content: dataItem?.text || ''
}
]
},
{
timeout: 120000,
httpsAgent
}
)
.then((res) => ({
rawContent: res?.data.choices[0].message?.content || '',
result: splitText(res?.data.choices[0].message?.content || '')
})) // 从 content 中提取 QA
)
);
// 过滤出成功的响应
const successResponse: {
rawContent: string;
result: { q: string; a: string }[];
}[] = response.filter((item) => item.status === 'fulfilled').map((item: any) => item.value);
const rawContents = successResponse.map((item) => item.rawContent);
const results = successResponse.map((item) => item.result).flat();
// 插入数据库,并修改状态
await DataItem.findByIdAndUpdate(dataItem._id, {
status: 0,
$push: {
rawResponse: {
$each: successResponse.map((item) => item.rawContent)
}, },
result: { {
$each: results timeout: 120000,
httpsAgent
} }
} )
}); .then((res) => ({
rawContent: res?.data.choices[0].message?.content || '',
result: splitText(res?.data.choices[0].message?.content || '')
})); // 从 content 中提取 QA
// 插入 modelData 表,生成向量
await ModelData.insertMany(
response.result.map((item) => ({
modelId: dataItem.modelId,
userId: dataItem.userId,
text: item.a,
q: [
{
id: nanoid(),
text: item.q
}
],
status: 1
}))
);
console.log( console.log(
'生成QA成功time:', '生成QA成功time:',
`${(Date.now() - startTime) / 1000}s`, `${(Date.now() - startTime) / 1000}s`,
'QA数量', 'QA数量',
results.length response.result.length
); );
// 计费 // 计费
pushSplitDataBill({ pushSplitDataBill({
isPay: !userApiKey && results.length > 0, isPay: !userApiKey && response.result.length > 0,
userId: dataItem.userId, userId: dataItem.userId,
type: 'QA', type: 'QA',
text: systemPrompt.content + dataItem.text + rawContents.join('') text: systemPrompt.content + text + response.rawContent
}); });
} catch (error: any) {
console.log('error: 生成QA错误', dataItem?._id);
console.log('response:', error?.response);
if (dataItem?._id) {
await DataItem.findByIdAndUpdate(dataItem._id, {
status: dataItem.times > 0 ? 1 : 0, // 还有重试次数则可以继续进行
$inc: {
// 剩余尝试次数-1
times: -1
}
});
}
}
generateQA(true); generateQA(true);
generateVector(true);
} catch (error: any) {
console.log(error);
console.log('生成QA错误:', error?.response);
setTimeout(() => {
generateQA(true);
}, 10000);
}
} }
/** /**

View File

@ -0,0 +1,88 @@
import { getOpenAIApi } from '@/service/utils/chat';
import { httpsAgent } from '@/service/utils/tools';
import { ModelData } from '../models/modelData';
import { connectRedis } from '../redis';
import { VecModelDataIndex } from '@/constants/redis';
export async function generateVector(next = false): Promise<any> {
if (global.generatingVector && !next) return;
global.generatingVector = true;
try {
const redis = await connectRedis();
// 找出一个需要生成的 dataItem
const dataItem = await ModelData.findOne({
status: { $ne: 0 }
});
if (!dataItem) {
console.log('没有需要生成 【向量】 的数据');
global.generatingVector = false;
return;
}
// 获取 openapi Key
const openAiKey = process.env.OPENAIKEY as string;
// 获取 openai 请求实例
const chatAPI = getOpenAIApi(openAiKey);
const dataId = String(dataItem._id);
// 生成词向量
const response = await Promise.allSettled(
dataItem.q.map((item, i) =>
chatAPI
.createEmbedding(
{
model: 'text-embedding-ada-002',
input: item.text
},
{
timeout: 120000,
httpsAgent
}
)
.then((res) => res?.data?.data?.[0]?.embedding || [])
.then((vector) =>
redis.sendCommand([
'JSON.SET',
`${VecModelDataIndex}:${dataId}:${i}`,
'$',
JSON.stringify({
dataId,
modelId: String(dataItem.modelId),
vector
})
])
)
)
);
if (response.filter((item) => item.status === 'fulfilled').length === 0) {
throw new Error(JSON.stringify(response));
}
// 修改该数据状态
await ModelData.findByIdAndUpdate(dataItem._id, {
status: 0
});
console.log(`生成向量成功: ${dataItem._id}`);
setTimeout(() => {
generateVector(true);
}, 3000);
} catch (error: any) {
console.log(error);
console.log('error: 生成向量错误', error?.response?.data);
if (error?.response?.statusText === 'Too Many Requests') {
console.log('次数限制1分钟后尝试');
// 限制次数1分钟后再试
setTimeout(() => {
generateVector(true);
}, 60000);
}
}
}

View File

@ -34,7 +34,7 @@ export const pushChatBill = async ({
// 计算价格 // 计算价格
const unitPrice = modelItem?.price || 5; const unitPrice = modelItem?.price || 5;
const price = unitPrice * tokens.length; const price = unitPrice * tokens.length;
console.log(`chat bill, price: ${formatPrice(price)}`); console.log(`chat bill, unit price: ${unitPrice}, price: ${formatPrice(price)}`);
try { try {
// 插入 Bill 记录 // 插入 Bill 记录

View File

@ -13,22 +13,23 @@ const ModelDataSchema = new Schema({
ref: 'user', ref: 'user',
required: true required: true
}, },
q: { text: {
type: String, type: String,
required: true required: true
}, },
a: { q: {
type: String, type: [
default: '' {
id: String, // 对应redis的key
text: String
}
],
default: []
}, },
status: { status: {
type: Number, type: Number,
enum: [0, 1, 2], enum: [0, 1], // 1 训练ing
default: 1 default: 1
},
createTime: {
type: Date,
default: () => new Date()
} }
}); });

View File

@ -0,0 +1,31 @@
/* 模型的知识库 */
import { Schema, model, models, Model as MongoModel } from 'mongoose';
import { ModelSplitDataSchema as SplitDataType } from '@/types/mongoSchema';
const SplitDataSchema = new Schema({
userId: {
type: Schema.Types.ObjectId,
ref: 'user',
required: true
},
modelId: {
type: Schema.Types.ObjectId,
ref: 'model',
required: true
},
rawText: {
type: String,
required: true
},
textList: {
type: [String],
default: []
},
errorText: {
type: String,
default: ''
}
});
export const SplitData: MongoModel<SplitDataType> =
models['splitData'] || model('splitData', SplitDataSchema);

View File

@ -1,6 +1,7 @@
import mongoose from 'mongoose'; import mongoose from 'mongoose';
import { generateQA } from './events/generateQA'; import { generateQA } from './events/generateQA';
import { generateAbstract } from './events/generateAbstract'; import { generateAbstract } from './events/generateAbstract';
import { generateVector } from './events/generateVector';
/** /**
* MongoDB * MongoDB
@ -27,7 +28,8 @@ export async function connectToDatabase(): Promise<void> {
} }
generateQA(); generateQA();
generateAbstract(); // generateAbstract();
generateVector();
} }
export * from './models/authCode'; export * from './models/authCode';
@ -40,3 +42,4 @@ export * from './models/bill';
export * from './models/pay'; export * from './models/pay';
export * from './models/data'; export * from './models/data';
export * from './models/dataItem'; export * from './models/dataItem';
export * from './models/splitData';

45
src/service/redis.ts Normal file
View File

@ -0,0 +1,45 @@
import { createClient } from 'redis';
import { customAlphabet } from 'nanoid';
const nanoid = customAlphabet('abcdefghijklmnopqrstuvwxyz1234567890', 10);
export const connectRedis = async () => {
// 断开了,重连
if (global.redisClient && !global.redisClient.isOpen) {
await global.redisClient.disconnect();
} else if (global.redisClient) {
// 没断开,不再连接
return global.redisClient;
}
try {
global.redisClient = createClient({
url: process.env.REDIS_URL
});
global.redisClient.on('error', (err) => {
console.log('Redis Client Error', err);
global.redisClient = null;
});
global.redisClient.on('end', () => {
global.redisClient = null;
});
global.redisClient.on('ready', () => {
console.log('redis connected');
});
await global.redisClient.connect();
// 0 - 测试库1 - 正式
await global.redisClient.select(0);
return global.redisClient;
} catch (error) {
console.log(error, '==');
global.redisClient = null;
return Promise.reject('redis 连接失败');
}
};
export const getKey = (prefix = '') => {
return `${prefix}:${nanoid()}`;
};

View File

@ -119,3 +119,21 @@ export const openaiChatFilter = (prompts: ChatItemType[], maxTokens: number) =>
return systemPrompt ? [systemPrompt, ...res] : res; return systemPrompt ? [systemPrompt, ...res] : res;
}; };
/* system 内容截断 */
export const systemPromptFilter = (prompts: string[], maxTokens: number) => {
let splitText = '';
// 从前往前截取
for (let i = 0; i < prompts.length; i++) {
const prompt = prompts[i];
splitText += `${prompt}\n`;
const tokens = encode(splitText).length;
if (tokens >= maxTokens) {
break;
}
}
return splitText;
};

View File

@ -1,9 +1,12 @@
import type { Mongoose } from 'mongoose'; import type { Mongoose } from 'mongoose';
import type { RedisClientType } from 'redis';
declare global { declare global {
var mongodb: Mongoose | string | null; var mongodb: Mongoose | string | null;
var redisClient: RedisClientType | null;
var generatingQA: boolean; var generatingQA: boolean;
var generatingAbstract: boolean; var generatingAbstract: boolean;
var generatingVector: boolean;
var QRCode: any; var QRCode: any;
interface Window { interface Window {
['pdfjs-dist/build/pdf']: any; ['pdfjs-dist/build/pdf']: any;

View File

@ -8,3 +8,12 @@ export interface ModelUpdateParams {
service: ModelSchema.service; service: ModelSchema.service;
security: ModelSchema.security; security: ModelSchema.security;
} }
export interface ModelDataItemType {
id: string;
status: 0 | 1; // 1代表向量生成完毕
q: string; // 提问词
a: string; // 原文
modelId: string;
userId: string;
}

View File

@ -51,12 +51,26 @@ export interface ModelPopulate extends ModelSchema {
userId: UserModelSchema; userId: UserModelSchema;
} }
export type ModelDataType = 0 | 1;
export interface ModelDataSchema { export interface ModelDataSchema {
_id: string; _id: string;
q: string; modelId: string;
a: string; userId: string;
status: 0 | 1 | 2; text: string;
createTime: Date; q: {
id: string;
text: string;
}[];
status: ModelDataType;
}
export interface ModelSplitDataSchema {
_id: string;
userId: string;
modelId: string;
rawText: string;
errorText: string;
textList: string[];
} }
export interface TrainingSchema { export interface TrainingSchema {

6
src/types/redis.d.ts vendored Normal file
View File

@ -0,0 +1,6 @@
export interface RedisModelDataItemType {
id: string;
vector: number[];
dataId: string;
modelId: string;
}

View File

@ -124,3 +124,15 @@ export const readDocContent = (file: File) =>
reject('读取 doc 文件失败'); reject('读取 doc 文件失败');
}; };
}); });
export const vectorToBuffer = (vector: number[]) => {
const float32Arr = new Float32Array(vector);
const myBuffer = new ArrayBuffer(float32Arr.length * Float32Array.BYTES_PER_ELEMENT);
const myView = new DataView(myBuffer);
for (let i = 0; i < float32Arr.length; i++) {
myView.setFloat32(i * Float32Array.BYTES_PER_ELEMENT, float32Arr[i], true);
}
return Buffer.from(myBuffer);
};