merge web_search and fetch_urls_contents
This commit is contained in:
parent
679d7142eb
commit
76ecca0da9
@ -535,7 +535,7 @@ const Chat = forwardRef<ChatRef, ChatProps>((props, ref) => {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if (toolArgs.type === 'search_web') {
|
} else if (toolArgs.type === 'search_web') {
|
||||||
const results = await webSearch(toolArgs.query, settings.serperApiKey)
|
const results = await webSearch(toolArgs.query, settings.serperApiKey, settings.jinaApiKey, (await getRAGEngine()))
|
||||||
const formattedContent = `[search_web for '${toolArgs.query}'] Result:\n${results}\n`;
|
const formattedContent = `[search_web for '${toolArgs.query}'] Result:\n${results}\n`;
|
||||||
return {
|
return {
|
||||||
type: 'search_web',
|
type: 'search_web',
|
||||||
|
|||||||
@ -97,7 +97,7 @@ export class RAGEngine {
|
|||||||
if (!this.initialized) {
|
if (!this.initialized) {
|
||||||
await this.updateVaultIndex({ reindexAll: false }, onQueryProgressChange)
|
await this.updateVaultIndex({ reindexAll: false }, onQueryProgressChange)
|
||||||
}
|
}
|
||||||
const queryEmbedding = await this.getQueryEmbedding(query)
|
const queryEmbedding = await this.getEmbedding(query)
|
||||||
onQueryProgressChange?.({
|
onQueryProgressChange?.({
|
||||||
type: 'querying',
|
type: 'querying',
|
||||||
})
|
})
|
||||||
@ -123,7 +123,7 @@ export class RAGEngine {
|
|||||||
return queryResult
|
return queryResult
|
||||||
}
|
}
|
||||||
|
|
||||||
private async getQueryEmbedding(query: string): Promise<number[]> {
|
async getEmbedding(query: string): Promise<number[]> {
|
||||||
if (!this.embeddingModel) {
|
if (!this.embeddingModel) {
|
||||||
throw new Error('Embedding model is not set')
|
throw new Error('Embedding model is not set')
|
||||||
}
|
}
|
||||||
|
|||||||
@ -3,20 +3,33 @@ import https from 'https';
|
|||||||
import { htmlToMarkdown, requestUrl } from 'obsidian';
|
import { htmlToMarkdown, requestUrl } from 'obsidian';
|
||||||
|
|
||||||
import { JINA_BASE_URL, SERPER_BASE_URL } from '../constants';
|
import { JINA_BASE_URL, SERPER_BASE_URL } from '../constants';
|
||||||
|
import { RAGEngine } from '../core/rag/rag-engine';
|
||||||
|
|
||||||
import { YoutubeTranscript, isYoutubeUrl } from './youtube-transcript';
|
import { YoutubeTranscript, isYoutubeUrl } from './youtube-transcript';
|
||||||
|
|
||||||
|
|
||||||
interface SearchResult {
|
interface SearchResult {
|
||||||
title: string;
|
title: string;
|
||||||
link: string;
|
link: string;
|
||||||
snippet: string;
|
snippet: string;
|
||||||
|
snippet_embedding: number[];
|
||||||
|
content?: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
interface SearchResponse {
|
interface SearchResponse {
|
||||||
organic_results?: SearchResult[];
|
organic_results?: SearchResult[];
|
||||||
}
|
}
|
||||||
|
|
||||||
export async function webSearch(query: string, serperApiKey: string): Promise<string> {
|
// 添加余弦相似度计算函数
|
||||||
|
function cosineSimilarity(vecA: number[], vecB: number[]): number {
|
||||||
|
const dotProduct = vecA.reduce((sum, a, i) => sum + a * vecB[i], 0);
|
||||||
|
const magnitudeA = Math.sqrt(vecA.reduce((sum, a) => sum + a * a, 0));
|
||||||
|
const magnitudeB = Math.sqrt(vecB.reduce((sum, b) => sum + b * b, 0));
|
||||||
|
|
||||||
|
return dotProduct / (magnitudeA * magnitudeB);
|
||||||
|
}
|
||||||
|
|
||||||
|
async function serperSearch(query: string, serperApiKey: string): Promise<SearchResult[]> {
|
||||||
return new Promise((resolve, reject) => {
|
return new Promise((resolve, reject) => {
|
||||||
const url = `${SERPER_BASE_URL}?q=${encodeURIComponent(query)}&engine=google&api_key=${serperApiKey}&num=20`;
|
const url = `${SERPER_BASE_URL}?q=${encodeURIComponent(query)}&engine=google&api_key=${serperApiKey}&num=20`;
|
||||||
|
|
||||||
@ -38,15 +51,17 @@ export async function webSearch(query: string, serperApiKey: string): Promise<st
|
|||||||
const results = parsedData?.organic_results;
|
const results = parsedData?.organic_results;
|
||||||
|
|
||||||
if (!results) {
|
if (!results) {
|
||||||
resolve('');
|
resolve([]);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const formattedResults = results.map((item: SearchResult) => {
|
resolve(results);
|
||||||
return `title: ${item.title}\nurl: ${item.link}\nsnippet: ${item.snippet}\n`;
|
|
||||||
}).join('\n\n');
|
|
||||||
|
|
||||||
resolve(formattedResults);
|
// const formattedResults = results.map((item: SearchResult) => {
|
||||||
|
// return `title: ${item.title}\nurl: ${item.link}\nsnippet: ${item.snippet}\n`;
|
||||||
|
// }).join('\n\n');
|
||||||
|
|
||||||
|
// resolve(formattedResults);
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
reject(error);
|
reject(error);
|
||||||
}
|
}
|
||||||
@ -57,7 +72,40 @@ export async function webSearch(query: string, serperApiKey: string): Promise<st
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
async function getWebsiteContent(url: string): Promise<string> {
|
async function filterByEmbedding(query: string, results: SearchResult[], ragEngine: RAGEngine): Promise<SearchResult[]> {
|
||||||
|
|
||||||
|
// 如果没有结果,直接返回空数组
|
||||||
|
if (results.length === 0) {
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取查询的嵌入向量
|
||||||
|
const queryEmbedding = await ragEngine.getEmbedding(query);
|
||||||
|
|
||||||
|
// 并行处理所有结果的嵌入向量计算
|
||||||
|
const processedResults = await Promise.all(
|
||||||
|
results.map(async (result) => {
|
||||||
|
const resultEmbedding = await ragEngine.getEmbedding(result.snippet);
|
||||||
|
const similarity = cosineSimilarity(queryEmbedding, resultEmbedding);
|
||||||
|
|
||||||
|
return {
|
||||||
|
...result,
|
||||||
|
similarity,
|
||||||
|
snippet_embedding: resultEmbedding
|
||||||
|
};
|
||||||
|
})
|
||||||
|
);
|
||||||
|
|
||||||
|
// 根据相似度过滤和排序结果
|
||||||
|
const filteredResults = processedResults
|
||||||
|
.filter(result => result.similarity > 0.5)
|
||||||
|
.sort((a, b) => b.similarity - a.similarity)
|
||||||
|
.slice(0, 5);
|
||||||
|
|
||||||
|
return filteredResults;
|
||||||
|
}
|
||||||
|
|
||||||
|
async function fetchByLocalTool(url: string): Promise<string> {
|
||||||
if (isYoutubeUrl(url)) {
|
if (isYoutubeUrl(url)) {
|
||||||
// TODO: pass language based on user preferences
|
// TODO: pass language based on user preferences
|
||||||
const { title, transcript } =
|
const { title, transcript } =
|
||||||
@ -73,30 +121,7 @@ ${transcript.map((t) => `${t.offset}: ${t.text}`).join('\n')}`
|
|||||||
return htmlToMarkdown(response.text)
|
return htmlToMarkdown(response.text)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async function fetchByJina(url: string, apiKey: string): Promise<string> {
|
||||||
export async function fetchUrlsContent(urls: string[], apiKey: string): Promise<string> {
|
|
||||||
const use_jina = apiKey && apiKey != '' ? true : false
|
|
||||||
return new Promise((resolve) => {
|
|
||||||
const results = urls.map(async (url) => {
|
|
||||||
try {
|
|
||||||
const content = use_jina ? await fetchJina(url, apiKey) : await getWebsiteContent(url);
|
|
||||||
return `<url_content url="${url}">\n${content}\n</url_content>`;
|
|
||||||
} catch (error) {
|
|
||||||
console.error(`Failed to fetch URL content: ${url}`, error);
|
|
||||||
return `<url_content url="${url}">\n fetch content error: ${error}\n</url_content>`;
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
Promise.all(results).then((texts) => {
|
|
||||||
resolve(texts.join('\n\n'));
|
|
||||||
}).catch((error) => {
|
|
||||||
console.error('fetch urls content error', error);
|
|
||||||
resolve('fetch urls content error'); // even if error, return some content
|
|
||||||
});
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
function fetchJina(url: string, apiKey: string): Promise<string> {
|
|
||||||
return new Promise((resolve) => {
|
return new Promise((resolve) => {
|
||||||
const jinaUrl = `${JINA_BASE_URL}/${url}`;
|
const jinaUrl = `${JINA_BASE_URL}/${url}`;
|
||||||
|
|
||||||
@ -141,4 +166,68 @@ function fetchJina(url: string, apiKey: string): Promise<string> {
|
|||||||
|
|
||||||
req.end();
|
req.end();
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export async function fetchUrlContent(url: string, apiKey: string): Promise<string | null> {
|
||||||
|
try {
|
||||||
|
if (isYoutubeUrl(url)) {
|
||||||
|
return await fetchByLocalTool(url);
|
||||||
|
}
|
||||||
|
let content: string | null = null;
|
||||||
|
const validJinaKey = apiKey && apiKey !== '';
|
||||||
|
if (validJinaKey) {
|
||||||
|
try {
|
||||||
|
content = await fetchByJina(url, apiKey);
|
||||||
|
} catch (error) {
|
||||||
|
console.error(`Failed to fetch URL by jina: ${url}`, error);
|
||||||
|
content = await fetchByLocalTool(url);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
content = await fetchByLocalTool(url);
|
||||||
|
}
|
||||||
|
return content.replaceAll(/\n{2,}/g, '\n');
|
||||||
|
} catch (error) {
|
||||||
|
console.error(`Failed to fetch URL content: ${url}`, error);
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function webSearch(query: string, serperApiKey: string, jinaApiKey: string, ragEngine: RAGEngine): Promise<string> {
|
||||||
|
try {
|
||||||
|
const results = await serperSearch(query, serperApiKey);
|
||||||
|
const filteredResults = await filterByEmbedding(query, results, ragEngine);
|
||||||
|
const filteredResultsWithContent = await Promise.all(filteredResults.map(async (result) => {
|
||||||
|
let content = await fetchUrlContent(result.link, jinaApiKey);
|
||||||
|
if (content.length === 0) {
|
||||||
|
content = result.snippet;
|
||||||
|
}
|
||||||
|
return `<url_content url="${result.link}">\n${content}\n</url_content>`;
|
||||||
|
}));
|
||||||
|
return filteredResultsWithContent.join('\n\n');
|
||||||
|
} catch (error) {
|
||||||
|
console.error(`Failed to web search: ${query}`, error);
|
||||||
|
return "web search error";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// todo: update
|
||||||
|
export async function fetchUrlsContent(urls: string[], apiKey: string): Promise<string> {
|
||||||
|
return new Promise((resolve) => {
|
||||||
|
const results = urls.map(async (url) => {
|
||||||
|
try {
|
||||||
|
const content = await fetchUrlContent(url, apiKey);
|
||||||
|
return `<url_content url="${url}">\n${content}\n</url_content>`;
|
||||||
|
} catch (error) {
|
||||||
|
console.error(`Failed to fetch URL content: ${url}`, error);
|
||||||
|
return `<url_content url="${url}">\n fetch content error: ${error}\n</url_content>`;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
Promise.all(results).then((texts) => {
|
||||||
|
resolve(texts.join('\n\n'));
|
||||||
|
}).catch((error) => {
|
||||||
|
console.error('fetch urls content error', error);
|
||||||
|
resolve('fetch urls content error'); // even if error, return some content
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user