merge web_search and fetch_urls_contents
This commit is contained in:
parent
679d7142eb
commit
76ecca0da9
@ -535,7 +535,7 @@ const Chat = forwardRef<ChatRef, ChatProps>((props, ref) => {
|
||||
}
|
||||
}
|
||||
} else if (toolArgs.type === 'search_web') {
|
||||
const results = await webSearch(toolArgs.query, settings.serperApiKey)
|
||||
const results = await webSearch(toolArgs.query, settings.serperApiKey, settings.jinaApiKey, (await getRAGEngine()))
|
||||
const formattedContent = `[search_web for '${toolArgs.query}'] Result:\n${results}\n`;
|
||||
return {
|
||||
type: 'search_web',
|
||||
|
||||
@ -97,7 +97,7 @@ export class RAGEngine {
|
||||
if (!this.initialized) {
|
||||
await this.updateVaultIndex({ reindexAll: false }, onQueryProgressChange)
|
||||
}
|
||||
const queryEmbedding = await this.getQueryEmbedding(query)
|
||||
const queryEmbedding = await this.getEmbedding(query)
|
||||
onQueryProgressChange?.({
|
||||
type: 'querying',
|
||||
})
|
||||
@ -123,7 +123,7 @@ export class RAGEngine {
|
||||
return queryResult
|
||||
}
|
||||
|
||||
private async getQueryEmbedding(query: string): Promise<number[]> {
|
||||
async getEmbedding(query: string): Promise<number[]> {
|
||||
if (!this.embeddingModel) {
|
||||
throw new Error('Embedding model is not set')
|
||||
}
|
||||
|
||||
@ -3,20 +3,33 @@ import https from 'https';
|
||||
import { htmlToMarkdown, requestUrl } from 'obsidian';
|
||||
|
||||
import { JINA_BASE_URL, SERPER_BASE_URL } from '../constants';
|
||||
import { RAGEngine } from '../core/rag/rag-engine';
|
||||
|
||||
import { YoutubeTranscript, isYoutubeUrl } from './youtube-transcript';
|
||||
|
||||
|
||||
interface SearchResult {
|
||||
title: string;
|
||||
link: string;
|
||||
snippet: string;
|
||||
snippet_embedding: number[];
|
||||
content?: string;
|
||||
}
|
||||
|
||||
interface SearchResponse {
|
||||
organic_results?: SearchResult[];
|
||||
}
|
||||
|
||||
export async function webSearch(query: string, serperApiKey: string): Promise<string> {
|
||||
// 添加余弦相似度计算函数
|
||||
function cosineSimilarity(vecA: number[], vecB: number[]): number {
|
||||
const dotProduct = vecA.reduce((sum, a, i) => sum + a * vecB[i], 0);
|
||||
const magnitudeA = Math.sqrt(vecA.reduce((sum, a) => sum + a * a, 0));
|
||||
const magnitudeB = Math.sqrt(vecB.reduce((sum, b) => sum + b * b, 0));
|
||||
|
||||
return dotProduct / (magnitudeA * magnitudeB);
|
||||
}
|
||||
|
||||
async function serperSearch(query: string, serperApiKey: string): Promise<SearchResult[]> {
|
||||
return new Promise((resolve, reject) => {
|
||||
const url = `${SERPER_BASE_URL}?q=${encodeURIComponent(query)}&engine=google&api_key=${serperApiKey}&num=20`;
|
||||
|
||||
@ -38,15 +51,17 @@ export async function webSearch(query: string, serperApiKey: string): Promise<st
|
||||
const results = parsedData?.organic_results;
|
||||
|
||||
if (!results) {
|
||||
resolve('');
|
||||
resolve([]);
|
||||
return;
|
||||
}
|
||||
|
||||
const formattedResults = results.map((item: SearchResult) => {
|
||||
return `title: ${item.title}\nurl: ${item.link}\nsnippet: ${item.snippet}\n`;
|
||||
}).join('\n\n');
|
||||
resolve(results);
|
||||
|
||||
resolve(formattedResults);
|
||||
// const formattedResults = results.map((item: SearchResult) => {
|
||||
// return `title: ${item.title}\nurl: ${item.link}\nsnippet: ${item.snippet}\n`;
|
||||
// }).join('\n\n');
|
||||
|
||||
// resolve(formattedResults);
|
||||
} catch (error) {
|
||||
reject(error);
|
||||
}
|
||||
@ -57,7 +72,40 @@ export async function webSearch(query: string, serperApiKey: string): Promise<st
|
||||
});
|
||||
}
|
||||
|
||||
async function getWebsiteContent(url: string): Promise<string> {
|
||||
async function filterByEmbedding(query: string, results: SearchResult[], ragEngine: RAGEngine): Promise<SearchResult[]> {
|
||||
|
||||
// 如果没有结果,直接返回空数组
|
||||
if (results.length === 0) {
|
||||
return [];
|
||||
}
|
||||
|
||||
// 获取查询的嵌入向量
|
||||
const queryEmbedding = await ragEngine.getEmbedding(query);
|
||||
|
||||
// 并行处理所有结果的嵌入向量计算
|
||||
const processedResults = await Promise.all(
|
||||
results.map(async (result) => {
|
||||
const resultEmbedding = await ragEngine.getEmbedding(result.snippet);
|
||||
const similarity = cosineSimilarity(queryEmbedding, resultEmbedding);
|
||||
|
||||
return {
|
||||
...result,
|
||||
similarity,
|
||||
snippet_embedding: resultEmbedding
|
||||
};
|
||||
})
|
||||
);
|
||||
|
||||
// 根据相似度过滤和排序结果
|
||||
const filteredResults = processedResults
|
||||
.filter(result => result.similarity > 0.5)
|
||||
.sort((a, b) => b.similarity - a.similarity)
|
||||
.slice(0, 5);
|
||||
|
||||
return filteredResults;
|
||||
}
|
||||
|
||||
async function fetchByLocalTool(url: string): Promise<string> {
|
||||
if (isYoutubeUrl(url)) {
|
||||
// TODO: pass language based on user preferences
|
||||
const { title, transcript } =
|
||||
@ -73,30 +121,7 @@ ${transcript.map((t) => `${t.offset}: ${t.text}`).join('\n')}`
|
||||
return htmlToMarkdown(response.text)
|
||||
}
|
||||
|
||||
|
||||
export async function fetchUrlsContent(urls: string[], apiKey: string): Promise<string> {
|
||||
const use_jina = apiKey && apiKey != '' ? true : false
|
||||
return new Promise((resolve) => {
|
||||
const results = urls.map(async (url) => {
|
||||
try {
|
||||
const content = use_jina ? await fetchJina(url, apiKey) : await getWebsiteContent(url);
|
||||
return `<url_content url="${url}">\n${content}\n</url_content>`;
|
||||
} catch (error) {
|
||||
console.error(`Failed to fetch URL content: ${url}`, error);
|
||||
return `<url_content url="${url}">\n fetch content error: ${error}\n</url_content>`;
|
||||
}
|
||||
});
|
||||
|
||||
Promise.all(results).then((texts) => {
|
||||
resolve(texts.join('\n\n'));
|
||||
}).catch((error) => {
|
||||
console.error('fetch urls content error', error);
|
||||
resolve('fetch urls content error'); // even if error, return some content
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
function fetchJina(url: string, apiKey: string): Promise<string> {
|
||||
async function fetchByJina(url: string, apiKey: string): Promise<string> {
|
||||
return new Promise((resolve) => {
|
||||
const jinaUrl = `${JINA_BASE_URL}/${url}`;
|
||||
|
||||
@ -141,4 +166,68 @@ function fetchJina(url: string, apiKey: string): Promise<string> {
|
||||
|
||||
req.end();
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
export async function fetchUrlContent(url: string, apiKey: string): Promise<string | null> {
|
||||
try {
|
||||
if (isYoutubeUrl(url)) {
|
||||
return await fetchByLocalTool(url);
|
||||
}
|
||||
let content: string | null = null;
|
||||
const validJinaKey = apiKey && apiKey !== '';
|
||||
if (validJinaKey) {
|
||||
try {
|
||||
content = await fetchByJina(url, apiKey);
|
||||
} catch (error) {
|
||||
console.error(`Failed to fetch URL by jina: ${url}`, error);
|
||||
content = await fetchByLocalTool(url);
|
||||
}
|
||||
} else {
|
||||
content = await fetchByLocalTool(url);
|
||||
}
|
||||
return content.replaceAll(/\n{2,}/g, '\n');
|
||||
} catch (error) {
|
||||
console.error(`Failed to fetch URL content: ${url}`, error);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
export async function webSearch(query: string, serperApiKey: string, jinaApiKey: string, ragEngine: RAGEngine): Promise<string> {
|
||||
try {
|
||||
const results = await serperSearch(query, serperApiKey);
|
||||
const filteredResults = await filterByEmbedding(query, results, ragEngine);
|
||||
const filteredResultsWithContent = await Promise.all(filteredResults.map(async (result) => {
|
||||
let content = await fetchUrlContent(result.link, jinaApiKey);
|
||||
if (content.length === 0) {
|
||||
content = result.snippet;
|
||||
}
|
||||
return `<url_content url="${result.link}">\n${content}\n</url_content>`;
|
||||
}));
|
||||
return filteredResultsWithContent.join('\n\n');
|
||||
} catch (error) {
|
||||
console.error(`Failed to web search: ${query}`, error);
|
||||
return "web search error";
|
||||
}
|
||||
}
|
||||
|
||||
// todo: update
|
||||
export async function fetchUrlsContent(urls: string[], apiKey: string): Promise<string> {
|
||||
return new Promise((resolve) => {
|
||||
const results = urls.map(async (url) => {
|
||||
try {
|
||||
const content = await fetchUrlContent(url, apiKey);
|
||||
return `<url_content url="${url}">\n${content}\n</url_content>`;
|
||||
} catch (error) {
|
||||
console.error(`Failed to fetch URL content: ${url}`, error);
|
||||
return `<url_content url="${url}">\n fetch content error: ${error}\n</url_content>`;
|
||||
}
|
||||
});
|
||||
|
||||
Promise.all(results).then((texts) => {
|
||||
resolve(texts.join('\n\n'));
|
||||
}).catch((error) => {
|
||||
console.error('fetch urls content error', error);
|
||||
resolve('fetch urls content error'); // even if error, return some content
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user