From d15681b0d5fd36415b27d15c3984420de7e24399 Mon Sep 17 00:00:00 2001 From: duanfuxiang Date: Tue, 18 Feb 2025 11:02:24 +0800 Subject: [PATCH] add reasoning block --- src/components/chat-view/Chat.tsx | 15 +- .../chat-view/MarkdownReasoningBlock.tsx | 56 ++++ .../chat-view/SyntaxHighlighterWrapper.tsx | 4 + src/core/llm/openai-message-adapter.ts | 6 +- .../conversation/conversation-manager.ts | 276 +++++++++--------- .../conversation/conversation-repository.ts | 31 +- src/database/schema.ts | 22 +- src/database/sql.ts | 12 + src/hooks/use-chat-history.ts | 6 +- src/types/chat.ts | 2 + src/types/llm/response.ts | 3 +- styles.css | 21 ++ 12 files changed, 285 insertions(+), 169 deletions(-) create mode 100644 src/components/chat-view/MarkdownReasoningBlock.tsx diff --git a/src/components/chat-view/Chat.tsx b/src/components/chat-view/Chat.tsx index d2bb336..7553ad4 100644 --- a/src/components/chat-view/Chat.tsx +++ b/src/components/chat-view/Chat.tsx @@ -44,10 +44,12 @@ import AssistantMessageActions from './AssistantMessageActions' import PromptInputWithActions, { ChatUserInputRef } from './chat-input/PromptInputWithActions' import { editorStateToPlainText } from './chat-input/utils/editor-state-to-plain-text' import { ChatHistory } from './ChatHistory' +import MarkdownReasoningBlock from './MarkdownReasoningBlock' import QueryProgress, { QueryProgressState } from './QueryProgress' import ReactMarkdown from './ReactMarkdown' import ShortcutInfo from './ShortcutInfo' import SimilaritySearchResults from './SimilaritySearchResults' + // Add an empty line here const getNewInputMessage = (app: App): ChatUserMessage => { return { @@ -242,6 +244,7 @@ const Chat = forwardRef((props, ref) => { { role: 'assistant', content: '', + reasoningContent: '', id: responseMessageId, metadata: { usage: undefined, @@ -269,6 +272,7 @@ const Chat = forwardRef((props, ref) => { { role: 'assistant', content: '', + reasoningContent: '', id: responseMessageId, metadata: { usage: undefined, @@ -290,12 +294,14 @@ const Chat = forwardRef((props, ref) => { for await (const chunk of stream) { const content = chunk.choices[0]?.delta?.content ?? '' + const reasoning_content = chunk.choices[0]?.delta?.reasoning_content ?? '' setChatMessages((prevChatHistory) => prevChatHistory.map((message) => message.role === 'assistant' && message.id === responseMessageId ? { ...message, content: message.content + content, + reasoningContent: message.reasoningContent + reasoning_content, metadata: { ...message.metadata, usage: chunk.usage ?? message.metadata?.usage, // Keep existing usage if chunk has no usage data @@ -584,13 +590,7 @@ const Chat = forwardRef((props, ref) => { { // If the chat is empty, show a message to start a new chat chatMessages.length === 0 && ( -
+
) @@ -638,6 +638,7 @@ const Chat = forwardRef((props, ref) => {
) : (
+ ) { + const { isDarkMode } = useDarkModeContext() + const containerRef = useRef(null) + const [isOpen, setIsOpen] = useState(true) + + useEffect(() => { + if (containerRef.current) { + containerRef.current.scrollTop = containerRef.current.scrollHeight + } + }, [reasoningContent]) + + return ( + reasoningContent && ( +
+
+
+ Reasoning +
+ +
+
+ + {reasoningContent} + +
+
+ ) + ) +} diff --git a/src/components/chat-view/SyntaxHighlighterWrapper.tsx b/src/components/chat-view/SyntaxHighlighterWrapper.tsx index 31f5fba..0e2ecb9 100644 --- a/src/components/chat-view/SyntaxHighlighterWrapper.tsx +++ b/src/components/chat-view/SyntaxHighlighterWrapper.tsx @@ -11,13 +11,17 @@ function SyntaxHighlighterWrapper({ hasFilename, wrapLines, children, + isOpen = true, }: { isDarkMode: boolean language: string | undefined hasFilename: boolean wrapLines: boolean children: string + isOpen?: boolean }) { + if (!isOpen) return null; + return ( ({ finish_reason: choice.finish_reason, message: { - content: choice.message.content, + content: choice.message.content, + reasoning_content: choice.message.reasoning_content, role: choice.message.role, }, })), @@ -135,13 +136,14 @@ export class OpenAIMessageAdapter { static parseStreamingResponseChunk( chunk: ChatCompletionChunk, - ): LLMResponseStreaming { + ): LLMResponseStreaming { return { id: chunk.id, choices: chunk.choices.map((choice) => ({ finish_reason: choice.finish_reason ?? null, delta: { content: choice.delta.content ?? null, + reasoning_content: choice.delta.reasoning_content ?? null, role: choice.delta.role, }, })), diff --git a/src/database/modules/conversation/conversation-manager.ts b/src/database/modules/conversation/conversation-manager.ts index 9b490e8..4270b8b 100644 --- a/src/database/modules/conversation/conversation-manager.ts +++ b/src/database/modules/conversation/conversation-manager.ts @@ -1,162 +1,164 @@ -import { SerializedEditorState } from 'lexical' import { App } from 'obsidian' import { editorStateToPlainText } from '../../../components/chat-view/chat-input/utils/editor-state-to-plain-text' import { ChatAssistantMessage, ChatConversationMeta, ChatMessage, ChatUserMessage } from '../../../types/chat' -import { ContentPart } from '../../../types/llm/request' -import { Mentionable, SerializedMentionable } from '../../../types/mentionable' +import { Mentionable } from '../../../types/mentionable' import { deserializeMentionable, serializeMentionable } from '../../../utils/mentionable' import { DBManager } from '../../database-manager' -import { InsertMessage } from '../../schema' +import { InsertMessage, SelectConversation, SelectMessage } from '../../schema' import { ConversationRepository } from './conversation-repository' export class ConversationManager { - private app: App - private repository: ConversationRepository - private dbManager: DBManager + private app: App + private repository: ConversationRepository + private dbManager: DBManager - constructor(app: App, dbManager: DBManager) { - this.app = app - this.dbManager = dbManager - const db = dbManager.getPgClient() - if (!db) throw new Error('Database not initialized') - this.repository = new ConversationRepository(app, db) - } + constructor(app: App, dbManager: DBManager) { + this.app = app + this.dbManager = dbManager + const db = dbManager.getPgClient() + if (!db) throw new Error('Database not initialized') + this.repository = new ConversationRepository(app, db) + } - async createConversation(id: string, title = 'New chat'): Promise { - const conversation = { - id, - title, - createdAt: new Date(), - updatedAt: new Date(), - } - await this.repository.create(conversation) - await this.dbManager.save() - } + async createConversation(id: string, title = 'New chat'): Promise { + const conversation = { + id, + title, + createdAt: new Date(), + updatedAt: new Date(), + } + await this.repository.create(conversation) + await this.dbManager.save() + } - async saveConversation(id: string, messages: ChatMessage[]): Promise { - const conversation = await this.repository.findById(id) - if (!conversation) { - let title = 'New chat' - if (messages.length > 0 && messages[0].role === 'user') { - const query = editorStateToPlainText(messages[0].content) - if (query.length > 20) { - title = `${query.slice(0, 20)}...` - } else { - title = query - } - } - await this.createConversation(id, title) - } + async saveConversation(id: string, messages: ChatMessage[]): Promise { + const conversation = await this.repository.findById(id) + if (!conversation) { + let title = 'New chat' + if (messages.length > 0 && messages[0].role === 'user') { + const query = editorStateToPlainText(messages[0].content) + if (query.length > 20) { + title = `${query.slice(0, 20)}...` + } else { + title = query + } + } + await this.createConversation(id, title) + } - // Delete existing messages - await this.repository.deleteAllMessagesFromConversation(id) + // Delete existing messages + await this.repository.deleteAllMessagesFromConversation(id) - // Insert new messages - for (const message of messages) { - const insertMessage = this.serializeMessage(message, id) - await this.repository.createMessage(insertMessage) - } + // Insert new messages + for (const message of messages) { + const insertMessage = this.serializeMessage(message, id) + await this.repository.createMessage(insertMessage) + } - // Update conversation timestamp - await this.repository.update(id, { updatedAt: new Date() }) - await this.dbManager.save() - } + // Update conversation timestamp + await this.repository.update(id, { updatedAt: new Date() }) + await this.dbManager.save() + } - async findConversation(id: string): Promise { - const conversation = await this.repository.findById(id) - if (!conversation) { - return null - } + async findConversation(id: string): Promise { + const conversation = await this.repository.findById(id) + if (!conversation) { + return null + } - const messages = await this.repository.findMessagesByConversationId(id) - return messages.map(msg => this.deserializeMessage(msg)) - } + const messages = await this.repository.findMessagesByConversationId(id) + return messages.map(msg => this.deserializeMessage(msg)) + } - async deleteConversation(id: string): Promise { - await this.repository.delete(id) - await this.dbManager.save() - } + async deleteConversation(id: string): Promise { + await this.repository.delete(id) + await this.dbManager.save() + } - getAllConversations(callback: (conversations: ChatConversationMeta[]) => void): void { - const db = this.dbManager.getPgClient() - db?.live.query('SELECT * FROM conversations ORDER BY updated_at', [], (results) => { - callback(results.rows.map(conv => ({ - id: conv.id, - title: conv.title, - schemaVersion: 2, - createdAt: conv.createdAt instanceof Date ? conv.createdAt.getTime() : conv.createdAt, - updatedAt: conv.updatedAt instanceof Date ? conv.updatedAt.getTime() : conv.updatedAt, - }))) - }) - } + getAllConversations(callback: (conversations: ChatConversationMeta[]) => void): void { + const db = this.dbManager.getPgClient() + db?.live.query('SELECT * FROM conversations ORDER BY updated_at DESC', [], (results: { rows: Array }) => { + callback(results.rows.map(conv => ({ + schemaVersion: 2, + id: conv.id, + title: conv.title, + createdAt: conv.created_at instanceof Date ? conv.created_at.getTime() : conv.created_at, + updatedAt: conv.updated_at instanceof Date ? conv.updated_at.getTime() : conv.updated_at, + }))) + }) + } - async updateConversationTitle(id: string, title: string): Promise { - await this.repository.update(id, { title }) - await this.dbManager.save() - } + async updateConversationTitle(id: string, title: string): Promise { + await this.repository.update(id, { title }) + await this.dbManager.save() + } - private serializeMessage(message: ChatMessage, conversationId: string): InsertMessage { - const base = { - id: message.id, - conversationId, - role: message.role, - createdAt: new Date(), - } + // convert ChatMessage to InsertMessage + private serializeMessage(message: ChatMessage, conversationId: string): InsertMessage { + const base = { + id: message.id, + conversationId: conversationId, + role: message.role, + createdAt: new Date(), + } - if (message.role === 'user') { - const userMessage: ChatUserMessage = message - return { - ...base, - content: userMessage.content ? JSON.stringify(userMessage.content) : null, - promptContent: userMessage.promptContent - ? typeof userMessage.promptContent === 'string' - ? userMessage.promptContent - : JSON.stringify(userMessage.promptContent) - : null, - mentionables: JSON.stringify(userMessage.mentionables.map(serializeMentionable)), - similaritySearchResults: userMessage.similaritySearchResults - ? JSON.stringify(userMessage.similaritySearchResults) - : null, - } - } else { - const assistantMessage: ChatAssistantMessage = message - return { - ...base, - content: assistantMessage.content, - metadata: assistantMessage.metadata ? JSON.stringify(assistantMessage.metadata) : null, - } - } - } + if (message.role === 'user') { + const userMessage: ChatUserMessage = message + return { + ...base, + content: userMessage.content ? JSON.stringify(userMessage.content) : null, + promptContent: userMessage.promptContent + ? typeof userMessage.promptContent === 'string' + ? userMessage.promptContent + : JSON.stringify(userMessage.promptContent) + : null, + mentionables: JSON.stringify(userMessage.mentionables.map(serializeMentionable)), + similaritySearchResults: userMessage.similaritySearchResults + ? JSON.stringify(userMessage.similaritySearchResults) + : null, + } + } else { + const assistantMessage: ChatAssistantMessage = message + return { + ...base, + content: assistantMessage.content, + reasoningContent: assistantMessage.reasoningContent, + metadata: assistantMessage.metadata ? JSON.stringify(assistantMessage.metadata) : null, + } + } + } - private deserializeMessage(message: InsertMessage): ChatMessage { - if (message.role === 'user') { - return { - id: message.id, - role: 'user', - content: message.content ? JSON.parse(message.content) as SerializedEditorState : null, - promptContent: message.promptContent - ? message.promptContent.startsWith('{') - ? JSON.parse(message.promptContent) as ContentPart[] - : message.promptContent - : null, - mentionables: message.mentionables - ? (JSON.parse(message.mentionables) as SerializedMentionable[]) - .map(m => deserializeMentionable(m, this.app)) - .filter((m: Mentionable | null): m is Mentionable => m !== null) - : [], - similaritySearchResults: message.similaritySearchResults - ? JSON.parse(message.similaritySearchResults) - : undefined, - } - } else { - return { - id: message.id, - role: 'assistant', - content: message.content || '', - metadata: message.metadata ? JSON.parse(message.metadata) : undefined, - } - } - } + // convert SelectMessage to ChatMessage + private deserializeMessage(message: SelectMessage): ChatMessage { + if (message.role === 'user') { + return { + id: message.id, + role: 'user', + content: message.content ? JSON.parse(message.content) : null, + promptContent: message.prompt_content + ? message.prompt_content.startsWith('{') + ? JSON.parse(message.prompt_content) + : message.prompt_content + : null, + mentionables: message.mentionables + ? JSON.parse(message.mentionables) + .map(m => deserializeMentionable(m, this.app)) + .filter((m: Mentionable | null): m is Mentionable => m !== null) + : [], + similaritySearchResults: message.similarity_search_results + ? JSON.parse(message.similarity_search_results) + : undefined, + } + } else { + return { + id: message.id, + role: 'assistant', + content: message.content || '', + reasoningContent: message.reasoning_content || '', + metadata: message.metadata ? JSON.parse(message.metadata) : undefined, + } + } + } } diff --git a/src/database/modules/conversation/conversation-repository.ts b/src/database/modules/conversation/conversation-repository.ts index cbd0c0e..4106ec8 100644 --- a/src/database/modules/conversation/conversation-repository.ts +++ b/src/database/modules/conversation/conversation-repository.ts @@ -8,10 +8,6 @@ import { SelectMessage, } from '../../schema' -type QueryResult = { - rows: T[] -} - export class ConversationRepository { private app: App private db: PGliteInterface @@ -32,31 +28,32 @@ export class ConversationRepository { conversation.createdAt || new Date(), conversation.updatedAt || new Date() ] - ) as QueryResult + ) return result.rows[0] } async createMessage(message: InsertMessage): Promise { - const result = await this.db.query( + const result = await this.db.query( `INSERT INTO messages ( - id, conversation_id, role, content, + id, conversation_id, role, content, reasoning_content, prompt_content, metadata, mentionables, similarity_search_results, created_at ) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING *`, [ message.id, message.conversationId, message.role, message.content, + message.reasoningContent, message.promptContent, message.metadata, message.mentionables, message.similaritySearchResults, message.createdAt || new Date() ] - ) as QueryResult + ) return result.rows[0] } @@ -64,30 +61,30 @@ export class ConversationRepository { const result = await this.db.query( `SELECT * FROM conversations WHERE id = $1 LIMIT 1`, [id] - ) as QueryResult + ) return result.rows[0] } - async findMessagesByConversationId(conversationId: string): Promise { + async findMessagesByConversationId(conversationId: string): Promise { const result = await this.db.query( `SELECT * FROM messages WHERE conversation_id = $1 ORDER BY created_at`, [conversationId] - ) as QueryResult + ) return result.rows } async findAll(): Promise { const result = await this.db.query( - `SELECT * FROM conversations ORDER BY updated_at DESC` - ) as QueryResult + `SELECT * FROM conversations ORDER BY created_at DESC` + ) return result.rows } async update(id: string, data: Partial): Promise { const setClauses: string[] = [] - const values: any[] = [] + const values: (string | Date)[] = [] let paramIndex = 1 if (data.title !== undefined) { @@ -110,7 +107,7 @@ export class ConversationRepository { WHERE id = $${paramIndex} RETURNING *`, values - ) as QueryResult + ) return result.rows[0] } @@ -118,7 +115,7 @@ export class ConversationRepository { const result = await this.db.query( `DELETE FROM conversations WHERE id = $1 RETURNING *`, [id] - ) as QueryResult + ) return result.rows.length > 0 } diff --git a/src/database/schema.ts b/src/database/schema.ts index f3ed40b..167dd92 100644 --- a/src/database/schema.ts +++ b/src/database/schema.ts @@ -125,6 +125,7 @@ export type Message = { conversationId: string // uuid role: 'user' | 'assistant' content: string | null + reasoningContent?: string | null promptContent?: string | null metadata?: string | null mentionables?: string | null @@ -139,13 +140,19 @@ export type InsertConversation = { updatedAt?: Date } -export type SelectConversation = Conversation +export type SelectConversation = { + id: string // uuid + title: string + created_at: Date + updated_at: Date +} export type InsertMessage = { id: string conversationId: string role: 'user' | 'assistant' content: string | null + reasoningContent?: string | null promptContent?: string | null metadata?: string | null mentionables?: string | null @@ -153,4 +160,15 @@ export type InsertMessage = { createdAt?: Date } -export type SelectMessage = Message +export type SelectMessage = { + id: string // uuid + conversation_id: string // uuid + role: 'user' | 'assistant' + content: string | null + reasoning_content?: string | null + prompt_content?: string | null + metadata?: string | null + mentionables?: string | null + similarity_search_results?: string | null + created_at: Date +} diff --git a/src/database/sql.ts b/src/database/sql.ts index 28636a7..0320ed0 100644 --- a/src/database/sql.ts +++ b/src/database/sql.ts @@ -102,11 +102,23 @@ export const migrations: Record = { "updated_at" timestamp DEFAULT now() NOT NULL ); + DO $$ + BEGIN + IF NOT EXISTS ( + SELECT 1 FROM information_schema.columns + WHERE table_name = 'messages' + AND column_name = 'reasoning_content' + ) THEN + ALTER TABLE "messages" ADD COLUMN "reasoning_content" text; + END IF; + END $$; + CREATE TABLE IF NOT EXISTS "messages" ( "id" uuid PRIMARY KEY NOT NULL, "conversation_id" uuid NOT NULL REFERENCES "conversations"("id") ON DELETE CASCADE, "role" text NOT NULL, "content" text, + "reasoning_content" text, "prompt_content" text, "metadata" text, "mentionables" text, diff --git a/src/hooks/use-chat-history.ts b/src/hooks/use-chat-history.ts index 8b53479..33fd2bf 100644 --- a/src/hooks/use-chat-history.ts +++ b/src/hooks/use-chat-history.ts @@ -1,6 +1,5 @@ import { useCallback, useEffect, useState } from 'react' -import { useApp } from '../contexts/AppContext' import { useDatabase } from '../contexts/DatabaseContext' import { DBManager } from '../database/database-manager' import { ChatConversationMeta, ChatMessage } from '../types/chat' @@ -17,7 +16,6 @@ type UseChatHistory = { } export function useChatHistory(): UseChatHistory { - const app = useApp() const { getDatabaseManager } = useDatabase() // 这里更新有点繁琐, 但是能保持 chatList 实时更新 @@ -29,7 +27,9 @@ export function useChatHistory(): UseChatHistory { const fetchChatList = useCallback(async () => { const dbManager = await getManager() - dbManager.getConversationManager().getAllConversations(setChatList) + dbManager.getConversationManager().getAllConversations((conversations) => { + setChatList(conversations) + }) }, [getManager]) useEffect(() => { diff --git a/src/types/chat.ts b/src/types/chat.ts index d8be2eb..577ed69 100644 --- a/src/types/chat.ts +++ b/src/types/chat.ts @@ -21,6 +21,7 @@ export type ChatUserMessage = { export type ChatAssistantMessage = { role: 'assistant' content: string + reasoningContent: string id: string metadata?: { usage?: ResponseUsage @@ -44,6 +45,7 @@ export type SerializedChatUserMessage = { export type SerializedChatAssistantMessage = { role: 'assistant' content: string + reasoningContent: string id: string metadata?: { usage?: ResponseUsage diff --git a/src/types/llm/response.ts b/src/types/llm/response.ts index 8d3486e..f185d98 100644 --- a/src/types/llm/response.ts +++ b/src/types/llm/response.ts @@ -39,7 +39,8 @@ type NonStreamingChoice = { type StreamingChoice = { finish_reason: string | null delta: { - content: string | null + content: string | null + reasoning_content: string | null role?: string } error?: Error diff --git a/styles.css b/styles.css index e15ff47..3a4611f 100644 --- a/styles.css +++ b/styles.css @@ -622,6 +622,19 @@ input[type='text'].infio-chat-list-dropdown-item-title-input { border-radius: var(--radius-s); } +.infio-chat-code-block.infio-reasoning-block { + max-height: 222px; + overflow: hidden; + margin-top: 22px; + margin-bottom: 22px; +} + +.infio-reasoning-content-wrapper { + height: calc(100% - 28px); + overflow-y: auto; + scroll-behavior: smooth; +} + .infio-chat-code-block code { padding: 0; } @@ -1777,3 +1790,11 @@ button.infio-chat-input-model-select { position: absolute; display: block; } + +.infio-chat-empty-state { + display: flex; + justify-content: center; + align-items: center; + height: 100%; + width: 100%; +}