更新嵌入管理器以支持 GPU 加速,调整批处理大小,优化内容处理逻辑,并添加获取数据库最大修改时间的功能以提高文件索引效率。同时修复了向量管理器中的类型问题,确保模型加载和嵌入过程的稳定性。

This commit is contained in:
duanfuxiang 2025-07-05 07:40:54 +08:00
parent 558e3b3fe4
commit c657a50563
7 changed files with 398 additions and 340 deletions

View File

@ -64,7 +64,7 @@ export const getEmbeddingModel = (
// 确保模型已加载 // 确保模型已加载
if (!embeddingManager.modelLoaded || embeddingManager.currentModel !== settings.embeddingModelId) { if (!embeddingManager.modelLoaded || embeddingManager.currentModel !== settings.embeddingModelId) {
console.log(`Loading model: ${settings.embeddingModelId}`) console.log(`Loading model: ${settings.embeddingModelId}`)
await embeddingManager.loadModel(settings.embeddingModelId, false) await embeddingManager.loadModel(settings.embeddingModelId, true)
} }
const results = await embeddingManager.embedBatch(texts) const results = await embeddingManager.embedBatch(texts)

View File

@ -27,7 +27,7 @@ export class VectorManager {
constructor(app: App, dbManager: DBManager) { constructor(app: App, dbManager: DBManager) {
this.app = app this.app = app
this.dbManager = dbManager this.dbManager = dbManager
this.repository = new VectorRepository(app, dbManager.getPgClient()) this.repository = new VectorRepository(app, dbManager.getPgClient() as any)
} }
async performSimilaritySearch( async performSimilaritySearch(
@ -88,6 +88,7 @@ export class VectorManager {
): Promise<void> { ): Promise<void> {
let filesToIndex: TFile[] let filesToIndex: TFile[]
if (options.reindexAll) { if (options.reindexAll) {
console.log("updateVaultIndex reindexAll")
filesToIndex = await this.getFilesToIndex({ filesToIndex = await this.getFilesToIndex({
embeddingModel: embeddingModel, embeddingModel: embeddingModel,
excludePatterns: options.excludePatterns, excludePatterns: options.excludePatterns,
@ -96,17 +97,22 @@ export class VectorManager {
}) })
await this.repository.clearAllVectors(embeddingModel) await this.repository.clearAllVectors(embeddingModel)
} else { } else {
console.log("updateVaultIndex for update files")
await this.cleanVectorsForDeletedFiles(embeddingModel) await this.cleanVectorsForDeletedFiles(embeddingModel)
console.log("updateVaultIndex cleanVectorsForDeletedFiles")
filesToIndex = await this.getFilesToIndex({ filesToIndex = await this.getFilesToIndex({
embeddingModel: embeddingModel, embeddingModel: embeddingModel,
excludePatterns: options.excludePatterns, excludePatterns: options.excludePatterns,
includePatterns: options.includePatterns, includePatterns: options.includePatterns,
}) })
console.log("get files to index: ", filesToIndex.length)
await this.repository.deleteVectorsForMultipleFiles( await this.repository.deleteVectorsForMultipleFiles(
filesToIndex.map((file) => file.path), filesToIndex.map((file) => file.path),
embeddingModel, embeddingModel,
) )
console.log("delete vectors for multiple files: ", filesToIndex.length)
} }
console.log("get files to index: ", filesToIndex.length)
if (filesToIndex.length === 0) { if (filesToIndex.length === 0) {
return return
@ -131,6 +137,7 @@ export class VectorManager {
"", "",
], ],
}); });
console.log("textSplitter chunkSize: ", options.chunkSize, "overlap: ", overlap)
const skippedFiles: string[] = [] const skippedFiles: string[] = []
const contentChunks: InsertVector[] = ( const contentChunks: InsertVector[] = (
@ -145,15 +152,16 @@ export class VectorManager {
]) ])
return fileDocuments return fileDocuments
.map((chunk): InsertVector | null => { .map((chunk): InsertVector | null => {
const content = removeMarkdown(chunk.pageContent).replace(/\0/g, '') // 保存原始内容,不在此处调用 removeMarkdown
if (!content || content.trim().length === 0) { const rawContent = chunk.pageContent.replace(/\0/g, '')
if (!rawContent || rawContent.trim().length === 0) {
console.log("skipped chunk", chunk.pageContent) console.log("skipped chunk", chunk.pageContent)
return null return null
} }
return { return {
path: file.path, path: file.path,
mtime: file.stat.mtime, mtime: file.stat.mtime,
content, content: rawContent, // 保存原始内容
embedding: [], embedding: [],
metadata: { metadata: {
startLine: Number(chunk.metadata.loc.lines.from), startLine: Number(chunk.metadata.loc.lines.from),
@ -171,6 +179,8 @@ export class VectorManager {
) )
).flat() ).flat()
console.log("contentChunks: ", contentChunks.length)
if (skippedFiles.length > 0) { if (skippedFiles.length > 0) {
console.warn(`跳过了 ${skippedFiles.length} 个有问题的文件:`, skippedFiles) console.warn(`跳过了 ${skippedFiles.length} 个有问题的文件:`, skippedFiles)
new Notice(`跳过了 ${skippedFiles.length} 个有问题的文件`) new Notice(`跳过了 ${skippedFiles.length} 个有问题的文件`)
@ -195,22 +205,33 @@ export class VectorManager {
for (let i = 0; i < contentChunks.length; i += embeddingBatchSize) { for (let i = 0; i < contentChunks.length; i += embeddingBatchSize) {
batchCount++ batchCount++
const batchChunks = contentChunks.slice(i, Math.min(i + embeddingBatchSize, contentChunks.length)) const batchChunks = contentChunks.slice(i, Math.min(i + embeddingBatchSize, contentChunks.length))
const batchTexts = batchChunks.map(chunk => chunk.content)
const embeddedBatch: InsertVector[] = [] const embeddedBatch: InsertVector[] = []
await backOff( await backOff(
async () => { async () => {
// 在嵌入之前处理 markdown只处理一次
const cleanedBatchData = batchChunks.map(chunk => {
const cleanContent = removeMarkdown(chunk.content).replace(/\0/g, '')
return { chunk, cleanContent }
}).filter(({ cleanContent }) => cleanContent && cleanContent.trim().length > 0)
if (cleanedBatchData.length === 0) {
return
}
const batchTexts = cleanedBatchData.map(({ cleanContent }) => cleanContent)
const batchEmbeddings = await embeddingModel.getBatchEmbeddings(batchTexts) const batchEmbeddings = await embeddingModel.getBatchEmbeddings(batchTexts)
// 合并embedding结果到chunk数据 // 合并embedding结果到chunk数据
for (let j = 0; j < batchChunks.length; j++) { for (let j = 0; j < cleanedBatchData.length; j++) {
const { chunk, cleanContent } = cleanedBatchData[j]
const embeddedChunk: InsertVector = { const embeddedChunk: InsertVector = {
path: batchChunks[j].path, path: chunk.path,
mtime: batchChunks[j].mtime, mtime: chunk.mtime,
content: batchChunks[j].content, content: cleanContent, // 使用已经清理过的内容
embedding: batchEmbeddings[j], embedding: batchEmbeddings[j],
metadata: batchChunks[j].metadata, metadata: chunk.metadata,
} }
embeddedBatch.push(embeddedChunk) embeddedBatch.push(embeddedChunk)
} }
@ -263,11 +284,18 @@ export class VectorManager {
try { try {
await backOff( await backOff(
async () => { async () => {
const embedding = await embeddingModel.getEmbedding(chunk.content) // 在嵌入之前处理 markdown
const cleanContent = removeMarkdown(chunk.content).replace(/\0/g, '')
// 跳过清理后为空的内容
if (!cleanContent || cleanContent.trim().length === 0) {
return
}
const embedding = await embeddingModel.getEmbedding(cleanContent)
const embeddedChunk = { const embeddedChunk = {
path: chunk.path, path: chunk.path,
mtime: chunk.mtime, mtime: chunk.mtime,
content: chunk.content, content: cleanContent, // 使用清理后的内容
embedding, embedding,
metadata: chunk.metadata, metadata: chunk.metadata,
} }
@ -339,9 +367,23 @@ export class VectorManager {
) )
// Embed the files // Embed the files
const textSplitter = new MarkdownTextSplitter({ const overlap = Math.floor(chunkSize * 0.15)
const textSplitter = new RecursiveCharacterTextSplitter({
chunkSize: chunkSize, chunkSize: chunkSize,
chunkOverlap: Math.floor(chunkSize * 0.15) chunkOverlap: overlap,
separators: [
"\n\n",
"\n",
".",
",",
" ",
"\u200b", // Zero-width space
"\uff0c", // Fullwidth comma
"\u3001", // Ideographic comma
"\uff0e", // Fullwidth full stop
"\u3002", // Ideographic full stop
"",
],
}); });
let fileContent = await this.app.vault.cachedRead(file) let fileContent = await this.app.vault.cachedRead(file)
// 清理null字节防止PostgreSQL UTF8编码错误 // 清理null字节防止PostgreSQL UTF8编码错误
@ -352,14 +394,15 @@ export class VectorManager {
const contentChunks: InsertVector[] = fileDocuments const contentChunks: InsertVector[] = fileDocuments
.map((chunk): InsertVector | null => { .map((chunk): InsertVector | null => {
const content = removeMarkdown(chunk.pageContent).replace(/\0/g, '') // 保存原始内容,不在此处调用 removeMarkdown
if (!content || content.trim().length === 0) { const rawContent = String(chunk.pageContent || '').replace(/\0/g, '')
if (!rawContent || rawContent.trim().length === 0) {
return null return null
} }
return { return {
path: file.path, path: file.path,
mtime: file.stat.mtime, mtime: file.stat.mtime,
content, content: rawContent, // 保存原始内容
embedding: [], embedding: [],
metadata: { metadata: {
startLine: Number(chunk.metadata.loc.lines.from), startLine: Number(chunk.metadata.loc.lines.from),
@ -382,22 +425,33 @@ export class VectorManager {
batchCount++ batchCount++
console.log(`Embedding batch ${batchCount} of ${Math.ceil(contentChunks.length / embeddingBatchSize)}`) console.log(`Embedding batch ${batchCount} of ${Math.ceil(contentChunks.length / embeddingBatchSize)}`)
const batchChunks = contentChunks.slice(i, Math.min(i + embeddingBatchSize, contentChunks.length)) const batchChunks = contentChunks.slice(i, Math.min(i + embeddingBatchSize, contentChunks.length))
const batchTexts = batchChunks.map(chunk => chunk.content)
const embeddedBatch: InsertVector[] = [] const embeddedBatch: InsertVector[] = []
await backOff( await backOff(
async () => { async () => {
// 在嵌入之前处理 markdown只处理一次
const cleanedBatchData = batchChunks.map(chunk => {
const cleanContent = removeMarkdown(chunk.content).replace(/\0/g, '')
return { chunk, cleanContent }
}).filter(({ cleanContent }) => cleanContent && cleanContent.trim().length > 0)
if (cleanedBatchData.length === 0) {
return
}
const batchTexts = cleanedBatchData.map(({ cleanContent }) => cleanContent)
const batchEmbeddings = await embeddingModel.getBatchEmbeddings(batchTexts) const batchEmbeddings = await embeddingModel.getBatchEmbeddings(batchTexts)
// 合并embedding结果到chunk数据 // 合并embedding结果到chunk数据
for (let j = 0; j < batchChunks.length; j++) { for (let j = 0; j < cleanedBatchData.length; j++) {
const { chunk, cleanContent } = cleanedBatchData[j]
const embeddedChunk: InsertVector = { const embeddedChunk: InsertVector = {
path: batchChunks[j].path, path: chunk.path,
mtime: batchChunks[j].mtime, mtime: chunk.mtime,
content: batchChunks[j].content, content: cleanContent, // 使用已经清理过的内容
embedding: batchEmbeddings[j], embedding: batchEmbeddings[j],
metadata: batchChunks[j].metadata, metadata: chunk.metadata,
} }
embeddedBatch.push(embeddedChunk) embeddedBatch.push(embeddedChunk)
} }
@ -443,11 +497,18 @@ export class VectorManager {
try { try {
await backOff( await backOff(
async () => { async () => {
const embedding = await embeddingModel.getEmbedding(chunk.content) // 在嵌入之前处理 markdown
const cleanContent = removeMarkdown(chunk.content).replace(/\0/g, '')
// 跳过清理后为空的内容
if (!cleanContent || cleanContent.trim().length === 0) {
return
}
const embedding = await embeddingModel.getEmbedding(cleanContent)
const embeddedChunk = { const embeddedChunk = {
path: chunk.path, path: chunk.path,
mtime: chunk.mtime, mtime: chunk.mtime,
content: chunk.content, content: cleanContent, // 使用清理后的内容
embedding, embedding,
metadata: chunk.metadata, metadata: chunk.metadata,
} }
@ -522,8 +583,9 @@ export class VectorManager {
excludePatterns: string[] excludePatterns: string[]
includePatterns: string[] includePatterns: string[]
reindexAll?: boolean reindexAll?: boolean
}): Promise<TFile[]> { }): Promise<TFile[]> {
let filesToIndex = this.app.vault.getMarkdownFiles() let filesToIndex = this.app.vault.getMarkdownFiles()
console.log("get all vault files: ", filesToIndex.length)
filesToIndex = filesToIndex.filter((file) => { filesToIndex = filesToIndex.filter((file) => {
return !excludePatterns.some((pattern) => minimatch(file.path, pattern)) return !excludePatterns.some((pattern) => minimatch(file.path, pattern))
@ -538,39 +600,24 @@ export class VectorManager {
if (reindexAll) { if (reindexAll) {
return filesToIndex return filesToIndex
} }
// Check for updated or new files
filesToIndex = await Promise.all(
filesToIndex.map(async (file) => {
try {
const fileChunks = await this.repository.getVectorsByFilePath(
file.path,
embeddingModel,
)
if (fileChunks.length === 0) {
// File is not indexed, so we need to index it
let fileContent = await this.app.vault.cachedRead(file)
// 清理null字节防止PostgreSQL UTF8编码错误
fileContent = fileContent.replace(/\0/g, '')
if (fileContent.length === 0) {
// Ignore empty files
return null
}
return file
}
const outOfDate = file.stat.mtime > fileChunks[0].mtime
if (outOfDate) {
// File has changed, so we need to re-index it
console.log("File has changed, so we need to re-index it", file.path)
return file
}
return null
} catch (error) {
console.warn(`跳过文件 ${file.path}:`, error.message)
return null
}
}),
).then((files) => files.filter(Boolean))
return filesToIndex // 优化流程使用数据库最大mtime来过滤需要更新的文件
try {
const maxMtime = await this.repository.getMaxMtime(embeddingModel)
console.log("Database max mtime:", maxMtime)
if (maxMtime === null) {
// 数据库中没有任何向量,需要索引所有文件
return filesToIndex
}
// 筛选出在数据库最后更新时间之后修改的文件
return filesToIndex.filter((file) => {
return file.stat.mtime > maxMtime
})
} catch (error) {
console.error("Error getting max mtime from database:", error)
return []
}
} }
} }

View File

@ -33,6 +33,17 @@ export class VectorRepository {
return result.rows.map((row: { path: string }) => row.path) return result.rows.map((row: { path: string }) => row.path)
} }
async getMaxMtime(embeddingModel: EmbeddingModel): Promise<number | null> {
if (!this.db) {
throw new DatabaseNotInitializedException()
}
const tableName = this.getTableName(embeddingModel)
const result = await this.db.query<{ max_mtime: number | null }>(
`SELECT MAX(mtime) as max_mtime FROM "${tableName}"`
)
return result.rows[0]?.max_mtime || null
}
async getVectorsByFilePath( async getVectorsByFilePath(
filePath: string, filePath: string,
embeddingModel: EmbeddingModel, embeddingModel: EmbeddingModel,

View File

@ -4,284 +4,285 @@ import EmbedWorker from './embed.worker';
// 类型定义 // 类型定义
export interface EmbedResult { export interface EmbedResult {
vec: number[]; vec: number[];
tokens: number; tokens: number;
embed_input?: string; embed_input?: string;
} }
export interface ModelLoadResult { export interface ModelLoadResult {
model_loaded: boolean; model_loaded: boolean;
} }
export interface ModelUnloadResult { export interface ModelUnloadResult {
model_unloaded: boolean; model_unloaded: boolean;
} }
export interface TokenCountResult { export interface TokenCountResult {
tokens: number; tokens: number;
} }
export class EmbeddingManager { export class EmbeddingManager {
private worker: Worker; private worker: Worker;
private requests = new Map<number, { resolve: (value: any) => void; reject: (reason?: any) => void }>(); private requests = new Map<number, { resolve: (value: any) => void; reject: (reason?: any) => void }>();
private nextRequestId = 0; private nextRequestId = 0;
private isModelLoaded = false; private isModelLoaded = false;
private currentModelId: string | null = null; private currentModelId: string | null = null;
constructor() { constructor() {
// 创建 Worker使用与 pgworker 相同的模式 // 创建 Worker使用与 pgworker 相同的模式
this.worker = new EmbedWorker(); this.worker = new EmbedWorker();
// 统一监听来自 Worker 的所有消息 // 统一监听来自 Worker 的所有消息
this.worker.onmessage = (event) => { this.worker.onmessage = (event) => {
try { try {
const { id, result, error } = event.data; const { id, result, error } = event.data;
// 根据返回的 id 找到对应的 Promise 回调 // 根据返回的 id 找到对应的 Promise 回调
const request = this.requests.get(id); const request = this.requests.get(id);
if (request) { if (request) {
if (error) { if (error) {
request.reject(new Error(error)); request.reject(new Error(error));
} else { } else {
request.resolve(result); request.resolve(result);
} }
// 完成后从 Map 中删除 // 完成后从 Map 中删除
this.requests.delete(id); this.requests.delete(id);
} }
} catch (err) { } catch (err) {
console.error("Error processing worker message:", err); console.error("Error processing worker message:", err);
// 拒绝所有待处理的请求 // 拒绝所有待处理的请求
this.requests.forEach(request => { this.requests.forEach(request => {
request.reject(new Error(`Worker message processing error: ${err.message}`)); request.reject(new Error(`Worker message processing error: ${err.message}`));
}); });
this.requests.clear(); this.requests.clear();
} }
}; };
this.worker.onerror = (error) => { this.worker.onerror = (error) => {
console.error("EmbeddingWorker error:", error); console.error("EmbeddingWorker error:", error);
// 拒绝所有待处理的请求 // 拒绝所有待处理的请求
this.requests.forEach(request => { this.requests.forEach(request => {
request.reject(new Error(`Worker error: ${error.message || 'Unknown worker error'}`)); request.reject(new Error(`Worker error: ${error.message || 'Unknown worker error'}`));
}); });
this.requests.clear(); this.requests.clear();
// 重置状态 // 重置状态
this.isModelLoaded = false; this.isModelLoaded = false;
this.currentModelId = null; this.currentModelId = null;
}; };
} }
/** /**
* Worker Promise Promise * Worker Promise Promise
* @param method (e.g., 'load', 'embed_batch') * @param method (e.g., 'load', 'embed_batch')
* @param params * @param params
*/ */
private postRequest<T>(method: string, params: any): Promise<T> { private postRequest<T>(method: string, params: any): Promise<T> {
return new Promise<T>((resolve, reject) => { return new Promise<T>((resolve, reject) => {
const id = this.nextRequestId++; const id = this.nextRequestId++;
this.requests.set(id, { resolve, reject }); this.requests.set(id, { resolve, reject });
this.worker.postMessage({ method, params, id }); this.worker.postMessage({ method, params, id });
}); });
} }
/** /**
* Worker * Worker
* @param modelId ID, 'TaylorAI/bge-micro-v2' * @param modelId ID, 'TaylorAI/bge-micro-v2'
* @param useGpu 使GPU加速false * @param useGpu 使GPU加速false
*/ */
public async loadModel(modelId: string, useGpu: boolean = false): Promise<ModelLoadResult> { public async loadModel(modelId: string, useGpu: boolean = false): Promise<ModelLoadResult> {
console.log(`Loading embedding model: ${modelId}, GPU: ${useGpu}`); console.log(`Loading embedding model: ${modelId}, GPU: ${useGpu}`);
try { try {
// 如果已经加载了相同的模型,直接返回 // 如果已经加载了相同的模型,直接返回
if (this.isModelLoaded && this.currentModelId === modelId) { if (this.isModelLoaded && this.currentModelId === modelId) {
console.log(`Model ${modelId} already loaded`); console.log(`Model ${modelId} already loaded`);
return { model_loaded: true }; return { model_loaded: true };
} }
// 如果加载了不同的模型,先卸载 // 如果加载了不同的模型,先卸载
if (this.isModelLoaded && this.currentModelId !== modelId) { if (this.isModelLoaded && this.currentModelId !== modelId) {
console.log(`Unloading previous model: ${this.currentModelId}`); console.log(`Unloading previous model: ${this.currentModelId}`);
await this.unloadModel(); await this.unloadModel();
} }
const result = await this.postRequest<ModelLoadResult>('load', { const result = await this.postRequest<ModelLoadResult>('load', {
model_key: modelId, model_key: modelId,
use_gpu: useGpu use_gpu: useGpu
}); });
this.isModelLoaded = result.model_loaded; this.isModelLoaded = result.model_loaded;
this.currentModelId = result.model_loaded ? modelId : null; this.currentModelId = result.model_loaded ? modelId : null;
if (result.model_loaded) { if (result.model_loaded) {
console.log(`Model ${modelId} loaded successfully`); console.log(`Model ${modelId} loaded successfully`);
} }
return result; return result;
} catch (error) { } catch (error) {
console.error(`Failed to load model ${modelId}:`, error); console.error(`Failed to load model ${modelId}:`, error);
this.isModelLoaded = false; this.isModelLoaded = false;
this.currentModelId = null; this.currentModelId = null;
throw error; throw error;
} }
} }
/** /**
* *
* @param texts * @param texts
* @returns token * @returns token
*/ */
public async embedBatch(texts: string[]): Promise<EmbedResult[]> { public async embedBatch(texts: string[]): Promise<EmbedResult[]> {
if (!this.isModelLoaded) { if (!this.isModelLoaded) {
throw new Error('Model not loaded. Please call loadModel() first.'); throw new Error('Model not loaded. Please call loadModel() first.');
} }
if (!texts || texts.length === 0) { if (!texts || texts.length === 0) {
return []; return [];
} }
console.log(`Generating embeddings for ${texts.length} texts`); console.log(`Generating embeddings for ${texts.length} texts`);
try { try {
const inputs = texts.map(text => ({ embed_input: text })); const inputs = texts.map(text => ({ embed_input: text }));
const results = await this.postRequest<EmbedResult[]>('embed_batch', { inputs }); const results = await this.postRequest<EmbedResult[]>('embed_batch', { inputs });
console.log(`Generated ${results.length} embeddings`); console.log(`Generated ${results.length} embeddings`);
return results; return results;
} catch (error) { } catch (error) {
console.error('Failed to generate embeddings:', error); console.error('Failed to generate embeddings:', error);
throw error; throw error;
} }
} }
/** /**
* *
* @param text * @param text
* @returns token * @returns token
*/ */
public async embed(text: string): Promise<EmbedResult> { public async embed(text: string): Promise<EmbedResult> {
if (!text || text.trim().length === 0) { if (!text || text.trim().length === 0) {
throw new Error('Text cannot be empty'); throw new Error('Text cannot be empty');
} }
const results = await this.embedBatch([text]); const results = await this.embedBatch([text]);
if (results.length === 0) { if (results.length === 0) {
throw new Error('Failed to generate embedding'); throw new Error('Failed to generate embedding');
} }
return results[0]; return results[0];
} }
/** /**
* token * token
* @param text * @param text
*/ */
public async countTokens(text: string): Promise<TokenCountResult> { public async countTokens(text: string): Promise<TokenCountResult> {
if (!this.isModelLoaded) { if (!this.isModelLoaded) {
throw new Error('Model not loaded. Please call loadModel() first.'); throw new Error('Model not loaded. Please call loadModel() first.');
} }
if (!text) { if (!text) {
return { tokens: 0 }; return { tokens: 0 };
} }
try { try {
return await this.postRequest<TokenCountResult>('count_tokens', text); return await this.postRequest<TokenCountResult>('count_tokens', text);
} catch (error) { } catch (error) {
console.error('Failed to count tokens:', error); console.error('Failed to count tokens:', error);
throw error; throw error;
} }
} }
/** /**
* *
*/ */
public async unloadModel(): Promise<ModelUnloadResult> { public async unloadModel(): Promise<ModelUnloadResult> {
if (!this.isModelLoaded) { if (!this.isModelLoaded) {
console.log('No model to unload'); console.log('No model to unload');
return { model_unloaded: true }; return { model_unloaded: true };
} }
try { try {
console.log(`Unloading model: ${this.currentModelId}`); console.log(`Unloading model: ${this.currentModelId}`);
const result = await this.postRequest<ModelUnloadResult>('unload', {}); const result = await this.postRequest<ModelUnloadResult>('unload', {});
this.isModelLoaded = false; this.isModelLoaded = false;
this.currentModelId = null; this.currentModelId = null;
console.log('Model unloaded successfully'); console.log('Model unloaded successfully');
return result; return result;
} catch (error) { } catch (error) {
console.error('Failed to unload model:', error); console.error('Failed to unload model:', error);
// 即使卸载失败,也重置状态 // 即使卸载失败,也重置状态
this.isModelLoaded = false; this.isModelLoaded = false;
this.currentModelId = null; this.currentModelId = null;
throw error; throw error;
} }
} }
/** /**
* *
*/ */
public get modelLoaded(): boolean { public get modelLoaded(): boolean {
return this.isModelLoaded; return this.isModelLoaded;
} }
/** /**
* ID * ID
*/ */
public get currentModel(): string | null { public get currentModel(): string | null {
return this.currentModelId; return this.currentModelId;
} }
/** /**
* *
*/ */
public getSupportedModels(): string[] { public getSupportedModels(): string[] {
return [ return [
'Xenova/all-MiniLM-L6-v2', 'TaylorAI/bge-micro-v2',
'Xenova/bge-small-en-v1.5', 'Xenova/all-MiniLM-L6-v2',
'Xenova/bge-base-en-v1.5', 'Xenova/bge-small-en-v1.5',
'Xenova/jina-embeddings-v2-base-zh', 'Xenova/bge-base-en-v1.5',
'Xenova/jina-embeddings-v2-small-en', 'Xenova/jina-embeddings-v2-base-zh',
'Xenova/multilingual-e5-small', 'Xenova/jina-embeddings-v2-small-en',
'Xenova/multilingual-e5-base', 'Xenova/multilingual-e5-small',
'Xenova/gte-small', 'Xenova/multilingual-e5-base',
'Xenova/e5-small-v2', 'Xenova/gte-small',
'Xenova/e5-base-v2' 'Xenova/e5-small-v2',
]; 'Xenova/e5-base-v2'
} ];
}
/** /**
* *
*/ */
public getModelInfo(modelId: string): { dims: number; maxTokens: number; description: string } | null { public getModelInfo(modelId: string): { dims: number; maxTokens: number; description: string } | null {
const modelInfoMap: Record<string, { dims: number; maxTokens: number; description: string }> = { const modelInfoMap: Record<string, { dims: number; maxTokens: number; description: string }> = {
'Xenova/all-MiniLM-L6-v2': { dims: 384, maxTokens: 512, description: 'All-MiniLM-L6-v2 (推荐,轻量级)' }, 'Xenova/all-MiniLM-L6-v2': { dims: 384, maxTokens: 512, description: 'All-MiniLM-L6-v2 (推荐,轻量级)' },
'Xenova/bge-small-en-v1.5': { dims: 384, maxTokens: 512, description: 'BGE-small-en-v1.5' }, 'Xenova/bge-small-en-v1.5': { dims: 384, maxTokens: 512, description: 'BGE-small-en-v1.5' },
'Xenova/bge-base-en-v1.5': { dims: 768, maxTokens: 512, description: 'BGE-base-en-v1.5 (更高质量)' }, 'Xenova/bge-base-en-v1.5': { dims: 768, maxTokens: 512, description: 'BGE-base-en-v1.5 (更高质量)' },
'Xenova/jina-embeddings-v2-base-zh': { dims: 768, maxTokens: 8192, description: 'Jina-v2-base-zh (中英双语)' }, 'Xenova/jina-embeddings-v2-base-zh': { dims: 768, maxTokens: 8192, description: 'Jina-v2-base-zh (中英双语)' },
'Xenova/jina-embeddings-v2-small-en': { dims: 512, maxTokens: 8192, description: 'Jina-v2-small-en' }, 'Xenova/jina-embeddings-v2-small-en': { dims: 512, maxTokens: 8192, description: 'Jina-v2-small-en' },
'Xenova/multilingual-e5-small': { dims: 384, maxTokens: 512, description: 'E5-small (多语言)' }, 'Xenova/multilingual-e5-small': { dims: 384, maxTokens: 512, description: 'E5-small (多语言)' },
'Xenova/multilingual-e5-base': { dims: 768, maxTokens: 512, description: 'E5-base (多语言,更高质量)' }, 'Xenova/multilingual-e5-base': { dims: 768, maxTokens: 512, description: 'E5-base (多语言,更高质量)' },
'Xenova/gte-small': { dims: 384, maxTokens: 512, description: 'GTE-small' }, 'Xenova/gte-small': { dims: 384, maxTokens: 512, description: 'GTE-small' },
'Xenova/e5-small-v2': { dims: 384, maxTokens: 512, description: 'E5-small-v2' }, 'Xenova/e5-small-v2': { dims: 384, maxTokens: 512, description: 'E5-small-v2' },
'Xenova/e5-base-v2': { dims: 768, maxTokens: 512, description: 'E5-base-v2 (更高质量)' } 'Xenova/e5-base-v2': { dims: 768, maxTokens: 512, description: 'E5-base-v2 (更高质量)' }
}; };
return modelInfoMap[modelId] || null; return modelInfoMap[modelId] || null;
} }
/** /**
* Worker * Worker
*/ */
public terminate() { public terminate() {
this.worker.terminate(); this.worker.terminate();
this.requests.clear(); this.requests.clear();
this.isModelLoaded = false; this.isModelLoaded = false;
} }
} }

View File

@ -48,7 +48,7 @@ async function loadTransformers() {
env.allowRemoteModels = true; env.allowRemoteModels = true;
// 配置 WASM 后端 - 修复线程配置 // 配置 WASM 后端 - 修复线程配置
env.backends.onnx.wasm.numThreads = 4; // 在 Worker 中使用单线程,避免竞态条件 env.backends.onnx.wasm.numThreads = 1; // 在 Worker 中使用单线程,避免竞态条件
env.backends.onnx.wasm.simd = true; env.backends.onnx.wasm.simd = true;
// 禁用 Node.js 特定功能 // 禁用 Node.js 特定功能
@ -201,7 +201,7 @@ async function embedBatch(inputs: EmbedInput[]): Promise<EmbedResult[]> {
} }
// 批处理大小(可以根据需要调整) // 批处理大小(可以根据需要调整)
const batchSize = 1; const batchSize = 8;
if (filteredInputs.length > batchSize) { if (filteredInputs.length > batchSize) {
console.log(`Processing ${filteredInputs.length} inputs in batches of ${batchSize}`); console.log(`Processing ${filteredInputs.length} inputs in batches of ${batchSize}`);

View File

@ -8,8 +8,8 @@ export { EmbeddingManager };
// 导出类型定义 // 导出类型定义
export type { export type {
EmbedResult, EmbedResult,
ModelLoadResult, ModelLoadResult,
ModelUnloadResult, ModelUnloadResult,
TokenCountResult TokenCountResult
} from './EmbeddingManager'; } from './EmbeddingManager';

View File

@ -1641,6 +1641,7 @@ export const localProviderDefaultAutoCompleteModelId = null // this is not suppo
export const localProviderDefaultEmbeddingModelId: keyof typeof localProviderEmbeddingModels = "TaylorAI/bge-micro-v2" export const localProviderDefaultEmbeddingModelId: keyof typeof localProviderEmbeddingModels = "TaylorAI/bge-micro-v2"
export const localProviderEmbeddingModels = { export const localProviderEmbeddingModels = {
'TaylorAI/bge-micro-v2': { dimensions: 384, description: 'BGE-micro-v2 (本地512令牌384维)' },
'Xenova/all-MiniLM-L6-v2': { dimensions: 384, description: 'All-MiniLM-L6-v2 (推荐,轻量级)' }, 'Xenova/all-MiniLM-L6-v2': { dimensions: 384, description: 'All-MiniLM-L6-v2 (推荐,轻量级)' },
'Xenova/bge-small-en-v1.5': { dimensions: 384, description: 'BGE-small-en-v1.5' }, 'Xenova/bge-small-en-v1.5': { dimensions: 384, description: 'BGE-small-en-v1.5' },
'Xenova/bge-base-en-v1.5': { dimensions: 768, description: 'BGE-base-en-v1.5 (更高质量)' }, 'Xenova/bge-base-en-v1.5': { dimensions: 768, description: 'BGE-base-en-v1.5 (更高质量)' },
@ -1651,8 +1652,6 @@ export const localProviderEmbeddingModels = {
'Xenova/gte-small': { dimensions: 384, description: 'GTE-small' }, 'Xenova/gte-small': { dimensions: 384, description: 'GTE-small' },
'Xenova/e5-small-v2': { dimensions: 384, description: 'E5-small-v2' }, 'Xenova/e5-small-v2': { dimensions: 384, description: 'E5-small-v2' },
'Xenova/e5-base-v2': { dimensions: 768, description: 'E5-base-v2 (更高质量)' }, 'Xenova/e5-base-v2': { dimensions: 768, description: 'E5-base-v2 (更高质量)' },
// 新增的模型
'TaylorAI/bge-micro-v2': { dimensions: 384, description: 'BGE-micro-v2 (本地512令牌384维)' },
'Snowflake/snowflake-arctic-embed-xs': { dimensions: 384, description: 'Snowflake Arctic Embed XS (本地512令牌384维)' }, 'Snowflake/snowflake-arctic-embed-xs': { dimensions: 384, description: 'Snowflake Arctic Embed XS (本地512令牌384维)' },
'Snowflake/snowflake-arctic-embed-s': { dimensions: 384, description: 'Snowflake Arctic Embed Small (本地512令牌384维)' }, 'Snowflake/snowflake-arctic-embed-s': { dimensions: 384, description: 'Snowflake Arctic Embed Small (本地512令牌384维)' },
'Snowflake/snowflake-arctic-embed-m': { dimensions: 768, description: 'Snowflake Arctic Embed Medium (本地512令牌768维)' }, 'Snowflake/snowflake-arctic-embed-m': { dimensions: 768, description: 'Snowflake Arctic Embed Medium (本地512令牌768维)' },