add tool use, update system prompt

This commit is contained in:
duanfuxiang 2025-03-12 21:39:29 +08:00
parent cabf2d5fa4
commit b0fbbb22d3
36 changed files with 7149 additions and 430 deletions

View File

@ -9,6 +9,7 @@ export type ApplyViewState = {
file: TFile
originalContent: string
newContent: string
onClose: (applied: boolean) => void
}
export class ApplyView extends View {

View File

@ -37,10 +37,16 @@ export default function ApplyViewRoot({
.map((change) => change.value)
.join('')
await app.vault.modify(state.file, newContent)
if (state.onClose) {
state.onClose(true)
}
close()
}
const handleReject = async () => {
if (state.onClose) {
state.onClose(false)
}
close()
}

View File

@ -1,127 +0,0 @@
import { $generateNodesFromSerializedNodes } from '@lexical/clipboard'
import { BaseSerializedNode } from '@lexical/clipboard/clipboard'
import { InitialEditorStateType } from '@lexical/react/LexicalComposer'
import * as Dialog from '@radix-ui/react-dialog'
import { $insertNodes, LexicalEditor } from 'lexical'
import { X } from 'lucide-react'
import { Notice } from 'obsidian'
import { useRef, useState } from 'react'
import { useDatabase } from '../../contexts/DatabaseContext'
import { useDialogContainer } from '../../contexts/DialogContext'
import { DuplicateTemplateException } from '../../database/exception'
import LexicalContentEditable from './chat-input/LexicalContentEditable'
/*
* This component must be used inside <Dialog.Root modal={false}>
* The modal={false} prop is required because modal mode blocks pointer events across the entire page,
* which would conflict with lexical editor popovers
*/
export default function CreateTemplateDialogContent({
selectedSerializedNodes,
onClose,
}: {
selectedSerializedNodes?: BaseSerializedNode[] | null
onClose: () => void
}) {
const container = useDialogContainer()
const { getTemplateManager } = useDatabase()
const [templateName, setTemplateName] = useState('')
const editorRef = useRef<LexicalEditor | null>(null)
const contentEditableRef = useRef<HTMLDivElement>(null)
const initialEditorState: InitialEditorStateType = (
editor: LexicalEditor,
) => {
if (!selectedSerializedNodes) return
editor.update(() => {
const parsedNodes = $generateNodesFromSerializedNodes(
selectedSerializedNodes,
)
$insertNodes(parsedNodes)
})
}
const onSubmit = async () => {
try {
if (!editorRef.current) return
const serializedEditorState = editorRef.current.toJSON()
const nodes = serializedEditorState.editorState.root.children
if (nodes.length === 0) {
new Notice('Please enter a content for your template')
return
}
if (templateName.trim().length === 0) {
new Notice('Please enter a name for your template')
return
}
await (
await getTemplateManager()
).createTemplate({
name: templateName,
content: { nodes },
})
new Notice(`Template created: ${templateName}`)
setTemplateName('')
onClose()
} catch (error) {
if (error instanceof DuplicateTemplateException) {
new Notice('A template with this name already exists')
} else {
console.error(error)
new Notice('Failed to create template')
}
}
}
return (
<Dialog.Portal container={container}>
<Dialog.Content className="infio-chat-dialog-content">
<div className="infio-dialog-header">
<Dialog.Title className="infio-dialog-title">
Create template
</Dialog.Title>
<Dialog.Description className="infio-dialog-description">
Create template from selected content
</Dialog.Description>
</div>
<div className="infio-dialog-input">
<label>Name</label>
<input
type="text"
value={templateName}
onChange={(e) => setTemplateName(e.target.value)}
onKeyDown={(e) => {
if (e.key === 'Enter') {
e.stopPropagation()
e.preventDefault()
onSubmit()
}
}}
/>
</div>
<div className="infio-chat-user-input-container">
<LexicalContentEditable
initialEditorState={initialEditorState}
editorRef={editorRef}
contentEditableRef={contentEditableRef}
onEnter={onSubmit}
/>
</div>
<div className="infio-dialog-footer">
<button onClick={onSubmit}>Create template</button>
</div>
<Dialog.Close className="infio-dialog-close" asChild>
<X size={16} />
</Dialog.Close>
</Dialog.Content>
</Dialog.Portal>
)
}

View File

@ -1,11 +1,12 @@
import { MarkdownView, Plugin, Platform } from 'obsidian';
import React, { useEffect, useMemo, useRef, useState } from 'react';
import { CornerDownLeft } from 'lucide-react';
import { MarkdownView, Plugin } from 'obsidian';
import React, { useEffect, useRef, useState } from 'react';
import { APPLY_VIEW_TYPE } from '../../constants';
import LLMManager from '../../core/llm/manager';
import { InfioSettings } from '../../types/settings';
import { GetProviderModelIds } from '../../utils/api';
import { manualApplyChangesToFile } from '../../utils/apply';
import { ApplyEditToFile } from '../../utils/apply';
import { removeAITags } from '../../utils/content-filter';
import { PromptGenerator } from '../../utils/prompt-generator';
@ -239,10 +240,10 @@ export const InlineEdit: React.FC<InlineEditProps> = ({
const startLine = parsedBlock?.startLine || defaultStartLine;
const endLine = parsedBlock?.endLine || defaultEndLine;
const updatedContent = await manualApplyChangesToFile(
finalContent,
const updatedContent = await ApplyEditToFile(
activeFile,
await plugin.app.vault.read(activeFile),
finalContent,
startLine,
endLine
);

View File

@ -0,0 +1,22 @@
import type { DiffStrategy } from "./types"
import { UnifiedDiffStrategy } from "./strategies/unified"
import { SearchReplaceDiffStrategy } from "./strategies/search-replace"
import { NewUnifiedDiffStrategy } from "./strategies/new-unified"
/**
* Get the appropriate diff strategy for the given model
* @param model The name of the model being used (e.g., 'gpt-4', 'claude-3-opus')
* @returns The appropriate diff strategy for the model
*/
export function getDiffStrategy(
model: string,
fuzzyMatchThreshold?: number,
experimentalDiffStrategy: boolean = false,
): DiffStrategy {
if (experimentalDiffStrategy) {
return new NewUnifiedDiffStrategy(fuzzyMatchThreshold)
}
return new SearchReplaceDiffStrategy(fuzzyMatchThreshold)
}
export type { DiffStrategy }
export { UnifiedDiffStrategy, SearchReplaceDiffStrategy }

View File

@ -0,0 +1,31 @@
/**
* Inserts multiple groups of elements at specified indices in an array
* @param original Array to insert into, split by lines
* @param insertGroups Array of groups to insert, each with an index and elements to insert
* @returns New array with all insertions applied
*/
export interface InsertGroup {
index: number
elements: string[]
}
export function insertGroups(original: string[], insertGroups: InsertGroup[]): string[] {
// Sort groups by index to maintain order
insertGroups.sort((a, b) => a.index - b.index)
let result: string[] = []
let lastIndex = 0
insertGroups.forEach(({ index, elements }) => {
// Add elements from original array up to insertion point
result.push(...original.slice(lastIndex, index))
// Add the group of elements
result.push(...elements)
lastIndex = index
})
// Add remaining elements from original array
result.push(...original.slice(lastIndex))
return result
}

View File

@ -0,0 +1,738 @@
import { NewUnifiedDiffStrategy } from "../new-unified"
describe("main", () => {
let strategy: NewUnifiedDiffStrategy
beforeEach(() => {
strategy = new NewUnifiedDiffStrategy(0.97)
})
describe("constructor", () => {
it("should use default confidence threshold when not provided", () => {
const defaultStrategy = new NewUnifiedDiffStrategy()
expect(defaultStrategy["confidenceThreshold"]).toBe(1)
})
it("should use provided confidence threshold", () => {
const customStrategy = new NewUnifiedDiffStrategy(0.85)
expect(customStrategy["confidenceThreshold"]).toBe(0.85)
})
it("should enforce minimum confidence threshold", () => {
const lowStrategy = new NewUnifiedDiffStrategy(0.7) // Below minimum of 0.8
expect(lowStrategy["confidenceThreshold"]).toBe(0.8)
})
})
describe("getToolDescription", () => {
it("should return tool description with correct cwd", () => {
const cwd = "/test/path"
const description = strategy.getToolDescription({ cwd })
expect(description).toContain("apply_diff Tool - Generate Precise Code Changes")
expect(description).toContain(cwd)
expect(description).toContain("Step-by-Step Instructions")
expect(description).toContain("Requirements")
expect(description).toContain("Examples")
expect(description).toContain("Parameters:")
})
})
it("should apply simple diff correctly", async () => {
const original = `line1
line2
line3`
const diff = `--- a/file.txt
+++ b/file.txt
@@ ... @@
line1
+new line
line2
-line3
+modified line3`
const result = await strategy.applyDiff(original, diff)
expect(result.success).toBe(true)
if (result.success) {
expect(result.content).toBe(`line1
new line
line2
modified line3`)
}
})
it("should handle multiple hunks", async () => {
const original = `line1
line2
line3
line4
line5`
const diff = `--- a/file.txt
+++ b/file.txt
@@ ... @@
line1
+new line
line2
-line3
+modified line3
@@ ... @@
line4
-line5
+modified line5
+new line at end`
const result = await strategy.applyDiff(original, diff)
expect(result.success).toBe(true)
if (result.success) {
expect(result.content).toBe(`line1
new line
line2
modified line3
line4
modified line5
new line at end`)
}
})
it("should handle complex large", async () => {
const original = `line1
line2
line3
line4
line5
line6
line7
line8
line9
line10`
const diff = `--- a/file.txt
+++ b/file.txt
@@ ... @@
line1
+header line
+another header
line2
-line3
-line4
+modified line3
+modified line4
+extra line
@@ ... @@
line6
+middle section
line7
-line8
+changed line8
+bonus line
@@ ... @@
line9
-line10
+final line
+very last line`
const result = await strategy.applyDiff(original, diff)
expect(result.success).toBe(true)
if (result.success) {
expect(result.content).toBe(`line1
header line
another header
line2
modified line3
modified line4
extra line
line5
line6
middle section
line7
changed line8
bonus line
line9
final line
very last line`)
}
})
it("should handle indentation changes", async () => {
const original = `first line
indented line
double indented line
back to single indent
no indent
indented again
double indent again
triple indent
back to single
last line`
const diff = `--- original
+++ modified
@@ ... @@
first line
indented line
+ tab indented line
+ new indented line
double indented line
back to single indent
no indent
indented again
double indent again
- triple indent
+ hi there mate
back to single
last line`
const expected = `first line
indented line
tab indented line
new indented line
double indented line
back to single indent
no indent
indented again
double indent again
hi there mate
back to single
last line`
const result = await strategy.applyDiff(original, diff)
expect(result.success).toBe(true)
if (result.success) {
expect(result.content).toBe(expected)
}
})
it("should handle high level edits", async () => {
const original = `def factorial(n):
if n == 0:
return 1
else:
return n * factorial(n-1)`
const diff = `@@ ... @@
-def factorial(n):
- if n == 0:
- return 1
- else:
- return n * factorial(n-1)
+def factorial(number):
+ if number == 0:
+ return 1
+ else:
+ return number * factorial(number-1)`
const expected = `def factorial(number):
if number == 0:
return 1
else:
return number * factorial(number-1)`
const result = await strategy.applyDiff(original, diff)
expect(result.success).toBe(true)
if (result.success) {
expect(result.content).toBe(expected)
}
})
it("it should handle very complex edits", async () => {
const original = `//Initialize the array that will hold the primes
var primeArray = [];
/*Write a function that checks for primeness and
pushes those values to t*he array*/
function PrimeCheck(candidate){
isPrime = true;
for(var i = 2; i < candidate && isPrime; i++){
if(candidate%i === 0){
isPrime = false;
} else {
isPrime = true;
}
}
if(isPrime){
primeArray.push(candidate);
}
return primeArray;
}
/*Write the code that runs the above until the
l ength of the array equa*ls the number of primes
desired*/
var numPrimes = prompt("How many primes?");
//Display the finished array of primes
//for loop starting at 2 as that is the lowest prime number keep going until the array is as long as we requested
for (var i = 2; primeArray.length < numPrimes; i++) {
PrimeCheck(i); //
}
console.log(primeArray);
`
const diff = `--- test_diff.js
+++ test_diff.js
@@ ... @@
-//Initialize the array that will hold the primes
var primeArray = [];
-/*Write a function that checks for primeness and
- pushes those values to t*he array*/
function PrimeCheck(candidate){
isPrime = true;
for(var i = 2; i < candidate && isPrime; i++){
@@ ... @@
return primeArray;
}
-/*Write the code that runs the above until the
- l ength of the array equa*ls the number of primes
- desired*/
var numPrimes = prompt("How many primes?");
-//Display the finished array of primes
-
-//for loop starting at 2 as that is the lowest prime number keep going until the array is as long as we requested
for (var i = 2; primeArray.length < numPrimes; i++) {
- PrimeCheck(i); //
+ PrimeCheck(i);
}
console.log(primeArray);`
const expected = `var primeArray = [];
function PrimeCheck(candidate){
isPrime = true;
for(var i = 2; i < candidate && isPrime; i++){
if(candidate%i === 0){
isPrime = false;
} else {
isPrime = true;
}
}
if(isPrime){
primeArray.push(candidate);
}
return primeArray;
}
var numPrimes = prompt("How many primes?");
for (var i = 2; primeArray.length < numPrimes; i++) {
PrimeCheck(i);
}
console.log(primeArray);
`
const result = await strategy.applyDiff(original, diff)
expect(result.success).toBe(true)
if (result.success) {
expect(result.content).toBe(expected)
}
})
describe("error handling and edge cases", () => {
it("should reject completely invalid diff format", async () => {
const original = "line1\nline2\nline3"
const invalidDiff = "this is not a diff at all"
const result = await strategy.applyDiff(original, invalidDiff)
expect(result.success).toBe(false)
})
it("should reject diff with invalid hunk format", async () => {
const original = "line1\nline2\nline3"
const invalidHunkDiff = `--- a/file.txt
+++ b/file.txt
invalid hunk header
line1
-line2
+new line`
const result = await strategy.applyDiff(original, invalidHunkDiff)
expect(result.success).toBe(false)
})
it("should fail when diff tries to modify non-existent content", async () => {
const original = "line1\nline2\nline3"
const nonMatchingDiff = `--- a/file.txt
+++ b/file.txt
@@ ... @@
line1
-nonexistent line
+new line
line3`
const result = await strategy.applyDiff(original, nonMatchingDiff)
expect(result.success).toBe(false)
})
it("should handle overlapping hunks", async () => {
const original = `line1
line2
line3
line4
line5`
const overlappingDiff = `--- a/file.txt
+++ b/file.txt
@@ ... @@
line1
line2
-line3
+modified3
line4
@@ ... @@
line2
-line3
-line4
+modified3and4
line5`
const result = await strategy.applyDiff(original, overlappingDiff)
expect(result.success).toBe(false)
})
it("should handle empty lines modifications", async () => {
const original = `line1
line3
line5`
const emptyLinesDiff = `--- a/file.txt
+++ b/file.txt
@@ ... @@
line1
-line3
+line3modified
line5`
const result = await strategy.applyDiff(original, emptyLinesDiff)
expect(result.success).toBe(true)
if (result.success) {
expect(result.content).toBe(`line1
line3modified
line5`)
}
})
it("should handle mixed line endings in diff", async () => {
const original = "line1\r\nline2\nline3\r\n"
const mixedEndingsDiff = `--- a/file.txt
+++ b/file.txt
@@ ... @@
line1\r
-line2
+modified2\r
line3`
const result = await strategy.applyDiff(original, mixedEndingsDiff)
expect(result.success).toBe(true)
if (result.success) {
expect(result.content).toBe("line1\r\nmodified2\r\nline3\r\n")
}
})
it("should handle partial line modifications", async () => {
const original = "const value = oldValue + 123;"
const partialDiff = `--- a/file.txt
+++ b/file.txt
@@ ... @@
-const value = oldValue + 123;
+const value = newValue + 123;`
const result = await strategy.applyDiff(original, partialDiff)
expect(result.success).toBe(true)
if (result.success) {
expect(result.content).toBe("const value = newValue + 123;")
}
})
it("should handle slightly malformed but recoverable diff", async () => {
const original = "line1\nline2\nline3"
// Missing space after --- and +++
const slightlyBadDiff = `---a/file.txt
+++b/file.txt
@@ ... @@
line1
-line2
+new line
line3`
const result = await strategy.applyDiff(original, slightlyBadDiff)
expect(result.success).toBe(true)
if (result.success) {
expect(result.content).toBe("line1\nnew line\nline3")
}
})
})
describe("similar code sections", () => {
it("should correctly modify the right section when similar code exists", async () => {
const original = `function add(a, b) {
return a + b;
}
function subtract(a, b) {
return a - b;
}
function multiply(a, b) {
return a + b; // Bug here
}`
const diff = `--- a/math.js
+++ b/math.js
@@ ... @@
function multiply(a, b) {
- return a + b; // Bug here
+ return a * b;
}`
const result = await strategy.applyDiff(original, diff)
expect(result.success).toBe(true)
if (result.success) {
expect(result.content).toBe(`function add(a, b) {
return a + b;
}
function subtract(a, b) {
return a - b;
}
function multiply(a, b) {
return a * b;
}`)
}
})
it("should handle multiple similar sections with correct context", async () => {
const original = `if (condition) {
doSomething();
doSomething();
doSomething();
}
if (otherCondition) {
doSomething();
doSomething();
doSomething();
}`
const diff = `--- a/file.js
+++ b/file.js
@@ ... @@
if (otherCondition) {
doSomething();
- doSomething();
+ doSomethingElse();
doSomething();
}`
const result = await strategy.applyDiff(original, diff)
expect(result.success).toBe(true)
if (result.success) {
expect(result.content).toBe(`if (condition) {
doSomething();
doSomething();
doSomething();
}
if (otherCondition) {
doSomething();
doSomethingElse();
doSomething();
}`)
}
})
})
describe("hunk splitting", () => {
it("should handle large diffs with multiple non-contiguous changes", async () => {
const original = `import { readFile } from 'fs';
import { join } from 'path';
import { Logger } from './logger';
const logger = new Logger();
async function processFile(filePath: string) {
try {
const data = await readFile(filePath, 'utf8');
logger.info('File read successfully');
return data;
} catch (error) {
logger.error('Failed to read file:', error);
throw error;
}
}
function validateInput(input: string): boolean {
if (!input) {
logger.warn('Empty input received');
return false;
}
return input.length > 0;
}
async function writeOutput(data: string) {
logger.info('Processing output');
// TODO: Implement output writing
return Promise.resolve();
}
function parseConfig(configPath: string) {
logger.debug('Reading config from:', configPath);
// Basic config parsing
return {
enabled: true,
maxRetries: 3
};
}
export {
processFile,
validateInput,
writeOutput,
parseConfig
};`
const diff = `--- a/file.ts
+++ b/file.ts
@@ ... @@
-import { readFile } from 'fs';
+import { readFile, writeFile } from 'fs';
import { join } from 'path';
-import { Logger } from './logger';
+import { Logger } from './utils/logger';
+import { Config } from './types';
-const logger = new Logger();
+const logger = new Logger('FileProcessor');
async function processFile(filePath: string) {
try {
const data = await readFile(filePath, 'utf8');
- logger.info('File read successfully');
+ logger.info(\`File \${filePath} read successfully\`);
return data;
} catch (error) {
- logger.error('Failed to read file:', error);
+ logger.error(\`Failed to read file \${filePath}:\`, error);
throw error;
}
}
function validateInput(input: string): boolean {
if (!input) {
- logger.warn('Empty input received');
+ logger.warn('Validation failed: Empty input received');
return false;
}
- return input.length > 0;
+ return input.trim().length > 0;
}
-async function writeOutput(data: string) {
- logger.info('Processing output');
- // TODO: Implement output writing
- return Promise.resolve();
+async function writeOutput(data: string, outputPath: string) {
+ try {
+ await writeFile(outputPath, data, 'utf8');
+ logger.info(\`Output written to \${outputPath}\`);
+ } catch (error) {
+ logger.error(\`Failed to write output to \${outputPath}:\`, error);
+ throw error;
+ }
}
-function parseConfig(configPath: string) {
- logger.debug('Reading config from:', configPath);
- // Basic config parsing
- return {
- enabled: true,
- maxRetries: 3
- };
+async function parseConfig(configPath: string): Promise<Config> {
+ try {
+ const configData = await readFile(configPath, 'utf8');
+ logger.debug(\`Reading config from \${configPath}\`);
+ return JSON.parse(configData);
+ } catch (error) {
+ logger.error(\`Failed to parse config from \${configPath}:\`, error);
+ throw error;
+ }
}
export {
processFile,
validateInput,
writeOutput,
- parseConfig
+ parseConfig,
+ type Config
};`
const expected = `import { readFile, writeFile } from 'fs';
import { join } from 'path';
import { Logger } from './utils/logger';
import { Config } from './types';
const logger = new Logger('FileProcessor');
async function processFile(filePath: string) {
try {
const data = await readFile(filePath, 'utf8');
logger.info(\`File \${filePath} read successfully\`);
return data;
} catch (error) {
logger.error(\`Failed to read file \${filePath}:\`, error);
throw error;
}
}
function validateInput(input: string): boolean {
if (!input) {
logger.warn('Validation failed: Empty input received');
return false;
}
return input.trim().length > 0;
}
async function writeOutput(data: string, outputPath: string) {
try {
await writeFile(outputPath, data, 'utf8');
logger.info(\`Output written to \${outputPath}\`);
} catch (error) {
logger.error(\`Failed to write output to \${outputPath}:\`, error);
throw error;
}
}
async function parseConfig(configPath: string): Promise<Config> {
try {
const configData = await readFile(configPath, 'utf8');
logger.debug(\`Reading config from \${configPath}\`);
return JSON.parse(configData);
} catch (error) {
logger.error(\`Failed to parse config from \${configPath}:\`, error);
throw error;
}
}
export {
processFile,
validateInput,
writeOutput,
parseConfig,
type Config
};`
const result = await strategy.applyDiff(original, diff)
expect(result.success).toBe(true)
if (result.success) {
expect(result.content).toBe(expected)
}
})
})
})

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,228 @@
import { UnifiedDiffStrategy } from "../unified"
describe("UnifiedDiffStrategy", () => {
let strategy: UnifiedDiffStrategy
beforeEach(() => {
strategy = new UnifiedDiffStrategy()
})
describe("getToolDescription", () => {
it("should return tool description with correct cwd", () => {
const cwd = "/test/path"
const description = strategy.getToolDescription({ cwd })
expect(description).toContain("apply_diff")
expect(description).toContain(cwd)
expect(description).toContain("Parameters:")
expect(description).toContain("Format Requirements:")
})
})
describe("applyDiff", () => {
it("should successfully apply a function modification diff", async () => {
const originalContent = `import { Logger } from '../logger';
function calculateTotal(items: number[]): number {
return items.reduce((sum, item) => {
return sum + item;
}, 0);
}
export { calculateTotal };`
const diffContent = `--- src/utils/helper.ts
+++ src/utils/helper.ts
@@ -1,9 +1,10 @@
import { Logger } from '../logger';
function calculateTotal(items: number[]): number {
- return items.reduce((sum, item) => {
- return sum + item;
+ const total = items.reduce((sum, item) => {
+ return sum + item * 1.1; // Add 10% markup
}, 0);
+ return Math.round(total * 100) / 100; // Round to 2 decimal places
}
export { calculateTotal };`
const expected = `import { Logger } from '../logger';
function calculateTotal(items: number[]): number {
const total = items.reduce((sum, item) => {
return sum + item * 1.1; // Add 10% markup
}, 0);
return Math.round(total * 100) / 100; // Round to 2 decimal places
}
export { calculateTotal };`
const result = await strategy.applyDiff(originalContent, diffContent)
expect(result.success).toBe(true)
if (result.success) {
expect(result.content).toBe(expected)
}
})
it("should successfully apply a diff adding a new method", async () => {
const originalContent = `class Calculator {
add(a: number, b: number): number {
return a + b;
}
}`
const diffContent = `--- src/Calculator.ts
+++ src/Calculator.ts
@@ -1,5 +1,9 @@
class Calculator {
add(a: number, b: number): number {
return a + b;
}
+
+ multiply(a: number, b: number): number {
+ return a * b;
+ }
}`
const expected = `class Calculator {
add(a: number, b: number): number {
return a + b;
}
multiply(a: number, b: number): number {
return a * b;
}
}`
const result = await strategy.applyDiff(originalContent, diffContent)
expect(result.success).toBe(true)
if (result.success) {
expect(result.content).toBe(expected)
}
})
it("should successfully apply a diff modifying imports", async () => {
const originalContent = `import { useState } from 'react';
import { Button } from './components';
function App() {
const [count, setCount] = useState(0);
return <Button onClick={() => setCount(count + 1)}>{count}</Button>;
}`
const diffContent = `--- src/App.tsx
+++ src/App.tsx
@@ -1,7 +1,8 @@
-import { useState } from 'react';
+import { useState, useEffect } from 'react';
import { Button } from './components';
function App() {
const [count, setCount] = useState(0);
+ useEffect(() => { document.title = \`Count: \${count}\` }, [count]);
return <Button onClick={() => setCount(count + 1)}>{count}</Button>;
}`
const expected = `import { useState, useEffect } from 'react';
import { Button } from './components';
function App() {
const [count, setCount] = useState(0);
useEffect(() => { document.title = \`Count: \${count}\` }, [count]);
return <Button onClick={() => setCount(count + 1)}>{count}</Button>;
}`
const result = await strategy.applyDiff(originalContent, diffContent)
expect(result.success).toBe(true)
if (result.success) {
expect(result.content).toBe(expected)
}
})
it("should successfully apply a diff with multiple hunks", async () => {
const originalContent = `import { readFile, writeFile } from 'fs';
function processFile(path: string) {
readFile(path, 'utf8', (err, data) => {
if (err) throw err;
const processed = data.toUpperCase();
writeFile(path, processed, (err) => {
if (err) throw err;
});
});
}
export { processFile };`
const diffContent = `--- src/file-processor.ts
+++ src/file-processor.ts
@@ -1,12 +1,14 @@
-import { readFile, writeFile } from 'fs';
+import { promises as fs } from 'fs';
+import { join } from 'path';
-function processFile(path: string) {
- readFile(path, 'utf8', (err, data) => {
- if (err) throw err;
+async function processFile(path: string) {
+ try {
+ const data = await fs.readFile(join(__dirname, path), 'utf8');
const processed = data.toUpperCase();
- writeFile(path, processed, (err) => {
- if (err) throw err;
- });
- });
+ await fs.writeFile(join(__dirname, path), processed);
+ } catch (error) {
+ console.error('Failed to process file:', error);
+ throw error;
+ }
}
export { processFile };`
const expected = `import { promises as fs } from 'fs';
import { join } from 'path';
async function processFile(path: string) {
try {
const data = await fs.readFile(join(__dirname, path), 'utf8');
const processed = data.toUpperCase();
await fs.writeFile(join(__dirname, path), processed);
} catch (error) {
console.error('Failed to process file:', error);
throw error;
}
}
export { processFile };`
const result = await strategy.applyDiff(originalContent, diffContent)
expect(result.success).toBe(true)
if (result.success) {
expect(result.content).toBe(expected)
}
})
it("should handle empty original content", async () => {
const originalContent = ""
const diffContent = `--- empty.ts
+++ empty.ts
@@ -0,0 +1,3 @@
+export function greet(name: string): string {
+ return \`Hello, \${name}!\`;
+}`
const expected = `export function greet(name: string): string {
return \`Hello, \${name}!\`;
}\n`
const result = await strategy.applyDiff(originalContent, diffContent)
expect(result.success).toBe(true)
if (result.success) {
expect(result.content).toBe(expected)
}
})
})
})

View File

@ -0,0 +1,295 @@
import { applyContextMatching, applyDMP, applyGitFallback } from "../edit-strategies"
import { Hunk } from "../types"
const testCases = [
{
name: "should return original content if no match is found",
hunk: {
changes: [
{ type: "context", content: "line1" },
{ type: "add", content: "line2" },
],
} as Hunk,
content: ["line1", "line3"],
matchPosition: -1,
expected: {
confidence: 0,
result: ["line1", "line3"],
},
expectedResult: "line1\nline3",
strategies: ["context", "dmp"],
},
{
name: "should apply a simple add change",
hunk: {
changes: [
{ type: "context", content: "line1" },
{ type: "add", content: "line2" },
],
} as Hunk,
content: ["line1", "line3"],
matchPosition: 0,
expected: {
confidence: 1,
result: ["line1", "line2", "line3"],
},
expectedResult: "line1\nline2\nline3",
strategies: ["context", "dmp"],
},
{
name: "should apply a simple remove change",
hunk: {
changes: [
{ type: "context", content: "line1" },
{ type: "remove", content: "line2" },
],
} as Hunk,
content: ["line1", "line2", "line3"],
matchPosition: 0,
expected: {
confidence: 1,
result: ["line1", "line3"],
},
expectedResult: "line1\nline3",
strategies: ["context", "dmp"],
},
{
name: "should apply a simple context change",
hunk: {
changes: [{ type: "context", content: "line1" }],
} as Hunk,
content: ["line1", "line2", "line3"],
matchPosition: 0,
expected: {
confidence: 1,
result: ["line1", "line2", "line3"],
},
expectedResult: "line1\nline2\nline3",
strategies: ["context", "dmp"],
},
{
name: "should apply a multi-line add change",
hunk: {
changes: [
{ type: "context", content: "line1" },
{ type: "add", content: "line2\nline3" },
],
} as Hunk,
content: ["line1", "line4"],
matchPosition: 0,
expected: {
confidence: 1,
result: ["line1", "line2\nline3", "line4"],
},
expectedResult: "line1\nline2\nline3\nline4",
strategies: ["context", "dmp"],
},
{
name: "should apply a multi-line remove change",
hunk: {
changes: [
{ type: "context", content: "line1" },
{ type: "remove", content: "line2\nline3" },
],
} as Hunk,
content: ["line1", "line2", "line3", "line4"],
matchPosition: 0,
expected: {
confidence: 1,
result: ["line1", "line4"],
},
expectedResult: "line1\nline4",
strategies: ["context", "dmp"],
},
{
name: "should apply a multi-line context change",
hunk: {
changes: [
{ type: "context", content: "line1" },
{ type: "context", content: "line2\nline3" },
],
} as Hunk,
content: ["line1", "line2", "line3", "line4"],
matchPosition: 0,
expected: {
confidence: 1,
result: ["line1", "line2\nline3", "line4"],
},
expectedResult: "line1\nline2\nline3\nline4",
strategies: ["context", "dmp"],
},
{
name: "should apply a change with indentation",
hunk: {
changes: [
{ type: "context", content: " line1" },
{ type: "add", content: " line2" },
],
} as Hunk,
content: [" line1", " line3"],
matchPosition: 0,
expected: {
confidence: 1,
result: [" line1", " line2", " line3"],
},
expectedResult: " line1\n line2\n line3",
strategies: ["context", "dmp"],
},
{
name: "should apply a change with mixed indentation",
hunk: {
changes: [
{ type: "context", content: "\tline1" },
{ type: "add", content: " line2" },
],
} as Hunk,
content: ["\tline1", " line3"],
matchPosition: 0,
expected: {
confidence: 1,
result: ["\tline1", " line2", " line3"],
},
expectedResult: "\tline1\n line2\n line3",
strategies: ["context", "dmp"],
},
{
name: "should apply a change with mixed indentation and multi-line",
hunk: {
changes: [
{ type: "context", content: " line1" },
{ type: "add", content: "\tline2\n line3" },
],
} as Hunk,
content: [" line1", " line4"],
matchPosition: 0,
expected: {
confidence: 1,
result: [" line1", "\tline2\n line3", " line4"],
},
expectedResult: " line1\n\tline2\n line3\n line4",
strategies: ["context", "dmp"],
},
{
name: "should apply a complex change with mixed indentation and multi-line",
hunk: {
changes: [
{ type: "context", content: " line1" },
{ type: "remove", content: " line2" },
{ type: "add", content: "\tline3\n line4" },
{ type: "context", content: " line5" },
],
} as Hunk,
content: [" line1", " line2", " line5", " line6"],
matchPosition: 0,
expected: {
confidence: 1,
result: [" line1", "\tline3\n line4", " line5", " line6"],
},
expectedResult: " line1\n\tline3\n line4\n line5\n line6",
strategies: ["context", "dmp"],
},
{
name: "should apply a complex change with mixed indentation and multi-line and context",
hunk: {
changes: [
{ type: "context", content: " line1" },
{ type: "remove", content: " line2" },
{ type: "add", content: "\tline3\n line4" },
{ type: "context", content: " line5" },
{ type: "context", content: " line6" },
],
} as Hunk,
content: [" line1", " line2", " line5", " line6", " line7"],
matchPosition: 0,
expected: {
confidence: 1,
result: [" line1", "\tline3\n line4", " line5", " line6", " line7"],
},
expectedResult: " line1\n\tline3\n line4\n line5\n line6\n line7",
strategies: ["context", "dmp"],
},
{
name: "should apply a complex change with mixed indentation and multi-line and context and a different match position",
hunk: {
changes: [
{ type: "context", content: " line1" },
{ type: "remove", content: " line2" },
{ type: "add", content: "\tline3\n line4" },
{ type: "context", content: " line5" },
{ type: "context", content: " line6" },
],
} as Hunk,
content: [" line0", " line1", " line2", " line5", " line6", " line7"],
matchPosition: 1,
expected: {
confidence: 1,
result: [" line0", " line1", "\tline3\n line4", " line5", " line6", " line7"],
},
expectedResult: " line0\n line1\n\tline3\n line4\n line5\n line6\n line7",
strategies: ["context", "dmp"],
},
]
describe("applyContextMatching", () => {
testCases.forEach(({ name, hunk, content, matchPosition, expected, strategies, expectedResult }) => {
if (!strategies?.includes("context")) {
return
}
it(name, () => {
const result = applyContextMatching(hunk, content, matchPosition)
expect(result.result.join("\n")).toEqual(expectedResult)
expect(result.confidence).toBeGreaterThanOrEqual(expected.confidence)
expect(result.strategy).toBe("context")
})
})
})
describe("applyDMP", () => {
testCases.forEach(({ name, hunk, content, matchPosition, expected, strategies, expectedResult }) => {
if (!strategies?.includes("dmp")) {
return
}
it(name, () => {
const result = applyDMP(hunk, content, matchPosition)
expect(result.result.join("\n")).toEqual(expectedResult)
expect(result.confidence).toBeGreaterThanOrEqual(expected.confidence)
expect(result.strategy).toBe("dmp")
})
})
})
describe("applyGitFallback", () => {
it("should successfully apply changes using git operations", async () => {
const hunk = {
changes: [
{ type: "context", content: "line1", indent: "" },
{ type: "remove", content: "line2", indent: "" },
{ type: "add", content: "new line2", indent: "" },
{ type: "context", content: "line3", indent: "" },
],
} as Hunk
const content = ["line1", "line2", "line3"]
const result = await applyGitFallback(hunk, content)
expect(result.result.join("\n")).toEqual("line1\nnew line2\nline3")
expect(result.confidence).toBe(1)
expect(result.strategy).toBe("git-fallback")
})
it("should return original content with 0 confidence when changes cannot be applied", async () => {
const hunk = {
changes: [
{ type: "context", content: "nonexistent", indent: "" },
{ type: "add", content: "new line", indent: "" },
],
} as Hunk
const content = ["line1", "line2", "line3"]
const result = await applyGitFallback(hunk, content)
expect(result.result).toEqual(content)
expect(result.confidence).toBe(0)
expect(result.strategy).toBe("git-fallback")
})
})

View File

@ -0,0 +1,262 @@
import { findAnchorMatch, findExactMatch, findSimilarityMatch, findLevenshteinMatch } from "../search-strategies"
type SearchStrategy = (
searchStr: string,
content: string[],
startIndex?: number,
) => {
index: number
confidence: number
strategy: string
}
const testCases = [
{
name: "should return no match if the search string is not found",
searchStr: "not found",
content: ["line1", "line2", "line3"],
expected: { index: -1, confidence: 0 },
strategies: ["exact", "similarity", "levenshtein"],
},
{
name: "should return a match if the search string is found",
searchStr: "line2",
content: ["line1", "line2", "line3"],
expected: { index: 1, confidence: 1 },
strategies: ["exact", "similarity", "levenshtein"],
},
{
name: "should return a match with correct index when startIndex is provided",
searchStr: "line3",
content: ["line1", "line2", "line3", "line4", "line3"],
startIndex: 3,
expected: { index: 4, confidence: 1 },
strategies: ["exact", "similarity", "levenshtein"],
},
{
name: "should return a match even if there are more lines in content",
searchStr: "line2",
content: ["line1", "line2", "line3", "line4", "line5"],
expected: { index: 1, confidence: 1 },
strategies: ["exact", "similarity", "levenshtein"],
},
{
name: "should return a match even if the search string is at the beginning of the content",
searchStr: "line1",
content: ["line1", "line2", "line3"],
expected: { index: 0, confidence: 1 },
strategies: ["exact", "similarity", "levenshtein"],
},
{
name: "should return a match even if the search string is at the end of the content",
searchStr: "line3",
content: ["line1", "line2", "line3"],
expected: { index: 2, confidence: 1 },
strategies: ["exact", "similarity", "levenshtein"],
},
{
name: "should return a match for a multi-line search string",
searchStr: "line2\nline3",
content: ["line1", "line2", "line3", "line4"],
expected: { index: 1, confidence: 1 },
strategies: ["exact", "similarity", "levenshtein"],
},
{
name: "should return no match if a multi-line search string is not found",
searchStr: "line2\nline4",
content: ["line1", "line2", "line3", "line4"],
expected: { index: -1, confidence: 0 },
strategies: ["exact", "similarity"],
},
{
name: "should return a match with indentation",
searchStr: " line2",
content: ["line1", " line2", "line3"],
expected: { index: 1, confidence: 1 },
strategies: ["exact", "similarity", "levenshtein"],
},
{
name: "should return a match with more complex indentation",
searchStr: " line3",
content: [" line1", " line2", " line3", " line4"],
expected: { index: 2, confidence: 1 },
strategies: ["exact", "similarity", "levenshtein"],
},
{
name: "should return a match with mixed indentation",
searchStr: "\tline2",
content: [" line1", "\tline2", " line3"],
expected: { index: 1, confidence: 1 },
strategies: ["exact", "similarity", "levenshtein"],
},
{
name: "should return a match with mixed indentation and multi-line",
searchStr: " line2\n\tline3",
content: ["line1", " line2", "\tline3", " line4"],
expected: { index: 1, confidence: 1 },
strategies: ["exact", "similarity", "levenshtein"],
},
{
name: "should return no match if mixed indentation and multi-line is not found",
searchStr: " line2\n line4",
content: ["line1", " line2", "\tline3", " line4"],
expected: { index: -1, confidence: 0 },
strategies: ["exact", "similarity"],
},
{
name: "should return a match with leading and trailing spaces",
searchStr: " line2 ",
content: ["line1", " line2 ", "line3"],
expected: { index: 1, confidence: 1 },
strategies: ["exact", "similarity", "levenshtein"],
},
{
name: "should return a match with leading and trailing tabs",
searchStr: "\tline2\t",
content: ["line1", "\tline2\t", "line3"],
expected: { index: 1, confidence: 1 },
strategies: ["exact", "similarity", "levenshtein"],
},
{
name: "should return a match with mixed leading and trailing spaces and tabs",
searchStr: " \tline2\t ",
content: ["line1", " \tline2\t ", "line3"],
expected: { index: 1, confidence: 1 },
strategies: ["exact", "similarity", "levenshtein"],
},
{
name: "should return a match with mixed leading and trailing spaces and tabs and multi-line",
searchStr: " \tline2\t \n line3 ",
content: ["line1", " \tline2\t ", " line3 ", "line4"],
expected: { index: 1, confidence: 1 },
strategies: ["exact", "similarity", "levenshtein"],
},
{
name: "should return no match if mixed leading and trailing spaces and tabs and multi-line is not found",
searchStr: " \tline2\t \n line4 ",
content: ["line1", " \tline2\t ", " line3 ", "line4"],
expected: { index: -1, confidence: 0 },
strategies: ["exact", "similarity"],
},
]
describe("findExactMatch", () => {
testCases.forEach(({ name, searchStr, content, startIndex, expected, strategies }) => {
if (!strategies?.includes("exact")) {
return
}
it(name, () => {
const result = findExactMatch(searchStr, content, startIndex)
expect(result.index).toBe(expected.index)
expect(result.confidence).toBeGreaterThanOrEqual(expected.confidence)
expect(result.strategy).toMatch(/exact(-overlapping)?/)
})
})
})
describe("findAnchorMatch", () => {
const anchorTestCases = [
{
name: "should return no match if no anchors are found",
searchStr: " \n \n ",
content: ["line1", "line2", "line3"],
expected: { index: -1, confidence: 0 },
},
{
name: "should return no match if anchor positions cannot be validated",
searchStr: "unique line\ncontext line 1\ncontext line 2",
content: [
"different line 1",
"different line 2",
"different line 3",
"another unique line",
"context line 1",
"context line 2",
],
expected: { index: -1, confidence: 0 },
},
{
name: "should return a match if anchor positions can be validated",
searchStr: "unique line\ncontext line 1\ncontext line 2",
content: ["line1", "line2", "unique line", "context line 1", "context line 2", "line 6"],
expected: { index: 2, confidence: 1 },
},
{
name: "should return a match with correct index when startIndex is provided",
searchStr: "unique line\ncontext line 1\ncontext line 2",
content: ["line1", "line2", "line3", "unique line", "context line 1", "context line 2", "line 7"],
startIndex: 3,
expected: { index: 3, confidence: 1 },
},
{
name: "should return a match even if there are more lines in content",
searchStr: "unique line\ncontext line 1\ncontext line 2",
content: [
"line1",
"line2",
"unique line",
"context line 1",
"context line 2",
"line 6",
"extra line 1",
"extra line 2",
],
expected: { index: 2, confidence: 1 },
},
{
name: "should return a match even if the anchor is at the beginning of the content",
searchStr: "unique line\ncontext line 1\ncontext line 2",
content: ["unique line", "context line 1", "context line 2", "line 6"],
expected: { index: 0, confidence: 1 },
},
{
name: "should return a match even if the anchor is at the end of the content",
searchStr: "unique line\ncontext line 1\ncontext line 2",
content: ["line1", "line2", "unique line", "context line 1", "context line 2"],
expected: { index: 2, confidence: 1 },
},
{
name: "should return no match if no valid anchor is found",
searchStr: "non-unique line\ncontext line 1\ncontext line 2",
content: ["line1", "line2", "non-unique line", "context line 1", "context line 2", "non-unique line"],
expected: { index: -1, confidence: 0 },
},
]
anchorTestCases.forEach(({ name, searchStr, content, startIndex, expected }) => {
it(name, () => {
const result = findAnchorMatch(searchStr, content, startIndex)
expect(result.index).toBe(expected.index)
expect(result.confidence).toBeGreaterThanOrEqual(expected.confidence)
expect(result.strategy).toBe("anchor")
})
})
})
describe("findSimilarityMatch", () => {
testCases.forEach(({ name, searchStr, content, startIndex, expected, strategies }) => {
if (!strategies?.includes("similarity")) {
return
}
it(name, () => {
const result = findSimilarityMatch(searchStr, content, startIndex)
expect(result.index).toBe(expected.index)
expect(result.confidence).toBeGreaterThanOrEqual(expected.confidence)
expect(result.strategy).toBe("similarity")
})
})
})
describe("findLevenshteinMatch", () => {
testCases.forEach(({ name, searchStr, content, startIndex, expected, strategies }) => {
if (!strategies?.includes("levenshtein")) {
return
}
it(name, () => {
const result = findLevenshteinMatch(searchStr, content, startIndex)
expect(result.index).toBe(expected.index)
expect(result.confidence).toBeGreaterThanOrEqual(expected.confidence)
expect(result.strategy).toBe("levenshtein")
})
})
})

View File

@ -0,0 +1,297 @@
import { diff_match_patch } from "diff-match-patch"
import { EditResult, Hunk } from "./types"
import { getDMPSimilarity, validateEditResult } from "./search-strategies"
import * as path from "path"
import simpleGit, { SimpleGit } from "simple-git"
import * as tmp from "tmp"
import * as fs from "fs"
// Helper function to infer indentation - simplified version
function inferIndentation(line: string, contextLines: string[], previousIndent: string = ""): string {
// If the line has explicit indentation in the change, use it exactly
const lineMatch = line.match(/^(\s+)/)
if (lineMatch) {
return lineMatch[1]
}
// If we have context lines, use the indentation from the first context line
const contextLine = contextLines[0]
if (contextLine) {
const contextMatch = contextLine.match(/^(\s+)/)
if (contextMatch) {
return contextMatch[1]
}
}
// Fallback to previous indent
return previousIndent
}
// Context matching edit strategy
export function applyContextMatching(hunk: Hunk, content: string[], matchPosition: number): EditResult {
if (matchPosition === -1) {
return { confidence: 0, result: content, strategy: "context" }
}
const newResult = [...content.slice(0, matchPosition)]
let sourceIndex = matchPosition
for (const change of hunk.changes) {
if (change.type === "context") {
// Use the original line from content if available
if (sourceIndex < content.length) {
newResult.push(content[sourceIndex])
} else {
const line = change.indent ? change.indent + change.content : change.content
newResult.push(line)
}
sourceIndex++
} else if (change.type === "add") {
// Use exactly the indentation from the change
const baseIndent = change.indent || ""
// Handle multi-line additions
const lines = change.content.split("\n").map((line) => {
// If the line already has indentation, preserve it relative to the base indent
const lineIndentMatch = line.match(/^(\s*)(.*)/)
if (lineIndentMatch) {
const [, lineIndent, content] = lineIndentMatch
// Only add base indent if the line doesn't already have it
return lineIndent ? line : baseIndent + content
}
return baseIndent + line
})
newResult.push(...lines)
} else if (change.type === "remove") {
// Handle multi-line removes by incrementing sourceIndex for each line
const removedLines = change.content.split("\n").length
sourceIndex += removedLines
}
}
// Append remaining content
newResult.push(...content.slice(sourceIndex))
// Calculate confidence based on the actual changes
const afterText = newResult.slice(matchPosition, newResult.length - (content.length - sourceIndex)).join("\n")
const confidence = validateEditResult(hunk, afterText)
return {
confidence,
result: newResult,
strategy: "context",
}
}
// DMP edit strategy
export function applyDMP(hunk: Hunk, content: string[], matchPosition: number): EditResult {
if (matchPosition === -1) {
return { confidence: 0, result: content, strategy: "dmp" }
}
const dmp = new diff_match_patch()
// Calculate total lines in before block accounting for multi-line content
const beforeLineCount = hunk.changes
.filter((change) => change.type === "context" || change.type === "remove")
.reduce((count, change) => count + change.content.split("\n").length, 0)
// Build BEFORE block (context + removals)
const beforeLines = hunk.changes
.filter((change) => change.type === "context" || change.type === "remove")
.map((change) => {
if (change.originalLine) {
return change.originalLine
}
return change.indent ? change.indent + change.content : change.content
})
// Build AFTER block (context + additions)
const afterLines = hunk.changes
.filter((change) => change.type === "context" || change.type === "add")
.map((change) => {
if (change.originalLine) {
return change.originalLine
}
return change.indent ? change.indent + change.content : change.content
})
// Convert to text with proper line endings
const beforeText = beforeLines.join("\n")
const afterText = afterLines.join("\n")
// Create and apply patch
const patch = dmp.patch_make(beforeText, afterText)
const targetText = content.slice(matchPosition, matchPosition + beforeLineCount).join("\n")
const [patchedText] = dmp.patch_apply(patch, targetText)
// Split result and preserve line endings
const patchedLines = patchedText.split("\n")
// Construct final result
const newResult = [
...content.slice(0, matchPosition),
...patchedLines,
...content.slice(matchPosition + beforeLineCount),
]
const confidence = validateEditResult(hunk, patchedText)
return {
confidence,
result: newResult,
strategy: "dmp",
}
}
// Git fallback strategy that works with full content
export async function applyGitFallback(hunk: Hunk, content: string[]): Promise<EditResult> {
let tmpDir: tmp.DirResult | undefined
try {
tmpDir = tmp.dirSync({ unsafeCleanup: true })
const git: SimpleGit = simpleGit(tmpDir.name)
await git.init()
await git.addConfig("user.name", "Temp")
await git.addConfig("user.email", "temp@example.com")
const filePath = path.join(tmpDir.name, "file.txt")
const searchLines = hunk.changes
.filter((change) => change.type === "context" || change.type === "remove")
.map((change) => change.originalLine || change.indent + change.content)
const replaceLines = hunk.changes
.filter((change) => change.type === "context" || change.type === "add")
.map((change) => change.originalLine || change.indent + change.content)
const searchText = searchLines.join("\n")
const replaceText = replaceLines.join("\n")
const originalText = content.join("\n")
try {
fs.writeFileSync(filePath, originalText)
await git.add("file.txt")
const originalCommit = await git.commit("original")
console.log("Strategy 1 - Original commit:", originalCommit.commit)
fs.writeFileSync(filePath, searchText)
await git.add("file.txt")
const searchCommit1 = await git.commit("search")
console.log("Strategy 1 - Search commit:", searchCommit1.commit)
fs.writeFileSync(filePath, replaceText)
await git.add("file.txt")
const replaceCommit = await git.commit("replace")
console.log("Strategy 1 - Replace commit:", replaceCommit.commit)
console.log("Strategy 1 - Attempting checkout of:", originalCommit.commit)
await git.raw(["checkout", originalCommit.commit])
try {
console.log("Strategy 1 - Attempting cherry-pick of:", replaceCommit.commit)
await git.raw(["cherry-pick", "--minimal", replaceCommit.commit])
const newText = fs.readFileSync(filePath, "utf-8")
const newLines = newText.split("\n")
return {
confidence: 1,
result: newLines,
strategy: "git-fallback",
}
} catch (cherryPickError) {
console.error("Strategy 1 failed with merge conflict")
}
} catch (error) {
console.error("Strategy 1 failed:", error)
}
try {
await git.init()
await git.addConfig("user.name", "Temp")
await git.addConfig("user.email", "temp@example.com")
fs.writeFileSync(filePath, searchText)
await git.add("file.txt")
const searchCommit = await git.commit("search")
const searchHash = searchCommit.commit.replace(/^HEAD /, "")
console.log("Strategy 2 - Search commit:", searchHash)
fs.writeFileSync(filePath, replaceText)
await git.add("file.txt")
const replaceCommit = await git.commit("replace")
const replaceHash = replaceCommit.commit.replace(/^HEAD /, "")
console.log("Strategy 2 - Replace commit:", replaceHash)
console.log("Strategy 2 - Attempting checkout of:", searchHash)
await git.raw(["checkout", searchHash])
fs.writeFileSync(filePath, originalText)
await git.add("file.txt")
const originalCommit2 = await git.commit("original")
console.log("Strategy 2 - Original commit:", originalCommit2.commit)
try {
console.log("Strategy 2 - Attempting cherry-pick of:", replaceHash)
await git.raw(["cherry-pick", "--minimal", replaceHash])
const newText = fs.readFileSync(filePath, "utf-8")
const newLines = newText.split("\n")
return {
confidence: 1,
result: newLines,
strategy: "git-fallback",
}
} catch (cherryPickError) {
console.error("Strategy 2 failed with merge conflict")
}
} catch (error) {
console.error("Strategy 2 failed:", error)
}
console.error("Git fallback failed")
return { confidence: 0, result: content, strategy: "git-fallback" }
} catch (error) {
console.error("Git fallback strategy failed:", error)
return { confidence: 0, result: content, strategy: "git-fallback" }
} finally {
if (tmpDir) {
tmpDir.removeCallback()
}
}
}
// Main edit function that tries strategies sequentially
export async function applyEdit(
hunk: Hunk,
content: string[],
matchPosition: number,
confidence: number,
confidenceThreshold: number = 0.97,
): Promise<EditResult> {
// Don't attempt regular edits if confidence is too low
if (confidence < confidenceThreshold) {
console.log(
`Search confidence (${confidence}) below minimum threshold (${confidenceThreshold}), trying git fallback...`,
)
return applyGitFallback(hunk, content)
}
// Try each strategy in sequence until one succeeds
const strategies = [
{ name: "dmp", apply: () => applyDMP(hunk, content, matchPosition) },
{ name: "context", apply: () => applyContextMatching(hunk, content, matchPosition) },
{ name: "git-fallback", apply: () => applyGitFallback(hunk, content) },
]
// Try strategies sequentially until one succeeds
for (const strategy of strategies) {
const result = await strategy.apply()
if (result.confidence >= confidenceThreshold) {
return result
}
}
return { confidence: 0, result: content, strategy: "none" }
}

View File

@ -0,0 +1,350 @@
import { Diff, Hunk, Change } from "./types"
import { findBestMatch, prepareSearchString } from "./search-strategies"
import { applyEdit } from "./edit-strategies"
import { DiffResult, DiffStrategy } from "../../types"
export class NewUnifiedDiffStrategy implements DiffStrategy {
private readonly confidenceThreshold: number
constructor(confidenceThreshold: number = 1) {
this.confidenceThreshold = Math.max(confidenceThreshold, 0.8)
}
private parseUnifiedDiff(diff: string): Diff {
const MAX_CONTEXT_LINES = 6 // Number of context lines to keep before/after changes
const lines = diff.split("\n")
const hunks: Hunk[] = []
let currentHunk: Hunk | null = null
let i = 0
while (i < lines.length && !lines[i].startsWith("@@")) {
i++
}
for (; i < lines.length; i++) {
const line = lines[i]
if (line.startsWith("@@")) {
if (
currentHunk &&
currentHunk.changes.length > 0 &&
currentHunk.changes.some((change) => change.type === "add" || change.type === "remove")
) {
const changes = currentHunk.changes
let startIdx = 0
let endIdx = changes.length - 1
for (let j = 0; j < changes.length; j++) {
if (changes[j].type !== "context") {
startIdx = Math.max(0, j - MAX_CONTEXT_LINES)
break
}
}
for (let j = changes.length - 1; j >= 0; j--) {
if (changes[j].type !== "context") {
endIdx = Math.min(changes.length - 1, j + MAX_CONTEXT_LINES)
break
}
}
currentHunk.changes = changes.slice(startIdx, endIdx + 1)
hunks.push(currentHunk)
}
currentHunk = { changes: [] }
continue
}
if (!currentHunk) {
continue
}
const content = line.slice(1)
const indentMatch = content.match(/^(\s*)/)
const indent = indentMatch ? indentMatch[0] : ""
const trimmedContent = content.slice(indent.length)
if (line.startsWith(" ")) {
currentHunk.changes.push({
type: "context",
content: trimmedContent,
indent,
originalLine: content,
})
} else if (line.startsWith("+")) {
currentHunk.changes.push({
type: "add",
content: trimmedContent,
indent,
originalLine: content,
})
} else if (line.startsWith("-")) {
currentHunk.changes.push({
type: "remove",
content: trimmedContent,
indent,
originalLine: content,
})
} else {
const finalContent = trimmedContent ? " " + trimmedContent : " "
currentHunk.changes.push({
type: "context",
content: finalContent,
indent,
originalLine: content,
})
}
}
if (
currentHunk &&
currentHunk.changes.length > 0 &&
currentHunk.changes.some((change) => change.type === "add" || change.type === "remove")
) {
hunks.push(currentHunk)
}
return { hunks }
}
getToolDescription(args: { cwd: string; toolOptions?: { [key: string]: string } }): string {
return `# apply_diff Tool - Generate Precise Code Changes
Generate a unified diff that can be cleanly applied to modify code files.
## Step-by-Step Instructions:
1. Start with file headers:
- First line: "--- {original_file_path}"
- Second line: "+++ {new_file_path}"
2. For each change section:
- Begin with "@@ ... @@" separator line without line numbers
- Include 2-3 lines of context before and after changes
- Mark removed lines with "-"
- Mark added lines with "+"
- Preserve exact indentation
3. Group related changes:
- Keep related modifications in the same hunk
- Start new hunks for logically separate changes
- When modifying functions/methods, include the entire block
## Requirements:
1. MUST include exact indentation
2. MUST include sufficient context for unique matching
3. MUST group related changes together
4. MUST use proper unified diff format
5. MUST NOT include timestamps in file headers
6. MUST NOT include line numbers in the @@ header
## Examples:
Good diff (follows all requirements):
\`\`\`diff
--- src/utils.ts
+++ src/utils.ts
@@ ... @@
def calculate_total(items):
- total = 0
- for item in items:
- total += item.price
+ return sum(item.price for item in items)
\`\`\`
Bad diff (violates requirements #1 and #2):
\`\`\`diff
--- src/utils.ts
+++ src/utils.ts
@@ ... @@
-total = 0
-for item in items:
+return sum(item.price for item in items)
\`\`\`
Parameters:
- path: (required) File path relative to ${args.cwd}
- diff: (required) Unified diff content in unified format to apply to the file.
Usage:
<apply_diff>
<path>path/to/file.ext</path>
<diff>
Your diff here
</diff>
</apply_diff>`
}
// Helper function to split a hunk into smaller hunks based on contiguous changes
private splitHunk(hunk: Hunk): Hunk[] {
const result: Hunk[] = []
let currentHunk: Hunk | null = null
let contextBefore: Change[] = []
let contextAfter: Change[] = []
const MAX_CONTEXT_LINES = 3 // Keep 3 lines of context before/after changes
for (let i = 0; i < hunk.changes.length; i++) {
const change = hunk.changes[i]
if (change.type === "context") {
if (!currentHunk) {
contextBefore.push(change)
if (contextBefore.length > MAX_CONTEXT_LINES) {
contextBefore.shift()
}
} else {
contextAfter.push(change)
if (contextAfter.length > MAX_CONTEXT_LINES) {
// We've collected enough context after changes, create a new hunk
currentHunk.changes.push(...contextAfter)
result.push(currentHunk)
currentHunk = null
// Keep the last few context lines for the next hunk
contextBefore = contextAfter
contextAfter = []
}
}
} else {
if (!currentHunk) {
currentHunk = { changes: [...contextBefore] }
contextAfter = []
} else if (contextAfter.length > 0) {
// Add accumulated context to current hunk
currentHunk.changes.push(...contextAfter)
contextAfter = []
}
currentHunk.changes.push(change)
}
}
// Add any remaining changes
if (currentHunk) {
if (contextAfter.length > 0) {
currentHunk.changes.push(...contextAfter)
}
result.push(currentHunk)
}
return result
}
async applyDiff(
originalContent: string,
diffContent: string,
startLine?: number,
endLine?: number,
): Promise<DiffResult> {
const parsedDiff = this.parseUnifiedDiff(diffContent)
const originalLines = originalContent.split("\n")
let result = [...originalLines]
if (!parsedDiff.hunks.length) {
return {
success: false,
error: "No hunks found in diff. Please ensure your diff includes actual changes and follows the unified diff format.",
}
}
for (const hunk of parsedDiff.hunks) {
const contextStr = prepareSearchString(hunk.changes)
const {
index: matchPosition,
confidence,
strategy,
} = findBestMatch(contextStr, result, 0, this.confidenceThreshold)
if (confidence < this.confidenceThreshold) {
console.log("Full hunk application failed, trying sub-hunks strategy")
// Try splitting the hunk into smaller hunks
const subHunks = this.splitHunk(hunk)
let subHunkSuccess = true
let subHunkResult = [...result]
for (const subHunk of subHunks) {
const subContextStr = prepareSearchString(subHunk.changes)
const subSearchResult = findBestMatch(subContextStr, subHunkResult, 0, this.confidenceThreshold)
if (subSearchResult.confidence >= this.confidenceThreshold) {
const subEditResult = await applyEdit(
subHunk,
subHunkResult,
subSearchResult.index,
subSearchResult.confidence,
this.confidenceThreshold,
)
if (subEditResult.confidence >= this.confidenceThreshold) {
subHunkResult = subEditResult.result
continue
}
}
subHunkSuccess = false
break
}
if (subHunkSuccess) {
result = subHunkResult
continue
}
// If sub-hunks also failed, return the original error
const contextLines = hunk.changes.filter((c) => c.type === "context").length
const totalLines = hunk.changes.length
const contextRatio = contextLines / totalLines
let errorMsg = `Failed to find a matching location in the file (${Math.floor(
confidence * 100,
)}% confidence, needs ${Math.floor(this.confidenceThreshold * 100)}%)\n\n`
errorMsg += "Debug Info:\n"
errorMsg += `- Search Strategy Used: ${strategy}\n`
errorMsg += `- Context Lines: ${contextLines} out of ${totalLines} total lines (${Math.floor(
contextRatio * 100,
)}%)\n`
errorMsg += `- Attempted to split into ${subHunks.length} sub-hunks but still failed\n`
if (contextRatio < 0.2) {
errorMsg += "\nPossible Issues:\n"
errorMsg += "- Not enough context lines to uniquely identify the location\n"
errorMsg += "- Add a few more lines of unchanged code around your changes\n"
} else if (contextRatio > 0.5) {
errorMsg += "\nPossible Issues:\n"
errorMsg += "- Too many context lines may reduce search accuracy\n"
errorMsg += "- Try to keep only 2-3 lines of context before and after changes\n"
} else {
errorMsg += "\nPossible Issues:\n"
errorMsg += "- The diff may be targeting a different version of the file\n"
errorMsg +=
"- There may be too many changes in a single hunk, try splitting the changes into multiple hunks\n"
}
if (startLine && endLine) {
errorMsg += `\nSearch Range: lines ${startLine}-${endLine}\n`
}
return { success: false, error: errorMsg }
}
const editResult = await applyEdit(hunk, result, matchPosition, confidence, this.confidenceThreshold)
if (editResult.confidence >= this.confidenceThreshold) {
result = editResult.result
} else {
// Edit failure - likely due to content mismatch
let errorMsg = `Failed to apply the edit using ${editResult.strategy} strategy (${Math.floor(
editResult.confidence * 100,
)}% confidence)\n\n`
errorMsg += "Debug Info:\n"
errorMsg += "- The location was found but the content didn't match exactly\n"
errorMsg += "- This usually means the file has been modified since the diff was created\n"
errorMsg += "- Or the diff may be targeting a different version of the file\n"
errorMsg += "\nPossible Solutions:\n"
errorMsg += "1. Refresh your view of the file and create a new diff\n"
errorMsg += "2. Double-check that the removed lines (-) match the current file content\n"
errorMsg += "3. Ensure your diff targets the correct version of the file"
return { success: false, error: errorMsg }
}
}
return { success: true, content: result.join("\n") }
}
}

View File

@ -0,0 +1,408 @@
import { compareTwoStrings } from "string-similarity"
import { closest } from "fastest-levenshtein"
import { diff_match_patch } from "diff-match-patch"
import { Change, Hunk } from "./types"
export type SearchResult = {
index: number
confidence: number
strategy: string
}
const LARGE_FILE_THRESHOLD = 1000 // lines
const UNIQUE_CONTENT_BOOST = 0.05
const DEFAULT_OVERLAP_SIZE = 3 // lines of overlap between windows
const MAX_WINDOW_SIZE = 500 // maximum lines in a window
// Helper function to calculate adaptive confidence threshold based on file size
function getAdaptiveThreshold(contentLength: number, baseThreshold: number): number {
if (contentLength <= LARGE_FILE_THRESHOLD) {
return baseThreshold
}
return Math.max(baseThreshold - 0.07, 0.8) // Reduce threshold for large files but keep minimum at 80%
}
// Helper function to evaluate content uniqueness
function evaluateContentUniqueness(searchStr: string, content: string[]): number {
const searchLines = searchStr.split("\n")
const uniqueLines = new Set(searchLines)
const contentStr = content.join("\n")
// Calculate how many search lines are relatively unique in the content
let uniqueCount = 0
for (const line of uniqueLines) {
const regex = new RegExp(line.replace(/[.*+?^${}()|[\]\\]/g, "\\$&"), "g")
const matches = contentStr.match(regex)
if (matches && matches.length <= 2) {
// Line appears at most twice
uniqueCount++
}
}
return uniqueCount / uniqueLines.size
}
// Helper function to prepare search string from context
export function prepareSearchString(changes: Change[]): string {
const lines = changes.filter((c) => c.type === "context" || c.type === "remove").map((c) => c.originalLine)
return lines.join("\n")
}
// Helper function to evaluate similarity between two texts
export function evaluateSimilarity(original: string, modified: string): number {
return compareTwoStrings(original, modified)
}
// Helper function to validate using diff-match-patch
export function getDMPSimilarity(original: string, modified: string): number {
const dmp = new diff_match_patch()
const diffs = dmp.diff_main(original, modified)
dmp.diff_cleanupSemantic(diffs)
const patches = dmp.patch_make(original, diffs)
const [expectedText] = dmp.patch_apply(patches, original)
const similarity = evaluateSimilarity(expectedText, modified)
return similarity
}
// Helper function to validate edit results using hunk information
export function validateEditResult(hunk: Hunk, result: string): number {
// Build the expected text from the hunk
const expectedText = hunk.changes
.filter((change) => change.type === "context" || change.type === "add")
.map((change) => (change.indent ? change.indent + change.content : change.content))
.join("\n")
// Calculate similarity between the result and expected text
const similarity = getDMPSimilarity(expectedText, result)
// If the result is unchanged from original, return low confidence
const originalText = hunk.changes
.filter((change) => change.type === "context" || change.type === "remove")
.map((change) => (change.indent ? change.indent + change.content : change.content))
.join("\n")
const originalSimilarity = getDMPSimilarity(originalText, result)
if (originalSimilarity > 0.97 && similarity !== 1) {
return 0.8 * similarity // Some confidence since we found the right location
}
// For partial matches, scale the confidence but keep it high if we're close
return similarity
}
// Helper function to validate context lines against original content
function validateContextLines(searchStr: string, content: string, confidenceThreshold: number): number {
// Extract just the context lines from the search string
const contextLines = searchStr.split("\n").filter((line) => !line.startsWith("-")) // Exclude removed lines
// Compare context lines with content
const similarity = evaluateSimilarity(contextLines.join("\n"), content)
// Get adaptive threshold based on content size
const threshold = getAdaptiveThreshold(content.split("\n").length, confidenceThreshold)
// Calculate uniqueness boost
const uniquenessScore = evaluateContentUniqueness(searchStr, content.split("\n"))
const uniquenessBoost = uniquenessScore * UNIQUE_CONTENT_BOOST
// Adjust confidence based on threshold and uniqueness
return similarity < threshold ? similarity * 0.3 + uniquenessBoost : similarity + uniquenessBoost
}
// Helper function to create overlapping windows
function createOverlappingWindows(
content: string[],
searchSize: number,
overlapSize: number = DEFAULT_OVERLAP_SIZE,
): { window: string[]; startIndex: number }[] {
const windows: { window: string[]; startIndex: number }[] = []
// Ensure minimum window size is at least searchSize
const effectiveWindowSize = Math.max(searchSize, Math.min(searchSize * 2, MAX_WINDOW_SIZE))
// Ensure overlap size doesn't exceed window size
const effectiveOverlapSize = Math.min(overlapSize, effectiveWindowSize - 1)
// Calculate step size, ensure it's at least 1
const stepSize = Math.max(1, effectiveWindowSize - effectiveOverlapSize)
for (let i = 0; i < content.length; i += stepSize) {
const windowContent = content.slice(i, i + effectiveWindowSize)
if (windowContent.length >= searchSize) {
windows.push({ window: windowContent, startIndex: i })
}
}
return windows
}
// Helper function to combine overlapping matches
function combineOverlappingMatches(
matches: (SearchResult & { windowIndex: number })[],
overlapSize: number = DEFAULT_OVERLAP_SIZE,
): SearchResult[] {
if (matches.length === 0) {
return []
}
// Sort matches by confidence
matches.sort((a, b) => b.confidence - a.confidence)
const combinedMatches: SearchResult[] = []
const usedIndices = new Set<number>()
for (const match of matches) {
if (usedIndices.has(match.windowIndex)) {
continue
}
// Find overlapping matches
const overlapping = matches.filter(
(m) =>
Math.abs(m.windowIndex - match.windowIndex) === 1 &&
Math.abs(m.index - match.index) <= overlapSize &&
!usedIndices.has(m.windowIndex),
)
if (overlapping.length > 0) {
// Boost confidence if we find same match in overlapping windows
const avgConfidence =
(match.confidence + overlapping.reduce((sum, m) => sum + m.confidence, 0)) / (overlapping.length + 1)
const boost = Math.min(0.05 * overlapping.length, 0.1) // Max 10% boost
combinedMatches.push({
index: match.index,
confidence: Math.min(1, avgConfidence + boost),
strategy: `${match.strategy}-overlapping`,
})
usedIndices.add(match.windowIndex)
overlapping.forEach((m) => usedIndices.add(m.windowIndex))
} else {
combinedMatches.push({
index: match.index,
confidence: match.confidence,
strategy: match.strategy,
})
usedIndices.add(match.windowIndex)
}
}
return combinedMatches
}
export function findExactMatch(
searchStr: string,
content: string[],
startIndex: number = 0,
confidenceThreshold: number = 0.97,
): SearchResult {
const searchLines = searchStr.split("\n")
const windows = createOverlappingWindows(content.slice(startIndex), searchLines.length)
const matches: (SearchResult & { windowIndex: number })[] = []
windows.forEach((windowData, windowIndex) => {
const windowStr = windowData.window.join("\n")
const exactMatch = windowStr.indexOf(searchStr)
if (exactMatch !== -1) {
const matchedContent = windowData.window
.slice(
windowStr.slice(0, exactMatch).split("\n").length - 1,
windowStr.slice(0, exactMatch).split("\n").length - 1 + searchLines.length,
)
.join("\n")
const similarity = getDMPSimilarity(searchStr, matchedContent)
const contextSimilarity = validateContextLines(searchStr, matchedContent, confidenceThreshold)
const confidence = Math.min(similarity, contextSimilarity)
matches.push({
index: startIndex + windowData.startIndex + windowStr.slice(0, exactMatch).split("\n").length - 1,
confidence,
strategy: "exact",
windowIndex,
})
}
})
const combinedMatches = combineOverlappingMatches(matches)
return combinedMatches.length > 0 ? combinedMatches[0] : { index: -1, confidence: 0, strategy: "exact" }
}
// String similarity strategy
export function findSimilarityMatch(
searchStr: string,
content: string[],
startIndex: number = 0,
confidenceThreshold: number = 0.97,
): SearchResult {
const searchLines = searchStr.split("\n")
let bestScore = 0
let bestIndex = -1
for (let i = startIndex; i < content.length - searchLines.length + 1; i++) {
const windowStr = content.slice(i, i + searchLines.length).join("\n")
const score = compareTwoStrings(searchStr, windowStr)
if (score > bestScore && score >= confidenceThreshold) {
const similarity = getDMPSimilarity(searchStr, windowStr)
const contextSimilarity = validateContextLines(searchStr, windowStr, confidenceThreshold)
const adjustedScore = Math.min(similarity, contextSimilarity) * score
if (adjustedScore > bestScore) {
bestScore = adjustedScore
bestIndex = i
}
}
}
return {
index: bestIndex,
confidence: bestIndex !== -1 ? bestScore : 0,
strategy: "similarity",
}
}
// Levenshtein strategy
export function findLevenshteinMatch(
searchStr: string,
content: string[],
startIndex: number = 0,
confidenceThreshold: number = 0.97,
): SearchResult {
const searchLines = searchStr.split("\n")
const candidates = []
for (let i = startIndex; i < content.length - searchLines.length + 1; i++) {
candidates.push(content.slice(i, i + searchLines.length).join("\n"))
}
if (candidates.length > 0) {
const closestMatch = closest(searchStr, candidates)
const index = startIndex + candidates.indexOf(closestMatch)
const similarity = getDMPSimilarity(searchStr, closestMatch)
const contextSimilarity = validateContextLines(searchStr, closestMatch, confidenceThreshold)
const confidence = Math.min(similarity, contextSimilarity)
return {
index: confidence === 0 ? -1 : index,
confidence: index !== -1 ? confidence : 0,
strategy: "levenshtein",
}
}
return { index: -1, confidence: 0, strategy: "levenshtein" }
}
// Helper function to identify anchor lines
function identifyAnchors(searchStr: string): { first: string | null; last: string | null } {
const searchLines = searchStr.split("\n")
let first: string | null = null
let last: string | null = null
// Find the first non-empty line
for (const line of searchLines) {
if (line.trim()) {
first = line
break
}
}
// Find the last non-empty line
for (let i = searchLines.length - 1; i >= 0; i--) {
if (searchLines[i].trim()) {
last = searchLines[i]
break
}
}
return { first, last }
}
// Anchor-based search strategy
export function findAnchorMatch(
searchStr: string,
content: string[],
startIndex: number = 0,
confidenceThreshold: number = 0.97,
): SearchResult {
const searchLines = searchStr.split("\n")
const { first, last } = identifyAnchors(searchStr)
if (!first || !last) {
return { index: -1, confidence: 0, strategy: "anchor" }
}
let firstIndex = -1
let lastIndex = -1
// Check if the first anchor is unique
let firstOccurrences = 0
for (const contentLine of content) {
if (contentLine === first) {
firstOccurrences++
}
}
if (firstOccurrences !== 1) {
return { index: -1, confidence: 0, strategy: "anchor" }
}
// Find the first anchor
for (let i = startIndex; i < content.length; i++) {
if (content[i] === first) {
firstIndex = i
break
}
}
// Find the last anchor
for (let i = content.length - 1; i >= startIndex; i--) {
if (content[i] === last) {
lastIndex = i
break
}
}
if (firstIndex === -1 || lastIndex === -1 || lastIndex <= firstIndex) {
return { index: -1, confidence: 0, strategy: "anchor" }
}
// Validate the context
const expectedContext = searchLines.slice(searchLines.indexOf(first) + 1, searchLines.indexOf(last)).join("\n")
const actualContext = content.slice(firstIndex + 1, lastIndex).join("\n")
const contextSimilarity = evaluateSimilarity(expectedContext, actualContext)
if (contextSimilarity < getAdaptiveThreshold(content.length, confidenceThreshold)) {
return { index: -1, confidence: 0, strategy: "anchor" }
}
const confidence = 1
return {
index: firstIndex,
confidence: confidence,
strategy: "anchor",
}
}
// Main search function that tries all strategies
export function findBestMatch(
searchStr: string,
content: string[],
startIndex: number = 0,
confidenceThreshold: number = 0.97,
): SearchResult {
const strategies = [findExactMatch, findAnchorMatch, findSimilarityMatch, findLevenshteinMatch]
let bestResult: SearchResult = { index: -1, confidence: 0, strategy: "none" }
for (const strategy of strategies) {
const result = strategy(searchStr, content, startIndex, confidenceThreshold)
if (result.confidence > bestResult.confidence) {
bestResult = result
}
}
return bestResult
}

View File

@ -0,0 +1,20 @@
export type Change = {
type: "context" | "add" | "remove"
content: string
indent: string
originalLine?: string
}
export type Hunk = {
changes: Change[]
}
export type Diff = {
hunks: Hunk[]
}
export type EditResult = {
confidence: number
result: string[]
strategy: string
}

View File

@ -0,0 +1,302 @@
import { DiffStrategy, DiffResult } from "../types"
import { addLineNumbers, everyLineHasLineNumbers, stripLineNumbers } from "../../../integrations/misc/extract-text"
import { distance } from "fastest-levenshtein"
const BUFFER_LINES = 20 // Number of extra context lines to show before and after matches
function getSimilarity(original: string, search: string): number {
if (search === "") {
return 1
}
// Normalize strings by removing extra whitespace but preserve case
const normalizeStr = (str: string) => str.replace(/\s+/g, " ").trim()
const normalizedOriginal = normalizeStr(original)
const normalizedSearch = normalizeStr(search)
if (normalizedOriginal === normalizedSearch) {
return 1
}
// Calculate Levenshtein distance using fastest-levenshtein's distance function
const dist = distance(normalizedOriginal, normalizedSearch)
// Calculate similarity ratio (0 to 1, where 1 is an exact match)
const maxLength = Math.max(normalizedOriginal.length, normalizedSearch.length)
return 1 - dist / maxLength
}
export class SearchReplaceDiffStrategy implements DiffStrategy {
private fuzzyThreshold: number
private bufferLines: number
constructor(fuzzyThreshold?: number, bufferLines?: number) {
// Use provided threshold or default to exact matching (1.0)
// Note: fuzzyThreshold is inverted in UI (0% = 1.0, 10% = 0.9)
// so we use it directly here
this.fuzzyThreshold = fuzzyThreshold ?? 1.0
this.bufferLines = bufferLines ?? BUFFER_LINES
}
getToolDescription(args: { cwd: string; toolOptions?: { [key: string]: string } }): string {
return `## apply_diff
Description: Request to replace existing code using a search and replace block.
This tool allows for precise, surgical replaces to files by specifying exactly what content to search for and what to replace it with.
The tool will maintain proper indentation and formatting while making changes.
Only a single operation is allowed per tool use.
The SEARCH section must exactly match existing content including whitespace and indentation.
If you're not confident in the exact content to search for, use the read_file tool first to get the exact content.
When applying the diffs, be extra careful to remember to change any closing brackets or other syntax that may be affected by the diff farther down in the file.
Parameters:
- path: (required) The path of the file to modify (relative to the current working directory ${args.cwd})
- diff: (required) The search/replace block defining the changes.
- start_line: (required) The line number where the search block starts.
- end_line: (required) The line number where the search block ends.
Diff format:
\`\`\`
<<<<<<< SEARCH
[exact content to find including whitespace]
=======
[new content to replace with]
>>>>>>> REPLACE
\`\`\`
Example:
Original file:
\`\`\`
1 | def calculate_total(items):
2 | total = 0
3 | for item in items:
4 | total += item
5 | return total
\`\`\`
Search/Replace content:
\`\`\`
<<<<<<< SEARCH
def calculate_total(items):
total = 0
for item in items:
total += item
return total
=======
def calculate_total(items):
"""Calculate total with 10% markup"""
return sum(item * 1.1 for item in items)
>>>>>>> REPLACE
\`\`\`
Usage:
<apply_diff>
<path>File path here</path>
<diff>
Your search/replace content here
</diff>
<start_line>1</start_line>
<end_line>5</end_line>
</apply_diff>`
}
async applyDiff(
originalContent: string,
diffContent: string,
startLine?: number,
endLine?: number,
): Promise<DiffResult> {
// Extract the search and replace blocks
const match = diffContent.match(/<<<<<<< SEARCH\n([\s\S]*?)\n?=======\n([\s\S]*?)\n?>>>>>>> REPLACE/)
if (!match) {
return {
success: false,
error: `Invalid diff format - missing required SEARCH/REPLACE sections\n\nDebug Info:\n- Expected Format: <<<<<<< SEARCH\\n[search content]\\n=======\\n[replace content]\\n>>>>>>> REPLACE\n- Tip: Make sure to include both SEARCH and REPLACE sections with correct markers`,
}
}
let [_, searchContent, replaceContent] = match
// Detect line ending from original content
const lineEnding = originalContent.includes("\r\n") ? "\r\n" : "\n"
// Strip line numbers from search and replace content if every line starts with a line number
if (everyLineHasLineNumbers(searchContent) && everyLineHasLineNumbers(replaceContent)) {
searchContent = stripLineNumbers(searchContent)
replaceContent = stripLineNumbers(replaceContent)
}
// Split content into lines, handling both \n and \r\n
const searchLines = searchContent === "" ? [] : searchContent.split(/\r?\n/)
const replaceLines = replaceContent === "" ? [] : replaceContent.split(/\r?\n/)
const originalLines = originalContent.split(/\r?\n/)
// Validate that empty search requires start line
if (searchLines.length === 0 && !startLine) {
return {
success: false,
error: `Empty search content requires start_line to be specified\n\nDebug Info:\n- Empty search content is only valid for insertions at a specific line\n- For insertions, specify the line number where content should be inserted`,
}
}
// Validate that empty search requires same start and end line
if (searchLines.length === 0 && startLine && endLine && startLine !== endLine) {
return {
success: false,
error: `Empty search content requires start_line and end_line to be the same (got ${startLine}-${endLine})\n\nDebug Info:\n- Empty search content is only valid for insertions at a specific line\n- For insertions, use the same line number for both start_line and end_line`,
}
}
// Initialize search variables
let matchIndex = -1
let bestMatchScore = 0
let bestMatchContent = ""
const searchChunk = searchLines.join("\n")
// Determine search bounds
let searchStartIndex = 0
let searchEndIndex = originalLines.length
// Validate and handle line range if provided
if (startLine && endLine) {
// Convert to 0-based index
const exactStartIndex = startLine - 1
const exactEndIndex = endLine - 1
if (exactStartIndex < 0 || exactEndIndex > originalLines.length || exactStartIndex > exactEndIndex) {
return {
success: false,
error: `Line range ${startLine}-${endLine} is invalid (file has ${originalLines.length} lines)\n\nDebug Info:\n- Requested Range: lines ${startLine}-${endLine}\n- File Bounds: lines 1-${originalLines.length}`,
}
}
// Try exact match first
const originalChunk = originalLines.slice(exactStartIndex, exactEndIndex + 1).join("\n")
const similarity = getSimilarity(originalChunk, searchChunk)
if (similarity >= this.fuzzyThreshold) {
matchIndex = exactStartIndex
bestMatchScore = similarity
bestMatchContent = originalChunk
} else {
// Set bounds for buffered search
searchStartIndex = Math.max(0, startLine - (this.bufferLines + 1))
searchEndIndex = Math.min(originalLines.length, endLine + this.bufferLines)
}
}
// If no match found yet, try middle-out search within bounds
if (matchIndex === -1) {
const midPoint = Math.floor((searchStartIndex + searchEndIndex) / 2)
let leftIndex = midPoint
let rightIndex = midPoint + 1
// Search outward from the middle within bounds
while (leftIndex >= searchStartIndex || rightIndex <= searchEndIndex - searchLines.length) {
// Check left side if still in range
if (leftIndex >= searchStartIndex) {
const originalChunk = originalLines.slice(leftIndex, leftIndex + searchLines.length).join("\n")
const similarity = getSimilarity(originalChunk, searchChunk)
if (similarity > bestMatchScore) {
bestMatchScore = similarity
matchIndex = leftIndex
bestMatchContent = originalChunk
}
leftIndex--
}
// Check right side if still in range
if (rightIndex <= searchEndIndex - searchLines.length) {
const originalChunk = originalLines.slice(rightIndex, rightIndex + searchLines.length).join("\n")
const similarity = getSimilarity(originalChunk, searchChunk)
if (similarity > bestMatchScore) {
bestMatchScore = similarity
matchIndex = rightIndex
bestMatchContent = originalChunk
}
rightIndex++
}
}
}
// Require similarity to meet threshold
if (matchIndex === -1 || bestMatchScore < this.fuzzyThreshold) {
const searchChunk = searchLines.join("\n")
const originalContentSection =
startLine !== undefined && endLine !== undefined
? `\n\nOriginal Content:\n${addLineNumbers(
originalLines
.slice(
Math.max(0, startLine - 1 - this.bufferLines),
Math.min(originalLines.length, endLine + this.bufferLines),
)
.join("\n"),
Math.max(1, startLine - this.bufferLines),
)}`
: `\n\nOriginal Content:\n${addLineNumbers(originalLines.join("\n"))}`
const bestMatchSection = bestMatchContent
? `\n\nBest Match Found:\n${addLineNumbers(bestMatchContent, matchIndex + 1)}`
: `\n\nBest Match Found:\n(no match)`
const lineRange =
startLine || endLine
? ` at ${startLine ? `start: ${startLine}` : "start"} to ${endLine ? `end: ${endLine}` : "end"}`
: ""
return {
success: false,
error: `No sufficiently similar match found${lineRange} (${Math.floor(bestMatchScore * 100)}% similar, needs ${Math.floor(this.fuzzyThreshold * 100)}%)\n\nDebug Info:\n- Similarity Score: ${Math.floor(bestMatchScore * 100)}%\n- Required Threshold: ${Math.floor(this.fuzzyThreshold * 100)}%\n- Search Range: ${startLine && endLine ? `lines ${startLine}-${endLine}` : "start to end"}\n- Tip: Use read_file to get the latest content of the file before attempting the diff again, as the file content may have changed\n\nSearch Content:\n${searchChunk}${bestMatchSection}${originalContentSection}`,
}
}
// Get the matched lines from the original content
const matchedLines = originalLines.slice(matchIndex, matchIndex + searchLines.length)
// Get the exact indentation (preserving tabs/spaces) of each line
const originalIndents = matchedLines.map((line) => {
const match = line.match(/^[\t ]*/)
return match ? match[0] : ""
})
// Get the exact indentation of each line in the search block
const searchIndents = searchLines.map((line) => {
const match = line.match(/^[\t ]*/)
return match ? match[0] : ""
})
// Apply the replacement while preserving exact indentation
const indentedReplaceLines = replaceLines.map((line, i) => {
// Get the matched line's exact indentation
const matchedIndent = originalIndents[0] || ""
// Get the current line's indentation relative to the search content
const currentIndentMatch = line.match(/^[\t ]*/)
const currentIndent = currentIndentMatch ? currentIndentMatch[0] : ""
const searchBaseIndent = searchIndents[0] || ""
// Calculate the relative indentation level
const searchBaseLevel = searchBaseIndent.length
const currentLevel = currentIndent.length
const relativeLevel = currentLevel - searchBaseLevel
// If relative level is negative, remove indentation from matched indent
// If positive, add to matched indent
const finalIndent =
relativeLevel < 0
? matchedIndent.slice(0, Math.max(0, matchedIndent.length + relativeLevel))
: matchedIndent + currentIndent.slice(searchBaseLevel)
return finalIndent + line.trim()
})
// Construct the final content
const beforeMatch = originalLines.slice(0, matchIndex)
const afterMatch = originalLines.slice(matchIndex + searchLines.length)
const finalContent = [...beforeMatch, ...indentedReplaceLines, ...afterMatch].join(lineEnding)
return {
success: true,
content: finalContent,
}
}
}

View File

@ -0,0 +1,137 @@
import { applyPatch } from "diff"
import { DiffStrategy, DiffResult } from "../types"
export class UnifiedDiffStrategy implements DiffStrategy {
getToolDescription(args: { cwd: string; toolOptions?: { [key: string]: string } }): string {
return `## apply_diff
Description: Apply a unified diff to a file at the specified path. This tool is useful when you need to make specific modifications to a file based on a set of changes provided in unified diff format (diff -U3).
Parameters:
- path: (required) The path of the file to apply the diff to (relative to the current working directory ${args.cwd})
- diff: (required) The diff content in unified format to apply to the file.
Format Requirements:
1. Header (REQUIRED):
\`\`\`
--- path/to/original/file
+++ path/to/modified/file
\`\`\`
- Must include both lines exactly as shown
- Use actual file paths
- NO timestamps after paths
2. Hunks:
\`\`\`
@@ -lineStart,lineCount +lineStart,lineCount @@
-removed line
+added line
\`\`\`
- Each hunk starts with @@ showing line numbers for changes
- Format: @@ -originalStart,originalCount +newStart,newCount @@
- Use - for removed/changed lines
- Use + for new/modified lines
- Indentation must match exactly
Complete Example:
Original file (with line numbers):
\`\`\`
1 | import { Logger } from '../logger';
2 |
3 | function calculateTotal(items: number[]): number {
4 | return items.reduce((sum, item) => {
5 | return sum + item;
6 | }, 0);
7 | }
8 |
9 | export { calculateTotal };
\`\`\`
After applying the diff, the file would look like:
\`\`\`
1 | import { Logger } from '../logger';
2 |
3 | function calculateTotal(items: number[]): number {
4 | const total = items.reduce((sum, item) => {
5 | return sum + item * 1.1; // Add 10% markup
6 | }, 0);
7 | return Math.round(total * 100) / 100; // Round to 2 decimal places
8 | }
9 |
10 | export { calculateTotal };
\`\`\`
Diff to modify the file:
\`\`\`
--- src/utils/helper.ts
+++ src/utils/helper.ts
@@ -1,9 +1,10 @@
import { Logger } from '../logger';
function calculateTotal(items: number[]): number {
- return items.reduce((sum, item) => {
- return sum + item;
+ const total = items.reduce((sum, item) => {
+ return sum + item * 1.1; // Add 10% markup
}, 0);
+ return Math.round(total * 100) / 100; // Round to 2 decimal places
}
export { calculateTotal };
\`\`\`
Common Pitfalls:
1. Missing or incorrect header lines
2. Incorrect line numbers in @@ lines
3. Wrong indentation in changed lines
4. Incomplete context (missing lines that need changing)
5. Not marking all modified lines with - and +
Best Practices:
1. Replace entire code blocks:
- Remove complete old version with - lines
- Add complete new version with + lines
- Include correct line numbers
2. Moving code requires two hunks:
- First hunk: Remove from old location
- Second hunk: Add to new location
3. One hunk per logical change
4. Verify line numbers match the line numbers you have in the file
Usage:
<apply_diff>
<path>File path here</path>
<diff>
Your diff here
</diff>
</apply_diff>`
}
async applyDiff(originalContent: string, diffContent: string): Promise<DiffResult> {
try {
const result = applyPatch(originalContent, diffContent)
if (result === false) {
return {
success: false,
error: "Failed to apply unified diff - patch rejected",
details: {
searchContent: diffContent,
},
}
}
return {
success: true,
content: result,
}
} catch (error) {
return {
success: false,
error: `Error applying unified diff: ${error.message}`,
details: {
searchContent: diffContent,
},
}
}
}
}

36
src/core/diff/types.ts Normal file
View File

@ -0,0 +1,36 @@
/**
* Interface for implementing different diff strategies
*/
export type DiffResult =
| { success: true; content: string }
| {
success: false
error: string
details?: {
similarity?: number
threshold?: number
matchedRange?: { start: number; end: number }
searchContent?: string
bestMatch?: string
}
}
export interface DiffStrategy {
/**
* Get the tool description for this diff strategy
* @param args The tool arguments including cwd and toolOptions
* @returns The complete tool description including format requirements and examples
*/
getToolDescription(args: { cwd: string; toolOptions?: { [key: string]: string } }): string
/**
* Apply a diff to the original content
* @param originalContent The original file content
* @param diffContent The diff content in the strategy's format
* @param startLine Optional line number where the search block starts. If not provided, searches the entire file.
* @param endLine Optional line number where the search block ends. If not provided, searches the entire file.
* @returns A DiffResult object containing either the successful result or error details
*/
applyDiff(originalContent: string, diffContent: string, startLine?: number, endLine?: number): Promise<DiffResult>
}

746
src/core/mcp/McpHub.ts Normal file
View File

@ -0,0 +1,746 @@
import { Client } from "@modelcontextprotocol/sdk/client/index.js"
import { StdioClientTransport, StdioServerParameters } from "@modelcontextprotocol/sdk/client/stdio.js"
import {
CallToolResultSchema,
ListResourcesResultSchema,
ListResourceTemplatesResultSchema,
ListToolsResultSchema,
ReadResourceResultSchema,
} from "@modelcontextprotocol/sdk/types.js"
import chokidar, { FSWatcher } from "chokidar"
import delay from "delay"
import deepEqual from "fast-deep-equal"
import * as fs from "fs/promises"
import * as path from "path"
import * as vscode from "vscode"
import { z } from "zod"
import { ClineProvider } from "../../core/webview/ClineProvider"
import { GlobalFileNames } from "../../shared/globalFileNames"
import {
McpResource,
McpResourceResponse,
McpResourceTemplate,
McpServer,
McpTool,
McpToolCallResponse,
} from "../../shared/mcp"
import { fileExistsAtPath } from "../../utils/fs"
import { arePathsEqual } from "../../utils/path"
export type McpConnection = {
server: McpServer
client: Client
transport: StdioClientTransport
}
// StdioServerParameters
const AlwaysAllowSchema = z.array(z.string()).default([])
export const StdioConfigSchema = z.object({
command: z.string(),
args: z.array(z.string()).optional(),
env: z.record(z.string()).optional(),
alwaysAllow: AlwaysAllowSchema.optional(),
disabled: z.boolean().optional(),
timeout: z.number().min(1).max(3600).optional().default(60),
})
const McpSettingsSchema = z.object({
mcpServers: z.record(StdioConfigSchema),
})
export class McpHub {
private providerRef: WeakRef<ClineProvider>
private disposables: vscode.Disposable[] = []
private settingsWatcher?: vscode.FileSystemWatcher
private fileWatchers: Map<string, FSWatcher> = new Map()
connections: McpConnection[] = []
isConnecting: boolean = false
constructor(provider: ClineProvider) {
this.providerRef = new WeakRef(provider)
this.watchMcpSettingsFile()
this.initializeMcpServers()
}
getServers(): McpServer[] {
// Only return enabled servers
return this.connections.filter((conn) => !conn.server.disabled).map((conn) => conn.server)
}
getAllServers(): McpServer[] {
// Return all servers regardless of state
return this.connections.map((conn) => conn.server)
}
async getMcpServersPath(): Promise<string> {
const provider = this.providerRef.deref()
if (!provider) {
throw new Error("Provider not available")
}
const mcpServersPath = await provider.ensureMcpServersDirectoryExists()
return mcpServersPath
}
async getMcpSettingsFilePath(): Promise<string> {
const provider = this.providerRef.deref()
if (!provider) {
throw new Error("Provider not available")
}
const mcpSettingsFilePath = path.join(
await provider.ensureSettingsDirectoryExists(),
GlobalFileNames.mcpSettings,
)
const fileExists = await fileExistsAtPath(mcpSettingsFilePath)
if (!fileExists) {
await fs.writeFile(
mcpSettingsFilePath,
`{
"mcpServers": {
}
}`,
)
}
return mcpSettingsFilePath
}
private async watchMcpSettingsFile(): Promise<void> {
const settingsPath = await this.getMcpSettingsFilePath()
this.disposables.push(
vscode.workspace.onDidSaveTextDocument(async (document) => {
if (arePathsEqual(document.uri.fsPath, settingsPath)) {
const content = await fs.readFile(settingsPath, "utf-8")
const errorMessage =
"Invalid MCP settings format. Please ensure your settings follow the correct JSON format."
let config: any
try {
config = JSON.parse(content)
} catch (error) {
vscode.window.showErrorMessage(errorMessage)
return
}
const result = McpSettingsSchema.safeParse(config)
if (!result.success) {
vscode.window.showErrorMessage(errorMessage)
return
}
try {
await this.updateServerConnections(result.data.mcpServers || {})
} catch (error) {
console.error("Failed to process MCP settings change:", error)
}
}
}),
)
}
private async initializeMcpServers(): Promise<void> {
try {
const settingsPath = await this.getMcpSettingsFilePath()
const content = await fs.readFile(settingsPath, "utf-8")
const config = JSON.parse(content)
await this.updateServerConnections(config.mcpServers || {})
} catch (error) {
console.error("Failed to initialize MCP servers:", error)
}
}
private async connectToServer(name: string, config: StdioServerParameters): Promise<void> {
// Remove existing connection if it exists (should never happen, the connection should be deleted beforehand)
this.connections = this.connections.filter((conn) => conn.server.name !== name)
try {
// Each MCP server requires its own transport connection and has unique capabilities, configurations, and error handling. Having separate clients also allows proper scoping of resources/tools and independent server management like reconnection.
const client = new Client(
{
name: "Roo Code",
version: this.providerRef.deref()?.context.extension?.packageJSON?.version ?? "1.0.0",
},
{
capabilities: {},
},
)
const transport = new StdioClientTransport({
command: config.command,
args: config.args,
env: {
...config.env,
...(process.env.PATH ? { PATH: process.env.PATH } : {}),
// ...(process.env.NODE_PATH ? { NODE_PATH: process.env.NODE_PATH } : {}),
},
stderr: "pipe", // necessary for stderr to be available
})
transport.onerror = async (error) => {
console.error(`Transport error for "${name}":`, error)
const connection = this.connections.find((conn) => conn.server.name === name)
if (connection) {
connection.server.status = "disconnected"
this.appendErrorMessage(connection, error.message)
}
await this.notifyWebviewOfServerChanges()
}
transport.onclose = async () => {
const connection = this.connections.find((conn) => conn.server.name === name)
if (connection) {
connection.server.status = "disconnected"
}
await this.notifyWebviewOfServerChanges()
}
// If the config is invalid, show an error
if (!StdioConfigSchema.safeParse(config).success) {
console.error(`Invalid config for "${name}": missing or invalid parameters`)
const connection: McpConnection = {
server: {
name,
config: JSON.stringify(config),
status: "disconnected",
error: "Invalid config: missing or invalid parameters",
},
client,
transport,
}
this.connections.push(connection)
return
}
// valid schema
const parsedConfig = StdioConfigSchema.parse(config)
const connection: McpConnection = {
server: {
name,
config: JSON.stringify(config),
status: "connecting",
disabled: parsedConfig.disabled,
},
client,
transport,
}
this.connections.push(connection)
// transport.stderr is only available after the process has been started. However we can't start it separately from the .connect() call because it also starts the transport. And we can't place this after the connect call since we need to capture the stderr stream before the connection is established, in order to capture errors during the connection process.
// As a workaround, we start the transport ourselves, and then monkey-patch the start method to no-op so that .connect() doesn't try to start it again.
await transport.start()
const stderrStream = transport.stderr
if (stderrStream) {
stderrStream.on("data", async (data: Buffer) => {
const errorOutput = data.toString()
console.error(`Server "${name}" stderr:`, errorOutput)
const connection = this.connections.find((conn) => conn.server.name === name)
if (connection) {
// NOTE: we do not set server status to "disconnected" because stderr logs do not necessarily mean the server crashed or disconnected, it could just be informational. In fact when the server first starts up, it immediately logs "<name> server running on stdio" to stderr.
this.appendErrorMessage(connection, errorOutput)
// Only need to update webview right away if it's already disconnected
if (connection.server.status === "disconnected") {
await this.notifyWebviewOfServerChanges()
}
}
})
} else {
console.error(`No stderr stream for ${name}`)
}
transport.start = async () => {} // No-op now, .connect() won't fail
// Connect
await client.connect(transport)
connection.server.status = "connected"
connection.server.error = ""
// Initial fetch of tools and resources
connection.server.tools = await this.fetchToolsList(name)
connection.server.resources = await this.fetchResourcesList(name)
connection.server.resourceTemplates = await this.fetchResourceTemplatesList(name)
} catch (error) {
// Update status with error
const connection = this.connections.find((conn) => conn.server.name === name)
if (connection) {
connection.server.status = "disconnected"
this.appendErrorMessage(connection, error instanceof Error ? error.message : String(error))
}
throw error
}
}
private appendErrorMessage(connection: McpConnection, error: string) {
const newError = connection.server.error ? `${connection.server.error}\n${error}` : error
connection.server.error = newError //.slice(0, 800)
}
private async fetchToolsList(serverName: string): Promise<McpTool[]> {
try {
const response = await this.connections
.find((conn) => conn.server.name === serverName)
?.client.request({ method: "tools/list" }, ListToolsResultSchema)
// Get always allow settings
const settingsPath = await this.getMcpSettingsFilePath()
const content = await fs.readFile(settingsPath, "utf-8")
const config = JSON.parse(content)
const alwaysAllowConfig = config.mcpServers[serverName]?.alwaysAllow || []
// Mark tools as always allowed based on settings
const tools = (response?.tools || []).map((tool) => ({
...tool,
alwaysAllow: alwaysAllowConfig.includes(tool.name),
}))
console.log(`[MCP] Fetched tools for ${serverName}:`, tools)
return tools
} catch (error) {
// console.error(`Failed to fetch tools for ${serverName}:`, error)
return []
}
}
private async fetchResourcesList(serverName: string): Promise<McpResource[]> {
try {
const response = await this.connections
.find((conn) => conn.server.name === serverName)
?.client.request({ method: "resources/list" }, ListResourcesResultSchema)
return response?.resources || []
} catch (error) {
// console.error(`Failed to fetch resources for ${serverName}:`, error)
return []
}
}
private async fetchResourceTemplatesList(serverName: string): Promise<McpResourceTemplate[]> {
try {
const response = await this.connections
.find((conn) => conn.server.name === serverName)
?.client.request({ method: "resources/templates/list" }, ListResourceTemplatesResultSchema)
return response?.resourceTemplates || []
} catch (error) {
// console.error(`Failed to fetch resource templates for ${serverName}:`, error)
return []
}
}
async deleteConnection(name: string): Promise<void> {
const connection = this.connections.find((conn) => conn.server.name === name)
if (connection) {
try {
await connection.transport.close()
await connection.client.close()
} catch (error) {
console.error(`Failed to close transport for ${name}:`, error)
}
this.connections = this.connections.filter((conn) => conn.server.name !== name)
}
}
async updateServerConnections(newServers: Record<string, any>): Promise<void> {
this.isConnecting = true
this.removeAllFileWatchers()
const currentNames = new Set(this.connections.map((conn) => conn.server.name))
const newNames = new Set(Object.keys(newServers))
// Delete removed servers
for (const name of currentNames) {
if (!newNames.has(name)) {
await this.deleteConnection(name)
console.log(`Deleted MCP server: ${name}`)
}
}
// Update or add servers
for (const [name, config] of Object.entries(newServers)) {
const currentConnection = this.connections.find((conn) => conn.server.name === name)
if (!currentConnection) {
// New server
try {
this.setupFileWatcher(name, config)
await this.connectToServer(name, config)
} catch (error) {
console.error(`Failed to connect to new MCP server ${name}:`, error)
}
} else if (!deepEqual(JSON.parse(currentConnection.server.config), config)) {
// Existing server with changed config
try {
this.setupFileWatcher(name, config)
await this.deleteConnection(name)
await this.connectToServer(name, config)
console.log(`Reconnected MCP server with updated config: ${name}`)
} catch (error) {
console.error(`Failed to reconnect MCP server ${name}:`, error)
}
}
// If server exists with same config, do nothing
}
await this.notifyWebviewOfServerChanges()
this.isConnecting = false
}
private setupFileWatcher(name: string, config: any) {
const filePath = config.args?.find((arg: string) => arg.includes("build/index.js"))
if (filePath) {
// we use chokidar instead of onDidSaveTextDocument because it doesn't require the file to be open in the editor. The settings config is better suited for onDidSave since that will be manually updated by the user or Cline (and we want to detect save events, not every file change)
const watcher = chokidar.watch(filePath, {
// persistent: true,
// ignoreInitial: true,
// awaitWriteFinish: true, // This helps with atomic writes
})
watcher.on("change", () => {
console.log(`Detected change in ${filePath}. Restarting server ${name}...`)
this.restartConnection(name)
})
this.fileWatchers.set(name, watcher)
}
}
private removeAllFileWatchers() {
this.fileWatchers.forEach((watcher) => watcher.close())
this.fileWatchers.clear()
}
async restartConnection(serverName: string): Promise<void> {
this.isConnecting = true
const provider = this.providerRef.deref()
if (!provider) {
return
}
// Get existing connection and update its status
const connection = this.connections.find((conn) => conn.server.name === serverName)
const config = connection?.server.config
if (config) {
vscode.window.showInformationMessage(`Restarting ${serverName} MCP server...`)
connection.server.status = "connecting"
connection.server.error = ""
await this.notifyWebviewOfServerChanges()
await delay(500) // artificial delay to show user that server is restarting
try {
await this.deleteConnection(serverName)
// Try to connect again using existing config
await this.connectToServer(serverName, JSON.parse(config))
vscode.window.showInformationMessage(`${serverName} MCP server connected`)
} catch (error) {
console.error(`Failed to restart connection for ${serverName}:`, error)
vscode.window.showErrorMessage(`Failed to connect to ${serverName} MCP server`)
}
}
await this.notifyWebviewOfServerChanges()
this.isConnecting = false
}
private async notifyWebviewOfServerChanges(): Promise<void> {
// servers should always be sorted in the order they are defined in the settings file
const settingsPath = await this.getMcpSettingsFilePath()
const content = await fs.readFile(settingsPath, "utf-8")
const config = JSON.parse(content)
const serverOrder = Object.keys(config.mcpServers || {})
await this.providerRef.deref()?.postMessageToWebview({
type: "mcpServers",
mcpServers: [...this.connections]
.sort((a, b) => {
const indexA = serverOrder.indexOf(a.server.name)
const indexB = serverOrder.indexOf(b.server.name)
return indexA - indexB
})
.map((connection) => connection.server),
})
}
public async toggleServerDisabled(serverName: string, disabled: boolean): Promise<void> {
let settingsPath: string
try {
settingsPath = await this.getMcpSettingsFilePath()
// Ensure the settings file exists and is accessible
try {
await fs.access(settingsPath)
} catch (error) {
console.error("Settings file not accessible:", error)
throw new Error("Settings file not accessible")
}
const content = await fs.readFile(settingsPath, "utf-8")
const config = JSON.parse(content)
// Validate the config structure
if (!config || typeof config !== "object") {
throw new Error("Invalid config structure")
}
if (!config.mcpServers || typeof config.mcpServers !== "object") {
config.mcpServers = {}
}
if (config.mcpServers[serverName]) {
// Create a new server config object to ensure clean structure
const serverConfig = {
...config.mcpServers[serverName],
disabled,
}
// Ensure required fields exist
if (!serverConfig.alwaysAllow) {
serverConfig.alwaysAllow = []
}
config.mcpServers[serverName] = serverConfig
// Write the entire config back
const updatedConfig = {
mcpServers: config.mcpServers,
}
await fs.writeFile(settingsPath, JSON.stringify(updatedConfig, null, 2))
const connection = this.connections.find((conn) => conn.server.name === serverName)
if (connection) {
try {
connection.server.disabled = disabled
// Only refresh capabilities if connected
if (connection.server.status === "connected") {
connection.server.tools = await this.fetchToolsList(serverName)
connection.server.resources = await this.fetchResourcesList(serverName)
connection.server.resourceTemplates = await this.fetchResourceTemplatesList(serverName)
}
} catch (error) {
console.error(`Failed to refresh capabilities for ${serverName}:`, error)
}
}
await this.notifyWebviewOfServerChanges()
}
} catch (error) {
console.error("Failed to update server disabled state:", error)
if (error instanceof Error) {
console.error("Error details:", error.message, error.stack)
}
vscode.window.showErrorMessage(
`Failed to update server state: ${error instanceof Error ? error.message : String(error)}`,
)
throw error
}
}
public async updateServerTimeout(serverName: string, timeout: number): Promise<void> {
let settingsPath: string
try {
settingsPath = await this.getMcpSettingsFilePath()
// Ensure the settings file exists and is accessible
try {
await fs.access(settingsPath)
} catch (error) {
console.error("Settings file not accessible:", error)
throw new Error("Settings file not accessible")
}
const content = await fs.readFile(settingsPath, "utf-8")
const config = JSON.parse(content)
// Validate the config structure
if (!config || typeof config !== "object") {
throw new Error("Invalid config structure")
}
if (!config.mcpServers || typeof config.mcpServers !== "object") {
config.mcpServers = {}
}
if (config.mcpServers[serverName]) {
// Create a new server config object to ensure clean structure
const serverConfig = {
...config.mcpServers[serverName],
timeout,
}
config.mcpServers[serverName] = serverConfig
// Write the entire config back
const updatedConfig = {
mcpServers: config.mcpServers,
}
await fs.writeFile(settingsPath, JSON.stringify(updatedConfig, null, 2))
await this.notifyWebviewOfServerChanges()
}
} catch (error) {
console.error("Failed to update server timeout:", error)
if (error instanceof Error) {
console.error("Error details:", error.message, error.stack)
}
vscode.window.showErrorMessage(
`Failed to update server timeout: ${error instanceof Error ? error.message : String(error)}`,
)
throw error
}
}
public async deleteServer(serverName: string): Promise<void> {
try {
const settingsPath = await this.getMcpSettingsFilePath()
// Ensure the settings file exists and is accessible
try {
await fs.access(settingsPath)
} catch (error) {
throw new Error("Settings file not accessible")
}
const content = await fs.readFile(settingsPath, "utf-8")
const config = JSON.parse(content)
// Validate the config structure
if (!config || typeof config !== "object") {
throw new Error("Invalid config structure")
}
if (!config.mcpServers || typeof config.mcpServers !== "object") {
config.mcpServers = {}
}
// Remove the server from the settings
if (config.mcpServers[serverName]) {
delete config.mcpServers[serverName]
// Write the entire config back
const updatedConfig = {
mcpServers: config.mcpServers,
}
await fs.writeFile(settingsPath, JSON.stringify(updatedConfig, null, 2))
// Update server connections
await this.updateServerConnections(config.mcpServers)
vscode.window.showInformationMessage(`Deleted MCP server: ${serverName}`)
} else {
vscode.window.showWarningMessage(`Server "${serverName}" not found in configuration`)
}
} catch (error) {
console.error("Failed to delete MCP server:", error)
if (error instanceof Error) {
console.error("Error details:", error.message, error.stack)
}
vscode.window.showErrorMessage(
`Failed to delete MCP server: ${error instanceof Error ? error.message : String(error)}`,
)
throw error
}
}
async readResource(serverName: string, uri: string): Promise<McpResourceResponse> {
const connection = this.connections.find((conn) => conn.server.name === serverName)
if (!connection) {
throw new Error(`No connection found for server: ${serverName}`)
}
if (connection.server.disabled) {
throw new Error(`Server "${serverName}" is disabled`)
}
return await connection.client.request(
{
method: "resources/read",
params: {
uri,
},
},
ReadResourceResultSchema,
)
}
async callTool(
serverName: string,
toolName: string,
toolArguments?: Record<string, unknown>,
): Promise<McpToolCallResponse> {
const connection = this.connections.find((conn) => conn.server.name === serverName)
if (!connection) {
throw new Error(
`No connection found for server: ${serverName}. Please make sure to use MCP servers available under 'Connected MCP Servers'.`,
)
}
if (connection.server.disabled) {
throw new Error(`Server "${serverName}" is disabled and cannot be used`)
}
let timeout: number
try {
const parsedConfig = StdioConfigSchema.parse(JSON.parse(connection.server.config))
timeout = (parsedConfig.timeout ?? 60) * 1000
} catch (error) {
console.error("Failed to parse server config for timeout:", error)
// Default to 60 seconds if parsing fails
timeout = 60 * 1000
}
return await connection.client.request(
{
method: "tools/call",
params: {
name: toolName,
arguments: toolArguments,
},
},
CallToolResultSchema,
{
timeout,
},
)
}
async toggleToolAlwaysAllow(serverName: string, toolName: string, shouldAllow: boolean): Promise<void> {
try {
const settingsPath = await this.getMcpSettingsFilePath()
const content = await fs.readFile(settingsPath, "utf-8")
const config = JSON.parse(content)
// Initialize alwaysAllow if it doesn't exist
if (!config.mcpServers[serverName].alwaysAllow) {
config.mcpServers[serverName].alwaysAllow = []
}
const alwaysAllow = config.mcpServers[serverName].alwaysAllow
const toolIndex = alwaysAllow.indexOf(toolName)
if (shouldAllow && toolIndex === -1) {
// Add tool to always allow list
alwaysAllow.push(toolName)
} else if (!shouldAllow && toolIndex !== -1) {
// Remove tool from always allow list
alwaysAllow.splice(toolIndex, 1)
}
// Write updated config back to file
await fs.writeFile(settingsPath, JSON.stringify(config, null, 2))
// Update the tools list to reflect the change
const connection = this.connections.find((conn) => conn.server.name === serverName)
if (connection) {
connection.server.tools = await this.fetchToolsList(serverName)
await this.notifyWebviewOfServerChanges()
}
} catch (error) {
console.error("Failed to update always allow settings:", error)
vscode.window.showErrorMessage("Failed to update always allow settings")
throw error // Re-throw to ensure the error is properly handled
}
}
async dispose(): Promise<void> {
this.removeAllFileWatchers()
for (const connection of this.connections) {
try {
await this.deleteConnection(connection.server.name)
} catch (error) {
console.error(`Failed to close connection for ${connection.server.name}:`, error)
}
}
this.connections = []
if (this.settingsWatcher) {
this.settingsWatcher.dispose()
}
this.disposables.forEach((d) => d.dispose())
}
}

View File

@ -0,0 +1,83 @@
import * as vscode from "vscode"
import { ClineProvider } from "../../core/webview/ClineProvider"
import { McpHub } from "./McpHub"
/**
* Singleton manager for MCP server instances.
* Ensures only one set of MCP servers runs across all webviews.
*/
export class McpServerManager {
private static instance: McpHub | null = null
private static readonly GLOBAL_STATE_KEY = "mcpHubInstanceId"
private static providers: Set<ClineProvider> = new Set()
private static initializationPromise: Promise<McpHub> | null = null
/**
* Get the singleton McpHub instance.
* Creates a new instance if one doesn't exist.
* Thread-safe implementation using a promise-based lock.
*/
static async getInstance(context: vscode.ExtensionContext, provider: ClineProvider): Promise<McpHub> {
// Register the provider
this.providers.add(provider)
// If we already have an instance, return it
if (this.instance) {
return this.instance
}
// If initialization is in progress, wait for it
if (this.initializationPromise) {
return this.initializationPromise
}
// Create a new initialization promise
this.initializationPromise = (async () => {
try {
// Double-check instance in case it was created while we were waiting
if (!this.instance) {
this.instance = new McpHub(provider)
// Store a unique identifier in global state to track the primary instance
await context.globalState.update(this.GLOBAL_STATE_KEY, Date.now().toString())
}
return this.instance
} finally {
// Clear the initialization promise after completion or error
this.initializationPromise = null
}
})()
return this.initializationPromise
}
/**
* Remove a provider from the tracked set.
* This is called when a webview is disposed.
*/
static unregisterProvider(provider: ClineProvider): void {
this.providers.delete(provider)
}
/**
* Notify all registered providers of server state changes.
*/
static notifyProviders(message: any): void {
this.providers.forEach((provider) => {
provider.postMessageToWebview(message).catch((error) => {
console.error("Failed to notify provider:", error)
})
})
}
/**
* Clean up the singleton instance and all its resources.
*/
static async cleanup(context: vscode.ExtensionContext): Promise<void> {
if (this.instance) {
await this.instance.dispose()
this.instance = null
await context.globalState.update(this.GLOBAL_STATE_KEY, undefined)
}
this.providers.clear()
}
}

View File

@ -0,0 +1,206 @@
// import * as vscode from "vscode"
import * as childProcess from "child_process"
import * as fs from "fs"
import * as path from "path"
import * as readline from "readline"
const isWindows = /^win/.test(process.platform)
const binName = isWindows ? "rg.exe" : "rg"
interface SearchResult {
file: string
line: number
column: number
match: string
beforeContext: string[]
afterContext: string[]
}
// Constants
const MAX_RESULTS = 300
const MAX_LINE_LENGTH = 500
/**
* Truncates a line if it exceeds the maximum length
* @param line The line to truncate
* @param maxLength The maximum allowed length (defaults to MAX_LINE_LENGTH)
* @returns The truncated line, or the original line if it's shorter than maxLength
*/
export function truncateLine(line: string, maxLength: number = MAX_LINE_LENGTH): string {
return line.length > maxLength ? line.substring(0, maxLength) + " [truncated...]" : line
}
async function getBinPath(): Promise<string | undefined> {
const binPath = path.join("/opt/homebrew/bin/", binName)
return (await pathExists(binPath)) ? binPath : undefined
}
async function pathExists(path: string): Promise<boolean> {
return new Promise((resolve) => {
fs.access(path, (err) => {
resolve(err === null)
})
})
}
async function execRipgrep(bin: string, args: string[]): Promise<string> {
return new Promise((resolve, reject) => {
const rgProcess = childProcess.spawn(bin, args)
// cross-platform alternative to head, which is ripgrep author's recommendation for limiting output.
const rl = readline.createInterface({
input: rgProcess.stdout,
crlfDelay: Infinity, // treat \r\n as a single line break even if it's split across chunks. This ensures consistent behavior across different operating systems.
})
let output = ""
let lineCount = 0
const maxLines = MAX_RESULTS * 5 // limiting ripgrep output with max lines since there's no other way to limit results. it's okay that we're outputting as json, since we're parsing it line by line and ignore anything that's not part of a match. This assumes each result is at most 5 lines.
rl.on("line", (line) => {
if (lineCount < maxLines) {
output += line + "\n"
lineCount++
} else {
rl.close()
rgProcess.kill()
}
})
let errorOutput = ""
rgProcess.stderr.on("data", (data) => {
errorOutput += data.toString()
})
rl.on("close", () => {
if (errorOutput) {
reject(new Error(`ripgrep process error: ${errorOutput}`))
} else {
resolve(output)
}
})
rgProcess.on("error", (error) => {
reject(new Error(`ripgrep process error: ${error.message}`))
})
})
}
export async function regexSearchFiles(
directoryPath: string,
regex: string,
): Promise<string> {
const rgPath = await getBinPath()
if (!rgPath) {
throw new Error("Could not find ripgrep binary")
}
// 使用--glob参数排除.obsidian目录
const args = [
"--json",
"-e",
regex,
"--glob",
"!.obsidian/**", // 排除.obsidian目录及其所有子目录
"--glob",
"!.git/**",
"--context",
"1",
directoryPath
]
let output: string
try {
output = await execRipgrep(rgPath, args)
console.log("output", output)
} catch (error) {
console.error("Error executing ripgrep:", error)
return "No results found"
}
const results: SearchResult[] = []
let currentResult: Partial<SearchResult> | null = null
output.split("\n").forEach((line) => {
if (line) {
try {
const parsed = JSON.parse(line)
if (parsed.type === "match") {
if (currentResult) {
results.push(currentResult as SearchResult)
}
// Safety check: truncate extremely long lines to prevent excessive output
const matchText = parsed.data.lines.text
const truncatedMatch = truncateLine(matchText)
currentResult = {
file: parsed.data.path.text,
line: parsed.data.line_number,
column: parsed.data.submatches[0].start,
match: truncatedMatch,
beforeContext: [],
afterContext: [],
}
} else if (parsed.type === "context" && currentResult) {
// Apply the same truncation logic to context lines
const contextText = parsed.data.lines.text
const truncatedContext = truncateLine(contextText)
if (parsed.data.line_number < currentResult.line!) {
currentResult.beforeContext!.push(truncatedContext)
} else {
currentResult.afterContext!.push(truncatedContext)
}
}
} catch (error) {
console.error("Error parsing ripgrep output:", error)
}
}
})
if (currentResult) {
results.push(currentResult as SearchResult)
}
console.log("results", results)
console.log("currentResult", currentResult)
return formatResults(results, directoryPath)
}
function formatResults(results: SearchResult[], cwd: string): string {
const groupedResults: { [key: string]: SearchResult[] } = {}
let output = ""
if (results.length >= MAX_RESULTS) {
output += `Showing first ${MAX_RESULTS} of ${MAX_RESULTS}+ results. Use a more specific search if necessary.\n\n`
} else {
output += `Found ${results.length === 1 ? "1 result" : `${results.length.toLocaleString()} results`}.\n\n`
}
// Group results by file name
results.slice(0, MAX_RESULTS).forEach((result) => {
const relativeFilePath = path.relative(cwd, result.file)
if (!groupedResults[relativeFilePath]) {
groupedResults[relativeFilePath] = []
}
groupedResults[relativeFilePath].push(result)
})
for (const [filePath, fileResults] of Object.entries(groupedResults)) {
output += `${filePath.toPosix()}\n│----\n`
fileResults.forEach((result, index) => {
const allLines = [...result.beforeContext, result.match, ...result.afterContext]
allLines.forEach((line) => {
output += `${line?.trimEnd() ?? ""}\n`
})
if (index < fileResults.length - 1) {
output += "│----\n"
}
})
output += "│----\n\n"
}
return output.trim()
}

View File

View File

@ -94,6 +94,7 @@ export class DBManager {
// return drizzle(this.pgClient)
} catch (error) {
console.error('Error loading database:', error)
console.log(this.dbPath)
return null
}
}

View File

@ -1,6 +1,7 @@
import { SerializedLexicalNode } from 'lexical'
import { SUPPORT_EMBEDDING_SIMENTION } from '../constants'
import { ApplyStatus } from '../types/apply'
// import { EmbeddingModelId } from '../types/embedding'
// PostgreSQL column types
@ -123,6 +124,7 @@ export type Conversation = {
export type Message = {
id: string // uuid
conversationId: string // uuid
applyStatus: number
role: 'user' | 'assistant'
content: string | null
reasoningContent?: string | null
@ -151,6 +153,7 @@ export type InsertMessage = {
id: string
conversationId: string
role: 'user' | 'assistant'
apply_status: number
content: string | null
reasoningContent?: string | null
promptContent?: string | null
@ -163,6 +166,7 @@ export type InsertMessage = {
export type SelectMessage = {
id: string // uuid
conversation_id: string // uuid
apply_status: number
role: 'user' | 'assistant'
content: string | null
reasoning_content?: string | null

View File

@ -102,20 +102,10 @@ export const migrations: Record<string, SqlMigration> = {
"updated_at" timestamp DEFAULT now() NOT NULL
);
DO $$
BEGIN
IF NOT EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_name = 'messages'
AND column_name = 'reasoning_content'
) THEN
ALTER TABLE "messages" ADD COLUMN "reasoning_content" text;
END IF;
END $$;
CREATE TABLE IF NOT EXISTS "messages" (
"id" uuid PRIMARY KEY NOT NULL,
"conversation_id" uuid NOT NULL REFERENCES "conversations"("id") ON DELETE CASCADE,
"apply_status" integer NOT NULL DEFAULT 0,
"role" text NOT NULL,
"content" text,
"reasoning_content" text,

View File

@ -25,6 +25,7 @@ import {
InfioSettings,
parseInfioSettings,
} from './types/settings'
import './utils/path'
import { getMentionableBlockData } from './utils/obsidian'
// Remember to rename these classes and interfaces!

66
src/types/apply.ts Normal file
View File

@ -0,0 +1,66 @@
/**
*
*/
export enum ApplyStatus {
Idle = 0,
Applied = 1,
Failed = 2,
Rejected = 3,
}
export type ReadFileToolArgs = {
type: 'read_file';
filepath?: string;
}
export type ListFilesToolArgs = {
type: 'list_files';
filepath?: string;
recursive?: boolean;
}
export type RegexSearchFilesToolArgs = {
type: 'regex_search_files';
filepath?: string;
regex?: string;
file_pattern?: string;
finish?: boolean;
}
export type SemanticSearchFilesToolArgs = {
type: 'semantic_search_files';
filepath?: string;
query?: string;
finish?: boolean;
}
export type WriteToFileToolArgs = {
type: 'write_to_file';
filepath?: string;
content?: string;
startLine?: number;
endLine?: number;
}
export type InsertContentToolArgs = {
type: 'insert_content';
filepath?: string;
content?: string;
startLine?: number;
endLine?: number;
}
export type SearchAndReplaceToolArgs = {
type: 'search_and_replace';
filepath: string;
operations: {
search: string;
replace: string;
startLine?: number;
endLine?: number;
useRegex?: boolean;
ignoreCase?: boolean;
regexFlags?: string;
}[];
}
export type ToolArgs = ReadFileToolArgs | WriteToFileToolArgs | InsertContentToolArgs | SearchAndReplaceToolArgs | ListFilesToolArgs | RegexSearchFilesToolArgs | SemanticSearchFilesToolArgs;

View File

@ -2,6 +2,7 @@ import { SerializedEditorState } from 'lexical'
import { SelectVector } from '../database/schema'
import { ApplyStatus } from './apply'
import { LLMModel } from './llm/model'
import { ContentPart } from './llm/request'
import { ResponseUsage } from './llm/response'
@ -9,6 +10,7 @@ import { Mentionable, SerializedMentionable } from './mentionable'
export type ChatUserMessage = {
role: 'user'
applyStatus: ApplyStatus
content: SerializedEditorState | null
promptContent: string | ContentPart[] | null
id: string
@ -20,6 +22,7 @@ export type ChatUserMessage = {
export type ChatAssistantMessage = {
role: 'assistant'
applyStatus: ApplyStatus
content: string
reasoningContent: string
id: string
@ -33,6 +36,7 @@ export type ChatMessage = ChatUserMessage | ChatAssistantMessage
export type SerializedChatUserMessage = {
role: 'user'
applyStatus: ApplyStatus
content: SerializedEditorState | null
promptContent: string | ContentPart[] | null
id: string
@ -44,6 +48,7 @@ export type SerializedChatUserMessage = {
export type SerializedChatAssistantMessage = {
role: 'assistant'
applyStatus: ApplyStatus
content: string
reasoningContent: string
id: string

View File

@ -1,4 +1,6 @@
import { TFile } from 'obsidian'
import { TFile } from 'obsidian';
import { SearchAndReplaceToolArgs } from '../types/apply';
/**
* Applies changes to a file by replacing content within specified line range
@ -9,36 +11,95 @@ import { TFile } from 'obsidian'
* @param endLine - Ending line number (1-based indexing, optional)
* @returns Promise resolving to the modified content or null if operation fails
*/
export const manualApplyChangesToFile = async (
content: string,
currentFile: TFile,
currentFileContent: string,
startLine?: number,
endLine?: number,
export const ApplyEditToFile = async (
currentFile: TFile,
currentFileContent: string,
content: string,
startLine?: number,
endLine?: number,
): Promise<string | null> => {
try {
// Input validation
if (!content || !currentFileContent) {
throw new Error('Content cannot be empty')
}
try {
// 如果文件为空,直接返回新内容
if (!currentFileContent || currentFileContent.trim() === '') {
return content;
}
const lines = currentFileContent.split('\n')
const effectiveStartLine = Math.max(1, startLine ?? 1)
const effectiveEndLine = Math.min(endLine ?? lines.length, lines.length)
// 如果要清空文件,直接返回空字符串
if (content === '') {
return '';
}
// Validate line numbers
if (effectiveStartLine > effectiveEndLine) {
throw new Error('Start line cannot be greater than end line')
}
const lines = currentFileContent.split('\n')
const effectiveStartLine = Math.max(1, startLine ?? 1)
const effectiveEndLine = Math.min(endLine ?? lines.length, lines.length)
// Construct new content
return [
...lines.slice(0, effectiveStartLine - 1),
content,
...lines.slice(effectiveEndLine)
].join('\n')
} catch (error) {
console.error('Error applying changes:', error instanceof Error ? error.message : 'Unknown error')
return null
}
// Validate line numbers
if (effectiveStartLine > effectiveEndLine) {
throw new Error('Start line cannot be greater than end line')
}
// Construct new content
return [
...lines.slice(0, effectiveStartLine - 1),
content,
...lines.slice(effectiveEndLine)
].join('\n')
} catch (error) {
console.error('Error applying changes:', error instanceof Error ? error.message : 'Unknown error')
return null
}
}
function escapeRegExp(string: string): string {
return string.replace(/[.*+?^${}()|[\]\\]/g, "\\$&")
}
/**
*
* @param currentFile -
* @param currentFileContent -
* @param search -
* @param replace -
*/
export const SearchAndReplace = async (
currentFile: TFile,
currentFileContent: string,
operations: SearchAndReplaceToolArgs['operations']
) => {
let lines = currentFileContent.split("\n")
for (const op of operations) {
const flags = op.regexFlags ?? (op.ignoreCase ? "gi" : "g")
const multilineFlags = flags.includes("m") ? flags : flags + "m"
const searchPattern = op.useRegex
? new RegExp(op.search, multilineFlags)
: new RegExp(escapeRegExp(op.search), multilineFlags)
if (op.startLine || op.endLine) {
const startLine = Math.max((op.startLine ?? 1) - 1, 0)
const endLine = Math.min((op.endLine ?? lines.length) - 1, lines.length - 1)
// Get the content before and after the target section
const beforeLines = lines.slice(0, startLine)
const afterLines = lines.slice(endLine + 1)
// Get the target section and perform replacement
const targetContent = lines.slice(startLine, endLine + 1).join("\n")
const modifiedContent = targetContent.replace(searchPattern, op.replace)
const modifiedLines = modifiedContent.split("\n")
// Reconstruct the full content with the modified section
lines = [...beforeLines, ...modifiedLines, ...afterLines]
} else {
// Global replacement
const fullContent = lines.join("\n")
const modifiedContent = fullContent.replace(searchPattern, op.replace)
lines = modifiedContent.split("\n")
}
}
const newContent = lines.join("\n")
return newContent;
}

View File

@ -1,5 +1,5 @@
import { minimatch } from 'minimatch'
import { Vault } from 'obsidian'
import { TFile, TFolder, Vault } from 'obsidian'
export const findFilesMatchingPatterns = async (
patterns: string[],
@ -10,3 +10,24 @@ export const findFilesMatchingPatterns = async (
return patterns.some((pattern) => minimatch(file.path, pattern))
})
}
export const listFilesAndFolders = async (vault: Vault, path: string) => {
const folder = vault.getAbstractFileByPath(path)
const childrenFiles: string[] = []
const childrenFolders: string[] = []
if (folder instanceof TFolder) {
folder.children.forEach((child) => {
if (child instanceof TFile) {
childrenFiles.push(child.path)
} else if (child instanceof TFolder) {
childrenFolders.push(child.path + "/")
}
})
return [...childrenFolders, ...childrenFiles]
}
return []
}
export const regexSearchFiles = async (vault: Vault, path: string, regex: string, file_pattern: string) => {
}

312
src/utils/modes.ts Normal file
View File

@ -0,0 +1,312 @@
import { addCustomInstructions } from "../core/prompts/sections/custom-instructions"
import { ALWAYS_AVAILABLE_TOOLS, TOOL_GROUPS, ToolGroup } from "./tool-groups"
// Mode types
export type Mode = string
// Group options type
export type GroupOptions = {
fileRegex?: string // Regular expression pattern
description?: string // Human-readable description of the pattern
}
// Group entry can be either a string or tuple with options
export type GroupEntry = ToolGroup | readonly [ToolGroup, GroupOptions]
// Mode configuration type
export type ModeConfig = {
slug: string
name: string
roleDefinition: string
customInstructions?: string
groups: readonly GroupEntry[] // Now supports both simple strings and tuples with options
source?: "global" | "project" // Where this mode was loaded from
}
// Mode-specific prompts only
export type PromptComponent = {
roleDefinition?: string
customInstructions?: string
}
export type CustomModePrompts = {
[key: string]: PromptComponent | undefined
}
// Helper to extract group name regardless of format
export function getGroupName(group: GroupEntry): ToolGroup {
if (typeof group === "string") {
return group
}
return group[0]
}
// Helper to get group options if they exist
function getGroupOptions(group: GroupEntry): GroupOptions | undefined {
return Array.isArray(group) ? group[1] : undefined
}
// Helper to check if a file path matches a regex pattern
export function doesFileMatchRegex(filePath: string, pattern: string): boolean {
try {
const regex = new RegExp(pattern)
return regex.test(filePath)
} catch (error) {
console.error(`Invalid regex pattern: ${pattern}`, error)
return false
}
}
// Helper to get all tools for a mode
export function getToolsForMode(groups: readonly GroupEntry[]): string[] {
const tools = new Set<string>()
// Add tools from each group
groups.forEach((group) => {
const groupName = getGroupName(group)
const groupConfig = TOOL_GROUPS[groupName]
groupConfig.tools.forEach((tool: string) => tools.add(tool))
})
// Always add required tools
ALWAYS_AVAILABLE_TOOLS.forEach((tool) => tools.add(tool))
return Array.from(tools)
}
// Main modes configuration as an ordered array
export const modes: readonly ModeConfig[] = [
{
slug: "write",
name: "Write",
roleDefinition:
"You are Infio, a versatile content creator skilled in composing, editing, and organizing various text-based documents. You excel at structuring information clearly, creating well-formatted content, and helping users express their ideas effectively.",
groups: ["read", "edit"],
customInstructions:
"You can create and modify any text-based files, with particular expertise in Markdown formatting. Help users organize their thoughts, create documentation, take notes, or draft any written content they need. When appropriate, suggest structural improvements and formatting enhancements that make content more readable and accessible. Consider the purpose and audience of each document to provide the most relevant assistance."
},
{
slug: "ask",
name: "Ask",
roleDefinition:
"You are Infio, a versatile assistant dedicated to providing informative responses, thoughtful explanations, and practical guidance on virtually any topic or challenge you face.",
groups: ["read"],
customInstructions:
"You can analyze information, explain concepts across various domains, and access external resources when helpful. Make sure to address the user's questions thoroughly with thoughtful explanations and practical guidance. Use visual aids like Mermaid diagrams when they help make complex topics clearer. Offer solutions to challenges from diverse fields, not just technical ones, and provide context that helps users better understand the subject matter.",
},
] as const
// Export the default mode slug
export const defaultModeSlug = modes[0].slug
// Helper functions
export function getModeBySlug(slug: string, customModes?: ModeConfig[]): ModeConfig | undefined {
// Check custom modes first
const customMode = customModes?.find((mode) => mode.slug === slug)
if (customMode) {
return customMode
}
// Then check built-in modes
return modes.find((mode) => mode.slug === slug)
}
export function getModeConfig(slug: string, customModes?: ModeConfig[]): ModeConfig {
const mode = getModeBySlug(slug, customModes)
if (!mode) {
throw new Error(`No mode found for slug: ${slug}`)
}
return mode
}
// Get all available modes, with custom modes overriding built-in modes
export function getAllModes(customModes?: ModeConfig[]): ModeConfig[] {
if (!customModes?.length) {
return [...modes]
}
// Start with built-in modes
const allModes = [...modes]
// Process custom modes
customModes.forEach((customMode) => {
const index = allModes.findIndex((mode) => mode.slug === customMode.slug)
if (index !== -1) {
// Override existing mode
allModes[index] = customMode
} else {
// Add new mode
allModes.push(customMode)
}
})
return allModes
}
// Check if a mode is custom or an override
export function isCustomMode(slug: string, customModes?: ModeConfig[]): boolean {
return !!customModes?.some((mode) => mode.slug === slug)
}
// Custom error class for file restrictions
export class FileRestrictionError extends Error {
constructor(mode: string, pattern: string, description: string | undefined, filePath: string) {
super(
`This mode (${mode}) can only edit files matching pattern: ${pattern}${description ? ` (${description})` : ""}. Got: ${filePath}`,
)
this.name = "FileRestrictionError"
}
}
export function isToolAllowedForMode(
tool: string,
modeSlug: string,
customModes: ModeConfig[],
toolRequirements?: Record<string, boolean>,
toolParams?: Record<string, any>, // All tool parameters
experiments?: Record<string, boolean>,
): boolean {
// Always allow these tools
if (ALWAYS_AVAILABLE_TOOLS.includes(tool as any)) {
return true
}
if (experiments && tool in experiments) {
if (!experiments[tool]) {
return false
}
}
// Check tool requirements if any exist
if (toolRequirements && tool in toolRequirements) {
if (!toolRequirements[tool]) {
return false
}
}
const mode = getModeBySlug(modeSlug, customModes)
if (!mode) {
return false
}
// Check if tool is in any of the mode's groups and respects any group options
for (const group of mode.groups) {
const groupName = getGroupName(group)
const options = getGroupOptions(group)
const groupConfig = TOOL_GROUPS[groupName]
// If the tool isn't in this group's tools, continue to next group
if (!groupConfig.tools.includes(tool)) {
continue
}
// If there are no options, allow the tool
if (!options) {
return true
}
// For the edit group, check file regex if specified
if (groupName === "edit" && options.fileRegex) {
const filePath = toolParams?.path
if (
filePath &&
(toolParams.diff || toolParams.content || toolParams.operations) &&
!doesFileMatchRegex(filePath, options.fileRegex)
) {
throw new FileRestrictionError(mode.name, options.fileRegex, options.description, filePath)
}
}
return true
}
return false
}
// Create the mode-specific default prompts
export const defaultPrompts: Readonly<CustomModePrompts> = Object.freeze(
Object.fromEntries(
modes.map((mode) => [
mode.slug,
{
roleDefinition: mode.roleDefinition,
customInstructions: mode.customInstructions,
},
]),
),
)
// Helper function to get all modes with their prompt overrides from extension state
export async function getAllModesWithPrompts(context: vscode.ExtensionContext): Promise<ModeConfig[]> {
const customModes = (await context.globalState.get<ModeConfig[]>("customModes")) || []
const customModePrompts = (await context.globalState.get<CustomModePrompts>("customModePrompts")) || {}
const allModes = getAllModes(customModes)
return allModes.map((mode) => ({
...mode,
roleDefinition: customModePrompts[mode.slug]?.roleDefinition ?? mode.roleDefinition,
customInstructions: customModePrompts[mode.slug]?.customInstructions ?? mode.customInstructions,
}))
}
// Helper function to get complete mode details with all overrides
export async function getFullModeDetails(
modeSlug: string,
customModes?: ModeConfig[],
customModePrompts?: CustomModePrompts,
options?: {
cwd?: string
globalCustomInstructions?: string
preferredLanguage?: string
},
): Promise<ModeConfig> {
// First get the base mode config from custom modes or built-in modes
const baseMode = getModeBySlug(modeSlug, customModes) || modes.find((m) => m.slug === modeSlug) || modes[0]
// Check for any prompt component overrides
const promptComponent = customModePrompts?.[modeSlug]
// Get the base custom instructions
const baseCustomInstructions = promptComponent?.customInstructions || baseMode.customInstructions || ""
// If we have cwd, load and combine all custom instructions
let fullCustomInstructions = baseCustomInstructions
if (options?.cwd) {
fullCustomInstructions = await addCustomInstructions(
baseCustomInstructions,
options.globalCustomInstructions || "",
options.cwd,
modeSlug,
{ preferredLanguage: options.preferredLanguage },
)
}
// Return mode with any overrides applied
return {
...baseMode,
roleDefinition: promptComponent?.roleDefinition || baseMode.roleDefinition,
customInstructions: fullCustomInstructions,
}
}
// Helper function to safely get role definition
export function getRoleDefinition(modeSlug: string, customModes?: ModeConfig[]): string {
const mode = getModeBySlug(modeSlug, customModes)
if (!mode) {
console.warn(`No mode found for slug: ${modeSlug}`)
return ""
}
return mode.roleDefinition
}
// Helper function to safely get custom instructions
export function getCustomInstructions(modeSlug: string, customModes?: ModeConfig[]): string {
const mode = getModeBySlug(modeSlug, customModes)
if (!mode) {
console.warn(`No mode found for slug: ${modeSlug}`)
return ""
}
return mode.customInstructions ?? ""
}

View File

@ -1,4 +1,4 @@
import { InfioBlockAction, ParsedInfioBlock, parseinfioBlocks } from './parse-infio-block'
import { InfioBlockAction, ParsedMsgBlock, parseMsgBlocks } from './parse-infio-block'
describe('parseinfioBlocks', () => {
it('should parse a string with infio_block elements', () => {
@ -22,7 +22,7 @@ print("Hello, world!")
</infio_block>
Some text after`
const expected: ParsedInfioBlock[] = [
const expected: ParsedMsgBlock[] = [
{ type: 'string', content: 'Some text before\n' },
{
type: 'infio_block',
@ -49,7 +49,7 @@ print("Hello, world!")
{ type: 'string', content: '\nSome text after' },
]
const result = parseinfioBlocks(input)
const result = parseMsgBlocks(input)
expect(result).toEqual(expected)
})
@ -58,7 +58,7 @@ print("Hello, world!")
<infio_block language="python"></infio_block>
`
const expected: ParsedInfioBlock[] = [
const expected: ParsedMsgBlock[] = [
{ type: 'string', content: '\n ' },
{
type: 'infio_block',
@ -69,16 +69,16 @@ print("Hello, world!")
{ type: 'string', content: '\n ' },
]
const result = parseinfioBlocks(input)
const result = parseMsgBlocks(input)
expect(result).toEqual(expected)
})
it('should handle input without infio_block elements', () => {
const input = 'Just a regular string without any infio_block elements.'
const expected: ParsedInfioBlock[] = [{ type: 'string', content: input }]
const expected: ParsedMsgBlock[] = [{ type: 'string', content: input }]
const result = parseinfioBlocks(input)
const result = parseMsgBlocks(input)
expect(result).toEqual(expected)
})
@ -100,7 +100,7 @@ print("Hello, world!")
</infio_block>
End`
const expected: ParsedInfioBlock[] = [
const expected: ParsedMsgBlock[] = [
{ type: 'string', content: 'Start\n' },
{
type: 'infio_block',
@ -129,7 +129,7 @@ print("Hello, world!")
{ type: 'string', content: '\nEnd' },
]
const result = parseinfioBlocks(input)
const result = parseMsgBlocks(input)
expect(result).toEqual(expected)
})
@ -139,7 +139,7 @@ print("Hello, world!")
# Unfinished infio_block
Some text after without closing tag`
const expected: ParsedInfioBlock[] = [
const expected: ParsedMsgBlock[] = [
{ type: 'string', content: 'Start\n' },
{
type: 'infio_block',
@ -152,13 +152,13 @@ Some text after without closing tag`,
},
]
const result = parseinfioBlocks(input)
const result = parseMsgBlocks(input)
expect(result).toEqual(expected)
})
it('should handle infio_block with startline and endline attributes', () => {
const input = `<infio_block language="markdown" startline="2" endline="5"></infio_block>`
const expected: ParsedInfioBlock[] = [
const expected: ParsedMsgBlock[] = [
{
type: 'infio_block',
content: '',
@ -168,13 +168,13 @@ Some text after without closing tag`,
},
]
const result = parseinfioBlocks(input)
const result = parseMsgBlocks(input)
expect(result).toEqual(expected)
})
it('should parse infio_block with action attribute', () => {
const input = `<infio_block type="edit"></infio_block>`
const expected: ParsedInfioBlock[] = [
const expected: ParsedMsgBlock[] = [
{
type: 'infio_block',
content: '',
@ -182,13 +182,13 @@ Some text after without closing tag`,
},
]
const result = parseinfioBlocks(input)
const result = parseMsgBlocks(input)
expect(result).toEqual(expected)
})
it('should handle invalid action attribute', () => {
const input = `<infio_block type="invalid"></infio_block>`
const expected: ParsedInfioBlock[] = [
const expected: ParsedMsgBlock[] = [
{
type: 'infio_block',
content: '',
@ -196,7 +196,7 @@ Some text after without closing tag`,
},
]
const result = parseinfioBlocks(input)
const result = parseMsgBlocks(input)
expect(result).toEqual(expected)
})
@ -208,7 +208,7 @@ It might contain multiple lines of text.
</think>
Some text after`
const expected: ParsedInfioBlock[] = [
const expected: ParsedMsgBlock[] = [
{ type: 'string', content: 'Some text before\n' },
{
type: 'think',
@ -220,7 +220,7 @@ It might contain multiple lines of text.
{ type: 'string', content: '\nSome text after' },
]
const result = parseinfioBlocks(input)
const result = parseMsgBlocks(input)
expect(result).toEqual(expected)
})
@ -229,7 +229,7 @@ It might contain multiple lines of text.
<think></think>
`
const expected: ParsedInfioBlock[] = [
const expected: ParsedMsgBlock[] = [
{ type: 'string', content: '\n ' },
{
type: 'think',
@ -238,7 +238,7 @@ It might contain multiple lines of text.
{ type: 'string', content: '\n ' },
]
const result = parseinfioBlocks(input)
const result = parseMsgBlocks(input)
expect(result).toEqual(expected)
})
@ -255,7 +255,7 @@ I need to consider several approaches.
</think>
End`
const expected: ParsedInfioBlock[] = [
const expected: ParsedMsgBlock[] = [
{ type: 'string', content: 'Start\n' },
{
type: 'infio_block',
@ -277,7 +277,7 @@ I need to consider several approaches.
{ type: 'string', content: '\nEnd' },
]
const result = parseinfioBlocks(input)
const result = parseMsgBlocks(input)
expect(result).toEqual(expected)
})
@ -286,7 +286,7 @@ I need to consider several approaches.
<think>
Some unfinished thought
without closing tag`
const expected: ParsedInfioBlock[] = [
const expected: ParsedMsgBlock[] = [
{ type: 'string', content: 'Start\n' },
{
type: 'think',
@ -296,7 +296,7 @@ without closing tag`,
},
]
const result = parseinfioBlocks(input)
const result = parseMsgBlocks(input)
expect(result).toEqual(expected)
})
})

View File

@ -1,128 +1,433 @@
import JSON5 from 'json5'
import { parseFragment } from 'parse5'
export enum InfioBlockAction {
Edit = 'edit',
New = 'new',
Reference = 'reference'
}
export type ParsedInfioBlock =
| { type: 'string'; content: string }
export type ParsedMsgBlock =
| {
type: 'infio_block'
type: 'string'
content: string
language?: string
filename?: string
startLine?: number
endLine?: number
action?: InfioBlockAction
}
| { type: 'think'; content: string }
function isInfioBlockAction(value: string): value is InfioBlockAction {
return Object.values<string>(InfioBlockAction).includes(value)
}
export function parseinfioBlocks(input: string): ParsedInfioBlock[] {
const parsedResult: ParsedInfioBlock[] = []
const fragment = parseFragment(input, {
sourceCodeLocationInfo: true,
})
let lastEndOffset = 0
for (const node of fragment.childNodes) {
if (node.nodeName === 'infio_block') {
if (!node.sourceCodeLocation) {
throw new Error('sourceCodeLocation is undefined')
}
const startOffset = node.sourceCodeLocation.startOffset
const endOffset = node.sourceCodeLocation.endOffset
if (startOffset > lastEndOffset) {
parsedResult.push({
type: 'string',
content: input.slice(lastEndOffset, startOffset),
})
}
const language = node.attrs.find((attr) => attr.name === 'language')?.value
const filename = node.attrs.find((attr) => attr.name === 'filename')?.value
const startLine = node.attrs.find((attr) => attr.name === 'startline')?.value
const endLine = node.attrs.find((attr) => attr.name === 'endline')?.value
const actionValue = node.attrs.find((attr) => attr.name === 'type')?.value
const action = actionValue && isInfioBlockAction(actionValue)
? actionValue
: undefined
const children = node.childNodes
if (children.length === 0) {
parsedResult.push({
type: 'infio_block',
content: '',
language,
filename,
startLine: startLine ? parseInt(startLine) : undefined,
endLine: endLine ? parseInt(endLine) : undefined,
action: action,
})
} else {
const innerContentStartOffset =
children[0].sourceCodeLocation?.startOffset
const innerContentEndOffset =
children[children.length - 1].sourceCodeLocation?.endOffset
if (!innerContentStartOffset || !innerContentEndOffset) {
throw new Error('sourceCodeLocation is undefined')
}
parsedResult.push({
type: 'infio_block',
content: input.slice(innerContentStartOffset, innerContentEndOffset),
language,
filename,
startLine: startLine ? parseInt(startLine) : undefined,
endLine: endLine ? parseInt(endLine) : undefined,
action: action,
})
}
lastEndOffset = endOffset
} else if (node.nodeName === 'think') {
if (!node.sourceCodeLocation) {
throw new Error('sourceCodeLocation is undefined')
}
const startOffset = node.sourceCodeLocation.startOffset
const endOffset = node.sourceCodeLocation.endOffset
if (startOffset > lastEndOffset) {
parsedResult.push({
type: 'string',
content: input.slice(lastEndOffset, startOffset),
})
}
const children = node.childNodes
if (children.length === 0) {
parsedResult.push({
type: 'think',
content: '',
})
} else {
const innerContentStartOffset =
children[0].sourceCodeLocation?.startOffset
const innerContentEndOffset =
children[children.length - 1].sourceCodeLocation?.endOffset
if (!innerContentStartOffset || !innerContentEndOffset) {
throw new Error('sourceCodeLocation is undefined')
}
parsedResult.push({
type: 'think',
content: input.slice(innerContentStartOffset, innerContentEndOffset),
})
}
lastEndOffset = endOffset
}
| {
type: 'think'
content: string
} | {
type: 'thinking'
content: string
} | {
type: 'write_to_file'
path: string
content: string
lineCount?: number
} | {
type: 'insert_content'
path: string
startLine: number
content: string
} | {
type: 'read_file'
path: string
finish: boolean
} | {
type: 'attempt_completion'
result: string
} | {
type: 'search_and_replace'
path: string
operations: {
search: string
replace: string
start_line?: number
end_line?: number
use_regex?: boolean
ignore_case?: boolean
regex_flags?: string
}[]
finish: boolean
} | {
type: 'ask_followup_question'
question: string
} | {
type: 'list_files'
path: string
recursive?: boolean
finish: boolean
} | {
type: 'regex_search_files'
path: string
regex: string
finish: boolean
} | {
type: 'semantic_search_files'
path: string
query: string
finish: boolean
}
if (lastEndOffset < input.length) {
parsedResult.push({
type: 'string',
content: input.slice(lastEndOffset),
export function parseMsgBlocks(
input: string,
): ParsedMsgBlock[] {
try {
const parsedResult: ParsedMsgBlock[] = []
const fragment = parseFragment(input, {
sourceCodeLocationInfo: true,
})
let lastEndOffset = 0
for (const node of fragment.childNodes) {
if (node.nodeName === 'thinking') {
if (!node.sourceCodeLocation) {
throw new Error('sourceCodeLocation is undefined')
}
const startOffset = node.sourceCodeLocation.startOffset
const endOffset = node.sourceCodeLocation.endOffset
if (startOffset > lastEndOffset) {
parsedResult.push({
type: 'string',
content: input.slice(lastEndOffset, startOffset),
})
}
const children = node.childNodes
if (children.length === 0) {
parsedResult.push({
type: 'thinking',
content: '',
})
} else {
const innerContentStartOffset =
children[0].sourceCodeLocation?.startOffset
const innerContentEndOffset =
children[children.length - 1].sourceCodeLocation?.endOffset
if (!innerContentStartOffset || !innerContentEndOffset) {
throw new Error('sourceCodeLocation is undefined')
}
parsedResult.push({
type: 'thinking',
content: input.slice(innerContentStartOffset, innerContentEndOffset),
})
}
lastEndOffset = endOffset
} else if (node.nodeName === 'think') {
if (!node.sourceCodeLocation) {
throw new Error('sourceCodeLocation is undefined')
}
const startOffset = node.sourceCodeLocation.startOffset
const endOffset = node.sourceCodeLocation.endOffset
if (startOffset > lastEndOffset) {
parsedResult.push({
type: 'string',
content: input.slice(lastEndOffset, startOffset),
})
}
const children = node.childNodes
if (children.length === 0) {
parsedResult.push({
type: 'think',
content: '',
})
} else {
const innerContentStartOffset =
children[0].sourceCodeLocation?.startOffset
const innerContentEndOffset =
children[children.length - 1].sourceCodeLocation?.endOffset
if (!innerContentStartOffset || !innerContentEndOffset) {
throw new Error('sourceCodeLocation is undefined')
}
parsedResult.push({
type: 'think',
content: input.slice(innerContentStartOffset, innerContentEndOffset),
})
}
lastEndOffset = endOffset
} else if (node.nodeName === 'list_files') {
if (!node.sourceCodeLocation) {
throw new Error('sourceCodeLocation is undefined')
}
const startOffset = node.sourceCodeLocation.startOffset
const endOffset = node.sourceCodeLocation.endOffset
if (startOffset > lastEndOffset) {
parsedResult.push({
type: 'string',
content: input.slice(lastEndOffset, startOffset),
})
}
let path: string | undefined
let recursive: boolean | undefined
for (const childNode of node.childNodes) {
if (childNode.nodeName === 'path' && childNode.childNodes.length > 0) {
path = childNode.childNodes[0].value
} else if (childNode.nodeName === 'recursive' && childNode.childNodes.length > 0) {
const recursiveValue = childNode.childNodes[0].value
recursive = recursiveValue ? recursiveValue.toLowerCase() === 'true' : false
}
}
parsedResult.push({
type: 'list_files',
path: path || '/',
recursive,
finish: node.sourceCodeLocation.endTag !== undefined
})
lastEndOffset = endOffset
} else if (node.nodeName === 'read_file') {
if (!node.sourceCodeLocation) {
throw new Error('sourceCodeLocation is undefined')
}
const startOffset = node.sourceCodeLocation.startOffset
const endOffset = node.sourceCodeLocation.endOffset
if (startOffset > lastEndOffset) {
parsedResult.push({
type: 'string',
content: input.slice(lastEndOffset, startOffset),
})
}
let path: string | undefined
for (const childNode of node.childNodes) {
if (childNode.nodeName === 'path' && childNode.childNodes.length > 0) {
path = childNode.childNodes[0].value
}
}
parsedResult.push({
type: 'read_file',
path,
// Check if the tag is completely parsed with proper closing tag
// In parse5, when a tag is properly closed, its sourceCodeLocation will include endTag
finish: node.sourceCodeLocation.endTag !== undefined
})
lastEndOffset = endOffset
} else if (node.nodeName === 'regex_search_files') {
if (!node.sourceCodeLocation) {
throw new Error('sourceCodeLocation is undefined')
}
const startOffset = node.sourceCodeLocation.startOffset
const endOffset = node.sourceCodeLocation.endOffset
if (startOffset > lastEndOffset) {
parsedResult.push({
type: 'string',
content: input.slice(lastEndOffset, startOffset),
})
}
let path: string | undefined
let regex: string | undefined
for (const childNode of node.childNodes) {
if (childNode.nodeName === 'path' && childNode.childNodes.length > 0) {
path = childNode.childNodes[0].value
} else if (childNode.nodeName === 'regex' && childNode.childNodes.length > 0) {
regex = childNode.childNodes[0].value
}
}
parsedResult.push({
type: 'regex_search_files',
path: path,
regex: regex,
finish: node.sourceCodeLocation.endTag !== undefined
})
lastEndOffset = endOffset
} else if (node.nodeName === 'semantic_search_files') {
if (!node.sourceCodeLocation) {
throw new Error('sourceCodeLocation is undefined')
}
const startOffset = node.sourceCodeLocation.startOffset
const endOffset = node.sourceCodeLocation.endOffset
if (startOffset > lastEndOffset) {
parsedResult.push({
type: 'string',
content: input.slice(lastEndOffset, startOffset),
})
}
let path: string | undefined
let query: string | undefined
for (const childNode of node.childNodes) {
if (childNode.nodeName === 'path' && childNode.childNodes.length > 0) {
path = childNode.childNodes[0].value
} else if (childNode.nodeName === 'query' && childNode.childNodes.length > 0) {
query = childNode.childNodes[0].value
}
}
parsedResult.push({
type: 'semantic_search_files',
path: path,
query: query,
finish: node.sourceCodeLocation.endTag !== undefined
})
lastEndOffset = endOffset
} else if (node.nodeName === 'write_to_file') {
if (!node.sourceCodeLocation) {
throw new Error('sourceCodeLocation is undefined')
}
const startOffset = node.sourceCodeLocation.startOffset
const endOffset = node.sourceCodeLocation.endOffset
if (startOffset > lastEndOffset) {
parsedResult.push({
type: 'string',
content: input.slice(lastEndOffset, startOffset),
})
}
let path: string | undefined
let content: string = ''
let lineCount: number | undefined
// 处理子标签
for (const childNode of node.childNodes) {
if (childNode.nodeName === 'path' && childNode.childNodes.length > 0) {
path = childNode.childNodes[0].value
} else if (childNode.nodeName === 'content' && childNode.childNodes.length > 0) {
// 如果内容有多个子节点,需要合并它们
content = childNode.childNodes.map(n => n.value || '').join('')
} else if (childNode.nodeName === 'line_count' && childNode.childNodes.length > 0) {
const lineCountStr = childNode.childNodes[0].value
lineCount = lineCountStr ? parseInt(lineCountStr) : undefined
}
}
parsedResult.push({
type: 'write_to_file',
content,
path,
lineCount
})
lastEndOffset = endOffset
} else if (node.nodeName === 'insert_content') {
if (!node.sourceCodeLocation) {
throw new Error('sourceCodeLocation is undefined')
}
const startOffset = node.sourceCodeLocation.startOffset
const endOffset = node.sourceCodeLocation.endOffset
if (startOffset > lastEndOffset) {
parsedResult.push({
type: 'string',
content: input.slice(lastEndOffset, startOffset),
})
}
let path: string | undefined
let content: string = ''
let startLine: number = 0
// 处理子标签
for (const childNode of node.childNodes) {
if (childNode.nodeName === 'path' && childNode.childNodes.length > 0) {
path = childNode.childNodes[0].value
} else if (childNode.nodeName === 'operations' && childNode.childNodes.length > 0) {
try {
const operationsJson = childNode.childNodes[0].value
const operations = JSON5.parse(operationsJson)
if (Array.isArray(operations) && operations.length > 0) {
const operation = operations[0]
startLine = operation.start_line || 1
content = operation.content || ''
}
} catch (error) {
console.error('Failed to parse operations JSON', error)
}
}
}
parsedResult.push({
type: 'insert_content',
path,
startLine,
content
})
lastEndOffset = endOffset
} else if (node.nodeName === 'search_and_replace') {
if (!node.sourceCodeLocation) {
throw new Error('sourceCodeLocation is undefined')
}
const startOffset = node.sourceCodeLocation.startOffset
const endOffset = node.sourceCodeLocation.endOffset
if (startOffset > lastEndOffset) {
parsedResult.push({
type: 'string',
content: input.slice(lastEndOffset, startOffset),
})
}
let path: string | undefined
let operations = []
// 处理子标签
for (const childNode of node.childNodes) {
if (childNode.nodeName === 'path' && childNode.childNodes.length > 0) {
path = childNode.childNodes[0].value
} else if (childNode.nodeName === 'operations' && childNode.childNodes.length > 0) {
try {
const operationsJson = childNode.childNodes[0].value
operations = JSON5.parse(operationsJson)
} catch (error) {
console.error('Failed to parse operations JSON', error)
}
}
}
parsedResult.push({
type: 'search_and_replace',
path,
operations,
finish: node.sourceCodeLocation.endTag !== undefined
})
lastEndOffset = endOffset
} else if (node.nodeName === 'attempt_completion') {
if (!node.sourceCodeLocation) {
throw new Error('sourceCodeLocation is undefined')
}
const startOffset = node.sourceCodeLocation.startOffset
const endOffset = node.sourceCodeLocation.endOffset
if (startOffset > lastEndOffset) {
parsedResult.push({
type: 'string',
content: input.slice(lastEndOffset, startOffset),
})
}
let result: string | undefined
for (const childNode of node.childNodes) {
if (childNode.nodeName === 'result' && childNode.childNodes.length > 0) {
result = childNode.childNodes[0].value
}
}
parsedResult.push({
type: 'attempt_completion',
result,
})
lastEndOffset = endOffset
} else if (node.nodeName === 'ask_followup_question') {
if (!node.sourceCodeLocation) {
throw new Error('sourceCodeLocation is undefined')
}
const startOffset = node.sourceCodeLocation.startOffset
const endOffset = node.sourceCodeLocation.endOffset
if (startOffset > lastEndOffset) {
parsedResult.push({
type: 'string',
content: input.slice(lastEndOffset, startOffset),
})
}
let question: string | undefined
for (const childNode of node.childNodes) {
if (childNode.nodeName === 'question' && childNode.childNodes.length > 0) {
question = childNode.childNodes[0].value
}
}
parsedResult.push({
type: 'ask_followup_question',
question,
})
lastEndOffset = endOffset
}
}
// handle the last part of the input
if (lastEndOffset < input.length) {
parsedResult.push({
type: 'string',
content: input.slice(lastEndOffset),
})
}
return parsedResult
} catch (error) {
console.error('Failed to parse infio block', error)
throw error
}
return parsedResult
}

107
src/utils/path.ts Normal file
View File

@ -0,0 +1,107 @@
import os from "os"
import * as path from "path"
/*
The Node.js 'path' module resolves and normalizes paths differently depending on the platform:
- On Windows, it uses backslashes (\) as the default path separator.
- On POSIX-compliant systems (Linux, macOS), it uses forward slashes (/) as the default path separator.
While modules like 'upath' can be used to normalize paths to use forward slashes consistently,
this can create inconsistencies when interfacing with other modules (like vscode.fs) that use
backslashes on Windows.
Our approach:
1. We present paths with forward slashes to the AI and user for consistency.
2. We use the 'arePathsEqual' function for safe path comparisons.
3. Internally, Node.js gracefully handles both backslashes and forward slashes.
This strategy ensures consistent path presentation while leveraging Node.js's built-in
path handling capabilities across different platforms.
Note: When interacting with the file system or VS Code APIs, we still use the native path module
to ensure correct behavior on all platforms. The toPosixPath and arePathsEqual functions are
primarily used for presentation and comparison purposes, not for actual file system operations.
Observations:
- Macos isn't so flexible with mixed separators, whereas windows can handle both. ("Node.js does automatically handle path separators on Windows, converting forward slashes to backslashes as needed. However, on macOS and other Unix-like systems, the path separator is always a forward slash (/), and backslashes are treated as regular characters.")
*/
function toPosixPath(p: string) {
// Extended-Length Paths in Windows start with "\\?\" to allow longer paths and bypass usual parsing. If detected, we return the path unmodified to maintain functionality, as altering these paths could break their special syntax.
const isExtendedLengthPath = p.startsWith("\\\\?\\")
if (isExtendedLengthPath) {
return p
}
return p.replace(/\\/g, "/")
}
// Declaration merging allows us to add a new method to the String type
// You must import this file in your entry point (extension.ts) to have access at runtime
declare global {
interface String {
toPosix(): string
}
}
String.prototype.toPosix = function (this: string): string {
return toPosixPath(this)
}
// Safe path comparison that works across different platforms
export function arePathsEqual(path1?: string, path2?: string): boolean {
if (!path1 && !path2) {
return true
}
if (!path1 || !path2) {
return false
}
path1 = normalizePath(path1)
path2 = normalizePath(path2)
if (process.platform === "win32") {
return path1.toLowerCase() === path2.toLowerCase()
}
return path1 === path2
}
function normalizePath(p: string): string {
// normalize resolve ./.. segments, removes duplicate slashes, and standardizes path separators
let normalized = path.normalize(p)
// however it doesn't remove trailing slashes
// remove trailing slash, except for root paths
if (normalized.length > 1 && (normalized.endsWith("/") || normalized.endsWith("\\"))) {
normalized = normalized.slice(0, -1)
}
return normalized
}
export function getReadablePath(cwd: string, relPath?: string): string {
relPath = relPath || ""
// path.resolve is flexible in that it will resolve relative paths like '../../' to the cwd and even ignore the cwd if the relPath is actually an absolute path
const absolutePath = path.resolve(cwd, relPath)
if (arePathsEqual(cwd, path.join(os.homedir(), "Desktop"))) {
// User opened vscode without a workspace, so cwd is the Desktop. Show the full absolute path to keep the user aware of where files are being created
return absolutePath.toPosix()
}
if (arePathsEqual(path.normalize(absolutePath), path.normalize(cwd))) {
return path.basename(absolutePath).toPosix()
} else {
// show the relative path to the cwd
const normalizedRelPath = path.relative(cwd, absolutePath)
if (absolutePath.includes(cwd)) {
return normalizedRelPath.toPosix()
} else {
// we are outside the cwd, so show the absolute path (useful for when cline passes in '../../' for example)
return absolutePath.toPosix()
}
}
}
export const toRelativePath = (filePath: string, cwd: string) => {
const relativePath = path.relative(cwd, filePath).toPosix()
return filePath.endsWith("/") ? relativePath + "/" : relativePath
}

View File

@ -1,28 +1,107 @@
import { App, TFile, htmlToMarkdown, requestUrl } from 'obsidian'
import { App, MarkdownView, TAbstractFile, TFile, TFolder, Vault, htmlToMarkdown, requestUrl } from 'obsidian'
import { editorStateToPlainText } from '../components/chat-view/chat-input/utils/editor-state-to-plain-text'
import { QueryProgressState } from '../components/chat-view/QueryProgress'
import { SYSTEM_PROMPT } from '../core/prompts/system'
import { RAGEngine } from '../core/rag/rag-engine'
import { SelectVector } from '../database/schema'
import { ChatMessage, ChatUserMessage } from '../types/chat'
import { ContentPart, RequestMessage } from '../types/llm/request'
import {
MentionableBlock, MentionableCurrentFile, MentionableFile,
MentionableBlock,
MentionableFile,
MentionableFolder,
MentionableImage,
MentionableUrl,
MentionableVault
} from '../types/mentionable'
import { InfioSettings } from '../types/settings'
import { defaultModeSlug, getFullModeDetails } from "../utils/modes"
import { listFilesAndFolders } from './glob-utils'
import {
getNestedFiles,
readMultipleTFiles,
readTFileContent,
readTFileContent
} from './obsidian'
import { tokenCount } from './token'
import { YoutubeTranscript, isYoutubeUrl } from './youtube-transcript'
export function addLineNumbers(content: string, startLine: number = 1): string {
const lines = content.split("\n")
const maxLineNumberWidth = String(startLine + lines.length - 1).length
return lines
.map((line, index) => {
const lineNumber = String(startLine + index).padStart(maxLineNumberWidth, " ")
return `${lineNumber} | ${line}`
})
.join("\n")
}
async function getFolderTreeContent(path: TFolder): Promise<string> {
try {
const entries = path.children
let folderContent = ""
entries.forEach((entry, index) => {
const isLast = index === entries.length - 1
const linePrefix = isLast ? "└── " : "├── "
if (entry instanceof TFile) {
folderContent += `${linePrefix}${entry.name}\n`
} else if (entry instanceof TFolder) {
folderContent += `${linePrefix}${entry.name}/\n`
} else {
folderContent += `${linePrefix}${entry.name}\n`
}
})
return folderContent
} catch (error) {
throw new Error(`Failed to access path "${path.path}": ${error.message}`)
}
}
async function getFileOrFolderContent(path: TAbstractFile, vault: Vault): Promise<string> {
try {
if (path instanceof TFile) {
if (path.extension != 'md') {
return "(Binary file, unable to display content)"
}
return addLineNumbers(await readTFileContent(path, vault))
} else if (path instanceof TFolder) {
const entries = path.children
let folderContent = ""
const fileContentPromises: Promise<string | undefined>[] = []
entries.forEach((entry, index) => {
const isLast = index === entries.length - 1
const linePrefix = isLast ? "└── " : "├── "
if (entry instanceof TFile) {
folderContent += `${linePrefix}${entry.name}\n`
fileContentPromises.push(
(async () => {
try {
if (entry.extension != 'md') {
return undefined
}
const content = addLineNumbers(await readTFileContent(entry, vault))
return `<file_content path="${entry.path}">\n${content}\n</file_content>`
} catch (error) {
return undefined
}
})(),
)
} else if (entry instanceof TFolder) {
folderContent += `${linePrefix}${entry.name}/\n`
} else {
folderContent += `${linePrefix}${entry.name}\n`
}
})
const fileContents = (await Promise.all(fileContentPromises)).filter((content) => content)
return `${folderContent}\n${fileContents.join("\n\n")}`.trim()
} else {
return `(Failed to read contents of ${path.path})`
}
} catch (error) {
throw new Error(`Failed to access path "${path.path}": ${error.message}`)
}
}
export class PromptGenerator {
private getRagEngine: () => Promise<RAGEngine>
private app: App
@ -47,12 +126,10 @@ export class PromptGenerator {
messages,
useVaultSearch,
onQueryProgressChange,
type,
}: {
messages: ChatMessage[]
useVaultSearch?: boolean
onQueryProgressChange?: (queryProgress: QueryProgressState) => void
type?: string
}): Promise<{
requestMessages: RequestMessage[]
compiledMessages: ChatMessage[]
@ -64,14 +141,16 @@ export class PromptGenerator {
if (lastUserMessage.role !== 'user') {
throw new Error('Last message is not a user message')
}
const isNewChat = messages.filter(message => message.role === 'user').length === 1
const { promptContent, shouldUseRAG, similaritySearchResults } =
const { promptContent, similaritySearchResults } =
await this.compileUserMessagePrompt({
isNewChat,
message: lastUserMessage,
useVaultSearch,
onQueryProgressChange,
})
let compiledMessages = [
const compiledMessages = [
...messages.slice(0, -1),
{
...lastUserMessage,
@ -80,39 +159,10 @@ export class PromptGenerator {
},
]
// Safeguard: ensure all user messages have parsed content
compiledMessages = await Promise.all(
compiledMessages.map(async (message) => {
if (message.role === 'user' && !message.promptContent) {
const { promptContent, similaritySearchResults } =
await this.compileUserMessagePrompt({
message,
})
return {
...message,
promptContent,
similaritySearchResults,
}
}
return message
}),
)
const systemMessage = this.getSystemMessage(shouldUseRAG, type)
const customInstructionMessage = this.getCustomInstructionMessage()
const currentFile = lastUserMessage.mentionables.find(
(m): m is MentionableCurrentFile => m.type === 'current-file',
)?.file
const currentFileMessage = currentFile
? await this.getCurrentFileMessage(currentFile)
: undefined
const systemMessage = await this.getSystemMessageNew()
const requestMessages: RequestMessage[] = [
systemMessage,
...(customInstructionMessage ? [customInstructionMessage, PromptGenerator.EMPTY_ASSISTANT_MESSAGE] : []),
...(currentFileMessage ? [currentFileMessage, PromptGenerator.EMPTY_ASSISTANT_MESSAGE] : []),
...compiledMessages.slice(-19).map((message): RequestMessage => {
if (message.role === 'user') {
return {
@ -126,7 +176,6 @@ export class PromptGenerator {
}
}
}),
...(shouldUseRAG ? [this.getRagInstructionMessage()] : []),
]
return {
@ -135,27 +184,101 @@ export class PromptGenerator {
}
}
private async getEnvironmentDetails() {
let details = ""
// Obsidian Current File
details += "\n\n# Obsidian Current File"
const currentFile = this.app.workspace.getActiveFile()
if (currentFile) {
details += `\n${currentFile?.path}`
} else {
details += "\n(No current file)"
}
// Obsidian Open Tabs
details += "\n\n# Obsidian Open Tabs"
const openTabs: string[] = [];
this.app.workspace.iterateAllLeaves(leaf => {
if (leaf.view instanceof MarkdownView && leaf.view.file) {
openTabs.push(leaf.view.file?.path);
}
});
if (openTabs.length === 0) {
details += "\n(No open tabs)"
} else {
details += `\n${openTabs.join("\n")}`
}
// Add current time information with timezone
const now = new Date()
const formatter = new Intl.DateTimeFormat(undefined, {
year: "numeric",
month: "numeric",
day: "numeric",
hour: "numeric",
minute: "numeric",
second: "numeric",
hour12: true,
})
const timeZone = formatter.resolvedOptions().timeZone
const timeZoneOffset = -now.getTimezoneOffset() / 60 // Convert to hours and invert sign to match conventional notation
const timeZoneOffsetStr = `${timeZoneOffset >= 0 ? "+" : ""}${timeZoneOffset}:00`
details += `\n\n# Current Time\n${formatter.format(now)} (${timeZone}, UTC${timeZoneOffsetStr})`
// Add current mode details
const currentMode = defaultModeSlug
const modeDetails = await getFullModeDetails(currentMode)
details += `\n\n# Current Mode\n`
details += `<slug>${currentMode}</slug>\n`
details += `<name>${modeDetails.name}</name>\n`
// // Obsidian Current Folder
// const currentFolder = this.app.workspace.getActiveFile() ? this.app.workspace.getActiveFile()?.parent?.path : "/"
// // Obsidian Vault Files and Folders
// if (currentFolder) {
// details += `\n\n# Obsidian Current Folder (${currentFolder}) Files`
// const filesAndFolders = await listFilesAndFolders(this.app.vault, currentFolder)
// if (filesAndFolders.length > 0) {
// details += `\n${filesAndFolders.filter(Boolean).join("\n")}`
// } else {
// details += "\n(No Markdown files in current folder)"
// }
// } else {
// details += "\n(No current folder)"
// }
return `<environment_details>\n${details.trim()}\n</environment_details>`
}
private async compileUserMessagePrompt({
isNewChat,
message,
useVaultSearch,
onQueryProgressChange,
}: {
isNewChat: boolean
message: ChatUserMessage
useVaultSearch?: boolean
onQueryProgressChange?: (queryProgress: QueryProgressState) => void
}): Promise<{
promptContent: ChatUserMessage['promptContent']
shouldUseRAG: boolean
similaritySearchResults?: (Omit<SelectVector, 'embedding'> & {
similarity: number
})[]
}> {
if (!message.content) {
// Add environment details
const environmentDetails = isNewChat
? await this.getEnvironmentDetails()
: undefined
// if isToolCallReturn, add read_file_content to promptContent
if (message.content === null) {
return {
promptContent: '',
shouldUseRAG: false,
promptContent: message.promptContent,
similaritySearchResults: undefined,
}
}
const query = editorStateToPlainText(message.content)
let similaritySearchResults = undefined
@ -169,33 +292,94 @@ export class PromptGenerator {
onQueryProgressChange?.({
type: 'reading-mentionables',
})
const taskPrompt = isNewChat ? `<task>\n${query}\n</task>` : `<feedback>\n${query}\n</feedback>`
// user mention files
const files = message.mentionables
.filter((m): m is MentionableFile => m.type === 'file')
.map((m) => m.file)
let fileContentsPrompts = files.length > 0
? (await Promise.all(files.map(async (file) => {
const content = await getFileOrFolderContent(file, this.app.vault)
return `<file_content path="${file.path}">\n${content}\n</file_content>`
}))).join('\n')
: undefined
// user mention folders
const folders = message.mentionables
.filter((m): m is MentionableFolder => m.type === 'folder')
.map((m) => m.folder)
const nestedFiles = folders.flatMap((folder) =>
getNestedFiles(folder, this.app.vault),
let folderContentsPrompts = folders.length > 0
? (await Promise.all(folders.map(async (folder) => {
const content = await getFileOrFolderContent(folder, this.app.vault)
return `<folder_content path="${folder.path}">\n${content}\n</folder_content>`
}))).join('\n')
: undefined
// user mention blocks
const blocks = message.mentionables.filter(
(m): m is MentionableBlock => m.type === 'block',
)
const allFiles = [...files, ...nestedFiles]
const fileContents = await readMultipleTFiles(allFiles, this.app.vault)
const blockContentsPrompt = blocks.length > 0
? blocks
.map(({ file, content, startLine, endLine }) => {
const content_with_line_numbers = addLineNumbers(content, startLine)
return `<file_block_content location="${file.path}#L${startLine}-${endLine}">\n${content_with_line_numbers}\n</file_block_content>`
})
.join('\n')
: undefined
// Count tokens incrementally to avoid long processing times on large content sets
const exceedsTokenThreshold = async () => {
let accTokenCount = 0
for (const content of fileContents) {
const count = await tokenCount(content)
accTokenCount += count
if (accTokenCount > this.settings.ragOptions.thresholdTokens) {
return true
}
// user mention urls
const urls = message.mentionables.filter(
(m): m is MentionableUrl => m.type === 'url',
)
const urlContents = await Promise.all(
urls.map(async ({ url }) => ({
url,
content: await this.getWebsiteContent(url)
}))
)
const urlContentsPrompt = urlContents.length > 0
? urlContents
.map(({ url, content }) => (
`<url_content url="${url}">\n${content}\n</url_content>`
))
.join('\n') : undefined
const currentFile = message.mentionables
.filter((m): m is MentionableFile => m.type === 'current-file')
.first()
const currentFileContent = currentFile && currentFile.file != null
? await getFileOrFolderContent(currentFile.file, this.app.vault)
: undefined
const currentFileContentPrompt = isNewChat && currentFileContent
? `<current_file_content path="${currentFile.file.path}">\n${currentFileContent}\n</current_file_content>`
: undefined
// Count file and folder tokens
let accTokenCount = 0
let isOverThreshold = false
for (const content of [fileContentsPrompts, folderContentsPrompts].filter(Boolean)) {
const count = await tokenCount(content)
accTokenCount += count
if (accTokenCount > this.settings.ragOptions.thresholdTokens) {
isOverThreshold = true
}
return false
}
const shouldUseRAG = useVaultSearch || (await exceedsTokenThreshold())
if (isOverThreshold) {
fileContentsPrompts = files.map((file) => {
return `<file_content path="${file.path}">\n(Content omitted due to token limit. Relevant sections will be provided by semantic search below.)\n</file_content>`
}).join('\n')
folderContentsPrompts = folders.map(async (folder) => {
const tree_content = await getFolderTreeContent(folder)
return `<folder_content path="${folder.path}">\n${tree_content}\n(Content omitted due to token limit. Relevant sections will be provided by semantic search below.)\n</folder_content>`
}).join('\n')
}
let filePrompt: string
const shouldUseRAG = useVaultSearch || isOverThreshold
let similaritySearchContents
if (shouldUseRAG) {
similaritySearchResults = useVaultSearch
? await (
@ -203,7 +387,7 @@ export class PromptGenerator {
).processQuery({
query,
onQueryProgressChange: onQueryProgressChange,
}) // TODO: Add similarity boosting for mentioned files or folders
})
: await (
await this.getRagEngine()
).processQuery({
@ -214,60 +398,42 @@ export class PromptGenerator {
},
onQueryProgressChange: onQueryProgressChange,
})
filePrompt = `## Potentially relevant snippets from the current vault
${similaritySearchResults
.map(({ path, content, metadata }) => {
const contentWithLineNumbers = this.addLineNumbersToContent({
content,
startLine: metadata.startLine,
})
return `\`\`\`${path}\n${contentWithLineNumbers}\n\`\`\`\n`
})
.join('')}\n`
} else {
filePrompt = allFiles
.map((file, index) => {
return `\`\`\`${file.path}\n${fileContents[index]}\n\`\`\`\n`
const snippets = similaritySearchResults.map(({ path, content, metadata }) => {
const contentWithLineNumbers = this.addLineNumbersToContent({
content,
startLine: metadata.startLine,
})
.join('')
return `<file_block_content location="${path}#L${metadata.startLine}-${metadata.endLine}">\n${contentWithLineNumbers}\n</file_block_content>`
}).join('\n')
similaritySearchContents = snippets.length > 0
? `<similarity_search_results>\n${snippets}\n</similarity_search_results>`
: '<similarity_search_results>\n(No relevant results found)\n</similarity_search_results>'
} else {
similaritySearchContents = undefined
}
const blocks = message.mentionables.filter(
(m): m is MentionableBlock => m.type === 'block',
)
const blockPrompt = blocks
.map(({ file, content, startLine, endLine }) => {
return `\`\`\`${file.path}#L${startLine}-${endLine}\n${content}\n\`\`\`\n`
})
.join('')
const urls = message.mentionables.filter(
(m): m is MentionableUrl => m.type === 'url',
)
const urlPrompt =
urls.length > 0
? `## Potentially relevant web search results
${(
await Promise.all(
urls.map(
async ({ url }) => `\`\`\`
Website URL: ${url}
Website Content:
${await this.getWebsiteContent(url)}
\`\`\``,
),
)
).join('\n')}
`
: ''
const parsedText = [
taskPrompt,
blockContentsPrompt,
fileContentsPrompts,
folderContentsPrompts,
urlContentsPrompt,
similaritySearchContents,
currentFileContentPrompt,
environmentDetails,
].filter(Boolean).join('\n\n')
// user mention images
const imageDataUrls = message.mentionables
.filter((m): m is MentionableImage => m.type === 'image')
.map(({ data }) => data)
return {
promptContent: [
{
type: 'text',
text: parsedText,
},
...imageDataUrls.map(
(data): ContentPart => ({
type: 'image_url',
@ -275,14 +441,18 @@ ${await this.getWebsiteContent(url)}
url: data,
},
}),
),
{
type: 'text',
text: `${filePrompt}${blockPrompt}${urlPrompt}\n\n${query}\n\n`,
},
)
],
shouldUseRAG,
similaritySearchResults: similaritySearchResults,
similaritySearchResults,
}
}
private async getSystemMessageNew(): Promise<RequestMessage> {
const systemPrompt = await SYSTEM_PROMPT(this.app.vault.getRoot().path, false)
return {
role: 'system',
content: systemPrompt,
}
}
@ -392,7 +562,7 @@ ${customInstruction}
return {
role: 'user',
content: `# Inputs
## Current file
## Current Open File
Here is the file I'm looking at.
\`\`\`${currentFile.path}
${fileContent}

77
src/utils/tool-groups.ts Normal file
View File

@ -0,0 +1,77 @@
// Define tool group configuration
export type ToolGroupConfig = {
tools: readonly string[]
alwaysAvailable?: boolean // Whether this group is always available and shouldn't show in prompts view
}
// Map of tool slugs to their display names
export const TOOL_DISPLAY_NAMES = {
execute_command: "run commands",
read_file: "read files",
write_to_file: "write files",
apply_diff: "apply changes",
search_files: "search files",
list_files: "list files",
// list_code_definition_names: "list definitions",
browser_action: "use a browser",
use_mcp_tool: "use mcp tools",
access_mcp_resource: "access mcp resources",
ask_followup_question: "ask questions",
attempt_completion: "complete tasks",
switch_mode: "switch modes",
new_task: "create new task",
} as const
// Define available tool groups
export const TOOL_GROUPS: Record<string, ToolGroupConfig> = {
read: {
tools: ["read_file", "list_files", "search_files"],
},
edit: {
tools: ["apply_diff", "write_to_file", "insert_content", "search_and_replace"],
},
// browser: {
// tools: ["browser_action"],
// },
// command: {
// tools: ["execute_command"],
// },
mcp: {
tools: ["use_mcp_tool", "access_mcp_resource"],
},
modes: {
tools: ["switch_mode",],
alwaysAvailable: true,
},
}
export type ToolGroup = keyof typeof TOOL_GROUPS
// Tools that are always available to all modes
export const ALWAYS_AVAILABLE_TOOLS = [
"ask_followup_question",
"attempt_completion",
"switch_mode",
"new_task",
] as const
// Tool name types for type safety
export type ToolName = keyof typeof TOOL_DISPLAY_NAMES
// Tool helper functions
export function getToolName(toolConfig: string | readonly [ToolName, ...any[]]): ToolName {
return typeof toolConfig === "string" ? (toolConfig as ToolName) : toolConfig[0]
}
export function getToolOptions(toolConfig: string | readonly [ToolName, ...any[]]): any {
return typeof toolConfig === "string" ? undefined : toolConfig[1]
}
// Display names for groups in UI
export const GROUP_DISPLAY_NAMES: Record<ToolGroup, string> = {
read: "Read Files",
edit: "Edit Files",
browser: "Use Browser",
command: "Run Commands",
mcp: "Use MCP",
}