diff --git a/src/components/chat-view/Chat.tsx b/src/components/chat-view/Chat.tsx index d4e5233..4d62074 100644 --- a/src/components/chat-view/Chat.tsx +++ b/src/components/chat-view/Chat.tsx @@ -535,7 +535,7 @@ const Chat = forwardRef((props, ref) => { } } } else if (toolArgs.type === 'search_web') { - const results = await webSearch(toolArgs.query, settings.serperApiKey) + const results = await webSearch(toolArgs.query, settings.serperApiKey, settings.jinaApiKey, (await getRAGEngine())) const formattedContent = `[search_web for '${toolArgs.query}'] Result:\n${results}\n`; return { type: 'search_web', diff --git a/src/core/rag/rag-engine.ts b/src/core/rag/rag-engine.ts index 1e0d48a..5cee91d 100644 --- a/src/core/rag/rag-engine.ts +++ b/src/core/rag/rag-engine.ts @@ -97,7 +97,7 @@ export class RAGEngine { if (!this.initialized) { await this.updateVaultIndex({ reindexAll: false }, onQueryProgressChange) } - const queryEmbedding = await this.getQueryEmbedding(query) + const queryEmbedding = await this.getEmbedding(query) onQueryProgressChange?.({ type: 'querying', }) @@ -123,7 +123,7 @@ export class RAGEngine { return queryResult } - private async getQueryEmbedding(query: string): Promise { + async getEmbedding(query: string): Promise { if (!this.embeddingModel) { throw new Error('Embedding model is not set') } diff --git a/src/utils/web-search.ts b/src/utils/web-search.ts index 6928831..cfe0146 100644 --- a/src/utils/web-search.ts +++ b/src/utils/web-search.ts @@ -3,20 +3,33 @@ import https from 'https'; import { htmlToMarkdown, requestUrl } from 'obsidian'; import { JINA_BASE_URL, SERPER_BASE_URL } from '../constants'; +import { RAGEngine } from '../core/rag/rag-engine'; import { YoutubeTranscript, isYoutubeUrl } from './youtube-transcript'; + interface SearchResult { title: string; link: string; snippet: string; + snippet_embedding: number[]; + content?: string; } interface SearchResponse { organic_results?: SearchResult[]; } -export async function webSearch(query: string, serperApiKey: string): Promise { +// 添加余弦相似度计算函数 +function cosineSimilarity(vecA: number[], vecB: number[]): number { + const dotProduct = vecA.reduce((sum, a, i) => sum + a * vecB[i], 0); + const magnitudeA = Math.sqrt(vecA.reduce((sum, a) => sum + a * a, 0)); + const magnitudeB = Math.sqrt(vecB.reduce((sum, b) => sum + b * b, 0)); + + return dotProduct / (magnitudeA * magnitudeB); +} + +async function serperSearch(query: string, serperApiKey: string): Promise { return new Promise((resolve, reject) => { const url = `${SERPER_BASE_URL}?q=${encodeURIComponent(query)}&engine=google&api_key=${serperApiKey}&num=20`; @@ -38,15 +51,17 @@ export async function webSearch(query: string, serperApiKey: string): Promise { - return `title: ${item.title}\nurl: ${item.link}\nsnippet: ${item.snippet}\n`; - }).join('\n\n'); + resolve(results); - resolve(formattedResults); + // const formattedResults = results.map((item: SearchResult) => { + // return `title: ${item.title}\nurl: ${item.link}\nsnippet: ${item.snippet}\n`; + // }).join('\n\n'); + + // resolve(formattedResults); } catch (error) { reject(error); } @@ -57,7 +72,40 @@ export async function webSearch(query: string, serperApiKey: string): Promise { +async function filterByEmbedding(query: string, results: SearchResult[], ragEngine: RAGEngine): Promise { + + // 如果没有结果,直接返回空数组 + if (results.length === 0) { + return []; + } + + // 获取查询的嵌入向量 + const queryEmbedding = await ragEngine.getEmbedding(query); + + // 并行处理所有结果的嵌入向量计算 + const processedResults = await Promise.all( + results.map(async (result) => { + const resultEmbedding = await ragEngine.getEmbedding(result.snippet); + const similarity = cosineSimilarity(queryEmbedding, resultEmbedding); + + return { + ...result, + similarity, + snippet_embedding: resultEmbedding + }; + }) + ); + + // 根据相似度过滤和排序结果 + const filteredResults = processedResults + .filter(result => result.similarity > 0.5) + .sort((a, b) => b.similarity - a.similarity) + .slice(0, 5); + + return filteredResults; +} + +async function fetchByLocalTool(url: string): Promise { if (isYoutubeUrl(url)) { // TODO: pass language based on user preferences const { title, transcript } = @@ -73,30 +121,7 @@ ${transcript.map((t) => `${t.offset}: ${t.text}`).join('\n')}` return htmlToMarkdown(response.text) } - -export async function fetchUrlsContent(urls: string[], apiKey: string): Promise { - const use_jina = apiKey && apiKey != '' ? true : false - return new Promise((resolve) => { - const results = urls.map(async (url) => { - try { - const content = use_jina ? await fetchJina(url, apiKey) : await getWebsiteContent(url); - return `\n${content}\n`; - } catch (error) { - console.error(`Failed to fetch URL content: ${url}`, error); - return `\n fetch content error: ${error}\n`; - } - }); - - Promise.all(results).then((texts) => { - resolve(texts.join('\n\n')); - }).catch((error) => { - console.error('fetch urls content error', error); - resolve('fetch urls content error'); // even if error, return some content - }); - }); -} - -function fetchJina(url: string, apiKey: string): Promise { +async function fetchByJina(url: string, apiKey: string): Promise { return new Promise((resolve) => { const jinaUrl = `${JINA_BASE_URL}/${url}`; @@ -141,4 +166,68 @@ function fetchJina(url: string, apiKey: string): Promise { req.end(); }); -} \ No newline at end of file +} + +export async function fetchUrlContent(url: string, apiKey: string): Promise { + try { + if (isYoutubeUrl(url)) { + return await fetchByLocalTool(url); + } + let content: string | null = null; + const validJinaKey = apiKey && apiKey !== ''; + if (validJinaKey) { + try { + content = await fetchByJina(url, apiKey); + } catch (error) { + console.error(`Failed to fetch URL by jina: ${url}`, error); + content = await fetchByLocalTool(url); + } + } else { + content = await fetchByLocalTool(url); + } + return content.replaceAll(/\n{2,}/g, '\n'); + } catch (error) { + console.error(`Failed to fetch URL content: ${url}`, error); + return null; + } +} + +export async function webSearch(query: string, serperApiKey: string, jinaApiKey: string, ragEngine: RAGEngine): Promise { + try { + const results = await serperSearch(query, serperApiKey); + const filteredResults = await filterByEmbedding(query, results, ragEngine); + const filteredResultsWithContent = await Promise.all(filteredResults.map(async (result) => { + let content = await fetchUrlContent(result.link, jinaApiKey); + if (content.length === 0) { + content = result.snippet; + } + return `\n${content}\n`; + })); + return filteredResultsWithContent.join('\n\n'); + } catch (error) { + console.error(`Failed to web search: ${query}`, error); + return "web search error"; + } +} + +// todo: update +export async function fetchUrlsContent(urls: string[], apiKey: string): Promise { + return new Promise((resolve) => { + const results = urls.map(async (url) => { + try { + const content = await fetchUrlContent(url, apiKey); + return `\n${content}\n`; + } catch (error) { + console.error(`Failed to fetch URL content: ${url}`, error); + return `\n fetch content error: ${error}\n`; + } + }); + + Promise.all(results).then((texts) => { + resolve(texts.join('\n\n')); + }).catch((error) => { + console.error('fetch urls content error', error); + resolve('fetch urls content error'); // even if error, return some content + }); + }); +}