From c89186a40d2eda062f08fcf501730002bc079e1c Mon Sep 17 00:00:00 2001 From: duanfuxiang Date: Mon, 7 Jul 2025 16:56:12 +0800 Subject: [PATCH] Optimize the search view component, add model selection functionality, support multiple search modes (notes, insights, all), update internationalization support, improve user interaction prompts, enhance log output, and ensure better user experience and code readability. --- src/components/chat-view/ChatView.tsx | 2 +- src/components/chat-view/InsightView.tsx | 481 +++++++------- src/components/chat-view/SearchView.tsx | 590 ++++++++++++------ .../chat-input/LexicalContentEditable.tsx | 4 +- .../chat-view/chat-input/ModelSelect.tsx | 447 ++++++++++--- .../chat-input/SearchInputWithActions.tsx | 45 +- .../chat-view/chat-input/SearchModeSelect.tsx | 164 +++++ src/database/modules/vector/vector-manager.ts | 2 +- src/embedworker/embed.worker.ts | 215 +++++-- src/types/settings.ts | 18 + styles.css | 18 + 11 files changed, 1393 insertions(+), 593 deletions(-) create mode 100644 src/components/chat-view/chat-input/SearchModeSelect.tsx diff --git a/src/components/chat-view/ChatView.tsx b/src/components/chat-view/ChatView.tsx index 7251af3..3035326 100644 --- a/src/components/chat-view/ChatView.tsx +++ b/src/components/chat-view/ChatView.tsx @@ -193,7 +193,7 @@ const Chat = forwardRef((props, ref) => { } } - const [tab, setTab] = useState<'chat' | 'commands' | 'custom-mode' | 'mcp' | 'search' | 'history' | 'workspace' | 'insights'>('chat') + const [tab, setTab] = useState<'chat' | 'commands' | 'custom-mode' | 'mcp' | 'search' | 'history' | 'workspace' | 'insights'>('search') const [selectedSerializedNodes, setSelectedSerializedNodes] = useState([]) diff --git a/src/components/chat-view/InsightView.tsx b/src/components/chat-view/InsightView.tsx index 54e0816..4674585 100644 --- a/src/components/chat-view/InsightView.tsx +++ b/src/components/chat-view/InsightView.tsx @@ -1,4 +1,4 @@ -import { ChevronDown, ChevronRight } from 'lucide-react' +import { ChevronDown, ChevronRight, RotateCcw } from 'lucide-react' import { useCallback, useEffect, useMemo, useState } from 'react' import { useApp } from '../../contexts/AppContext' @@ -12,6 +12,8 @@ import { t } from '../../lang/helpers' import { getFilesWithTag } from '../../utils/glob-utils' import { openMarkdownFile } from '../../utils/obsidian' +import { ModelSelect } from './chat-input/ModelSelect' + // 洞察源分组结果接口 interface InsightFileGroup { path: string @@ -52,7 +54,7 @@ const InsightView = () => { result?: InitWorkspaceInsightResult workspaceName?: string }>({ show: false }) - + // 删除洞察状态 const [isDeleting, setIsDeleting] = useState(false) const [deletingInsightId, setDeletingInsightId] = useState(null) @@ -95,25 +97,25 @@ const InsightView = () => { for (const item of currentWorkspace.content) { if (item.type === 'folder') { const folderPath = item.content - + // 添加文件夹路径本身 workspacePaths.add(folderPath) - + // 获取文件夹下的所有文件 - const files = app.vault.getMarkdownFiles().filter(file => + const files = app.vault.getMarkdownFiles().filter(file => file.path.startsWith(folderPath === '/' ? '' : folderPath + '/') ) - + // 添加所有文件路径 files.forEach(file => { workspacePaths.add(file.path) - + // 添加中间文件夹路径 const dirPath = file.path.substring(0, file.path.lastIndexOf('/')) if (dirPath && dirPath !== folderPath) { let currentPath = folderPath === '/' ? '' : folderPath const pathParts = dirPath.substring(currentPath.length).split('/').filter(Boolean) - + for (let i = 0; i < pathParts.length; i++) { currentPath += (currentPath ? '/' : '') + pathParts[i] workspacePaths.add(currentPath) @@ -124,16 +126,16 @@ const InsightView = () => { } else if (item.type === 'tag') { // 获取标签对应的所有文件 const tagFiles = getFilesWithTag(item.content, app) - + tagFiles.forEach(filePath => { workspacePaths.add(filePath) - + // 添加文件所在的文件夹路径 const dirPath = filePath.substring(0, filePath.lastIndexOf('/')) if (dirPath) { const pathParts = dirPath.split('/').filter(Boolean) let currentPath = '' - + for (let i = 0; i < pathParts.length; i++) { currentPath += (currentPath ? '/' : '') + pathParts[i] workspacePaths.add(currentPath) @@ -147,7 +149,7 @@ const InsightView = () => { // 过滤洞察 let filteredInsights = allInsights if (workspacePaths) { - filteredInsights = allInsights.filter(insight => + filteredInsights = allInsights.filter(insight => workspacePaths.has(insight.source_path) ) } @@ -196,13 +198,13 @@ const InsightView = () => { } const transEngine = await getTransEngine() - + // 使用新的 initWorkspaceInsight 方法 const result = await transEngine.initWorkspaceInsight({ workspace: currentWorkspace, model: { - provider: settings.applyModelProvider, - modelId: settings.applyModelId, + provider: settings.insightModelProvider || settings.chatModelProvider, + modelId: settings.insightModelId || settings.chatModelId, }, onProgress: (progress) => { setInitProgress({ @@ -215,10 +217,10 @@ const InsightView = () => { } }) - if (result.success) { + if (result.success) { // 刷新洞察列表 await loadInsights() - + // 显示成功消息和统计信息 console.log(t('insights.success.workspaceInitialized', { name: currentWorkspace.name })) console.log(`✅ 深度处理完成统计:`) @@ -232,19 +234,19 @@ const InsightView = () => { console.log(`🔍 洞察ID: ${result.insightId}`) } console.log(`💡 工作区摘要仅使用顶层配置项目,避免内容重叠`) - + // 显示成功状态 setInitSuccess({ show: true, result: result, workspaceName: currentWorkspace.name }) - + // 3秒后自动隐藏成功消息 setTimeout(() => { setInitSuccess({ show: false }) }, 5000) - + } else { console.error(t('insights.error.initializationFailed'), result.error) throw new Error(String(result.error || t('insights.error.initializationFailed'))) @@ -260,10 +262,7 @@ const InsightView = () => { } }, [getTransEngine, settings, workspaceManager, loadInsights]) - // 确认删除工作区洞察 - const handleDeleteWorkspaceInsights = useCallback(() => { - setShowDeleteConfirm(true) - }, []) + // 确认初始化/更新洞察 const handleInitWorkspaceInsights = useCallback(() => { @@ -282,17 +281,17 @@ const InsightView = () => { } const transEngine = await getTransEngine() - + // 删除工作区的所有转换 const result = await transEngine.deleteWorkspaceTransformations(currentWorkspace) if (result.success) { const workspaceName = currentWorkspace?.name || 'vault' console.log(t('insights.success.workspaceDeleted', { name: workspaceName, count: result.deletedCount })) - + // 刷新洞察列表 await loadInsights() - + // 可以在这里添加用户通知,比如显示删除成功的消息 } else { console.error(t('insights.error.deletionFailed'), result.error) @@ -335,13 +334,13 @@ const InsightView = () => { try { const transEngine = await getTransEngine() - + // 删除单个洞察 const result = await transEngine.deleteSingleInsight(insightId) if (result.success) { console.log(t('insights.success.insightDeleted', { id: insightId })) - + // 刷新洞察列表 await loadInsights() } else { @@ -388,7 +387,7 @@ const InsightView = () => { } else if (insight.source_type === 'folder') { // 文件夹洞察 - 在文件管理器中显示文件夹 console.debug('📁 [InsightView] 点击文件夹洞察:', insight.source_path) - + // 尝试在 Obsidian 文件管理器中显示文件夹 const folder = app.vault.getAbstractFileByPath(insight.source_path) if (folder) { @@ -496,7 +495,7 @@ const InsightView = () => { const typeOrder = { workspace: 0, folder: 1, file: 2 } const typeComparison = typeOrder[a.groupType || 'file'] - typeOrder[b.groupType || 'file'] if (typeComparison !== 0) return typeComparison - + // 同类型按时间排序 return b.maxCreatedAt - a.maxCreatedAt }) @@ -522,76 +521,85 @@ const InsightView = () => {

{t('insights.title')}

- - + */}
- {/* 结果统计 */} - {hasLoaded && !isLoading && ( -
-
-
- {insightResults.length} - 个洞察 + {/* 结果统计 & 洞察操作 */} +
+ {hasLoaded && !isLoading && ( +
+
+ {insightResults.length} + 个洞察
-
+
{insightGroupedResults.length > 0 && ( -
+
{insightGroupedResults.filter(g => g.groupType === 'workspace').length > 0 && ( -
- 🌐 - +
+ 🌐 + {insightGroupedResults.filter(g => g.groupType === 'workspace').length} - 工作区 + 工作区
)} {insightGroupedResults.filter(g => g.groupType === 'folder').length > 0 && ( -
- 📂 - +
+ 📂 + {insightGroupedResults.filter(g => g.groupType === 'folder').length} - 文件夹 + 文件夹
)} {insightGroupedResults.filter(g => g.groupType === 'file').length > 0 && ( -
- 📄 - +
+ 📄 + {insightGroupedResults.filter(g => g.groupType === 'file').length} - 文件 + 文件
)}
)}
+ )} +
+
+ 洞察模型: + +
+
+ +
- )} +
{/* 加载进度 */} @@ -617,10 +625,10 @@ const InsightView = () => {
-
@@ -662,7 +670,7 @@ const InsightView = () => { {t('insights.success.workspaceInitialized', { name: initSuccess.workspaceName })}
- - -
{/* 统计信息 */} - {!isLoadingStats && statisticsInfo && ( -
+
+ {!isLoadingStats && statisticsInfo && (
{statisticsInfo.totalChunks} @@ -566,6 +638,79 @@ const SearchView = () => {
+ )} +
+
+ 嵌入模型: + +
+
+ + +
+
+
+ + {/* 索引进度 */} + {isInitializingRAG && ( +
+
+

正在初始化工作区 RAG 向量索引

+

为当前工作区的文件建立向量索引,提高搜索精度

+
+ {ragInitProgress && ragInitProgress.type === 'indexing' && ragInitProgress.indexProgress && ( +
+
+ 建立向量索引 + + {ragInitProgress.indexProgress.completedChunks} / {ragInitProgress.indexProgress.totalChunks} 块 + +
+
+
+
+
+
+ 共 {ragInitProgress.indexProgress.totalFiles} 个文件 +
+
+ {Math.round((ragInitProgress.indexProgress.completedChunks / Math.max(ragInitProgress.indexProgress.totalChunks, 1)) * 100)}% +
+
+
+ )} +
+ )} + + {/* RAG 初始化成功消息 */} + {ragInitSuccess.show && ( +
+
+ +
+ + 工作区 RAG 向量索引初始化完成: {ragInitSuccess.workspaceName} + +
+ +
)} @@ -581,107 +726,26 @@ const SearchView = () => { placeholder="语义搜索(按回车键搜索)..." autoFocus={true} disabled={isSearching} + searchMode={searchMode} + onSearchModeChange={setSearchMode} /> - - {/* 搜索模式切换 */} -
- - -
- {/* 结果统计 */} + {/* 索引统计 */} {hasSearched && !isSearching && (
{searchMode === 'notes' ? ( `${totalFiles} 个文件,${totalBlocks} 个块` - ) : ( + ) : searchMode === 'insights' ? ( `${insightGroupedResults.length} 个文件,${insightResults.length} 个洞察` + ) : ( + `${totalAllFiles} 个文件,${totalBlocks} 个块,${insightResults.length} 个洞察` )}
)} - - {/* 搜索进度 */} - {isSearching && ( -
- 正在搜索... -
- )} - - {/* RAG 初始化进度 */} - {isInitializingRAG && ( -
-
-

正在初始化工作区 RAG 向量索引

-

为当前工作区的文件建立向量索引,提高搜索精度

-
- {ragInitProgress && ragInitProgress.type === 'indexing' && ragInitProgress.indexProgress && ( -
-
- 建立向量索引 - - {ragInitProgress.indexProgress.completedChunks} / {ragInitProgress.indexProgress.totalChunks} 块 - -
-
-
-
-
-
- 共 {ragInitProgress.indexProgress.totalFiles} 个文件 -
-
- {Math.round((ragInitProgress.indexProgress.completedChunks / Math.max(ragInitProgress.indexProgress.totalChunks, 1)) * 100)}% -
-
-
- )} -
- )} - - {/* RAG 初始化成功消息 */} - {ragInitSuccess.show && ( -
-
- -
- - 工作区 RAG 向量索引初始化完成: {ragInitSuccess.workspaceName} - - - 处理了 {ragInitSuccess.totalFiles} 个文件,生成 {ragInitSuccess.totalChunks} 个向量块 - -
- -
-
- )} - {/* 确认删除对话框 */} {showDeleteConfirm && (
@@ -727,20 +791,20 @@ const SearchView = () => {

- {statisticsInfo && (statisticsInfo.totalFiles > 0 || statisticsInfo.totalChunks > 0) + {statisticsInfo && (statisticsInfo.totalFiles > 0 || statisticsInfo.totalChunks > 0) ? '将更新当前工作区的向量索引,重新处理所有文件以确保索引最新。' : '将为当前工作区的所有文件建立向量索引,这将提高语义搜索的准确性。' }

- 嵌入模型: + 嵌入模型: - {settings.embeddingModelProvider} / {settings.embeddingModelId || '默认模型'} + {settings.embeddingModelId}
- 工作区: + 工作区: {settings.workspace === 'vault' ? '整个 Vault' : settings.workspace} @@ -768,6 +832,13 @@ const SearchView = () => {
)} + {/* 搜索进度 */} + {isSearching && ( +
+ 正在搜索... +
+ )} + {/* 搜索结果 */}
{searchMode === 'notes' ? ( @@ -777,7 +848,7 @@ const SearchView = () => { {groupedResults.map((fileGroup) => (
{/* 文件头部 */} -
toggleFileExpansion(fileGroup.path)} > @@ -827,14 +898,14 @@ const SearchView = () => { ))}
) - ) : ( + ) : searchMode === 'insights' ? ( // AI 洞察搜索结果 !isSearching && insightGroupedResults.length > 0 && (
{insightGroupedResults.map((fileGroup) => (
{/* 文件头部 */} -
toggleFileExpansion(fileGroup.path)} > @@ -885,21 +956,131 @@ const SearchView = () => { ))}
) + ) : ( + // 全部搜索结果:按文件聚合显示原始笔记和洞察 + !isSearching && allGroupedResults.length > 0 && ( +
+ {allGroupedResults.map((fileGroup) => ( +
+ {/* 文件头部 */} +
toggleFileExpansion(fileGroup.path)} + > +
+
+
+ {expandedFiles.has(fileGroup.path) ? ( + + ) : ( + + )} + {fileGroup.fileName} +
+
+
+ {fileGroup.path} +
+
+
+ + {/* 文件内容:混合显示笔记块和洞察 */} + {expandedFiles.has(fileGroup.path) && ( +
+ {/* AI 洞察 */} + {fileGroup.insights.map((insight, insightIndex) => ( +
+
+ {insightIndex + 1} + + {insight.insight_type.toUpperCase()} + + + {insight.similarity.toFixed(3)} + +
+
+
+ {insight.insight} +
+
+
+ ))} + {/* 原始笔记块 */} + {fileGroup.blocks.map((result, blockIndex) => ( +
handleResultClick(result)} + > +
+ {blockIndex + 1} + + L{result.metadata.startLine}-{result.metadata.endLine} + + + {result.similarity.toFixed(3)} + +
+
+ {renderMarkdownContent(result.content)} +
+
+ ))} +
+ )} +
+ ))} +
+ ) )} - + {!isSearching && hasSearched && ( - (searchMode === 'notes' && groupedResults.length === 0) || - (searchMode === 'insights' && insightGroupedResults.length === 0) + (searchMode === 'notes' && groupedResults.length === 0) || + (searchMode === 'insights' && insightGroupedResults.length === 0) || + (searchMode === 'all' && allGroupedResults.length === 0) ) && ( -
-

未找到相关结果

-
- )} +
+

未找到相关结果

+
+ )}
{/* 样式 */} diff --git a/src/components/chat-view/chat-input/SearchInputWithActions.tsx b/src/components/chat-view/chat-input/SearchInputWithActions.tsx index 85e0ae6..d506611 100644 --- a/src/components/chat-view/chat-input/SearchInputWithActions.tsx +++ b/src/components/chat-view/chat-input/SearchInputWithActions.tsx @@ -6,10 +6,12 @@ import { useState } from 'react' + import { Mentionable } from '../../../types/mentionable' import LexicalContentEditable from './LexicalContentEditable' import { SearchButton } from './SearchButton' +import { SearchModeSelect } from './SearchModeSelect' export type SearchInputRef = { focus: () => void @@ -25,26 +27,28 @@ export type SearchInputProps = { placeholder?: string autoFocus?: boolean disabled?: boolean + searchMode?: 'notes' | 'insights' | 'all' + onSearchModeChange?: (mode: 'notes' | 'insights' | 'all') => void } -// 检查编辑器状态是否为空的辅助函数 +// 检查编辑器状态是否为空 const isEditorStateEmpty = (editorState: SerializedEditorState): boolean => { - if (!editorState || !editorState.root || !editorState.root.children) { + try { + const root = editorState.root + if (!root || !root.children) return true + + // 检查是否有实际内容 + const hasContent = root.children.some((child: any) => { + if (child.type === 'paragraph') { + return child.children && child.children.length > 0 + } + return true + }) + + return !hasContent + } catch (error) { return true } - - const children = editorState.root.children - if (children.length === 0) { - return true - } - - // 检查是否只有空的段落 - if (children.length === 1 && children[0].type === 'paragraph') { - const paragraph = children[0] as any - return !paragraph.children || paragraph.children.length === 0 - } - - return false } const SearchInputWithActions = forwardRef( @@ -56,6 +60,8 @@ const SearchInputWithActions = forwardRef( placeholder = '', autoFocus = false, disabled = false, + searchMode = 'all', + onSearchModeChange, }, ref ) => { @@ -112,6 +118,7 @@ const SearchInputWithActions = forwardRef(
)} { if (initialSerializedEditorState) { editor.setEditorState( @@ -139,7 +146,13 @@ const SearchInputWithActions = forwardRef(
- {/* TODO: add model select */} + {onSearchModeChange && ( + + )} +
handleSubmit()} /> diff --git a/src/components/chat-view/chat-input/SearchModeSelect.tsx b/src/components/chat-view/chat-input/SearchModeSelect.tsx new file mode 100644 index 0000000..63a9c38 --- /dev/null +++ b/src/components/chat-view/chat-input/SearchModeSelect.tsx @@ -0,0 +1,164 @@ +import * as DropdownMenu from '@radix-ui/react-dropdown-menu' +import { ChevronDown, ChevronUp, FileText, Lightbulb, Globe } from 'lucide-react' +import { useState } from 'react' + +interface SearchModeSelectProps { + searchMode: 'notes' | 'insights' | 'all' + onSearchModeChange: (mode: 'notes' | 'insights' | 'all') => void +} + +export function SearchModeSelect({ searchMode, onSearchModeChange }: SearchModeSelectProps) { + const [isOpen, setIsOpen] = useState(false) + + const searchModes = [ + { + value: 'all' as const, + name: '全部', + icon: , + description: '聚合搜索原始笔记和 AI 洞察' + }, + { + value: 'notes' as const, + name: '原始笔记', + icon: , + description: '搜索原始笔记内容' + }, + { + value: 'insights' as const, + name: 'AI 洞察', + icon: , + description: '搜索 AI 洞察内容' + } + ] + + const currentMode = searchModes.find((m) => m.value === searchMode) + + return ( + <> + + + {currentMode?.icon} +
+ {currentMode?.name} +
+
+ {isOpen ? : } +
+
+ + + +
    + {searchModes.map((mode) => ( + { + onSearchModeChange(mode.value) + }} + asChild + > +
  • +
    + {mode.icon} +
    + {mode.name} + {mode.description} +
    +
    +
  • +
    + ))} +
+
+
+
+ + + ) +} diff --git a/src/database/modules/vector/vector-manager.ts b/src/database/modules/vector/vector-manager.ts index 6f72a1c..fa807d9 100644 --- a/src/database/modules/vector/vector-manager.ts +++ b/src/database/modules/vector/vector-manager.ts @@ -254,7 +254,7 @@ export class VectorManager { await backOff( async () => { - // 在嵌入之前处理 markdown,只处理一次 + // 在嵌入之前处理 markdown const cleanedBatchData = batchChunks.map(chunk => { const cleanContent = removeMarkdown(chunk.content).replace(/\0/g, '') return { chunk, cleanContent } diff --git a/src/embedworker/embed.worker.ts b/src/embedworker/embed.worker.ts index f2d2eb6..4c0bf8a 100644 --- a/src/embedworker/embed.worker.ts +++ b/src/embedworker/embed.worker.ts @@ -8,36 +8,153 @@ interface EmbedResult { vec: number[]; tokens: number; embed_input?: string; + error?: string; } +// 定义工作器消息的参数类型 +interface LoadParams { + model_key: string; + use_gpu?: boolean; +} + +interface EmbedBatchParams { + inputs: EmbedInput[]; +} + +type WorkerParams = LoadParams | EmbedBatchParams | string | undefined; + interface WorkerMessage { method: string; - params: any; + params: WorkerParams; id: number; worker_id?: string; } interface WorkerResponse { id: number; - result?: any; + result?: unknown; error?: string; worker_id?: string; } +// 定义 Transformers.js 相关类型 +interface TransformersEnv { + allowLocalModels: boolean; + allowRemoteModels: boolean; + backends: { + onnx: { + wasm: { + numThreads: number; + simd: boolean; + }; + }; + }; + useFS: boolean; + useBrowserCache: boolean; + remoteHost?: string; +} + +interface PipelineOptions { + quantized?: boolean; + progress_callback?: (progress: unknown) => void; + device?: string; + dtype?: string; +} + +interface ModelInfo { + loaded: boolean; + model_key: string; + use_gpu: boolean; +} + +interface TokenizerResult { + input_ids: { + data: number[]; + }; +} + +interface GlobalTransformers { + pipelineFactory: (task: string, model: string, options?: PipelineOptions) => Promise; + AutoTokenizer: { + from_pretrained: (model: string) => Promise; + }; + env: TransformersEnv; +} + // 全局变量 -let model: any = null; -let pipeline: any = null; -let tokenizer: any = null; +let model: ModelInfo | null = null; +let pipeline: unknown = null; +let tokenizer: unknown = null; let processing_message = false; let transformersLoaded = false; +/** + * 测试一个网络端点是否可访问 + * @param {string} url 要测试的 URL + * @param {number} timeout 超时时间 (毫秒) + * @returns {Promise} 如果可访问则返回 true,否则返回 false + */ +async function testEndpoint(url: string, timeout = 3000): Promise { + // AbortController 用于在超时后取消 fetch 请求 + const controller = new AbortController(); + const signal = controller.signal; + + const timeoutId = setTimeout(() => { + console.log(`请求 ${url} 超时。`); + controller.abort(); + }, timeout); + + try { + console.log(`正在测试端点: ${url}`); + // 我们使用 'HEAD' 方法,因为它只请求头部信息,非常快速,适合做存活检测。 + // 'no-cors' 模式允许我们在浏览器环境中进行跨域请求以进行简单的可达性测试, + // 即使我们不能读取响应内容,请求成功也意味着网络是通的。 + await fetch(url, { method: 'HEAD', mode: 'no-cors', signal }); + + // 如果 fetch 成功,清除超时定时器并返回 true + clearTimeout(timeoutId); + console.log(`端点 ${url} 可访问。`); + return true; + } catch (error) { + // 如果发生网络错误或请求被中止 (超时),则进入 catch 块 + clearTimeout(timeoutId); // 同样需要清除定时器 + console.warn(`无法访问端点 ${url}:`, error instanceof Error && error.name === 'AbortError' ? '超时' : (error as Error).message); + return false; + } +} + +/** + * 初始化 Hugging Face 端点,如果默认的不可用,则自动切换到备用镜像。 + */ +async function initializeEndpoint(): Promise { + const defaultEndpoint = 'https://huggingface.co'; + const fallbackEndpoint = 'https://hf-mirror.com'; + + const isDefaultReachable = await testEndpoint(defaultEndpoint); + + const globalTransformers = globalThis as unknown as { transformers?: GlobalTransformers }; + + if (!isDefaultReachable) { + console.log(`默认端点不可达,将切换到备用镜像: ${fallbackEndpoint}`); + // 这是关键步骤:在代码中设置 endpoint + if (globalTransformers.transformers?.env) { + globalTransformers.transformers.env.remoteHost = fallbackEndpoint; + } + } else { + console.log(`将使用默认端点: ${defaultEndpoint}`); + } +} + // 动态导入 Transformers.js -async function loadTransformers() { +async function loadTransformers(): Promise { if (transformersLoaded) return; try { console.log('Loading Transformers.js...'); + // 首先初始化端点 + await initializeEndpoint(); + // 尝试使用旧版本的 Transformers.js,它在 Worker 中更稳定 const { pipeline: pipelineFactory, env, AutoTokenizer } = await import('@xenova/transformers'); @@ -53,9 +170,12 @@ async function loadTransformers() { env.useFS = false; env.useBrowserCache = true; - (globalThis as any).pipelineFactory = pipelineFactory; - (globalThis as any).AutoTokenizer = AutoTokenizer; - (globalThis as any).env = env; + const globalTransformers = globalThis as unknown as { transformers?: GlobalTransformers }; + globalTransformers.transformers = { + pipelineFactory, + AutoTokenizer, + env + }; transformersLoaded = true; console.log('Transformers.js loaded successfully'); @@ -65,22 +185,27 @@ async function loadTransformers() { } } -async function loadModel(modelKey: string, useGpu: boolean = false) { +async function loadModel(modelKey: string, useGpu: boolean = false): Promise<{ model_loaded: boolean }> { try { console.log(`Loading model: ${modelKey}, GPU: ${useGpu}`); // 确保 Transformers.js 已加载 await loadTransformers(); - const pipelineFactory = (globalThis as any).pipelineFactory; - const AutoTokenizer = (globalThis as any).AutoTokenizer; - const env = (globalThis as any).env; + const globalTransformers = globalThis as unknown as { transformers?: GlobalTransformers }; + const transformers = globalTransformers.transformers; + + if (!transformers) { + throw new Error('Transformers.js not loaded'); + } + + const { pipelineFactory, AutoTokenizer } = transformers; // 配置管道选项 - const pipelineOpts: any = { + const pipelineOpts: PipelineOptions = { quantized: true, // 修复进度回调,添加错误处理 - progress_callback: (progress: any) => { + progress_callback: (progress: unknown) => { try { if (progress && typeof progress === 'object') { // console.log('Model loading progress:', progress); @@ -96,9 +221,9 @@ async function loadModel(modelKey: string, useGpu: boolean = false) { if (useGpu) { try { // 检查 WebGPU 支持 - console.log("useGpu", useGpu) + console.log("useGpu", useGpu); if (typeof navigator !== 'undefined' && 'gpu' in navigator) { - const gpu = (navigator as any).gpu; + const gpu = (navigator as { gpu?: { requestAdapter?: () => unknown } }).gpu; if (gpu && typeof gpu.requestAdapter === 'function') { console.log('[Transformers] Attempting to use GPU'); pipelineOpts.device = 'webgpu'; @@ -137,21 +262,17 @@ async function loadModel(modelKey: string, useGpu: boolean = false) { } } -async function unloadModel() { +async function unloadModel(): Promise<{ model_unloaded: boolean }> { try { console.log('Unloading model...'); - if (pipeline) { - if (pipeline.destroy) { - pipeline.destroy(); - } - pipeline = null; - } - - if (tokenizer) { - tokenizer = null; + if (pipeline && typeof pipeline === 'object' && 'destroy' in pipeline) { + const pipelineWithDestroy = pipeline as { destroy: () => void }; + pipelineWithDestroy.destroy(); } + pipeline = null; + tokenizer = null; model = null; console.log('Model unloaded successfully'); @@ -163,13 +284,14 @@ async function unloadModel() { } } -async function countTokens(input: string) { +async function countTokens(input: string): Promise<{ tokens: number }> { try { if (!tokenizer) { throw new Error('Tokenizer not loaded'); } - const { input_ids } = await tokenizer(input); + const tokenizerWithCall = tokenizer as (input: string) => Promise; + const { input_ids } = await tokenizerWithCall(input); return { tokens: input_ids.data.length }; } catch (error) { @@ -249,7 +371,8 @@ async function processBatch(batchInputs: EmbedInput[]): Promise { ); // 生成嵌入向量 - const resp = await pipeline(embedInputs, { pooling: 'mean', normalize: true }); + const pipelineCall = pipeline as (inputs: string[], options: { pooling: string; normalize: boolean }) => Promise<{ data: number[] }[]>; + const resp = await pipelineCall(embedInputs, { pooling: 'mean', normalize: true }); // 处理结果 return batchInputs.map((item, i) => ({ @@ -262,10 +385,11 @@ async function processBatch(batchInputs: EmbedInput[]): Promise { console.error('Error processing batch:', error); // 如果批处理失败,尝试逐个处理 - return Promise.all( - batchInputs.map(async (item) => { + const results = await Promise.all( + batchInputs.map(async (item): Promise => { try { - const result = await pipeline(item.embed_input, { pooling: 'mean', normalize: true }); + const pipelineCall = pipeline as (input: string, options: { pooling: string; normalize: boolean }) => Promise<{ data: number[] }[]>; + const result = await pipelineCall(item.embed_input, { pooling: 'mean', normalize: true }); const tokenCount = await countTokens(item.embed_input); return { @@ -279,11 +403,13 @@ async function processBatch(batchInputs: EmbedInput[]): Promise { vec: [], tokens: 0, embed_input: item.embed_input, - error: (singleError as Error).message - } as any; + error: singleError instanceof Error ? singleError.message : 'Unknown error' + }; } }) ); + + return results; } } @@ -291,12 +417,13 @@ async function processMessage(data: WorkerMessage): Promise { const { method, params, id, worker_id } = data; try { - let result: any; + let result: unknown; switch (method) { case 'load': console.log('Load method called with params:', params); - result = await loadModel(params.model_key, params.use_gpu || false); + const loadParams = params as LoadParams; + result = await loadModel(loadParams.model_key, loadParams.use_gpu || false); break; case 'unload': @@ -318,7 +445,8 @@ async function processMessage(data: WorkerMessage): Promise { } processing_message = true; - result = await embedBatch(params.inputs); + const embedParams = params as EmbedBatchParams; + result = await embedBatch(embedParams.inputs); processing_message = false; break; @@ -336,7 +464,8 @@ async function processMessage(data: WorkerMessage): Promise { } processing_message = true; - result = await countTokens(params); + const tokenParams = params as string; + result = await countTokens(tokenParams); processing_message = false; break; @@ -349,7 +478,7 @@ async function processMessage(data: WorkerMessage): Promise { } catch (error) { console.error('Error processing message:', error); processing_message = false; - return { id, error: (error as Error).message, worker_id }; + return { id, error: error instanceof Error ? error.message : 'Unknown error', worker_id }; } } @@ -367,14 +496,14 @@ self.addEventListener('message', async (event) => { return; } - const response = await processMessage(event.data); + const response = await processMessage(event.data as WorkerMessage); console.log('Worker sending response:', response); self.postMessage(response); } catch (error) { console.error('Unhandled error in worker message handler:', error); self.postMessage({ - id: event.data?.id || -1, - error: `Worker error: ${error.message || 'Unknown error'}` + id: (event.data as { id?: number })?.id || -1, + error: `Worker error: ${error instanceof Error ? error.message : 'Unknown error'}` }); } }); diff --git a/src/types/settings.ts b/src/types/settings.ts index e1d59c6..b96e142 100644 --- a/src/types/settings.ts +++ b/src/types/settings.ts @@ -282,6 +282,24 @@ export const InfioSettingsSchema = z.object({ modelId: z.string(), })).catch([]), + // Insight Model start list + collectedInsightModels: z.array(z.object({ + provider: z.nativeEnum(ApiProvider), + modelId: z.string(), + })).catch([]), + + // Apply Model start list + collectedApplyModels: z.array(z.object({ + provider: z.nativeEnum(ApiProvider), + modelId: z.string(), + })).catch([]), + + // Embedding Model start list + collectedEmbeddingModels: z.array(z.object({ + provider: z.nativeEnum(ApiProvider), + modelId: z.string(), + })).catch([]), + // Active Provider Tab (for UI state) activeProviderTab: z.nativeEnum(ApiProvider).catch(ApiProvider.Infio), diff --git a/styles.css b/styles.css index 675106f..264add3 100644 --- a/styles.css +++ b/styles.css @@ -828,6 +828,24 @@ input[type='text'].infio-chat-list-dropdown-item-title-input { word-break: break-all; } +.infio-search-lexical-content-editable-root { + min-height: 36px; + max-height: 500px; + overflow-y: auto; +} + +.infio-search-lexical-content-editable-root .mention { + background-color: var(--tag-background); + color: var(--tag-color); + padding: var(--size-2-1) calc(var(--size-2-1)); + border-radius: var(--radius-s); + background-color: var(--tag-background); + color: var(--tag-color); + padding: 0 calc(var(--size-2-1)); + border-radius: var(--radius-s); + word-break: break-all; +} + .infio-chat-lexical-content-editable-paragraph { margin: 0; line-height: 1.6;