merge web_search and fetch_urls_contents

This commit is contained in:
duanfuxiang 2025-03-20 08:56:18 +08:00
parent 679d7142eb
commit 76ecca0da9
3 changed files with 124 additions and 35 deletions

View File

@ -535,7 +535,7 @@ const Chat = forwardRef<ChatRef, ChatProps>((props, ref) => {
} }
} }
} else if (toolArgs.type === 'search_web') { } else if (toolArgs.type === 'search_web') {
const results = await webSearch(toolArgs.query, settings.serperApiKey) const results = await webSearch(toolArgs.query, settings.serperApiKey, settings.jinaApiKey, (await getRAGEngine()))
const formattedContent = `[search_web for '${toolArgs.query}'] Result:\n${results}\n`; const formattedContent = `[search_web for '${toolArgs.query}'] Result:\n${results}\n`;
return { return {
type: 'search_web', type: 'search_web',

View File

@ -97,7 +97,7 @@ export class RAGEngine {
if (!this.initialized) { if (!this.initialized) {
await this.updateVaultIndex({ reindexAll: false }, onQueryProgressChange) await this.updateVaultIndex({ reindexAll: false }, onQueryProgressChange)
} }
const queryEmbedding = await this.getQueryEmbedding(query) const queryEmbedding = await this.getEmbedding(query)
onQueryProgressChange?.({ onQueryProgressChange?.({
type: 'querying', type: 'querying',
}) })
@ -123,7 +123,7 @@ export class RAGEngine {
return queryResult return queryResult
} }
private async getQueryEmbedding(query: string): Promise<number[]> { async getEmbedding(query: string): Promise<number[]> {
if (!this.embeddingModel) { if (!this.embeddingModel) {
throw new Error('Embedding model is not set') throw new Error('Embedding model is not set')
} }

View File

@ -3,20 +3,33 @@ import https from 'https';
import { htmlToMarkdown, requestUrl } from 'obsidian'; import { htmlToMarkdown, requestUrl } from 'obsidian';
import { JINA_BASE_URL, SERPER_BASE_URL } from '../constants'; import { JINA_BASE_URL, SERPER_BASE_URL } from '../constants';
import { RAGEngine } from '../core/rag/rag-engine';
import { YoutubeTranscript, isYoutubeUrl } from './youtube-transcript'; import { YoutubeTranscript, isYoutubeUrl } from './youtube-transcript';
interface SearchResult { interface SearchResult {
title: string; title: string;
link: string; link: string;
snippet: string; snippet: string;
snippet_embedding: number[];
content?: string;
} }
interface SearchResponse { interface SearchResponse {
organic_results?: SearchResult[]; organic_results?: SearchResult[];
} }
export async function webSearch(query: string, serperApiKey: string): Promise<string> { // 添加余弦相似度计算函数
function cosineSimilarity(vecA: number[], vecB: number[]): number {
const dotProduct = vecA.reduce((sum, a, i) => sum + a * vecB[i], 0);
const magnitudeA = Math.sqrt(vecA.reduce((sum, a) => sum + a * a, 0));
const magnitudeB = Math.sqrt(vecB.reduce((sum, b) => sum + b * b, 0));
return dotProduct / (magnitudeA * magnitudeB);
}
async function serperSearch(query: string, serperApiKey: string): Promise<SearchResult[]> {
return new Promise((resolve, reject) => { return new Promise((resolve, reject) => {
const url = `${SERPER_BASE_URL}?q=${encodeURIComponent(query)}&engine=google&api_key=${serperApiKey}&num=20`; const url = `${SERPER_BASE_URL}?q=${encodeURIComponent(query)}&engine=google&api_key=${serperApiKey}&num=20`;
@ -38,15 +51,17 @@ export async function webSearch(query: string, serperApiKey: string): Promise<st
const results = parsedData?.organic_results; const results = parsedData?.organic_results;
if (!results) { if (!results) {
resolve(''); resolve([]);
return; return;
} }
const formattedResults = results.map((item: SearchResult) => { resolve(results);
return `title: ${item.title}\nurl: ${item.link}\nsnippet: ${item.snippet}\n`;
}).join('\n\n');
resolve(formattedResults); // const formattedResults = results.map((item: SearchResult) => {
// return `title: ${item.title}\nurl: ${item.link}\nsnippet: ${item.snippet}\n`;
// }).join('\n\n');
// resolve(formattedResults);
} catch (error) { } catch (error) {
reject(error); reject(error);
} }
@ -57,7 +72,40 @@ export async function webSearch(query: string, serperApiKey: string): Promise<st
}); });
} }
async function getWebsiteContent(url: string): Promise<string> { async function filterByEmbedding(query: string, results: SearchResult[], ragEngine: RAGEngine): Promise<SearchResult[]> {
// 如果没有结果,直接返回空数组
if (results.length === 0) {
return [];
}
// 获取查询的嵌入向量
const queryEmbedding = await ragEngine.getEmbedding(query);
// 并行处理所有结果的嵌入向量计算
const processedResults = await Promise.all(
results.map(async (result) => {
const resultEmbedding = await ragEngine.getEmbedding(result.snippet);
const similarity = cosineSimilarity(queryEmbedding, resultEmbedding);
return {
...result,
similarity,
snippet_embedding: resultEmbedding
};
})
);
// 根据相似度过滤和排序结果
const filteredResults = processedResults
.filter(result => result.similarity > 0.5)
.sort((a, b) => b.similarity - a.similarity)
.slice(0, 5);
return filteredResults;
}
async function fetchByLocalTool(url: string): Promise<string> {
if (isYoutubeUrl(url)) { if (isYoutubeUrl(url)) {
// TODO: pass language based on user preferences // TODO: pass language based on user preferences
const { title, transcript } = const { title, transcript } =
@ -73,30 +121,7 @@ ${transcript.map((t) => `${t.offset}: ${t.text}`).join('\n')}`
return htmlToMarkdown(response.text) return htmlToMarkdown(response.text)
} }
async function fetchByJina(url: string, apiKey: string): Promise<string> {
export async function fetchUrlsContent(urls: string[], apiKey: string): Promise<string> {
const use_jina = apiKey && apiKey != '' ? true : false
return new Promise((resolve) => {
const results = urls.map(async (url) => {
try {
const content = use_jina ? await fetchJina(url, apiKey) : await getWebsiteContent(url);
return `<url_content url="${url}">\n${content}\n</url_content>`;
} catch (error) {
console.error(`Failed to fetch URL content: ${url}`, error);
return `<url_content url="${url}">\n fetch content error: ${error}\n</url_content>`;
}
});
Promise.all(results).then((texts) => {
resolve(texts.join('\n\n'));
}).catch((error) => {
console.error('fetch urls content error', error);
resolve('fetch urls content error'); // even if error, return some content
});
});
}
function fetchJina(url: string, apiKey: string): Promise<string> {
return new Promise((resolve) => { return new Promise((resolve) => {
const jinaUrl = `${JINA_BASE_URL}/${url}`; const jinaUrl = `${JINA_BASE_URL}/${url}`;
@ -141,4 +166,68 @@ function fetchJina(url: string, apiKey: string): Promise<string> {
req.end(); req.end();
}); });
} }
export async function fetchUrlContent(url: string, apiKey: string): Promise<string | null> {
try {
if (isYoutubeUrl(url)) {
return await fetchByLocalTool(url);
}
let content: string | null = null;
const validJinaKey = apiKey && apiKey !== '';
if (validJinaKey) {
try {
content = await fetchByJina(url, apiKey);
} catch (error) {
console.error(`Failed to fetch URL by jina: ${url}`, error);
content = await fetchByLocalTool(url);
}
} else {
content = await fetchByLocalTool(url);
}
return content.replaceAll(/\n{2,}/g, '\n');
} catch (error) {
console.error(`Failed to fetch URL content: ${url}`, error);
return null;
}
}
export async function webSearch(query: string, serperApiKey: string, jinaApiKey: string, ragEngine: RAGEngine): Promise<string> {
try {
const results = await serperSearch(query, serperApiKey);
const filteredResults = await filterByEmbedding(query, results, ragEngine);
const filteredResultsWithContent = await Promise.all(filteredResults.map(async (result) => {
let content = await fetchUrlContent(result.link, jinaApiKey);
if (content.length === 0) {
content = result.snippet;
}
return `<url_content url="${result.link}">\n${content}\n</url_content>`;
}));
return filteredResultsWithContent.join('\n\n');
} catch (error) {
console.error(`Failed to web search: ${query}`, error);
return "web search error";
}
}
// todo: update
export async function fetchUrlsContent(urls: string[], apiKey: string): Promise<string> {
return new Promise((resolve) => {
const results = urls.map(async (url) => {
try {
const content = await fetchUrlContent(url, apiKey);
return `<url_content url="${url}">\n${content}\n</url_content>`;
} catch (error) {
console.error(`Failed to fetch URL content: ${url}`, error);
return `<url_content url="${url}">\n fetch content error: ${error}\n</url_content>`;
}
});
Promise.all(results).then((texts) => {
resolve(texts.join('\n\n'));
}).catch((error) => {
console.error('fetch urls content error', error);
resolve('fetch urls content error'); // even if error, return some content
});
});
}