add tool use, update system prompt
This commit is contained in:
parent
cabf2d5fa4
commit
b0fbbb22d3
@ -9,6 +9,7 @@ export type ApplyViewState = {
|
|||||||
file: TFile
|
file: TFile
|
||||||
originalContent: string
|
originalContent: string
|
||||||
newContent: string
|
newContent: string
|
||||||
|
onClose: (applied: boolean) => void
|
||||||
}
|
}
|
||||||
|
|
||||||
export class ApplyView extends View {
|
export class ApplyView extends View {
|
||||||
|
|||||||
@ -37,10 +37,16 @@ export default function ApplyViewRoot({
|
|||||||
.map((change) => change.value)
|
.map((change) => change.value)
|
||||||
.join('')
|
.join('')
|
||||||
await app.vault.modify(state.file, newContent)
|
await app.vault.modify(state.file, newContent)
|
||||||
|
if (state.onClose) {
|
||||||
|
state.onClose(true)
|
||||||
|
}
|
||||||
close()
|
close()
|
||||||
}
|
}
|
||||||
|
|
||||||
const handleReject = async () => {
|
const handleReject = async () => {
|
||||||
|
if (state.onClose) {
|
||||||
|
state.onClose(false)
|
||||||
|
}
|
||||||
close()
|
close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -1,127 +0,0 @@
|
|||||||
import { $generateNodesFromSerializedNodes } from '@lexical/clipboard'
|
|
||||||
import { BaseSerializedNode } from '@lexical/clipboard/clipboard'
|
|
||||||
import { InitialEditorStateType } from '@lexical/react/LexicalComposer'
|
|
||||||
import * as Dialog from '@radix-ui/react-dialog'
|
|
||||||
import { $insertNodes, LexicalEditor } from 'lexical'
|
|
||||||
import { X } from 'lucide-react'
|
|
||||||
import { Notice } from 'obsidian'
|
|
||||||
import { useRef, useState } from 'react'
|
|
||||||
|
|
||||||
import { useDatabase } from '../../contexts/DatabaseContext'
|
|
||||||
import { useDialogContainer } from '../../contexts/DialogContext'
|
|
||||||
import { DuplicateTemplateException } from '../../database/exception'
|
|
||||||
|
|
||||||
import LexicalContentEditable from './chat-input/LexicalContentEditable'
|
|
||||||
|
|
||||||
/*
|
|
||||||
* This component must be used inside <Dialog.Root modal={false}>
|
|
||||||
* The modal={false} prop is required because modal mode blocks pointer events across the entire page,
|
|
||||||
* which would conflict with lexical editor popovers
|
|
||||||
*/
|
|
||||||
export default function CreateTemplateDialogContent({
|
|
||||||
selectedSerializedNodes,
|
|
||||||
onClose,
|
|
||||||
}: {
|
|
||||||
selectedSerializedNodes?: BaseSerializedNode[] | null
|
|
||||||
onClose: () => void
|
|
||||||
}) {
|
|
||||||
const container = useDialogContainer()
|
|
||||||
const { getTemplateManager } = useDatabase()
|
|
||||||
|
|
||||||
const [templateName, setTemplateName] = useState('')
|
|
||||||
const editorRef = useRef<LexicalEditor | null>(null)
|
|
||||||
const contentEditableRef = useRef<HTMLDivElement>(null)
|
|
||||||
|
|
||||||
const initialEditorState: InitialEditorStateType = (
|
|
||||||
editor: LexicalEditor,
|
|
||||||
) => {
|
|
||||||
if (!selectedSerializedNodes) return
|
|
||||||
editor.update(() => {
|
|
||||||
const parsedNodes = $generateNodesFromSerializedNodes(
|
|
||||||
selectedSerializedNodes,
|
|
||||||
)
|
|
||||||
$insertNodes(parsedNodes)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
const onSubmit = async () => {
|
|
||||||
try {
|
|
||||||
if (!editorRef.current) return
|
|
||||||
const serializedEditorState = editorRef.current.toJSON()
|
|
||||||
const nodes = serializedEditorState.editorState.root.children
|
|
||||||
if (nodes.length === 0) {
|
|
||||||
new Notice('Please enter a content for your template')
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if (templateName.trim().length === 0) {
|
|
||||||
new Notice('Please enter a name for your template')
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
await (
|
|
||||||
await getTemplateManager()
|
|
||||||
).createTemplate({
|
|
||||||
name: templateName,
|
|
||||||
content: { nodes },
|
|
||||||
})
|
|
||||||
new Notice(`Template created: ${templateName}`)
|
|
||||||
setTemplateName('')
|
|
||||||
onClose()
|
|
||||||
} catch (error) {
|
|
||||||
if (error instanceof DuplicateTemplateException) {
|
|
||||||
new Notice('A template with this name already exists')
|
|
||||||
} else {
|
|
||||||
console.error(error)
|
|
||||||
new Notice('Failed to create template')
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return (
|
|
||||||
<Dialog.Portal container={container}>
|
|
||||||
<Dialog.Content className="infio-chat-dialog-content">
|
|
||||||
<div className="infio-dialog-header">
|
|
||||||
<Dialog.Title className="infio-dialog-title">
|
|
||||||
Create template
|
|
||||||
</Dialog.Title>
|
|
||||||
<Dialog.Description className="infio-dialog-description">
|
|
||||||
Create template from selected content
|
|
||||||
</Dialog.Description>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div className="infio-dialog-input">
|
|
||||||
<label>Name</label>
|
|
||||||
<input
|
|
||||||
type="text"
|
|
||||||
value={templateName}
|
|
||||||
onChange={(e) => setTemplateName(e.target.value)}
|
|
||||||
onKeyDown={(e) => {
|
|
||||||
if (e.key === 'Enter') {
|
|
||||||
e.stopPropagation()
|
|
||||||
e.preventDefault()
|
|
||||||
onSubmit()
|
|
||||||
}
|
|
||||||
}}
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div className="infio-chat-user-input-container">
|
|
||||||
<LexicalContentEditable
|
|
||||||
initialEditorState={initialEditorState}
|
|
||||||
editorRef={editorRef}
|
|
||||||
contentEditableRef={contentEditableRef}
|
|
||||||
onEnter={onSubmit}
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div className="infio-dialog-footer">
|
|
||||||
<button onClick={onSubmit}>Create template</button>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<Dialog.Close className="infio-dialog-close" asChild>
|
|
||||||
<X size={16} />
|
|
||||||
</Dialog.Close>
|
|
||||||
</Dialog.Content>
|
|
||||||
</Dialog.Portal>
|
|
||||||
)
|
|
||||||
}
|
|
||||||
@ -1,11 +1,12 @@
|
|||||||
import { MarkdownView, Plugin, Platform } from 'obsidian';
|
|
||||||
import React, { useEffect, useMemo, useRef, useState } from 'react';
|
|
||||||
import { CornerDownLeft } from 'lucide-react';
|
import { CornerDownLeft } from 'lucide-react';
|
||||||
|
import { MarkdownView, Plugin } from 'obsidian';
|
||||||
|
import React, { useEffect, useRef, useState } from 'react';
|
||||||
|
|
||||||
import { APPLY_VIEW_TYPE } from '../../constants';
|
import { APPLY_VIEW_TYPE } from '../../constants';
|
||||||
import LLMManager from '../../core/llm/manager';
|
import LLMManager from '../../core/llm/manager';
|
||||||
import { InfioSettings } from '../../types/settings';
|
import { InfioSettings } from '../../types/settings';
|
||||||
import { GetProviderModelIds } from '../../utils/api';
|
import { GetProviderModelIds } from '../../utils/api';
|
||||||
import { manualApplyChangesToFile } from '../../utils/apply';
|
import { ApplyEditToFile } from '../../utils/apply';
|
||||||
import { removeAITags } from '../../utils/content-filter';
|
import { removeAITags } from '../../utils/content-filter';
|
||||||
import { PromptGenerator } from '../../utils/prompt-generator';
|
import { PromptGenerator } from '../../utils/prompt-generator';
|
||||||
|
|
||||||
@ -239,10 +240,10 @@ export const InlineEdit: React.FC<InlineEditProps> = ({
|
|||||||
const startLine = parsedBlock?.startLine || defaultStartLine;
|
const startLine = parsedBlock?.startLine || defaultStartLine;
|
||||||
const endLine = parsedBlock?.endLine || defaultEndLine;
|
const endLine = parsedBlock?.endLine || defaultEndLine;
|
||||||
|
|
||||||
const updatedContent = await manualApplyChangesToFile(
|
const updatedContent = await ApplyEditToFile(
|
||||||
finalContent,
|
|
||||||
activeFile,
|
activeFile,
|
||||||
await plugin.app.vault.read(activeFile),
|
await plugin.app.vault.read(activeFile),
|
||||||
|
finalContent,
|
||||||
startLine,
|
startLine,
|
||||||
endLine
|
endLine
|
||||||
);
|
);
|
||||||
|
|||||||
22
src/core/diff/DiffStrategy.ts
Normal file
22
src/core/diff/DiffStrategy.ts
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
import type { DiffStrategy } from "./types"
|
||||||
|
import { UnifiedDiffStrategy } from "./strategies/unified"
|
||||||
|
import { SearchReplaceDiffStrategy } from "./strategies/search-replace"
|
||||||
|
import { NewUnifiedDiffStrategy } from "./strategies/new-unified"
|
||||||
|
/**
|
||||||
|
* Get the appropriate diff strategy for the given model
|
||||||
|
* @param model The name of the model being used (e.g., 'gpt-4', 'claude-3-opus')
|
||||||
|
* @returns The appropriate diff strategy for the model
|
||||||
|
*/
|
||||||
|
export function getDiffStrategy(
|
||||||
|
model: string,
|
||||||
|
fuzzyMatchThreshold?: number,
|
||||||
|
experimentalDiffStrategy: boolean = false,
|
||||||
|
): DiffStrategy {
|
||||||
|
if (experimentalDiffStrategy) {
|
||||||
|
return new NewUnifiedDiffStrategy(fuzzyMatchThreshold)
|
||||||
|
}
|
||||||
|
return new SearchReplaceDiffStrategy(fuzzyMatchThreshold)
|
||||||
|
}
|
||||||
|
|
||||||
|
export type { DiffStrategy }
|
||||||
|
export { UnifiedDiffStrategy, SearchReplaceDiffStrategy }
|
||||||
31
src/core/diff/insert-groups.ts
Normal file
31
src/core/diff/insert-groups.ts
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
/**
|
||||||
|
* Inserts multiple groups of elements at specified indices in an array
|
||||||
|
* @param original Array to insert into, split by lines
|
||||||
|
* @param insertGroups Array of groups to insert, each with an index and elements to insert
|
||||||
|
* @returns New array with all insertions applied
|
||||||
|
*/
|
||||||
|
export interface InsertGroup {
|
||||||
|
index: number
|
||||||
|
elements: string[]
|
||||||
|
}
|
||||||
|
|
||||||
|
export function insertGroups(original: string[], insertGroups: InsertGroup[]): string[] {
|
||||||
|
// Sort groups by index to maintain order
|
||||||
|
insertGroups.sort((a, b) => a.index - b.index)
|
||||||
|
|
||||||
|
let result: string[] = []
|
||||||
|
let lastIndex = 0
|
||||||
|
|
||||||
|
insertGroups.forEach(({ index, elements }) => {
|
||||||
|
// Add elements from original array up to insertion point
|
||||||
|
result.push(...original.slice(lastIndex, index))
|
||||||
|
// Add the group of elements
|
||||||
|
result.push(...elements)
|
||||||
|
lastIndex = index
|
||||||
|
})
|
||||||
|
|
||||||
|
// Add remaining elements from original array
|
||||||
|
result.push(...original.slice(lastIndex))
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
738
src/core/diff/strategies/__tests__/new-unified.test.ts
Normal file
738
src/core/diff/strategies/__tests__/new-unified.test.ts
Normal file
@ -0,0 +1,738 @@
|
|||||||
|
import { NewUnifiedDiffStrategy } from "../new-unified"
|
||||||
|
|
||||||
|
describe("main", () => {
|
||||||
|
let strategy: NewUnifiedDiffStrategy
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
strategy = new NewUnifiedDiffStrategy(0.97)
|
||||||
|
})
|
||||||
|
|
||||||
|
describe("constructor", () => {
|
||||||
|
it("should use default confidence threshold when not provided", () => {
|
||||||
|
const defaultStrategy = new NewUnifiedDiffStrategy()
|
||||||
|
expect(defaultStrategy["confidenceThreshold"]).toBe(1)
|
||||||
|
})
|
||||||
|
|
||||||
|
it("should use provided confidence threshold", () => {
|
||||||
|
const customStrategy = new NewUnifiedDiffStrategy(0.85)
|
||||||
|
expect(customStrategy["confidenceThreshold"]).toBe(0.85)
|
||||||
|
})
|
||||||
|
|
||||||
|
it("should enforce minimum confidence threshold", () => {
|
||||||
|
const lowStrategy = new NewUnifiedDiffStrategy(0.7) // Below minimum of 0.8
|
||||||
|
expect(lowStrategy["confidenceThreshold"]).toBe(0.8)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe("getToolDescription", () => {
|
||||||
|
it("should return tool description with correct cwd", () => {
|
||||||
|
const cwd = "/test/path"
|
||||||
|
const description = strategy.getToolDescription({ cwd })
|
||||||
|
|
||||||
|
expect(description).toContain("apply_diff Tool - Generate Precise Code Changes")
|
||||||
|
expect(description).toContain(cwd)
|
||||||
|
expect(description).toContain("Step-by-Step Instructions")
|
||||||
|
expect(description).toContain("Requirements")
|
||||||
|
expect(description).toContain("Examples")
|
||||||
|
expect(description).toContain("Parameters:")
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it("should apply simple diff correctly", async () => {
|
||||||
|
const original = `line1
|
||||||
|
line2
|
||||||
|
line3`
|
||||||
|
|
||||||
|
const diff = `--- a/file.txt
|
||||||
|
+++ b/file.txt
|
||||||
|
@@ ... @@
|
||||||
|
line1
|
||||||
|
+new line
|
||||||
|
line2
|
||||||
|
-line3
|
||||||
|
+modified line3`
|
||||||
|
|
||||||
|
const result = await strategy.applyDiff(original, diff)
|
||||||
|
expect(result.success).toBe(true)
|
||||||
|
if (result.success) {
|
||||||
|
expect(result.content).toBe(`line1
|
||||||
|
new line
|
||||||
|
line2
|
||||||
|
modified line3`)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
it("should handle multiple hunks", async () => {
|
||||||
|
const original = `line1
|
||||||
|
line2
|
||||||
|
line3
|
||||||
|
line4
|
||||||
|
line5`
|
||||||
|
|
||||||
|
const diff = `--- a/file.txt
|
||||||
|
+++ b/file.txt
|
||||||
|
@@ ... @@
|
||||||
|
line1
|
||||||
|
+new line
|
||||||
|
line2
|
||||||
|
-line3
|
||||||
|
+modified line3
|
||||||
|
@@ ... @@
|
||||||
|
line4
|
||||||
|
-line5
|
||||||
|
+modified line5
|
||||||
|
+new line at end`
|
||||||
|
|
||||||
|
const result = await strategy.applyDiff(original, diff)
|
||||||
|
expect(result.success).toBe(true)
|
||||||
|
if (result.success) {
|
||||||
|
expect(result.content).toBe(`line1
|
||||||
|
new line
|
||||||
|
line2
|
||||||
|
modified line3
|
||||||
|
line4
|
||||||
|
modified line5
|
||||||
|
new line at end`)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
it("should handle complex large", async () => {
|
||||||
|
const original = `line1
|
||||||
|
line2
|
||||||
|
line3
|
||||||
|
line4
|
||||||
|
line5
|
||||||
|
line6
|
||||||
|
line7
|
||||||
|
line8
|
||||||
|
line9
|
||||||
|
line10`
|
||||||
|
|
||||||
|
const diff = `--- a/file.txt
|
||||||
|
+++ b/file.txt
|
||||||
|
@@ ... @@
|
||||||
|
line1
|
||||||
|
+header line
|
||||||
|
+another header
|
||||||
|
line2
|
||||||
|
-line3
|
||||||
|
-line4
|
||||||
|
+modified line3
|
||||||
|
+modified line4
|
||||||
|
+extra line
|
||||||
|
@@ ... @@
|
||||||
|
line6
|
||||||
|
+middle section
|
||||||
|
line7
|
||||||
|
-line8
|
||||||
|
+changed line8
|
||||||
|
+bonus line
|
||||||
|
@@ ... @@
|
||||||
|
line9
|
||||||
|
-line10
|
||||||
|
+final line
|
||||||
|
+very last line`
|
||||||
|
|
||||||
|
const result = await strategy.applyDiff(original, diff)
|
||||||
|
expect(result.success).toBe(true)
|
||||||
|
if (result.success) {
|
||||||
|
expect(result.content).toBe(`line1
|
||||||
|
header line
|
||||||
|
another header
|
||||||
|
line2
|
||||||
|
modified line3
|
||||||
|
modified line4
|
||||||
|
extra line
|
||||||
|
line5
|
||||||
|
line6
|
||||||
|
middle section
|
||||||
|
line7
|
||||||
|
changed line8
|
||||||
|
bonus line
|
||||||
|
line9
|
||||||
|
final line
|
||||||
|
very last line`)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
it("should handle indentation changes", async () => {
|
||||||
|
const original = `first line
|
||||||
|
indented line
|
||||||
|
double indented line
|
||||||
|
back to single indent
|
||||||
|
no indent
|
||||||
|
indented again
|
||||||
|
double indent again
|
||||||
|
triple indent
|
||||||
|
back to single
|
||||||
|
last line`
|
||||||
|
|
||||||
|
const diff = `--- original
|
||||||
|
+++ modified
|
||||||
|
@@ ... @@
|
||||||
|
first line
|
||||||
|
indented line
|
||||||
|
+ tab indented line
|
||||||
|
+ new indented line
|
||||||
|
double indented line
|
||||||
|
back to single indent
|
||||||
|
no indent
|
||||||
|
indented again
|
||||||
|
double indent again
|
||||||
|
- triple indent
|
||||||
|
+ hi there mate
|
||||||
|
back to single
|
||||||
|
last line`
|
||||||
|
|
||||||
|
const expected = `first line
|
||||||
|
indented line
|
||||||
|
tab indented line
|
||||||
|
new indented line
|
||||||
|
double indented line
|
||||||
|
back to single indent
|
||||||
|
no indent
|
||||||
|
indented again
|
||||||
|
double indent again
|
||||||
|
hi there mate
|
||||||
|
back to single
|
||||||
|
last line`
|
||||||
|
|
||||||
|
const result = await strategy.applyDiff(original, diff)
|
||||||
|
expect(result.success).toBe(true)
|
||||||
|
if (result.success) {
|
||||||
|
expect(result.content).toBe(expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
it("should handle high level edits", async () => {
|
||||||
|
const original = `def factorial(n):
|
||||||
|
if n == 0:
|
||||||
|
return 1
|
||||||
|
else:
|
||||||
|
return n * factorial(n-1)`
|
||||||
|
const diff = `@@ ... @@
|
||||||
|
-def factorial(n):
|
||||||
|
- if n == 0:
|
||||||
|
- return 1
|
||||||
|
- else:
|
||||||
|
- return n * factorial(n-1)
|
||||||
|
+def factorial(number):
|
||||||
|
+ if number == 0:
|
||||||
|
+ return 1
|
||||||
|
+ else:
|
||||||
|
+ return number * factorial(number-1)`
|
||||||
|
|
||||||
|
const expected = `def factorial(number):
|
||||||
|
if number == 0:
|
||||||
|
return 1
|
||||||
|
else:
|
||||||
|
return number * factorial(number-1)`
|
||||||
|
|
||||||
|
const result = await strategy.applyDiff(original, diff)
|
||||||
|
expect(result.success).toBe(true)
|
||||||
|
if (result.success) {
|
||||||
|
expect(result.content).toBe(expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
it("it should handle very complex edits", async () => {
|
||||||
|
const original = `//Initialize the array that will hold the primes
|
||||||
|
var primeArray = [];
|
||||||
|
/*Write a function that checks for primeness and
|
||||||
|
pushes those values to t*he array*/
|
||||||
|
function PrimeCheck(candidate){
|
||||||
|
isPrime = true;
|
||||||
|
for(var i = 2; i < candidate && isPrime; i++){
|
||||||
|
if(candidate%i === 0){
|
||||||
|
isPrime = false;
|
||||||
|
} else {
|
||||||
|
isPrime = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if(isPrime){
|
||||||
|
primeArray.push(candidate);
|
||||||
|
}
|
||||||
|
return primeArray;
|
||||||
|
}
|
||||||
|
/*Write the code that runs the above until the
|
||||||
|
l ength of the array equa*ls the number of primes
|
||||||
|
desired*/
|
||||||
|
|
||||||
|
var numPrimes = prompt("How many primes?");
|
||||||
|
|
||||||
|
//Display the finished array of primes
|
||||||
|
|
||||||
|
//for loop starting at 2 as that is the lowest prime number keep going until the array is as long as we requested
|
||||||
|
for (var i = 2; primeArray.length < numPrimes; i++) {
|
||||||
|
PrimeCheck(i); //
|
||||||
|
}
|
||||||
|
console.log(primeArray);
|
||||||
|
`
|
||||||
|
|
||||||
|
const diff = `--- test_diff.js
|
||||||
|
+++ test_diff.js
|
||||||
|
@@ ... @@
|
||||||
|
-//Initialize the array that will hold the primes
|
||||||
|
var primeArray = [];
|
||||||
|
-/*Write a function that checks for primeness and
|
||||||
|
- pushes those values to t*he array*/
|
||||||
|
function PrimeCheck(candidate){
|
||||||
|
isPrime = true;
|
||||||
|
for(var i = 2; i < candidate && isPrime; i++){
|
||||||
|
@@ ... @@
|
||||||
|
return primeArray;
|
||||||
|
}
|
||||||
|
-/*Write the code that runs the above until the
|
||||||
|
- l ength of the array equa*ls the number of primes
|
||||||
|
- desired*/
|
||||||
|
|
||||||
|
var numPrimes = prompt("How many primes?");
|
||||||
|
|
||||||
|
-//Display the finished array of primes
|
||||||
|
-
|
||||||
|
-//for loop starting at 2 as that is the lowest prime number keep going until the array is as long as we requested
|
||||||
|
for (var i = 2; primeArray.length < numPrimes; i++) {
|
||||||
|
- PrimeCheck(i); //
|
||||||
|
+ PrimeCheck(i);
|
||||||
|
}
|
||||||
|
console.log(primeArray);`
|
||||||
|
|
||||||
|
const expected = `var primeArray = [];
|
||||||
|
function PrimeCheck(candidate){
|
||||||
|
isPrime = true;
|
||||||
|
for(var i = 2; i < candidate && isPrime; i++){
|
||||||
|
if(candidate%i === 0){
|
||||||
|
isPrime = false;
|
||||||
|
} else {
|
||||||
|
isPrime = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if(isPrime){
|
||||||
|
primeArray.push(candidate);
|
||||||
|
}
|
||||||
|
return primeArray;
|
||||||
|
}
|
||||||
|
|
||||||
|
var numPrimes = prompt("How many primes?");
|
||||||
|
|
||||||
|
for (var i = 2; primeArray.length < numPrimes; i++) {
|
||||||
|
PrimeCheck(i);
|
||||||
|
}
|
||||||
|
console.log(primeArray);
|
||||||
|
`
|
||||||
|
|
||||||
|
const result = await strategy.applyDiff(original, diff)
|
||||||
|
expect(result.success).toBe(true)
|
||||||
|
if (result.success) {
|
||||||
|
expect(result.content).toBe(expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
describe("error handling and edge cases", () => {
|
||||||
|
it("should reject completely invalid diff format", async () => {
|
||||||
|
const original = "line1\nline2\nline3"
|
||||||
|
const invalidDiff = "this is not a diff at all"
|
||||||
|
|
||||||
|
const result = await strategy.applyDiff(original, invalidDiff)
|
||||||
|
expect(result.success).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
it("should reject diff with invalid hunk format", async () => {
|
||||||
|
const original = "line1\nline2\nline3"
|
||||||
|
const invalidHunkDiff = `--- a/file.txt
|
||||||
|
+++ b/file.txt
|
||||||
|
invalid hunk header
|
||||||
|
line1
|
||||||
|
-line2
|
||||||
|
+new line`
|
||||||
|
|
||||||
|
const result = await strategy.applyDiff(original, invalidHunkDiff)
|
||||||
|
expect(result.success).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
it("should fail when diff tries to modify non-existent content", async () => {
|
||||||
|
const original = "line1\nline2\nline3"
|
||||||
|
const nonMatchingDiff = `--- a/file.txt
|
||||||
|
+++ b/file.txt
|
||||||
|
@@ ... @@
|
||||||
|
line1
|
||||||
|
-nonexistent line
|
||||||
|
+new line
|
||||||
|
line3`
|
||||||
|
|
||||||
|
const result = await strategy.applyDiff(original, nonMatchingDiff)
|
||||||
|
expect(result.success).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
it("should handle overlapping hunks", async () => {
|
||||||
|
const original = `line1
|
||||||
|
line2
|
||||||
|
line3
|
||||||
|
line4
|
||||||
|
line5`
|
||||||
|
const overlappingDiff = `--- a/file.txt
|
||||||
|
+++ b/file.txt
|
||||||
|
@@ ... @@
|
||||||
|
line1
|
||||||
|
line2
|
||||||
|
-line3
|
||||||
|
+modified3
|
||||||
|
line4
|
||||||
|
@@ ... @@
|
||||||
|
line2
|
||||||
|
-line3
|
||||||
|
-line4
|
||||||
|
+modified3and4
|
||||||
|
line5`
|
||||||
|
|
||||||
|
const result = await strategy.applyDiff(original, overlappingDiff)
|
||||||
|
expect(result.success).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
it("should handle empty lines modifications", async () => {
|
||||||
|
const original = `line1
|
||||||
|
|
||||||
|
line3
|
||||||
|
|
||||||
|
line5`
|
||||||
|
const emptyLinesDiff = `--- a/file.txt
|
||||||
|
+++ b/file.txt
|
||||||
|
@@ ... @@
|
||||||
|
line1
|
||||||
|
|
||||||
|
-line3
|
||||||
|
+line3modified
|
||||||
|
|
||||||
|
line5`
|
||||||
|
|
||||||
|
const result = await strategy.applyDiff(original, emptyLinesDiff)
|
||||||
|
expect(result.success).toBe(true)
|
||||||
|
if (result.success) {
|
||||||
|
expect(result.content).toBe(`line1
|
||||||
|
|
||||||
|
line3modified
|
||||||
|
|
||||||
|
line5`)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
it("should handle mixed line endings in diff", async () => {
|
||||||
|
const original = "line1\r\nline2\nline3\r\n"
|
||||||
|
const mixedEndingsDiff = `--- a/file.txt
|
||||||
|
+++ b/file.txt
|
||||||
|
@@ ... @@
|
||||||
|
line1\r
|
||||||
|
-line2
|
||||||
|
+modified2\r
|
||||||
|
line3`
|
||||||
|
|
||||||
|
const result = await strategy.applyDiff(original, mixedEndingsDiff)
|
||||||
|
expect(result.success).toBe(true)
|
||||||
|
if (result.success) {
|
||||||
|
expect(result.content).toBe("line1\r\nmodified2\r\nline3\r\n")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
it("should handle partial line modifications", async () => {
|
||||||
|
const original = "const value = oldValue + 123;"
|
||||||
|
const partialDiff = `--- a/file.txt
|
||||||
|
+++ b/file.txt
|
||||||
|
@@ ... @@
|
||||||
|
-const value = oldValue + 123;
|
||||||
|
+const value = newValue + 123;`
|
||||||
|
|
||||||
|
const result = await strategy.applyDiff(original, partialDiff)
|
||||||
|
expect(result.success).toBe(true)
|
||||||
|
if (result.success) {
|
||||||
|
expect(result.content).toBe("const value = newValue + 123;")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
it("should handle slightly malformed but recoverable diff", async () => {
|
||||||
|
const original = "line1\nline2\nline3"
|
||||||
|
// Missing space after --- and +++
|
||||||
|
const slightlyBadDiff = `---a/file.txt
|
||||||
|
+++b/file.txt
|
||||||
|
@@ ... @@
|
||||||
|
line1
|
||||||
|
-line2
|
||||||
|
+new line
|
||||||
|
line3`
|
||||||
|
|
||||||
|
const result = await strategy.applyDiff(original, slightlyBadDiff)
|
||||||
|
expect(result.success).toBe(true)
|
||||||
|
if (result.success) {
|
||||||
|
expect(result.content).toBe("line1\nnew line\nline3")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe("similar code sections", () => {
|
||||||
|
it("should correctly modify the right section when similar code exists", async () => {
|
||||||
|
const original = `function add(a, b) {
|
||||||
|
return a + b;
|
||||||
|
}
|
||||||
|
|
||||||
|
function subtract(a, b) {
|
||||||
|
return a - b;
|
||||||
|
}
|
||||||
|
|
||||||
|
function multiply(a, b) {
|
||||||
|
return a + b; // Bug here
|
||||||
|
}`
|
||||||
|
|
||||||
|
const diff = `--- a/math.js
|
||||||
|
+++ b/math.js
|
||||||
|
@@ ... @@
|
||||||
|
function multiply(a, b) {
|
||||||
|
- return a + b; // Bug here
|
||||||
|
+ return a * b;
|
||||||
|
}`
|
||||||
|
|
||||||
|
const result = await strategy.applyDiff(original, diff)
|
||||||
|
expect(result.success).toBe(true)
|
||||||
|
if (result.success) {
|
||||||
|
expect(result.content).toBe(`function add(a, b) {
|
||||||
|
return a + b;
|
||||||
|
}
|
||||||
|
|
||||||
|
function subtract(a, b) {
|
||||||
|
return a - b;
|
||||||
|
}
|
||||||
|
|
||||||
|
function multiply(a, b) {
|
||||||
|
return a * b;
|
||||||
|
}`)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
it("should handle multiple similar sections with correct context", async () => {
|
||||||
|
const original = `if (condition) {
|
||||||
|
doSomething();
|
||||||
|
doSomething();
|
||||||
|
doSomething();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (otherCondition) {
|
||||||
|
doSomething();
|
||||||
|
doSomething();
|
||||||
|
doSomething();
|
||||||
|
}`
|
||||||
|
|
||||||
|
const diff = `--- a/file.js
|
||||||
|
+++ b/file.js
|
||||||
|
@@ ... @@
|
||||||
|
if (otherCondition) {
|
||||||
|
doSomething();
|
||||||
|
- doSomething();
|
||||||
|
+ doSomethingElse();
|
||||||
|
doSomething();
|
||||||
|
}`
|
||||||
|
|
||||||
|
const result = await strategy.applyDiff(original, diff)
|
||||||
|
expect(result.success).toBe(true)
|
||||||
|
if (result.success) {
|
||||||
|
expect(result.content).toBe(`if (condition) {
|
||||||
|
doSomething();
|
||||||
|
doSomething();
|
||||||
|
doSomething();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (otherCondition) {
|
||||||
|
doSomething();
|
||||||
|
doSomethingElse();
|
||||||
|
doSomething();
|
||||||
|
}`)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe("hunk splitting", () => {
|
||||||
|
it("should handle large diffs with multiple non-contiguous changes", async () => {
|
||||||
|
const original = `import { readFile } from 'fs';
|
||||||
|
import { join } from 'path';
|
||||||
|
import { Logger } from './logger';
|
||||||
|
|
||||||
|
const logger = new Logger();
|
||||||
|
|
||||||
|
async function processFile(filePath: string) {
|
||||||
|
try {
|
||||||
|
const data = await readFile(filePath, 'utf8');
|
||||||
|
logger.info('File read successfully');
|
||||||
|
return data;
|
||||||
|
} catch (error) {
|
||||||
|
logger.error('Failed to read file:', error);
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function validateInput(input: string): boolean {
|
||||||
|
if (!input) {
|
||||||
|
logger.warn('Empty input received');
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return input.length > 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
async function writeOutput(data: string) {
|
||||||
|
logger.info('Processing output');
|
||||||
|
// TODO: Implement output writing
|
||||||
|
return Promise.resolve();
|
||||||
|
}
|
||||||
|
|
||||||
|
function parseConfig(configPath: string) {
|
||||||
|
logger.debug('Reading config from:', configPath);
|
||||||
|
// Basic config parsing
|
||||||
|
return {
|
||||||
|
enabled: true,
|
||||||
|
maxRetries: 3
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
export {
|
||||||
|
processFile,
|
||||||
|
validateInput,
|
||||||
|
writeOutput,
|
||||||
|
parseConfig
|
||||||
|
};`
|
||||||
|
|
||||||
|
const diff = `--- a/file.ts
|
||||||
|
+++ b/file.ts
|
||||||
|
@@ ... @@
|
||||||
|
-import { readFile } from 'fs';
|
||||||
|
+import { readFile, writeFile } from 'fs';
|
||||||
|
import { join } from 'path';
|
||||||
|
-import { Logger } from './logger';
|
||||||
|
+import { Logger } from './utils/logger';
|
||||||
|
+import { Config } from './types';
|
||||||
|
|
||||||
|
-const logger = new Logger();
|
||||||
|
+const logger = new Logger('FileProcessor');
|
||||||
|
|
||||||
|
async function processFile(filePath: string) {
|
||||||
|
try {
|
||||||
|
const data = await readFile(filePath, 'utf8');
|
||||||
|
- logger.info('File read successfully');
|
||||||
|
+ logger.info(\`File \${filePath} read successfully\`);
|
||||||
|
return data;
|
||||||
|
} catch (error) {
|
||||||
|
- logger.error('Failed to read file:', error);
|
||||||
|
+ logger.error(\`Failed to read file \${filePath}:\`, error);
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function validateInput(input: string): boolean {
|
||||||
|
if (!input) {
|
||||||
|
- logger.warn('Empty input received');
|
||||||
|
+ logger.warn('Validation failed: Empty input received');
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
- return input.length > 0;
|
||||||
|
+ return input.trim().length > 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
-async function writeOutput(data: string) {
|
||||||
|
- logger.info('Processing output');
|
||||||
|
- // TODO: Implement output writing
|
||||||
|
- return Promise.resolve();
|
||||||
|
+async function writeOutput(data: string, outputPath: string) {
|
||||||
|
+ try {
|
||||||
|
+ await writeFile(outputPath, data, 'utf8');
|
||||||
|
+ logger.info(\`Output written to \${outputPath}\`);
|
||||||
|
+ } catch (error) {
|
||||||
|
+ logger.error(\`Failed to write output to \${outputPath}:\`, error);
|
||||||
|
+ throw error;
|
||||||
|
+ }
|
||||||
|
}
|
||||||
|
|
||||||
|
-function parseConfig(configPath: string) {
|
||||||
|
- logger.debug('Reading config from:', configPath);
|
||||||
|
- // Basic config parsing
|
||||||
|
- return {
|
||||||
|
- enabled: true,
|
||||||
|
- maxRetries: 3
|
||||||
|
- };
|
||||||
|
+async function parseConfig(configPath: string): Promise<Config> {
|
||||||
|
+ try {
|
||||||
|
+ const configData = await readFile(configPath, 'utf8');
|
||||||
|
+ logger.debug(\`Reading config from \${configPath}\`);
|
||||||
|
+ return JSON.parse(configData);
|
||||||
|
+ } catch (error) {
|
||||||
|
+ logger.error(\`Failed to parse config from \${configPath}:\`, error);
|
||||||
|
+ throw error;
|
||||||
|
+ }
|
||||||
|
}
|
||||||
|
|
||||||
|
export {
|
||||||
|
processFile,
|
||||||
|
validateInput,
|
||||||
|
writeOutput,
|
||||||
|
- parseConfig
|
||||||
|
+ parseConfig,
|
||||||
|
+ type Config
|
||||||
|
};`
|
||||||
|
|
||||||
|
const expected = `import { readFile, writeFile } from 'fs';
|
||||||
|
import { join } from 'path';
|
||||||
|
import { Logger } from './utils/logger';
|
||||||
|
import { Config } from './types';
|
||||||
|
|
||||||
|
const logger = new Logger('FileProcessor');
|
||||||
|
|
||||||
|
async function processFile(filePath: string) {
|
||||||
|
try {
|
||||||
|
const data = await readFile(filePath, 'utf8');
|
||||||
|
logger.info(\`File \${filePath} read successfully\`);
|
||||||
|
return data;
|
||||||
|
} catch (error) {
|
||||||
|
logger.error(\`Failed to read file \${filePath}:\`, error);
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function validateInput(input: string): boolean {
|
||||||
|
if (!input) {
|
||||||
|
logger.warn('Validation failed: Empty input received');
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return input.trim().length > 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
async function writeOutput(data: string, outputPath: string) {
|
||||||
|
try {
|
||||||
|
await writeFile(outputPath, data, 'utf8');
|
||||||
|
logger.info(\`Output written to \${outputPath}\`);
|
||||||
|
} catch (error) {
|
||||||
|
logger.error(\`Failed to write output to \${outputPath}:\`, error);
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async function parseConfig(configPath: string): Promise<Config> {
|
||||||
|
try {
|
||||||
|
const configData = await readFile(configPath, 'utf8');
|
||||||
|
logger.debug(\`Reading config from \${configPath}\`);
|
||||||
|
return JSON.parse(configData);
|
||||||
|
} catch (error) {
|
||||||
|
logger.error(\`Failed to parse config from \${configPath}:\`, error);
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export {
|
||||||
|
processFile,
|
||||||
|
validateInput,
|
||||||
|
writeOutput,
|
||||||
|
parseConfig,
|
||||||
|
type Config
|
||||||
|
};`
|
||||||
|
|
||||||
|
const result = await strategy.applyDiff(original, diff)
|
||||||
|
expect(result.success).toBe(true)
|
||||||
|
if (result.success) {
|
||||||
|
expect(result.content).toBe(expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
1557
src/core/diff/strategies/__tests__/search-replace.test.ts
Normal file
1557
src/core/diff/strategies/__tests__/search-replace.test.ts
Normal file
File diff suppressed because it is too large
Load Diff
228
src/core/diff/strategies/__tests__/unified.test.ts
Normal file
228
src/core/diff/strategies/__tests__/unified.test.ts
Normal file
@ -0,0 +1,228 @@
|
|||||||
|
import { UnifiedDiffStrategy } from "../unified"
|
||||||
|
|
||||||
|
describe("UnifiedDiffStrategy", () => {
|
||||||
|
let strategy: UnifiedDiffStrategy
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
strategy = new UnifiedDiffStrategy()
|
||||||
|
})
|
||||||
|
|
||||||
|
describe("getToolDescription", () => {
|
||||||
|
it("should return tool description with correct cwd", () => {
|
||||||
|
const cwd = "/test/path"
|
||||||
|
const description = strategy.getToolDescription({ cwd })
|
||||||
|
|
||||||
|
expect(description).toContain("apply_diff")
|
||||||
|
expect(description).toContain(cwd)
|
||||||
|
expect(description).toContain("Parameters:")
|
||||||
|
expect(description).toContain("Format Requirements:")
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe("applyDiff", () => {
|
||||||
|
it("should successfully apply a function modification diff", async () => {
|
||||||
|
const originalContent = `import { Logger } from '../logger';
|
||||||
|
|
||||||
|
function calculateTotal(items: number[]): number {
|
||||||
|
return items.reduce((sum, item) => {
|
||||||
|
return sum + item;
|
||||||
|
}, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
export { calculateTotal };`
|
||||||
|
|
||||||
|
const diffContent = `--- src/utils/helper.ts
|
||||||
|
+++ src/utils/helper.ts
|
||||||
|
@@ -1,9 +1,10 @@
|
||||||
|
import { Logger } from '../logger';
|
||||||
|
|
||||||
|
function calculateTotal(items: number[]): number {
|
||||||
|
- return items.reduce((sum, item) => {
|
||||||
|
- return sum + item;
|
||||||
|
+ const total = items.reduce((sum, item) => {
|
||||||
|
+ return sum + item * 1.1; // Add 10% markup
|
||||||
|
}, 0);
|
||||||
|
+ return Math.round(total * 100) / 100; // Round to 2 decimal places
|
||||||
|
}
|
||||||
|
|
||||||
|
export { calculateTotal };`
|
||||||
|
|
||||||
|
const expected = `import { Logger } from '../logger';
|
||||||
|
|
||||||
|
function calculateTotal(items: number[]): number {
|
||||||
|
const total = items.reduce((sum, item) => {
|
||||||
|
return sum + item * 1.1; // Add 10% markup
|
||||||
|
}, 0);
|
||||||
|
return Math.round(total * 100) / 100; // Round to 2 decimal places
|
||||||
|
}
|
||||||
|
|
||||||
|
export { calculateTotal };`
|
||||||
|
|
||||||
|
const result = await strategy.applyDiff(originalContent, diffContent)
|
||||||
|
expect(result.success).toBe(true)
|
||||||
|
if (result.success) {
|
||||||
|
expect(result.content).toBe(expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
it("should successfully apply a diff adding a new method", async () => {
|
||||||
|
const originalContent = `class Calculator {
|
||||||
|
add(a: number, b: number): number {
|
||||||
|
return a + b;
|
||||||
|
}
|
||||||
|
}`
|
||||||
|
|
||||||
|
const diffContent = `--- src/Calculator.ts
|
||||||
|
+++ src/Calculator.ts
|
||||||
|
@@ -1,5 +1,9 @@
|
||||||
|
class Calculator {
|
||||||
|
add(a: number, b: number): number {
|
||||||
|
return a + b;
|
||||||
|
}
|
||||||
|
+
|
||||||
|
+ multiply(a: number, b: number): number {
|
||||||
|
+ return a * b;
|
||||||
|
+ }
|
||||||
|
}`
|
||||||
|
|
||||||
|
const expected = `class Calculator {
|
||||||
|
add(a: number, b: number): number {
|
||||||
|
return a + b;
|
||||||
|
}
|
||||||
|
|
||||||
|
multiply(a: number, b: number): number {
|
||||||
|
return a * b;
|
||||||
|
}
|
||||||
|
}`
|
||||||
|
|
||||||
|
const result = await strategy.applyDiff(originalContent, diffContent)
|
||||||
|
expect(result.success).toBe(true)
|
||||||
|
if (result.success) {
|
||||||
|
expect(result.content).toBe(expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
it("should successfully apply a diff modifying imports", async () => {
|
||||||
|
const originalContent = `import { useState } from 'react';
|
||||||
|
import { Button } from './components';
|
||||||
|
|
||||||
|
function App() {
|
||||||
|
const [count, setCount] = useState(0);
|
||||||
|
return <Button onClick={() => setCount(count + 1)}>{count}</Button>;
|
||||||
|
}`
|
||||||
|
|
||||||
|
const diffContent = `--- src/App.tsx
|
||||||
|
+++ src/App.tsx
|
||||||
|
@@ -1,7 +1,8 @@
|
||||||
|
-import { useState } from 'react';
|
||||||
|
+import { useState, useEffect } from 'react';
|
||||||
|
import { Button } from './components';
|
||||||
|
|
||||||
|
function App() {
|
||||||
|
const [count, setCount] = useState(0);
|
||||||
|
+ useEffect(() => { document.title = \`Count: \${count}\` }, [count]);
|
||||||
|
return <Button onClick={() => setCount(count + 1)}>{count}</Button>;
|
||||||
|
}`
|
||||||
|
|
||||||
|
const expected = `import { useState, useEffect } from 'react';
|
||||||
|
import { Button } from './components';
|
||||||
|
|
||||||
|
function App() {
|
||||||
|
const [count, setCount] = useState(0);
|
||||||
|
useEffect(() => { document.title = \`Count: \${count}\` }, [count]);
|
||||||
|
return <Button onClick={() => setCount(count + 1)}>{count}</Button>;
|
||||||
|
}`
|
||||||
|
|
||||||
|
const result = await strategy.applyDiff(originalContent, diffContent)
|
||||||
|
expect(result.success).toBe(true)
|
||||||
|
if (result.success) {
|
||||||
|
expect(result.content).toBe(expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
it("should successfully apply a diff with multiple hunks", async () => {
|
||||||
|
const originalContent = `import { readFile, writeFile } from 'fs';
|
||||||
|
|
||||||
|
function processFile(path: string) {
|
||||||
|
readFile(path, 'utf8', (err, data) => {
|
||||||
|
if (err) throw err;
|
||||||
|
const processed = data.toUpperCase();
|
||||||
|
writeFile(path, processed, (err) => {
|
||||||
|
if (err) throw err;
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
export { processFile };`
|
||||||
|
|
||||||
|
const diffContent = `--- src/file-processor.ts
|
||||||
|
+++ src/file-processor.ts
|
||||||
|
@@ -1,12 +1,14 @@
|
||||||
|
-import { readFile, writeFile } from 'fs';
|
||||||
|
+import { promises as fs } from 'fs';
|
||||||
|
+import { join } from 'path';
|
||||||
|
|
||||||
|
-function processFile(path: string) {
|
||||||
|
- readFile(path, 'utf8', (err, data) => {
|
||||||
|
- if (err) throw err;
|
||||||
|
+async function processFile(path: string) {
|
||||||
|
+ try {
|
||||||
|
+ const data = await fs.readFile(join(__dirname, path), 'utf8');
|
||||||
|
const processed = data.toUpperCase();
|
||||||
|
- writeFile(path, processed, (err) => {
|
||||||
|
- if (err) throw err;
|
||||||
|
- });
|
||||||
|
- });
|
||||||
|
+ await fs.writeFile(join(__dirname, path), processed);
|
||||||
|
+ } catch (error) {
|
||||||
|
+ console.error('Failed to process file:', error);
|
||||||
|
+ throw error;
|
||||||
|
+ }
|
||||||
|
}
|
||||||
|
|
||||||
|
export { processFile };`
|
||||||
|
|
||||||
|
const expected = `import { promises as fs } from 'fs';
|
||||||
|
import { join } from 'path';
|
||||||
|
|
||||||
|
async function processFile(path: string) {
|
||||||
|
try {
|
||||||
|
const data = await fs.readFile(join(__dirname, path), 'utf8');
|
||||||
|
const processed = data.toUpperCase();
|
||||||
|
await fs.writeFile(join(__dirname, path), processed);
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Failed to process file:', error);
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export { processFile };`
|
||||||
|
|
||||||
|
const result = await strategy.applyDiff(originalContent, diffContent)
|
||||||
|
expect(result.success).toBe(true)
|
||||||
|
if (result.success) {
|
||||||
|
expect(result.content).toBe(expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
it("should handle empty original content", async () => {
|
||||||
|
const originalContent = ""
|
||||||
|
const diffContent = `--- empty.ts
|
||||||
|
+++ empty.ts
|
||||||
|
@@ -0,0 +1,3 @@
|
||||||
|
+export function greet(name: string): string {
|
||||||
|
+ return \`Hello, \${name}!\`;
|
||||||
|
+}`
|
||||||
|
|
||||||
|
const expected = `export function greet(name: string): string {
|
||||||
|
return \`Hello, \${name}!\`;
|
||||||
|
}\n`
|
||||||
|
|
||||||
|
const result = await strategy.applyDiff(originalContent, diffContent)
|
||||||
|
expect(result.success).toBe(true)
|
||||||
|
if (result.success) {
|
||||||
|
expect(result.content).toBe(expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
@ -0,0 +1,295 @@
|
|||||||
|
import { applyContextMatching, applyDMP, applyGitFallback } from "../edit-strategies"
|
||||||
|
import { Hunk } from "../types"
|
||||||
|
|
||||||
|
const testCases = [
|
||||||
|
{
|
||||||
|
name: "should return original content if no match is found",
|
||||||
|
hunk: {
|
||||||
|
changes: [
|
||||||
|
{ type: "context", content: "line1" },
|
||||||
|
{ type: "add", content: "line2" },
|
||||||
|
],
|
||||||
|
} as Hunk,
|
||||||
|
content: ["line1", "line3"],
|
||||||
|
matchPosition: -1,
|
||||||
|
expected: {
|
||||||
|
confidence: 0,
|
||||||
|
result: ["line1", "line3"],
|
||||||
|
},
|
||||||
|
expectedResult: "line1\nline3",
|
||||||
|
strategies: ["context", "dmp"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "should apply a simple add change",
|
||||||
|
hunk: {
|
||||||
|
changes: [
|
||||||
|
{ type: "context", content: "line1" },
|
||||||
|
{ type: "add", content: "line2" },
|
||||||
|
],
|
||||||
|
} as Hunk,
|
||||||
|
content: ["line1", "line3"],
|
||||||
|
matchPosition: 0,
|
||||||
|
expected: {
|
||||||
|
confidence: 1,
|
||||||
|
result: ["line1", "line2", "line3"],
|
||||||
|
},
|
||||||
|
expectedResult: "line1\nline2\nline3",
|
||||||
|
strategies: ["context", "dmp"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "should apply a simple remove change",
|
||||||
|
hunk: {
|
||||||
|
changes: [
|
||||||
|
{ type: "context", content: "line1" },
|
||||||
|
{ type: "remove", content: "line2" },
|
||||||
|
],
|
||||||
|
} as Hunk,
|
||||||
|
content: ["line1", "line2", "line3"],
|
||||||
|
matchPosition: 0,
|
||||||
|
expected: {
|
||||||
|
confidence: 1,
|
||||||
|
result: ["line1", "line3"],
|
||||||
|
},
|
||||||
|
expectedResult: "line1\nline3",
|
||||||
|
strategies: ["context", "dmp"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "should apply a simple context change",
|
||||||
|
hunk: {
|
||||||
|
changes: [{ type: "context", content: "line1" }],
|
||||||
|
} as Hunk,
|
||||||
|
content: ["line1", "line2", "line3"],
|
||||||
|
matchPosition: 0,
|
||||||
|
expected: {
|
||||||
|
confidence: 1,
|
||||||
|
result: ["line1", "line2", "line3"],
|
||||||
|
},
|
||||||
|
expectedResult: "line1\nline2\nline3",
|
||||||
|
strategies: ["context", "dmp"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "should apply a multi-line add change",
|
||||||
|
hunk: {
|
||||||
|
changes: [
|
||||||
|
{ type: "context", content: "line1" },
|
||||||
|
{ type: "add", content: "line2\nline3" },
|
||||||
|
],
|
||||||
|
} as Hunk,
|
||||||
|
content: ["line1", "line4"],
|
||||||
|
matchPosition: 0,
|
||||||
|
expected: {
|
||||||
|
confidence: 1,
|
||||||
|
result: ["line1", "line2\nline3", "line4"],
|
||||||
|
},
|
||||||
|
expectedResult: "line1\nline2\nline3\nline4",
|
||||||
|
strategies: ["context", "dmp"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "should apply a multi-line remove change",
|
||||||
|
hunk: {
|
||||||
|
changes: [
|
||||||
|
{ type: "context", content: "line1" },
|
||||||
|
{ type: "remove", content: "line2\nline3" },
|
||||||
|
],
|
||||||
|
} as Hunk,
|
||||||
|
content: ["line1", "line2", "line3", "line4"],
|
||||||
|
matchPosition: 0,
|
||||||
|
expected: {
|
||||||
|
confidence: 1,
|
||||||
|
result: ["line1", "line4"],
|
||||||
|
},
|
||||||
|
expectedResult: "line1\nline4",
|
||||||
|
strategies: ["context", "dmp"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "should apply a multi-line context change",
|
||||||
|
hunk: {
|
||||||
|
changes: [
|
||||||
|
{ type: "context", content: "line1" },
|
||||||
|
{ type: "context", content: "line2\nline3" },
|
||||||
|
],
|
||||||
|
} as Hunk,
|
||||||
|
content: ["line1", "line2", "line3", "line4"],
|
||||||
|
matchPosition: 0,
|
||||||
|
expected: {
|
||||||
|
confidence: 1,
|
||||||
|
result: ["line1", "line2\nline3", "line4"],
|
||||||
|
},
|
||||||
|
expectedResult: "line1\nline2\nline3\nline4",
|
||||||
|
strategies: ["context", "dmp"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "should apply a change with indentation",
|
||||||
|
hunk: {
|
||||||
|
changes: [
|
||||||
|
{ type: "context", content: " line1" },
|
||||||
|
{ type: "add", content: " line2" },
|
||||||
|
],
|
||||||
|
} as Hunk,
|
||||||
|
content: [" line1", " line3"],
|
||||||
|
matchPosition: 0,
|
||||||
|
expected: {
|
||||||
|
confidence: 1,
|
||||||
|
result: [" line1", " line2", " line3"],
|
||||||
|
},
|
||||||
|
expectedResult: " line1\n line2\n line3",
|
||||||
|
strategies: ["context", "dmp"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "should apply a change with mixed indentation",
|
||||||
|
hunk: {
|
||||||
|
changes: [
|
||||||
|
{ type: "context", content: "\tline1" },
|
||||||
|
{ type: "add", content: " line2" },
|
||||||
|
],
|
||||||
|
} as Hunk,
|
||||||
|
content: ["\tline1", " line3"],
|
||||||
|
matchPosition: 0,
|
||||||
|
expected: {
|
||||||
|
confidence: 1,
|
||||||
|
result: ["\tline1", " line2", " line3"],
|
||||||
|
},
|
||||||
|
expectedResult: "\tline1\n line2\n line3",
|
||||||
|
strategies: ["context", "dmp"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "should apply a change with mixed indentation and multi-line",
|
||||||
|
hunk: {
|
||||||
|
changes: [
|
||||||
|
{ type: "context", content: " line1" },
|
||||||
|
{ type: "add", content: "\tline2\n line3" },
|
||||||
|
],
|
||||||
|
} as Hunk,
|
||||||
|
content: [" line1", " line4"],
|
||||||
|
matchPosition: 0,
|
||||||
|
expected: {
|
||||||
|
confidence: 1,
|
||||||
|
result: [" line1", "\tline2\n line3", " line4"],
|
||||||
|
},
|
||||||
|
expectedResult: " line1\n\tline2\n line3\n line4",
|
||||||
|
strategies: ["context", "dmp"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "should apply a complex change with mixed indentation and multi-line",
|
||||||
|
hunk: {
|
||||||
|
changes: [
|
||||||
|
{ type: "context", content: " line1" },
|
||||||
|
{ type: "remove", content: " line2" },
|
||||||
|
{ type: "add", content: "\tline3\n line4" },
|
||||||
|
{ type: "context", content: " line5" },
|
||||||
|
],
|
||||||
|
} as Hunk,
|
||||||
|
content: [" line1", " line2", " line5", " line6"],
|
||||||
|
matchPosition: 0,
|
||||||
|
expected: {
|
||||||
|
confidence: 1,
|
||||||
|
result: [" line1", "\tline3\n line4", " line5", " line6"],
|
||||||
|
},
|
||||||
|
expectedResult: " line1\n\tline3\n line4\n line5\n line6",
|
||||||
|
strategies: ["context", "dmp"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "should apply a complex change with mixed indentation and multi-line and context",
|
||||||
|
hunk: {
|
||||||
|
changes: [
|
||||||
|
{ type: "context", content: " line1" },
|
||||||
|
{ type: "remove", content: " line2" },
|
||||||
|
{ type: "add", content: "\tline3\n line4" },
|
||||||
|
{ type: "context", content: " line5" },
|
||||||
|
{ type: "context", content: " line6" },
|
||||||
|
],
|
||||||
|
} as Hunk,
|
||||||
|
content: [" line1", " line2", " line5", " line6", " line7"],
|
||||||
|
matchPosition: 0,
|
||||||
|
expected: {
|
||||||
|
confidence: 1,
|
||||||
|
result: [" line1", "\tline3\n line4", " line5", " line6", " line7"],
|
||||||
|
},
|
||||||
|
expectedResult: " line1\n\tline3\n line4\n line5\n line6\n line7",
|
||||||
|
strategies: ["context", "dmp"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "should apply a complex change with mixed indentation and multi-line and context and a different match position",
|
||||||
|
hunk: {
|
||||||
|
changes: [
|
||||||
|
{ type: "context", content: " line1" },
|
||||||
|
{ type: "remove", content: " line2" },
|
||||||
|
{ type: "add", content: "\tline3\n line4" },
|
||||||
|
{ type: "context", content: " line5" },
|
||||||
|
{ type: "context", content: " line6" },
|
||||||
|
],
|
||||||
|
} as Hunk,
|
||||||
|
content: [" line0", " line1", " line2", " line5", " line6", " line7"],
|
||||||
|
matchPosition: 1,
|
||||||
|
expected: {
|
||||||
|
confidence: 1,
|
||||||
|
result: [" line0", " line1", "\tline3\n line4", " line5", " line6", " line7"],
|
||||||
|
},
|
||||||
|
expectedResult: " line0\n line1\n\tline3\n line4\n line5\n line6\n line7",
|
||||||
|
strategies: ["context", "dmp"],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
describe("applyContextMatching", () => {
|
||||||
|
testCases.forEach(({ name, hunk, content, matchPosition, expected, strategies, expectedResult }) => {
|
||||||
|
if (!strategies?.includes("context")) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
it(name, () => {
|
||||||
|
const result = applyContextMatching(hunk, content, matchPosition)
|
||||||
|
expect(result.result.join("\n")).toEqual(expectedResult)
|
||||||
|
expect(result.confidence).toBeGreaterThanOrEqual(expected.confidence)
|
||||||
|
expect(result.strategy).toBe("context")
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe("applyDMP", () => {
|
||||||
|
testCases.forEach(({ name, hunk, content, matchPosition, expected, strategies, expectedResult }) => {
|
||||||
|
if (!strategies?.includes("dmp")) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
it(name, () => {
|
||||||
|
const result = applyDMP(hunk, content, matchPosition)
|
||||||
|
expect(result.result.join("\n")).toEqual(expectedResult)
|
||||||
|
expect(result.confidence).toBeGreaterThanOrEqual(expected.confidence)
|
||||||
|
expect(result.strategy).toBe("dmp")
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe("applyGitFallback", () => {
|
||||||
|
it("should successfully apply changes using git operations", async () => {
|
||||||
|
const hunk = {
|
||||||
|
changes: [
|
||||||
|
{ type: "context", content: "line1", indent: "" },
|
||||||
|
{ type: "remove", content: "line2", indent: "" },
|
||||||
|
{ type: "add", content: "new line2", indent: "" },
|
||||||
|
{ type: "context", content: "line3", indent: "" },
|
||||||
|
],
|
||||||
|
} as Hunk
|
||||||
|
|
||||||
|
const content = ["line1", "line2", "line3"]
|
||||||
|
const result = await applyGitFallback(hunk, content)
|
||||||
|
|
||||||
|
expect(result.result.join("\n")).toEqual("line1\nnew line2\nline3")
|
||||||
|
expect(result.confidence).toBe(1)
|
||||||
|
expect(result.strategy).toBe("git-fallback")
|
||||||
|
})
|
||||||
|
|
||||||
|
it("should return original content with 0 confidence when changes cannot be applied", async () => {
|
||||||
|
const hunk = {
|
||||||
|
changes: [
|
||||||
|
{ type: "context", content: "nonexistent", indent: "" },
|
||||||
|
{ type: "add", content: "new line", indent: "" },
|
||||||
|
],
|
||||||
|
} as Hunk
|
||||||
|
|
||||||
|
const content = ["line1", "line2", "line3"]
|
||||||
|
const result = await applyGitFallback(hunk, content)
|
||||||
|
|
||||||
|
expect(result.result).toEqual(content)
|
||||||
|
expect(result.confidence).toBe(0)
|
||||||
|
expect(result.strategy).toBe("git-fallback")
|
||||||
|
})
|
||||||
|
})
|
||||||
@ -0,0 +1,262 @@
|
|||||||
|
import { findAnchorMatch, findExactMatch, findSimilarityMatch, findLevenshteinMatch } from "../search-strategies"
|
||||||
|
|
||||||
|
type SearchStrategy = (
|
||||||
|
searchStr: string,
|
||||||
|
content: string[],
|
||||||
|
startIndex?: number,
|
||||||
|
) => {
|
||||||
|
index: number
|
||||||
|
confidence: number
|
||||||
|
strategy: string
|
||||||
|
}
|
||||||
|
|
||||||
|
const testCases = [
|
||||||
|
{
|
||||||
|
name: "should return no match if the search string is not found",
|
||||||
|
searchStr: "not found",
|
||||||
|
content: ["line1", "line2", "line3"],
|
||||||
|
expected: { index: -1, confidence: 0 },
|
||||||
|
strategies: ["exact", "similarity", "levenshtein"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "should return a match if the search string is found",
|
||||||
|
searchStr: "line2",
|
||||||
|
content: ["line1", "line2", "line3"],
|
||||||
|
expected: { index: 1, confidence: 1 },
|
||||||
|
strategies: ["exact", "similarity", "levenshtein"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "should return a match with correct index when startIndex is provided",
|
||||||
|
searchStr: "line3",
|
||||||
|
content: ["line1", "line2", "line3", "line4", "line3"],
|
||||||
|
startIndex: 3,
|
||||||
|
expected: { index: 4, confidence: 1 },
|
||||||
|
strategies: ["exact", "similarity", "levenshtein"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "should return a match even if there are more lines in content",
|
||||||
|
searchStr: "line2",
|
||||||
|
content: ["line1", "line2", "line3", "line4", "line5"],
|
||||||
|
expected: { index: 1, confidence: 1 },
|
||||||
|
strategies: ["exact", "similarity", "levenshtein"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "should return a match even if the search string is at the beginning of the content",
|
||||||
|
searchStr: "line1",
|
||||||
|
content: ["line1", "line2", "line3"],
|
||||||
|
expected: { index: 0, confidence: 1 },
|
||||||
|
strategies: ["exact", "similarity", "levenshtein"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "should return a match even if the search string is at the end of the content",
|
||||||
|
searchStr: "line3",
|
||||||
|
content: ["line1", "line2", "line3"],
|
||||||
|
expected: { index: 2, confidence: 1 },
|
||||||
|
strategies: ["exact", "similarity", "levenshtein"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "should return a match for a multi-line search string",
|
||||||
|
searchStr: "line2\nline3",
|
||||||
|
content: ["line1", "line2", "line3", "line4"],
|
||||||
|
expected: { index: 1, confidence: 1 },
|
||||||
|
strategies: ["exact", "similarity", "levenshtein"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "should return no match if a multi-line search string is not found",
|
||||||
|
searchStr: "line2\nline4",
|
||||||
|
content: ["line1", "line2", "line3", "line4"],
|
||||||
|
expected: { index: -1, confidence: 0 },
|
||||||
|
strategies: ["exact", "similarity"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "should return a match with indentation",
|
||||||
|
searchStr: " line2",
|
||||||
|
content: ["line1", " line2", "line3"],
|
||||||
|
expected: { index: 1, confidence: 1 },
|
||||||
|
strategies: ["exact", "similarity", "levenshtein"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "should return a match with more complex indentation",
|
||||||
|
searchStr: " line3",
|
||||||
|
content: [" line1", " line2", " line3", " line4"],
|
||||||
|
expected: { index: 2, confidence: 1 },
|
||||||
|
strategies: ["exact", "similarity", "levenshtein"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "should return a match with mixed indentation",
|
||||||
|
searchStr: "\tline2",
|
||||||
|
content: [" line1", "\tline2", " line3"],
|
||||||
|
expected: { index: 1, confidence: 1 },
|
||||||
|
strategies: ["exact", "similarity", "levenshtein"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "should return a match with mixed indentation and multi-line",
|
||||||
|
searchStr: " line2\n\tline3",
|
||||||
|
content: ["line1", " line2", "\tline3", " line4"],
|
||||||
|
expected: { index: 1, confidence: 1 },
|
||||||
|
strategies: ["exact", "similarity", "levenshtein"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "should return no match if mixed indentation and multi-line is not found",
|
||||||
|
searchStr: " line2\n line4",
|
||||||
|
content: ["line1", " line2", "\tline3", " line4"],
|
||||||
|
expected: { index: -1, confidence: 0 },
|
||||||
|
strategies: ["exact", "similarity"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "should return a match with leading and trailing spaces",
|
||||||
|
searchStr: " line2 ",
|
||||||
|
content: ["line1", " line2 ", "line3"],
|
||||||
|
expected: { index: 1, confidence: 1 },
|
||||||
|
strategies: ["exact", "similarity", "levenshtein"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "should return a match with leading and trailing tabs",
|
||||||
|
searchStr: "\tline2\t",
|
||||||
|
content: ["line1", "\tline2\t", "line3"],
|
||||||
|
expected: { index: 1, confidence: 1 },
|
||||||
|
strategies: ["exact", "similarity", "levenshtein"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "should return a match with mixed leading and trailing spaces and tabs",
|
||||||
|
searchStr: " \tline2\t ",
|
||||||
|
content: ["line1", " \tline2\t ", "line3"],
|
||||||
|
expected: { index: 1, confidence: 1 },
|
||||||
|
strategies: ["exact", "similarity", "levenshtein"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "should return a match with mixed leading and trailing spaces and tabs and multi-line",
|
||||||
|
searchStr: " \tline2\t \n line3 ",
|
||||||
|
content: ["line1", " \tline2\t ", " line3 ", "line4"],
|
||||||
|
expected: { index: 1, confidence: 1 },
|
||||||
|
strategies: ["exact", "similarity", "levenshtein"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "should return no match if mixed leading and trailing spaces and tabs and multi-line is not found",
|
||||||
|
searchStr: " \tline2\t \n line4 ",
|
||||||
|
content: ["line1", " \tline2\t ", " line3 ", "line4"],
|
||||||
|
expected: { index: -1, confidence: 0 },
|
||||||
|
strategies: ["exact", "similarity"],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
describe("findExactMatch", () => {
|
||||||
|
testCases.forEach(({ name, searchStr, content, startIndex, expected, strategies }) => {
|
||||||
|
if (!strategies?.includes("exact")) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
it(name, () => {
|
||||||
|
const result = findExactMatch(searchStr, content, startIndex)
|
||||||
|
expect(result.index).toBe(expected.index)
|
||||||
|
expect(result.confidence).toBeGreaterThanOrEqual(expected.confidence)
|
||||||
|
expect(result.strategy).toMatch(/exact(-overlapping)?/)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe("findAnchorMatch", () => {
|
||||||
|
const anchorTestCases = [
|
||||||
|
{
|
||||||
|
name: "should return no match if no anchors are found",
|
||||||
|
searchStr: " \n \n ",
|
||||||
|
content: ["line1", "line2", "line3"],
|
||||||
|
expected: { index: -1, confidence: 0 },
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "should return no match if anchor positions cannot be validated",
|
||||||
|
searchStr: "unique line\ncontext line 1\ncontext line 2",
|
||||||
|
content: [
|
||||||
|
"different line 1",
|
||||||
|
"different line 2",
|
||||||
|
"different line 3",
|
||||||
|
"another unique line",
|
||||||
|
"context line 1",
|
||||||
|
"context line 2",
|
||||||
|
],
|
||||||
|
expected: { index: -1, confidence: 0 },
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "should return a match if anchor positions can be validated",
|
||||||
|
searchStr: "unique line\ncontext line 1\ncontext line 2",
|
||||||
|
content: ["line1", "line2", "unique line", "context line 1", "context line 2", "line 6"],
|
||||||
|
expected: { index: 2, confidence: 1 },
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "should return a match with correct index when startIndex is provided",
|
||||||
|
searchStr: "unique line\ncontext line 1\ncontext line 2",
|
||||||
|
content: ["line1", "line2", "line3", "unique line", "context line 1", "context line 2", "line 7"],
|
||||||
|
startIndex: 3,
|
||||||
|
expected: { index: 3, confidence: 1 },
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "should return a match even if there are more lines in content",
|
||||||
|
searchStr: "unique line\ncontext line 1\ncontext line 2",
|
||||||
|
content: [
|
||||||
|
"line1",
|
||||||
|
"line2",
|
||||||
|
"unique line",
|
||||||
|
"context line 1",
|
||||||
|
"context line 2",
|
||||||
|
"line 6",
|
||||||
|
"extra line 1",
|
||||||
|
"extra line 2",
|
||||||
|
],
|
||||||
|
expected: { index: 2, confidence: 1 },
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "should return a match even if the anchor is at the beginning of the content",
|
||||||
|
searchStr: "unique line\ncontext line 1\ncontext line 2",
|
||||||
|
content: ["unique line", "context line 1", "context line 2", "line 6"],
|
||||||
|
expected: { index: 0, confidence: 1 },
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "should return a match even if the anchor is at the end of the content",
|
||||||
|
searchStr: "unique line\ncontext line 1\ncontext line 2",
|
||||||
|
content: ["line1", "line2", "unique line", "context line 1", "context line 2"],
|
||||||
|
expected: { index: 2, confidence: 1 },
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "should return no match if no valid anchor is found",
|
||||||
|
searchStr: "non-unique line\ncontext line 1\ncontext line 2",
|
||||||
|
content: ["line1", "line2", "non-unique line", "context line 1", "context line 2", "non-unique line"],
|
||||||
|
expected: { index: -1, confidence: 0 },
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
anchorTestCases.forEach(({ name, searchStr, content, startIndex, expected }) => {
|
||||||
|
it(name, () => {
|
||||||
|
const result = findAnchorMatch(searchStr, content, startIndex)
|
||||||
|
expect(result.index).toBe(expected.index)
|
||||||
|
expect(result.confidence).toBeGreaterThanOrEqual(expected.confidence)
|
||||||
|
expect(result.strategy).toBe("anchor")
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe("findSimilarityMatch", () => {
|
||||||
|
testCases.forEach(({ name, searchStr, content, startIndex, expected, strategies }) => {
|
||||||
|
if (!strategies?.includes("similarity")) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
it(name, () => {
|
||||||
|
const result = findSimilarityMatch(searchStr, content, startIndex)
|
||||||
|
expect(result.index).toBe(expected.index)
|
||||||
|
expect(result.confidence).toBeGreaterThanOrEqual(expected.confidence)
|
||||||
|
expect(result.strategy).toBe("similarity")
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe("findLevenshteinMatch", () => {
|
||||||
|
testCases.forEach(({ name, searchStr, content, startIndex, expected, strategies }) => {
|
||||||
|
if (!strategies?.includes("levenshtein")) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
it(name, () => {
|
||||||
|
const result = findLevenshteinMatch(searchStr, content, startIndex)
|
||||||
|
expect(result.index).toBe(expected.index)
|
||||||
|
expect(result.confidence).toBeGreaterThanOrEqual(expected.confidence)
|
||||||
|
expect(result.strategy).toBe("levenshtein")
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
297
src/core/diff/strategies/new-unified/edit-strategies.ts
Normal file
297
src/core/diff/strategies/new-unified/edit-strategies.ts
Normal file
@ -0,0 +1,297 @@
|
|||||||
|
import { diff_match_patch } from "diff-match-patch"
|
||||||
|
import { EditResult, Hunk } from "./types"
|
||||||
|
import { getDMPSimilarity, validateEditResult } from "./search-strategies"
|
||||||
|
import * as path from "path"
|
||||||
|
import simpleGit, { SimpleGit } from "simple-git"
|
||||||
|
import * as tmp from "tmp"
|
||||||
|
import * as fs from "fs"
|
||||||
|
|
||||||
|
// Helper function to infer indentation - simplified version
|
||||||
|
function inferIndentation(line: string, contextLines: string[], previousIndent: string = ""): string {
|
||||||
|
// If the line has explicit indentation in the change, use it exactly
|
||||||
|
const lineMatch = line.match(/^(\s+)/)
|
||||||
|
if (lineMatch) {
|
||||||
|
return lineMatch[1]
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we have context lines, use the indentation from the first context line
|
||||||
|
const contextLine = contextLines[0]
|
||||||
|
if (contextLine) {
|
||||||
|
const contextMatch = contextLine.match(/^(\s+)/)
|
||||||
|
if (contextMatch) {
|
||||||
|
return contextMatch[1]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback to previous indent
|
||||||
|
return previousIndent
|
||||||
|
}
|
||||||
|
|
||||||
|
// Context matching edit strategy
|
||||||
|
export function applyContextMatching(hunk: Hunk, content: string[], matchPosition: number): EditResult {
|
||||||
|
if (matchPosition === -1) {
|
||||||
|
return { confidence: 0, result: content, strategy: "context" }
|
||||||
|
}
|
||||||
|
|
||||||
|
const newResult = [...content.slice(0, matchPosition)]
|
||||||
|
let sourceIndex = matchPosition
|
||||||
|
|
||||||
|
for (const change of hunk.changes) {
|
||||||
|
if (change.type === "context") {
|
||||||
|
// Use the original line from content if available
|
||||||
|
if (sourceIndex < content.length) {
|
||||||
|
newResult.push(content[sourceIndex])
|
||||||
|
} else {
|
||||||
|
const line = change.indent ? change.indent + change.content : change.content
|
||||||
|
newResult.push(line)
|
||||||
|
}
|
||||||
|
sourceIndex++
|
||||||
|
} else if (change.type === "add") {
|
||||||
|
// Use exactly the indentation from the change
|
||||||
|
const baseIndent = change.indent || ""
|
||||||
|
|
||||||
|
// Handle multi-line additions
|
||||||
|
const lines = change.content.split("\n").map((line) => {
|
||||||
|
// If the line already has indentation, preserve it relative to the base indent
|
||||||
|
const lineIndentMatch = line.match(/^(\s*)(.*)/)
|
||||||
|
if (lineIndentMatch) {
|
||||||
|
const [, lineIndent, content] = lineIndentMatch
|
||||||
|
// Only add base indent if the line doesn't already have it
|
||||||
|
return lineIndent ? line : baseIndent + content
|
||||||
|
}
|
||||||
|
return baseIndent + line
|
||||||
|
})
|
||||||
|
|
||||||
|
newResult.push(...lines)
|
||||||
|
} else if (change.type === "remove") {
|
||||||
|
// Handle multi-line removes by incrementing sourceIndex for each line
|
||||||
|
const removedLines = change.content.split("\n").length
|
||||||
|
sourceIndex += removedLines
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Append remaining content
|
||||||
|
newResult.push(...content.slice(sourceIndex))
|
||||||
|
|
||||||
|
// Calculate confidence based on the actual changes
|
||||||
|
const afterText = newResult.slice(matchPosition, newResult.length - (content.length - sourceIndex)).join("\n")
|
||||||
|
|
||||||
|
const confidence = validateEditResult(hunk, afterText)
|
||||||
|
|
||||||
|
return {
|
||||||
|
confidence,
|
||||||
|
result: newResult,
|
||||||
|
strategy: "context",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DMP edit strategy
|
||||||
|
export function applyDMP(hunk: Hunk, content: string[], matchPosition: number): EditResult {
|
||||||
|
if (matchPosition === -1) {
|
||||||
|
return { confidence: 0, result: content, strategy: "dmp" }
|
||||||
|
}
|
||||||
|
|
||||||
|
const dmp = new diff_match_patch()
|
||||||
|
|
||||||
|
// Calculate total lines in before block accounting for multi-line content
|
||||||
|
const beforeLineCount = hunk.changes
|
||||||
|
.filter((change) => change.type === "context" || change.type === "remove")
|
||||||
|
.reduce((count, change) => count + change.content.split("\n").length, 0)
|
||||||
|
|
||||||
|
// Build BEFORE block (context + removals)
|
||||||
|
const beforeLines = hunk.changes
|
||||||
|
.filter((change) => change.type === "context" || change.type === "remove")
|
||||||
|
.map((change) => {
|
||||||
|
if (change.originalLine) {
|
||||||
|
return change.originalLine
|
||||||
|
}
|
||||||
|
return change.indent ? change.indent + change.content : change.content
|
||||||
|
})
|
||||||
|
|
||||||
|
// Build AFTER block (context + additions)
|
||||||
|
const afterLines = hunk.changes
|
||||||
|
.filter((change) => change.type === "context" || change.type === "add")
|
||||||
|
.map((change) => {
|
||||||
|
if (change.originalLine) {
|
||||||
|
return change.originalLine
|
||||||
|
}
|
||||||
|
return change.indent ? change.indent + change.content : change.content
|
||||||
|
})
|
||||||
|
|
||||||
|
// Convert to text with proper line endings
|
||||||
|
const beforeText = beforeLines.join("\n")
|
||||||
|
const afterText = afterLines.join("\n")
|
||||||
|
|
||||||
|
// Create and apply patch
|
||||||
|
const patch = dmp.patch_make(beforeText, afterText)
|
||||||
|
const targetText = content.slice(matchPosition, matchPosition + beforeLineCount).join("\n")
|
||||||
|
const [patchedText] = dmp.patch_apply(patch, targetText)
|
||||||
|
|
||||||
|
// Split result and preserve line endings
|
||||||
|
const patchedLines = patchedText.split("\n")
|
||||||
|
|
||||||
|
// Construct final result
|
||||||
|
const newResult = [
|
||||||
|
...content.slice(0, matchPosition),
|
||||||
|
...patchedLines,
|
||||||
|
...content.slice(matchPosition + beforeLineCount),
|
||||||
|
]
|
||||||
|
|
||||||
|
const confidence = validateEditResult(hunk, patchedText)
|
||||||
|
|
||||||
|
return {
|
||||||
|
confidence,
|
||||||
|
result: newResult,
|
||||||
|
strategy: "dmp",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Git fallback strategy that works with full content
|
||||||
|
export async function applyGitFallback(hunk: Hunk, content: string[]): Promise<EditResult> {
|
||||||
|
let tmpDir: tmp.DirResult | undefined
|
||||||
|
|
||||||
|
try {
|
||||||
|
tmpDir = tmp.dirSync({ unsafeCleanup: true })
|
||||||
|
const git: SimpleGit = simpleGit(tmpDir.name)
|
||||||
|
|
||||||
|
await git.init()
|
||||||
|
await git.addConfig("user.name", "Temp")
|
||||||
|
await git.addConfig("user.email", "temp@example.com")
|
||||||
|
|
||||||
|
const filePath = path.join(tmpDir.name, "file.txt")
|
||||||
|
|
||||||
|
const searchLines = hunk.changes
|
||||||
|
.filter((change) => change.type === "context" || change.type === "remove")
|
||||||
|
.map((change) => change.originalLine || change.indent + change.content)
|
||||||
|
|
||||||
|
const replaceLines = hunk.changes
|
||||||
|
.filter((change) => change.type === "context" || change.type === "add")
|
||||||
|
.map((change) => change.originalLine || change.indent + change.content)
|
||||||
|
|
||||||
|
const searchText = searchLines.join("\n")
|
||||||
|
const replaceText = replaceLines.join("\n")
|
||||||
|
const originalText = content.join("\n")
|
||||||
|
|
||||||
|
try {
|
||||||
|
fs.writeFileSync(filePath, originalText)
|
||||||
|
await git.add("file.txt")
|
||||||
|
const originalCommit = await git.commit("original")
|
||||||
|
console.log("Strategy 1 - Original commit:", originalCommit.commit)
|
||||||
|
|
||||||
|
fs.writeFileSync(filePath, searchText)
|
||||||
|
await git.add("file.txt")
|
||||||
|
const searchCommit1 = await git.commit("search")
|
||||||
|
console.log("Strategy 1 - Search commit:", searchCommit1.commit)
|
||||||
|
|
||||||
|
fs.writeFileSync(filePath, replaceText)
|
||||||
|
await git.add("file.txt")
|
||||||
|
const replaceCommit = await git.commit("replace")
|
||||||
|
console.log("Strategy 1 - Replace commit:", replaceCommit.commit)
|
||||||
|
|
||||||
|
console.log("Strategy 1 - Attempting checkout of:", originalCommit.commit)
|
||||||
|
await git.raw(["checkout", originalCommit.commit])
|
||||||
|
try {
|
||||||
|
console.log("Strategy 1 - Attempting cherry-pick of:", replaceCommit.commit)
|
||||||
|
await git.raw(["cherry-pick", "--minimal", replaceCommit.commit])
|
||||||
|
|
||||||
|
const newText = fs.readFileSync(filePath, "utf-8")
|
||||||
|
const newLines = newText.split("\n")
|
||||||
|
return {
|
||||||
|
confidence: 1,
|
||||||
|
result: newLines,
|
||||||
|
strategy: "git-fallback",
|
||||||
|
}
|
||||||
|
} catch (cherryPickError) {
|
||||||
|
console.error("Strategy 1 failed with merge conflict")
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error("Strategy 1 failed:", error)
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
await git.init()
|
||||||
|
await git.addConfig("user.name", "Temp")
|
||||||
|
await git.addConfig("user.email", "temp@example.com")
|
||||||
|
|
||||||
|
fs.writeFileSync(filePath, searchText)
|
||||||
|
await git.add("file.txt")
|
||||||
|
const searchCommit = await git.commit("search")
|
||||||
|
const searchHash = searchCommit.commit.replace(/^HEAD /, "")
|
||||||
|
console.log("Strategy 2 - Search commit:", searchHash)
|
||||||
|
|
||||||
|
fs.writeFileSync(filePath, replaceText)
|
||||||
|
await git.add("file.txt")
|
||||||
|
const replaceCommit = await git.commit("replace")
|
||||||
|
const replaceHash = replaceCommit.commit.replace(/^HEAD /, "")
|
||||||
|
console.log("Strategy 2 - Replace commit:", replaceHash)
|
||||||
|
|
||||||
|
console.log("Strategy 2 - Attempting checkout of:", searchHash)
|
||||||
|
await git.raw(["checkout", searchHash])
|
||||||
|
fs.writeFileSync(filePath, originalText)
|
||||||
|
await git.add("file.txt")
|
||||||
|
const originalCommit2 = await git.commit("original")
|
||||||
|
console.log("Strategy 2 - Original commit:", originalCommit2.commit)
|
||||||
|
|
||||||
|
try {
|
||||||
|
console.log("Strategy 2 - Attempting cherry-pick of:", replaceHash)
|
||||||
|
await git.raw(["cherry-pick", "--minimal", replaceHash])
|
||||||
|
|
||||||
|
const newText = fs.readFileSync(filePath, "utf-8")
|
||||||
|
const newLines = newText.split("\n")
|
||||||
|
return {
|
||||||
|
confidence: 1,
|
||||||
|
result: newLines,
|
||||||
|
strategy: "git-fallback",
|
||||||
|
}
|
||||||
|
} catch (cherryPickError) {
|
||||||
|
console.error("Strategy 2 failed with merge conflict")
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error("Strategy 2 failed:", error)
|
||||||
|
}
|
||||||
|
|
||||||
|
console.error("Git fallback failed")
|
||||||
|
return { confidence: 0, result: content, strategy: "git-fallback" }
|
||||||
|
} catch (error) {
|
||||||
|
console.error("Git fallback strategy failed:", error)
|
||||||
|
return { confidence: 0, result: content, strategy: "git-fallback" }
|
||||||
|
} finally {
|
||||||
|
if (tmpDir) {
|
||||||
|
tmpDir.removeCallback()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Main edit function that tries strategies sequentially
|
||||||
|
export async function applyEdit(
|
||||||
|
hunk: Hunk,
|
||||||
|
content: string[],
|
||||||
|
matchPosition: number,
|
||||||
|
confidence: number,
|
||||||
|
confidenceThreshold: number = 0.97,
|
||||||
|
): Promise<EditResult> {
|
||||||
|
// Don't attempt regular edits if confidence is too low
|
||||||
|
if (confidence < confidenceThreshold) {
|
||||||
|
console.log(
|
||||||
|
`Search confidence (${confidence}) below minimum threshold (${confidenceThreshold}), trying git fallback...`,
|
||||||
|
)
|
||||||
|
return applyGitFallback(hunk, content)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try each strategy in sequence until one succeeds
|
||||||
|
const strategies = [
|
||||||
|
{ name: "dmp", apply: () => applyDMP(hunk, content, matchPosition) },
|
||||||
|
{ name: "context", apply: () => applyContextMatching(hunk, content, matchPosition) },
|
||||||
|
{ name: "git-fallback", apply: () => applyGitFallback(hunk, content) },
|
||||||
|
]
|
||||||
|
|
||||||
|
// Try strategies sequentially until one succeeds
|
||||||
|
for (const strategy of strategies) {
|
||||||
|
const result = await strategy.apply()
|
||||||
|
if (result.confidence >= confidenceThreshold) {
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return { confidence: 0, result: content, strategy: "none" }
|
||||||
|
}
|
||||||
350
src/core/diff/strategies/new-unified/index.ts
Normal file
350
src/core/diff/strategies/new-unified/index.ts
Normal file
@ -0,0 +1,350 @@
|
|||||||
|
import { Diff, Hunk, Change } from "./types"
|
||||||
|
import { findBestMatch, prepareSearchString } from "./search-strategies"
|
||||||
|
import { applyEdit } from "./edit-strategies"
|
||||||
|
import { DiffResult, DiffStrategy } from "../../types"
|
||||||
|
|
||||||
|
export class NewUnifiedDiffStrategy implements DiffStrategy {
|
||||||
|
private readonly confidenceThreshold: number
|
||||||
|
|
||||||
|
constructor(confidenceThreshold: number = 1) {
|
||||||
|
this.confidenceThreshold = Math.max(confidenceThreshold, 0.8)
|
||||||
|
}
|
||||||
|
|
||||||
|
private parseUnifiedDiff(diff: string): Diff {
|
||||||
|
const MAX_CONTEXT_LINES = 6 // Number of context lines to keep before/after changes
|
||||||
|
const lines = diff.split("\n")
|
||||||
|
const hunks: Hunk[] = []
|
||||||
|
let currentHunk: Hunk | null = null
|
||||||
|
|
||||||
|
let i = 0
|
||||||
|
while (i < lines.length && !lines[i].startsWith("@@")) {
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
|
||||||
|
for (; i < lines.length; i++) {
|
||||||
|
const line = lines[i]
|
||||||
|
|
||||||
|
if (line.startsWith("@@")) {
|
||||||
|
if (
|
||||||
|
currentHunk &&
|
||||||
|
currentHunk.changes.length > 0 &&
|
||||||
|
currentHunk.changes.some((change) => change.type === "add" || change.type === "remove")
|
||||||
|
) {
|
||||||
|
const changes = currentHunk.changes
|
||||||
|
let startIdx = 0
|
||||||
|
let endIdx = changes.length - 1
|
||||||
|
|
||||||
|
for (let j = 0; j < changes.length; j++) {
|
||||||
|
if (changes[j].type !== "context") {
|
||||||
|
startIdx = Math.max(0, j - MAX_CONTEXT_LINES)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (let j = changes.length - 1; j >= 0; j--) {
|
||||||
|
if (changes[j].type !== "context") {
|
||||||
|
endIdx = Math.min(changes.length - 1, j + MAX_CONTEXT_LINES)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
currentHunk.changes = changes.slice(startIdx, endIdx + 1)
|
||||||
|
hunks.push(currentHunk)
|
||||||
|
}
|
||||||
|
currentHunk = { changes: [] }
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!currentHunk) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
const content = line.slice(1)
|
||||||
|
const indentMatch = content.match(/^(\s*)/)
|
||||||
|
const indent = indentMatch ? indentMatch[0] : ""
|
||||||
|
const trimmedContent = content.slice(indent.length)
|
||||||
|
|
||||||
|
if (line.startsWith(" ")) {
|
||||||
|
currentHunk.changes.push({
|
||||||
|
type: "context",
|
||||||
|
content: trimmedContent,
|
||||||
|
indent,
|
||||||
|
originalLine: content,
|
||||||
|
})
|
||||||
|
} else if (line.startsWith("+")) {
|
||||||
|
currentHunk.changes.push({
|
||||||
|
type: "add",
|
||||||
|
content: trimmedContent,
|
||||||
|
indent,
|
||||||
|
originalLine: content,
|
||||||
|
})
|
||||||
|
} else if (line.startsWith("-")) {
|
||||||
|
currentHunk.changes.push({
|
||||||
|
type: "remove",
|
||||||
|
content: trimmedContent,
|
||||||
|
indent,
|
||||||
|
originalLine: content,
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
const finalContent = trimmedContent ? " " + trimmedContent : " "
|
||||||
|
currentHunk.changes.push({
|
||||||
|
type: "context",
|
||||||
|
content: finalContent,
|
||||||
|
indent,
|
||||||
|
originalLine: content,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (
|
||||||
|
currentHunk &&
|
||||||
|
currentHunk.changes.length > 0 &&
|
||||||
|
currentHunk.changes.some((change) => change.type === "add" || change.type === "remove")
|
||||||
|
) {
|
||||||
|
hunks.push(currentHunk)
|
||||||
|
}
|
||||||
|
|
||||||
|
return { hunks }
|
||||||
|
}
|
||||||
|
|
||||||
|
getToolDescription(args: { cwd: string; toolOptions?: { [key: string]: string } }): string {
|
||||||
|
return `# apply_diff Tool - Generate Precise Code Changes
|
||||||
|
|
||||||
|
Generate a unified diff that can be cleanly applied to modify code files.
|
||||||
|
|
||||||
|
## Step-by-Step Instructions:
|
||||||
|
|
||||||
|
1. Start with file headers:
|
||||||
|
- First line: "--- {original_file_path}"
|
||||||
|
- Second line: "+++ {new_file_path}"
|
||||||
|
|
||||||
|
2. For each change section:
|
||||||
|
- Begin with "@@ ... @@" separator line without line numbers
|
||||||
|
- Include 2-3 lines of context before and after changes
|
||||||
|
- Mark removed lines with "-"
|
||||||
|
- Mark added lines with "+"
|
||||||
|
- Preserve exact indentation
|
||||||
|
|
||||||
|
3. Group related changes:
|
||||||
|
- Keep related modifications in the same hunk
|
||||||
|
- Start new hunks for logically separate changes
|
||||||
|
- When modifying functions/methods, include the entire block
|
||||||
|
|
||||||
|
## Requirements:
|
||||||
|
|
||||||
|
1. MUST include exact indentation
|
||||||
|
2. MUST include sufficient context for unique matching
|
||||||
|
3. MUST group related changes together
|
||||||
|
4. MUST use proper unified diff format
|
||||||
|
5. MUST NOT include timestamps in file headers
|
||||||
|
6. MUST NOT include line numbers in the @@ header
|
||||||
|
|
||||||
|
## Examples:
|
||||||
|
|
||||||
|
✅ Good diff (follows all requirements):
|
||||||
|
\`\`\`diff
|
||||||
|
--- src/utils.ts
|
||||||
|
+++ src/utils.ts
|
||||||
|
@@ ... @@
|
||||||
|
def calculate_total(items):
|
||||||
|
- total = 0
|
||||||
|
- for item in items:
|
||||||
|
- total += item.price
|
||||||
|
+ return sum(item.price for item in items)
|
||||||
|
\`\`\`
|
||||||
|
|
||||||
|
❌ Bad diff (violates requirements #1 and #2):
|
||||||
|
\`\`\`diff
|
||||||
|
--- src/utils.ts
|
||||||
|
+++ src/utils.ts
|
||||||
|
@@ ... @@
|
||||||
|
-total = 0
|
||||||
|
-for item in items:
|
||||||
|
+return sum(item.price for item in items)
|
||||||
|
\`\`\`
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- path: (required) File path relative to ${args.cwd}
|
||||||
|
- diff: (required) Unified diff content in unified format to apply to the file.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
<apply_diff>
|
||||||
|
<path>path/to/file.ext</path>
|
||||||
|
<diff>
|
||||||
|
Your diff here
|
||||||
|
</diff>
|
||||||
|
</apply_diff>`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function to split a hunk into smaller hunks based on contiguous changes
|
||||||
|
private splitHunk(hunk: Hunk): Hunk[] {
|
||||||
|
const result: Hunk[] = []
|
||||||
|
let currentHunk: Hunk | null = null
|
||||||
|
let contextBefore: Change[] = []
|
||||||
|
let contextAfter: Change[] = []
|
||||||
|
const MAX_CONTEXT_LINES = 3 // Keep 3 lines of context before/after changes
|
||||||
|
|
||||||
|
for (let i = 0; i < hunk.changes.length; i++) {
|
||||||
|
const change = hunk.changes[i]
|
||||||
|
|
||||||
|
if (change.type === "context") {
|
||||||
|
if (!currentHunk) {
|
||||||
|
contextBefore.push(change)
|
||||||
|
if (contextBefore.length > MAX_CONTEXT_LINES) {
|
||||||
|
contextBefore.shift()
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
contextAfter.push(change)
|
||||||
|
if (contextAfter.length > MAX_CONTEXT_LINES) {
|
||||||
|
// We've collected enough context after changes, create a new hunk
|
||||||
|
currentHunk.changes.push(...contextAfter)
|
||||||
|
result.push(currentHunk)
|
||||||
|
currentHunk = null
|
||||||
|
// Keep the last few context lines for the next hunk
|
||||||
|
contextBefore = contextAfter
|
||||||
|
contextAfter = []
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (!currentHunk) {
|
||||||
|
currentHunk = { changes: [...contextBefore] }
|
||||||
|
contextAfter = []
|
||||||
|
} else if (contextAfter.length > 0) {
|
||||||
|
// Add accumulated context to current hunk
|
||||||
|
currentHunk.changes.push(...contextAfter)
|
||||||
|
contextAfter = []
|
||||||
|
}
|
||||||
|
currentHunk.changes.push(change)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add any remaining changes
|
||||||
|
if (currentHunk) {
|
||||||
|
if (contextAfter.length > 0) {
|
||||||
|
currentHunk.changes.push(...contextAfter)
|
||||||
|
}
|
||||||
|
result.push(currentHunk)
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
async applyDiff(
|
||||||
|
originalContent: string,
|
||||||
|
diffContent: string,
|
||||||
|
startLine?: number,
|
||||||
|
endLine?: number,
|
||||||
|
): Promise<DiffResult> {
|
||||||
|
const parsedDiff = this.parseUnifiedDiff(diffContent)
|
||||||
|
const originalLines = originalContent.split("\n")
|
||||||
|
let result = [...originalLines]
|
||||||
|
|
||||||
|
if (!parsedDiff.hunks.length) {
|
||||||
|
return {
|
||||||
|
success: false,
|
||||||
|
error: "No hunks found in diff. Please ensure your diff includes actual changes and follows the unified diff format.",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const hunk of parsedDiff.hunks) {
|
||||||
|
const contextStr = prepareSearchString(hunk.changes)
|
||||||
|
const {
|
||||||
|
index: matchPosition,
|
||||||
|
confidence,
|
||||||
|
strategy,
|
||||||
|
} = findBestMatch(contextStr, result, 0, this.confidenceThreshold)
|
||||||
|
|
||||||
|
if (confidence < this.confidenceThreshold) {
|
||||||
|
console.log("Full hunk application failed, trying sub-hunks strategy")
|
||||||
|
// Try splitting the hunk into smaller hunks
|
||||||
|
const subHunks = this.splitHunk(hunk)
|
||||||
|
let subHunkSuccess = true
|
||||||
|
let subHunkResult = [...result]
|
||||||
|
|
||||||
|
for (const subHunk of subHunks) {
|
||||||
|
const subContextStr = prepareSearchString(subHunk.changes)
|
||||||
|
const subSearchResult = findBestMatch(subContextStr, subHunkResult, 0, this.confidenceThreshold)
|
||||||
|
|
||||||
|
if (subSearchResult.confidence >= this.confidenceThreshold) {
|
||||||
|
const subEditResult = await applyEdit(
|
||||||
|
subHunk,
|
||||||
|
subHunkResult,
|
||||||
|
subSearchResult.index,
|
||||||
|
subSearchResult.confidence,
|
||||||
|
this.confidenceThreshold,
|
||||||
|
)
|
||||||
|
if (subEditResult.confidence >= this.confidenceThreshold) {
|
||||||
|
subHunkResult = subEditResult.result
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
subHunkSuccess = false
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if (subHunkSuccess) {
|
||||||
|
result = subHunkResult
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// If sub-hunks also failed, return the original error
|
||||||
|
const contextLines = hunk.changes.filter((c) => c.type === "context").length
|
||||||
|
const totalLines = hunk.changes.length
|
||||||
|
const contextRatio = contextLines / totalLines
|
||||||
|
|
||||||
|
let errorMsg = `Failed to find a matching location in the file (${Math.floor(
|
||||||
|
confidence * 100,
|
||||||
|
)}% confidence, needs ${Math.floor(this.confidenceThreshold * 100)}%)\n\n`
|
||||||
|
errorMsg += "Debug Info:\n"
|
||||||
|
errorMsg += `- Search Strategy Used: ${strategy}\n`
|
||||||
|
errorMsg += `- Context Lines: ${contextLines} out of ${totalLines} total lines (${Math.floor(
|
||||||
|
contextRatio * 100,
|
||||||
|
)}%)\n`
|
||||||
|
errorMsg += `- Attempted to split into ${subHunks.length} sub-hunks but still failed\n`
|
||||||
|
|
||||||
|
if (contextRatio < 0.2) {
|
||||||
|
errorMsg += "\nPossible Issues:\n"
|
||||||
|
errorMsg += "- Not enough context lines to uniquely identify the location\n"
|
||||||
|
errorMsg += "- Add a few more lines of unchanged code around your changes\n"
|
||||||
|
} else if (contextRatio > 0.5) {
|
||||||
|
errorMsg += "\nPossible Issues:\n"
|
||||||
|
errorMsg += "- Too many context lines may reduce search accuracy\n"
|
||||||
|
errorMsg += "- Try to keep only 2-3 lines of context before and after changes\n"
|
||||||
|
} else {
|
||||||
|
errorMsg += "\nPossible Issues:\n"
|
||||||
|
errorMsg += "- The diff may be targeting a different version of the file\n"
|
||||||
|
errorMsg +=
|
||||||
|
"- There may be too many changes in a single hunk, try splitting the changes into multiple hunks\n"
|
||||||
|
}
|
||||||
|
|
||||||
|
if (startLine && endLine) {
|
||||||
|
errorMsg += `\nSearch Range: lines ${startLine}-${endLine}\n`
|
||||||
|
}
|
||||||
|
|
||||||
|
return { success: false, error: errorMsg }
|
||||||
|
}
|
||||||
|
|
||||||
|
const editResult = await applyEdit(hunk, result, matchPosition, confidence, this.confidenceThreshold)
|
||||||
|
if (editResult.confidence >= this.confidenceThreshold) {
|
||||||
|
result = editResult.result
|
||||||
|
} else {
|
||||||
|
// Edit failure - likely due to content mismatch
|
||||||
|
let errorMsg = `Failed to apply the edit using ${editResult.strategy} strategy (${Math.floor(
|
||||||
|
editResult.confidence * 100,
|
||||||
|
)}% confidence)\n\n`
|
||||||
|
errorMsg += "Debug Info:\n"
|
||||||
|
errorMsg += "- The location was found but the content didn't match exactly\n"
|
||||||
|
errorMsg += "- This usually means the file has been modified since the diff was created\n"
|
||||||
|
errorMsg += "- Or the diff may be targeting a different version of the file\n"
|
||||||
|
errorMsg += "\nPossible Solutions:\n"
|
||||||
|
errorMsg += "1. Refresh your view of the file and create a new diff\n"
|
||||||
|
errorMsg += "2. Double-check that the removed lines (-) match the current file content\n"
|
||||||
|
errorMsg += "3. Ensure your diff targets the correct version of the file"
|
||||||
|
|
||||||
|
return { success: false, error: errorMsg }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return { success: true, content: result.join("\n") }
|
||||||
|
}
|
||||||
|
}
|
||||||
408
src/core/diff/strategies/new-unified/search-strategies.ts
Normal file
408
src/core/diff/strategies/new-unified/search-strategies.ts
Normal file
@ -0,0 +1,408 @@
|
|||||||
|
import { compareTwoStrings } from "string-similarity"
|
||||||
|
import { closest } from "fastest-levenshtein"
|
||||||
|
import { diff_match_patch } from "diff-match-patch"
|
||||||
|
import { Change, Hunk } from "./types"
|
||||||
|
|
||||||
|
export type SearchResult = {
|
||||||
|
index: number
|
||||||
|
confidence: number
|
||||||
|
strategy: string
|
||||||
|
}
|
||||||
|
|
||||||
|
const LARGE_FILE_THRESHOLD = 1000 // lines
|
||||||
|
const UNIQUE_CONTENT_BOOST = 0.05
|
||||||
|
const DEFAULT_OVERLAP_SIZE = 3 // lines of overlap between windows
|
||||||
|
const MAX_WINDOW_SIZE = 500 // maximum lines in a window
|
||||||
|
|
||||||
|
// Helper function to calculate adaptive confidence threshold based on file size
|
||||||
|
function getAdaptiveThreshold(contentLength: number, baseThreshold: number): number {
|
||||||
|
if (contentLength <= LARGE_FILE_THRESHOLD) {
|
||||||
|
return baseThreshold
|
||||||
|
}
|
||||||
|
return Math.max(baseThreshold - 0.07, 0.8) // Reduce threshold for large files but keep minimum at 80%
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function to evaluate content uniqueness
|
||||||
|
function evaluateContentUniqueness(searchStr: string, content: string[]): number {
|
||||||
|
const searchLines = searchStr.split("\n")
|
||||||
|
const uniqueLines = new Set(searchLines)
|
||||||
|
const contentStr = content.join("\n")
|
||||||
|
|
||||||
|
// Calculate how many search lines are relatively unique in the content
|
||||||
|
let uniqueCount = 0
|
||||||
|
for (const line of uniqueLines) {
|
||||||
|
const regex = new RegExp(line.replace(/[.*+?^${}()|[\]\\]/g, "\\$&"), "g")
|
||||||
|
const matches = contentStr.match(regex)
|
||||||
|
if (matches && matches.length <= 2) {
|
||||||
|
// Line appears at most twice
|
||||||
|
uniqueCount++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return uniqueCount / uniqueLines.size
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function to prepare search string from context
|
||||||
|
export function prepareSearchString(changes: Change[]): string {
|
||||||
|
const lines = changes.filter((c) => c.type === "context" || c.type === "remove").map((c) => c.originalLine)
|
||||||
|
return lines.join("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function to evaluate similarity between two texts
|
||||||
|
export function evaluateSimilarity(original: string, modified: string): number {
|
||||||
|
return compareTwoStrings(original, modified)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function to validate using diff-match-patch
|
||||||
|
export function getDMPSimilarity(original: string, modified: string): number {
|
||||||
|
const dmp = new diff_match_patch()
|
||||||
|
const diffs = dmp.diff_main(original, modified)
|
||||||
|
dmp.diff_cleanupSemantic(diffs)
|
||||||
|
const patches = dmp.patch_make(original, diffs)
|
||||||
|
const [expectedText] = dmp.patch_apply(patches, original)
|
||||||
|
|
||||||
|
const similarity = evaluateSimilarity(expectedText, modified)
|
||||||
|
return similarity
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function to validate edit results using hunk information
|
||||||
|
export function validateEditResult(hunk: Hunk, result: string): number {
|
||||||
|
// Build the expected text from the hunk
|
||||||
|
const expectedText = hunk.changes
|
||||||
|
.filter((change) => change.type === "context" || change.type === "add")
|
||||||
|
.map((change) => (change.indent ? change.indent + change.content : change.content))
|
||||||
|
.join("\n")
|
||||||
|
|
||||||
|
// Calculate similarity between the result and expected text
|
||||||
|
const similarity = getDMPSimilarity(expectedText, result)
|
||||||
|
|
||||||
|
// If the result is unchanged from original, return low confidence
|
||||||
|
const originalText = hunk.changes
|
||||||
|
.filter((change) => change.type === "context" || change.type === "remove")
|
||||||
|
.map((change) => (change.indent ? change.indent + change.content : change.content))
|
||||||
|
.join("\n")
|
||||||
|
|
||||||
|
const originalSimilarity = getDMPSimilarity(originalText, result)
|
||||||
|
if (originalSimilarity > 0.97 && similarity !== 1) {
|
||||||
|
return 0.8 * similarity // Some confidence since we found the right location
|
||||||
|
}
|
||||||
|
|
||||||
|
// For partial matches, scale the confidence but keep it high if we're close
|
||||||
|
return similarity
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function to validate context lines against original content
|
||||||
|
function validateContextLines(searchStr: string, content: string, confidenceThreshold: number): number {
|
||||||
|
// Extract just the context lines from the search string
|
||||||
|
const contextLines = searchStr.split("\n").filter((line) => !line.startsWith("-")) // Exclude removed lines
|
||||||
|
|
||||||
|
// Compare context lines with content
|
||||||
|
const similarity = evaluateSimilarity(contextLines.join("\n"), content)
|
||||||
|
|
||||||
|
// Get adaptive threshold based on content size
|
||||||
|
const threshold = getAdaptiveThreshold(content.split("\n").length, confidenceThreshold)
|
||||||
|
|
||||||
|
// Calculate uniqueness boost
|
||||||
|
const uniquenessScore = evaluateContentUniqueness(searchStr, content.split("\n"))
|
||||||
|
const uniquenessBoost = uniquenessScore * UNIQUE_CONTENT_BOOST
|
||||||
|
|
||||||
|
// Adjust confidence based on threshold and uniqueness
|
||||||
|
return similarity < threshold ? similarity * 0.3 + uniquenessBoost : similarity + uniquenessBoost
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function to create overlapping windows
|
||||||
|
function createOverlappingWindows(
|
||||||
|
content: string[],
|
||||||
|
searchSize: number,
|
||||||
|
overlapSize: number = DEFAULT_OVERLAP_SIZE,
|
||||||
|
): { window: string[]; startIndex: number }[] {
|
||||||
|
const windows: { window: string[]; startIndex: number }[] = []
|
||||||
|
|
||||||
|
// Ensure minimum window size is at least searchSize
|
||||||
|
const effectiveWindowSize = Math.max(searchSize, Math.min(searchSize * 2, MAX_WINDOW_SIZE))
|
||||||
|
|
||||||
|
// Ensure overlap size doesn't exceed window size
|
||||||
|
const effectiveOverlapSize = Math.min(overlapSize, effectiveWindowSize - 1)
|
||||||
|
|
||||||
|
// Calculate step size, ensure it's at least 1
|
||||||
|
const stepSize = Math.max(1, effectiveWindowSize - effectiveOverlapSize)
|
||||||
|
|
||||||
|
for (let i = 0; i < content.length; i += stepSize) {
|
||||||
|
const windowContent = content.slice(i, i + effectiveWindowSize)
|
||||||
|
if (windowContent.length >= searchSize) {
|
||||||
|
windows.push({ window: windowContent, startIndex: i })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return windows
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function to combine overlapping matches
|
||||||
|
function combineOverlappingMatches(
|
||||||
|
matches: (SearchResult & { windowIndex: number })[],
|
||||||
|
overlapSize: number = DEFAULT_OVERLAP_SIZE,
|
||||||
|
): SearchResult[] {
|
||||||
|
if (matches.length === 0) {
|
||||||
|
return []
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sort matches by confidence
|
||||||
|
matches.sort((a, b) => b.confidence - a.confidence)
|
||||||
|
|
||||||
|
const combinedMatches: SearchResult[] = []
|
||||||
|
const usedIndices = new Set<number>()
|
||||||
|
|
||||||
|
for (const match of matches) {
|
||||||
|
if (usedIndices.has(match.windowIndex)) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find overlapping matches
|
||||||
|
const overlapping = matches.filter(
|
||||||
|
(m) =>
|
||||||
|
Math.abs(m.windowIndex - match.windowIndex) === 1 &&
|
||||||
|
Math.abs(m.index - match.index) <= overlapSize &&
|
||||||
|
!usedIndices.has(m.windowIndex),
|
||||||
|
)
|
||||||
|
|
||||||
|
if (overlapping.length > 0) {
|
||||||
|
// Boost confidence if we find same match in overlapping windows
|
||||||
|
const avgConfidence =
|
||||||
|
(match.confidence + overlapping.reduce((sum, m) => sum + m.confidence, 0)) / (overlapping.length + 1)
|
||||||
|
const boost = Math.min(0.05 * overlapping.length, 0.1) // Max 10% boost
|
||||||
|
|
||||||
|
combinedMatches.push({
|
||||||
|
index: match.index,
|
||||||
|
confidence: Math.min(1, avgConfidence + boost),
|
||||||
|
strategy: `${match.strategy}-overlapping`,
|
||||||
|
})
|
||||||
|
|
||||||
|
usedIndices.add(match.windowIndex)
|
||||||
|
overlapping.forEach((m) => usedIndices.add(m.windowIndex))
|
||||||
|
} else {
|
||||||
|
combinedMatches.push({
|
||||||
|
index: match.index,
|
||||||
|
confidence: match.confidence,
|
||||||
|
strategy: match.strategy,
|
||||||
|
})
|
||||||
|
usedIndices.add(match.windowIndex)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return combinedMatches
|
||||||
|
}
|
||||||
|
|
||||||
|
export function findExactMatch(
|
||||||
|
searchStr: string,
|
||||||
|
content: string[],
|
||||||
|
startIndex: number = 0,
|
||||||
|
confidenceThreshold: number = 0.97,
|
||||||
|
): SearchResult {
|
||||||
|
const searchLines = searchStr.split("\n")
|
||||||
|
const windows = createOverlappingWindows(content.slice(startIndex), searchLines.length)
|
||||||
|
const matches: (SearchResult & { windowIndex: number })[] = []
|
||||||
|
|
||||||
|
windows.forEach((windowData, windowIndex) => {
|
||||||
|
const windowStr = windowData.window.join("\n")
|
||||||
|
const exactMatch = windowStr.indexOf(searchStr)
|
||||||
|
|
||||||
|
if (exactMatch !== -1) {
|
||||||
|
const matchedContent = windowData.window
|
||||||
|
.slice(
|
||||||
|
windowStr.slice(0, exactMatch).split("\n").length - 1,
|
||||||
|
windowStr.slice(0, exactMatch).split("\n").length - 1 + searchLines.length,
|
||||||
|
)
|
||||||
|
.join("\n")
|
||||||
|
|
||||||
|
const similarity = getDMPSimilarity(searchStr, matchedContent)
|
||||||
|
const contextSimilarity = validateContextLines(searchStr, matchedContent, confidenceThreshold)
|
||||||
|
const confidence = Math.min(similarity, contextSimilarity)
|
||||||
|
|
||||||
|
matches.push({
|
||||||
|
index: startIndex + windowData.startIndex + windowStr.slice(0, exactMatch).split("\n").length - 1,
|
||||||
|
confidence,
|
||||||
|
strategy: "exact",
|
||||||
|
windowIndex,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
const combinedMatches = combineOverlappingMatches(matches)
|
||||||
|
return combinedMatches.length > 0 ? combinedMatches[0] : { index: -1, confidence: 0, strategy: "exact" }
|
||||||
|
}
|
||||||
|
|
||||||
|
// String similarity strategy
|
||||||
|
export function findSimilarityMatch(
|
||||||
|
searchStr: string,
|
||||||
|
content: string[],
|
||||||
|
startIndex: number = 0,
|
||||||
|
confidenceThreshold: number = 0.97,
|
||||||
|
): SearchResult {
|
||||||
|
const searchLines = searchStr.split("\n")
|
||||||
|
let bestScore = 0
|
||||||
|
let bestIndex = -1
|
||||||
|
|
||||||
|
for (let i = startIndex; i < content.length - searchLines.length + 1; i++) {
|
||||||
|
const windowStr = content.slice(i, i + searchLines.length).join("\n")
|
||||||
|
const score = compareTwoStrings(searchStr, windowStr)
|
||||||
|
if (score > bestScore && score >= confidenceThreshold) {
|
||||||
|
const similarity = getDMPSimilarity(searchStr, windowStr)
|
||||||
|
const contextSimilarity = validateContextLines(searchStr, windowStr, confidenceThreshold)
|
||||||
|
const adjustedScore = Math.min(similarity, contextSimilarity) * score
|
||||||
|
|
||||||
|
if (adjustedScore > bestScore) {
|
||||||
|
bestScore = adjustedScore
|
||||||
|
bestIndex = i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
index: bestIndex,
|
||||||
|
confidence: bestIndex !== -1 ? bestScore : 0,
|
||||||
|
strategy: "similarity",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Levenshtein strategy
|
||||||
|
export function findLevenshteinMatch(
|
||||||
|
searchStr: string,
|
||||||
|
content: string[],
|
||||||
|
startIndex: number = 0,
|
||||||
|
confidenceThreshold: number = 0.97,
|
||||||
|
): SearchResult {
|
||||||
|
const searchLines = searchStr.split("\n")
|
||||||
|
const candidates = []
|
||||||
|
|
||||||
|
for (let i = startIndex; i < content.length - searchLines.length + 1; i++) {
|
||||||
|
candidates.push(content.slice(i, i + searchLines.length).join("\n"))
|
||||||
|
}
|
||||||
|
|
||||||
|
if (candidates.length > 0) {
|
||||||
|
const closestMatch = closest(searchStr, candidates)
|
||||||
|
const index = startIndex + candidates.indexOf(closestMatch)
|
||||||
|
const similarity = getDMPSimilarity(searchStr, closestMatch)
|
||||||
|
const contextSimilarity = validateContextLines(searchStr, closestMatch, confidenceThreshold)
|
||||||
|
const confidence = Math.min(similarity, contextSimilarity)
|
||||||
|
return {
|
||||||
|
index: confidence === 0 ? -1 : index,
|
||||||
|
confidence: index !== -1 ? confidence : 0,
|
||||||
|
strategy: "levenshtein",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return { index: -1, confidence: 0, strategy: "levenshtein" }
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function to identify anchor lines
|
||||||
|
function identifyAnchors(searchStr: string): { first: string | null; last: string | null } {
|
||||||
|
const searchLines = searchStr.split("\n")
|
||||||
|
let first: string | null = null
|
||||||
|
let last: string | null = null
|
||||||
|
|
||||||
|
// Find the first non-empty line
|
||||||
|
for (const line of searchLines) {
|
||||||
|
if (line.trim()) {
|
||||||
|
first = line
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find the last non-empty line
|
||||||
|
for (let i = searchLines.length - 1; i >= 0; i--) {
|
||||||
|
if (searchLines[i].trim()) {
|
||||||
|
last = searchLines[i]
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return { first, last }
|
||||||
|
}
|
||||||
|
|
||||||
|
// Anchor-based search strategy
|
||||||
|
export function findAnchorMatch(
|
||||||
|
searchStr: string,
|
||||||
|
content: string[],
|
||||||
|
startIndex: number = 0,
|
||||||
|
confidenceThreshold: number = 0.97,
|
||||||
|
): SearchResult {
|
||||||
|
const searchLines = searchStr.split("\n")
|
||||||
|
const { first, last } = identifyAnchors(searchStr)
|
||||||
|
|
||||||
|
if (!first || !last) {
|
||||||
|
return { index: -1, confidence: 0, strategy: "anchor" }
|
||||||
|
}
|
||||||
|
|
||||||
|
let firstIndex = -1
|
||||||
|
let lastIndex = -1
|
||||||
|
|
||||||
|
// Check if the first anchor is unique
|
||||||
|
let firstOccurrences = 0
|
||||||
|
for (const contentLine of content) {
|
||||||
|
if (contentLine === first) {
|
||||||
|
firstOccurrences++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (firstOccurrences !== 1) {
|
||||||
|
return { index: -1, confidence: 0, strategy: "anchor" }
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find the first anchor
|
||||||
|
for (let i = startIndex; i < content.length; i++) {
|
||||||
|
if (content[i] === first) {
|
||||||
|
firstIndex = i
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find the last anchor
|
||||||
|
for (let i = content.length - 1; i >= startIndex; i--) {
|
||||||
|
if (content[i] === last) {
|
||||||
|
lastIndex = i
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (firstIndex === -1 || lastIndex === -1 || lastIndex <= firstIndex) {
|
||||||
|
return { index: -1, confidence: 0, strategy: "anchor" }
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate the context
|
||||||
|
const expectedContext = searchLines.slice(searchLines.indexOf(first) + 1, searchLines.indexOf(last)).join("\n")
|
||||||
|
const actualContext = content.slice(firstIndex + 1, lastIndex).join("\n")
|
||||||
|
const contextSimilarity = evaluateSimilarity(expectedContext, actualContext)
|
||||||
|
|
||||||
|
if (contextSimilarity < getAdaptiveThreshold(content.length, confidenceThreshold)) {
|
||||||
|
return { index: -1, confidence: 0, strategy: "anchor" }
|
||||||
|
}
|
||||||
|
|
||||||
|
const confidence = 1
|
||||||
|
|
||||||
|
return {
|
||||||
|
index: firstIndex,
|
||||||
|
confidence: confidence,
|
||||||
|
strategy: "anchor",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Main search function that tries all strategies
|
||||||
|
export function findBestMatch(
|
||||||
|
searchStr: string,
|
||||||
|
content: string[],
|
||||||
|
startIndex: number = 0,
|
||||||
|
confidenceThreshold: number = 0.97,
|
||||||
|
): SearchResult {
|
||||||
|
const strategies = [findExactMatch, findAnchorMatch, findSimilarityMatch, findLevenshteinMatch]
|
||||||
|
|
||||||
|
let bestResult: SearchResult = { index: -1, confidence: 0, strategy: "none" }
|
||||||
|
|
||||||
|
for (const strategy of strategies) {
|
||||||
|
const result = strategy(searchStr, content, startIndex, confidenceThreshold)
|
||||||
|
if (result.confidence > bestResult.confidence) {
|
||||||
|
bestResult = result
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return bestResult
|
||||||
|
}
|
||||||
20
src/core/diff/strategies/new-unified/types.ts
Normal file
20
src/core/diff/strategies/new-unified/types.ts
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
export type Change = {
|
||||||
|
type: "context" | "add" | "remove"
|
||||||
|
content: string
|
||||||
|
indent: string
|
||||||
|
originalLine?: string
|
||||||
|
}
|
||||||
|
|
||||||
|
export type Hunk = {
|
||||||
|
changes: Change[]
|
||||||
|
}
|
||||||
|
|
||||||
|
export type Diff = {
|
||||||
|
hunks: Hunk[]
|
||||||
|
}
|
||||||
|
|
||||||
|
export type EditResult = {
|
||||||
|
confidence: number
|
||||||
|
result: string[]
|
||||||
|
strategy: string
|
||||||
|
}
|
||||||
302
src/core/diff/strategies/search-replace.ts
Normal file
302
src/core/diff/strategies/search-replace.ts
Normal file
@ -0,0 +1,302 @@
|
|||||||
|
import { DiffStrategy, DiffResult } from "../types"
|
||||||
|
import { addLineNumbers, everyLineHasLineNumbers, stripLineNumbers } from "../../../integrations/misc/extract-text"
|
||||||
|
import { distance } from "fastest-levenshtein"
|
||||||
|
|
||||||
|
const BUFFER_LINES = 20 // Number of extra context lines to show before and after matches
|
||||||
|
|
||||||
|
function getSimilarity(original: string, search: string): number {
|
||||||
|
if (search === "") {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
|
// Normalize strings by removing extra whitespace but preserve case
|
||||||
|
const normalizeStr = (str: string) => str.replace(/\s+/g, " ").trim()
|
||||||
|
|
||||||
|
const normalizedOriginal = normalizeStr(original)
|
||||||
|
const normalizedSearch = normalizeStr(search)
|
||||||
|
|
||||||
|
if (normalizedOriginal === normalizedSearch) {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate Levenshtein distance using fastest-levenshtein's distance function
|
||||||
|
const dist = distance(normalizedOriginal, normalizedSearch)
|
||||||
|
|
||||||
|
// Calculate similarity ratio (0 to 1, where 1 is an exact match)
|
||||||
|
const maxLength = Math.max(normalizedOriginal.length, normalizedSearch.length)
|
||||||
|
return 1 - dist / maxLength
|
||||||
|
}
|
||||||
|
|
||||||
|
export class SearchReplaceDiffStrategy implements DiffStrategy {
|
||||||
|
private fuzzyThreshold: number
|
||||||
|
private bufferLines: number
|
||||||
|
|
||||||
|
constructor(fuzzyThreshold?: number, bufferLines?: number) {
|
||||||
|
// Use provided threshold or default to exact matching (1.0)
|
||||||
|
// Note: fuzzyThreshold is inverted in UI (0% = 1.0, 10% = 0.9)
|
||||||
|
// so we use it directly here
|
||||||
|
this.fuzzyThreshold = fuzzyThreshold ?? 1.0
|
||||||
|
this.bufferLines = bufferLines ?? BUFFER_LINES
|
||||||
|
}
|
||||||
|
|
||||||
|
getToolDescription(args: { cwd: string; toolOptions?: { [key: string]: string } }): string {
|
||||||
|
return `## apply_diff
|
||||||
|
Description: Request to replace existing code using a search and replace block.
|
||||||
|
This tool allows for precise, surgical replaces to files by specifying exactly what content to search for and what to replace it with.
|
||||||
|
The tool will maintain proper indentation and formatting while making changes.
|
||||||
|
Only a single operation is allowed per tool use.
|
||||||
|
The SEARCH section must exactly match existing content including whitespace and indentation.
|
||||||
|
If you're not confident in the exact content to search for, use the read_file tool first to get the exact content.
|
||||||
|
When applying the diffs, be extra careful to remember to change any closing brackets or other syntax that may be affected by the diff farther down in the file.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- path: (required) The path of the file to modify (relative to the current working directory ${args.cwd})
|
||||||
|
- diff: (required) The search/replace block defining the changes.
|
||||||
|
- start_line: (required) The line number where the search block starts.
|
||||||
|
- end_line: (required) The line number where the search block ends.
|
||||||
|
|
||||||
|
Diff format:
|
||||||
|
\`\`\`
|
||||||
|
<<<<<<< SEARCH
|
||||||
|
[exact content to find including whitespace]
|
||||||
|
=======
|
||||||
|
[new content to replace with]
|
||||||
|
>>>>>>> REPLACE
|
||||||
|
\`\`\`
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
Original file:
|
||||||
|
\`\`\`
|
||||||
|
1 | def calculate_total(items):
|
||||||
|
2 | total = 0
|
||||||
|
3 | for item in items:
|
||||||
|
4 | total += item
|
||||||
|
5 | return total
|
||||||
|
\`\`\`
|
||||||
|
|
||||||
|
Search/Replace content:
|
||||||
|
\`\`\`
|
||||||
|
<<<<<<< SEARCH
|
||||||
|
def calculate_total(items):
|
||||||
|
total = 0
|
||||||
|
for item in items:
|
||||||
|
total += item
|
||||||
|
return total
|
||||||
|
=======
|
||||||
|
def calculate_total(items):
|
||||||
|
"""Calculate total with 10% markup"""
|
||||||
|
return sum(item * 1.1 for item in items)
|
||||||
|
>>>>>>> REPLACE
|
||||||
|
\`\`\`
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
<apply_diff>
|
||||||
|
<path>File path here</path>
|
||||||
|
<diff>
|
||||||
|
Your search/replace content here
|
||||||
|
</diff>
|
||||||
|
<start_line>1</start_line>
|
||||||
|
<end_line>5</end_line>
|
||||||
|
</apply_diff>`
|
||||||
|
}
|
||||||
|
|
||||||
|
async applyDiff(
|
||||||
|
originalContent: string,
|
||||||
|
diffContent: string,
|
||||||
|
startLine?: number,
|
||||||
|
endLine?: number,
|
||||||
|
): Promise<DiffResult> {
|
||||||
|
// Extract the search and replace blocks
|
||||||
|
const match = diffContent.match(/<<<<<<< SEARCH\n([\s\S]*?)\n?=======\n([\s\S]*?)\n?>>>>>>> REPLACE/)
|
||||||
|
if (!match) {
|
||||||
|
return {
|
||||||
|
success: false,
|
||||||
|
error: `Invalid diff format - missing required SEARCH/REPLACE sections\n\nDebug Info:\n- Expected Format: <<<<<<< SEARCH\\n[search content]\\n=======\\n[replace content]\\n>>>>>>> REPLACE\n- Tip: Make sure to include both SEARCH and REPLACE sections with correct markers`,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let [_, searchContent, replaceContent] = match
|
||||||
|
|
||||||
|
// Detect line ending from original content
|
||||||
|
const lineEnding = originalContent.includes("\r\n") ? "\r\n" : "\n"
|
||||||
|
|
||||||
|
// Strip line numbers from search and replace content if every line starts with a line number
|
||||||
|
if (everyLineHasLineNumbers(searchContent) && everyLineHasLineNumbers(replaceContent)) {
|
||||||
|
searchContent = stripLineNumbers(searchContent)
|
||||||
|
replaceContent = stripLineNumbers(replaceContent)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Split content into lines, handling both \n and \r\n
|
||||||
|
const searchLines = searchContent === "" ? [] : searchContent.split(/\r?\n/)
|
||||||
|
const replaceLines = replaceContent === "" ? [] : replaceContent.split(/\r?\n/)
|
||||||
|
const originalLines = originalContent.split(/\r?\n/)
|
||||||
|
|
||||||
|
// Validate that empty search requires start line
|
||||||
|
if (searchLines.length === 0 && !startLine) {
|
||||||
|
return {
|
||||||
|
success: false,
|
||||||
|
error: `Empty search content requires start_line to be specified\n\nDebug Info:\n- Empty search content is only valid for insertions at a specific line\n- For insertions, specify the line number where content should be inserted`,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate that empty search requires same start and end line
|
||||||
|
if (searchLines.length === 0 && startLine && endLine && startLine !== endLine) {
|
||||||
|
return {
|
||||||
|
success: false,
|
||||||
|
error: `Empty search content requires start_line and end_line to be the same (got ${startLine}-${endLine})\n\nDebug Info:\n- Empty search content is only valid for insertions at a specific line\n- For insertions, use the same line number for both start_line and end_line`,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize search variables
|
||||||
|
let matchIndex = -1
|
||||||
|
let bestMatchScore = 0
|
||||||
|
let bestMatchContent = ""
|
||||||
|
const searchChunk = searchLines.join("\n")
|
||||||
|
|
||||||
|
// Determine search bounds
|
||||||
|
let searchStartIndex = 0
|
||||||
|
let searchEndIndex = originalLines.length
|
||||||
|
|
||||||
|
// Validate and handle line range if provided
|
||||||
|
if (startLine && endLine) {
|
||||||
|
// Convert to 0-based index
|
||||||
|
const exactStartIndex = startLine - 1
|
||||||
|
const exactEndIndex = endLine - 1
|
||||||
|
|
||||||
|
if (exactStartIndex < 0 || exactEndIndex > originalLines.length || exactStartIndex > exactEndIndex) {
|
||||||
|
return {
|
||||||
|
success: false,
|
||||||
|
error: `Line range ${startLine}-${endLine} is invalid (file has ${originalLines.length} lines)\n\nDebug Info:\n- Requested Range: lines ${startLine}-${endLine}\n- File Bounds: lines 1-${originalLines.length}`,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try exact match first
|
||||||
|
const originalChunk = originalLines.slice(exactStartIndex, exactEndIndex + 1).join("\n")
|
||||||
|
const similarity = getSimilarity(originalChunk, searchChunk)
|
||||||
|
if (similarity >= this.fuzzyThreshold) {
|
||||||
|
matchIndex = exactStartIndex
|
||||||
|
bestMatchScore = similarity
|
||||||
|
bestMatchContent = originalChunk
|
||||||
|
} else {
|
||||||
|
// Set bounds for buffered search
|
||||||
|
searchStartIndex = Math.max(0, startLine - (this.bufferLines + 1))
|
||||||
|
searchEndIndex = Math.min(originalLines.length, endLine + this.bufferLines)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If no match found yet, try middle-out search within bounds
|
||||||
|
if (matchIndex === -1) {
|
||||||
|
const midPoint = Math.floor((searchStartIndex + searchEndIndex) / 2)
|
||||||
|
let leftIndex = midPoint
|
||||||
|
let rightIndex = midPoint + 1
|
||||||
|
|
||||||
|
// Search outward from the middle within bounds
|
||||||
|
while (leftIndex >= searchStartIndex || rightIndex <= searchEndIndex - searchLines.length) {
|
||||||
|
// Check left side if still in range
|
||||||
|
if (leftIndex >= searchStartIndex) {
|
||||||
|
const originalChunk = originalLines.slice(leftIndex, leftIndex + searchLines.length).join("\n")
|
||||||
|
const similarity = getSimilarity(originalChunk, searchChunk)
|
||||||
|
if (similarity > bestMatchScore) {
|
||||||
|
bestMatchScore = similarity
|
||||||
|
matchIndex = leftIndex
|
||||||
|
bestMatchContent = originalChunk
|
||||||
|
}
|
||||||
|
leftIndex--
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check right side if still in range
|
||||||
|
if (rightIndex <= searchEndIndex - searchLines.length) {
|
||||||
|
const originalChunk = originalLines.slice(rightIndex, rightIndex + searchLines.length).join("\n")
|
||||||
|
const similarity = getSimilarity(originalChunk, searchChunk)
|
||||||
|
if (similarity > bestMatchScore) {
|
||||||
|
bestMatchScore = similarity
|
||||||
|
matchIndex = rightIndex
|
||||||
|
bestMatchContent = originalChunk
|
||||||
|
}
|
||||||
|
rightIndex++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Require similarity to meet threshold
|
||||||
|
if (matchIndex === -1 || bestMatchScore < this.fuzzyThreshold) {
|
||||||
|
const searchChunk = searchLines.join("\n")
|
||||||
|
const originalContentSection =
|
||||||
|
startLine !== undefined && endLine !== undefined
|
||||||
|
? `\n\nOriginal Content:\n${addLineNumbers(
|
||||||
|
originalLines
|
||||||
|
.slice(
|
||||||
|
Math.max(0, startLine - 1 - this.bufferLines),
|
||||||
|
Math.min(originalLines.length, endLine + this.bufferLines),
|
||||||
|
)
|
||||||
|
.join("\n"),
|
||||||
|
Math.max(1, startLine - this.bufferLines),
|
||||||
|
)}`
|
||||||
|
: `\n\nOriginal Content:\n${addLineNumbers(originalLines.join("\n"))}`
|
||||||
|
|
||||||
|
const bestMatchSection = bestMatchContent
|
||||||
|
? `\n\nBest Match Found:\n${addLineNumbers(bestMatchContent, matchIndex + 1)}`
|
||||||
|
: `\n\nBest Match Found:\n(no match)`
|
||||||
|
|
||||||
|
const lineRange =
|
||||||
|
startLine || endLine
|
||||||
|
? ` at ${startLine ? `start: ${startLine}` : "start"} to ${endLine ? `end: ${endLine}` : "end"}`
|
||||||
|
: ""
|
||||||
|
return {
|
||||||
|
success: false,
|
||||||
|
error: `No sufficiently similar match found${lineRange} (${Math.floor(bestMatchScore * 100)}% similar, needs ${Math.floor(this.fuzzyThreshold * 100)}%)\n\nDebug Info:\n- Similarity Score: ${Math.floor(bestMatchScore * 100)}%\n- Required Threshold: ${Math.floor(this.fuzzyThreshold * 100)}%\n- Search Range: ${startLine && endLine ? `lines ${startLine}-${endLine}` : "start to end"}\n- Tip: Use read_file to get the latest content of the file before attempting the diff again, as the file content may have changed\n\nSearch Content:\n${searchChunk}${bestMatchSection}${originalContentSection}`,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the matched lines from the original content
|
||||||
|
const matchedLines = originalLines.slice(matchIndex, matchIndex + searchLines.length)
|
||||||
|
|
||||||
|
// Get the exact indentation (preserving tabs/spaces) of each line
|
||||||
|
const originalIndents = matchedLines.map((line) => {
|
||||||
|
const match = line.match(/^[\t ]*/)
|
||||||
|
return match ? match[0] : ""
|
||||||
|
})
|
||||||
|
|
||||||
|
// Get the exact indentation of each line in the search block
|
||||||
|
const searchIndents = searchLines.map((line) => {
|
||||||
|
const match = line.match(/^[\t ]*/)
|
||||||
|
return match ? match[0] : ""
|
||||||
|
})
|
||||||
|
|
||||||
|
// Apply the replacement while preserving exact indentation
|
||||||
|
const indentedReplaceLines = replaceLines.map((line, i) => {
|
||||||
|
// Get the matched line's exact indentation
|
||||||
|
const matchedIndent = originalIndents[0] || ""
|
||||||
|
|
||||||
|
// Get the current line's indentation relative to the search content
|
||||||
|
const currentIndentMatch = line.match(/^[\t ]*/)
|
||||||
|
const currentIndent = currentIndentMatch ? currentIndentMatch[0] : ""
|
||||||
|
const searchBaseIndent = searchIndents[0] || ""
|
||||||
|
|
||||||
|
// Calculate the relative indentation level
|
||||||
|
const searchBaseLevel = searchBaseIndent.length
|
||||||
|
const currentLevel = currentIndent.length
|
||||||
|
const relativeLevel = currentLevel - searchBaseLevel
|
||||||
|
|
||||||
|
// If relative level is negative, remove indentation from matched indent
|
||||||
|
// If positive, add to matched indent
|
||||||
|
const finalIndent =
|
||||||
|
relativeLevel < 0
|
||||||
|
? matchedIndent.slice(0, Math.max(0, matchedIndent.length + relativeLevel))
|
||||||
|
: matchedIndent + currentIndent.slice(searchBaseLevel)
|
||||||
|
|
||||||
|
return finalIndent + line.trim()
|
||||||
|
})
|
||||||
|
|
||||||
|
// Construct the final content
|
||||||
|
const beforeMatch = originalLines.slice(0, matchIndex)
|
||||||
|
const afterMatch = originalLines.slice(matchIndex + searchLines.length)
|
||||||
|
|
||||||
|
const finalContent = [...beforeMatch, ...indentedReplaceLines, ...afterMatch].join(lineEnding)
|
||||||
|
return {
|
||||||
|
success: true,
|
||||||
|
content: finalContent,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
137
src/core/diff/strategies/unified.ts
Normal file
137
src/core/diff/strategies/unified.ts
Normal file
@ -0,0 +1,137 @@
|
|||||||
|
import { applyPatch } from "diff"
|
||||||
|
import { DiffStrategy, DiffResult } from "../types"
|
||||||
|
|
||||||
|
export class UnifiedDiffStrategy implements DiffStrategy {
|
||||||
|
getToolDescription(args: { cwd: string; toolOptions?: { [key: string]: string } }): string {
|
||||||
|
return `## apply_diff
|
||||||
|
Description: Apply a unified diff to a file at the specified path. This tool is useful when you need to make specific modifications to a file based on a set of changes provided in unified diff format (diff -U3).
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- path: (required) The path of the file to apply the diff to (relative to the current working directory ${args.cwd})
|
||||||
|
- diff: (required) The diff content in unified format to apply to the file.
|
||||||
|
|
||||||
|
Format Requirements:
|
||||||
|
|
||||||
|
1. Header (REQUIRED):
|
||||||
|
\`\`\`
|
||||||
|
--- path/to/original/file
|
||||||
|
+++ path/to/modified/file
|
||||||
|
\`\`\`
|
||||||
|
- Must include both lines exactly as shown
|
||||||
|
- Use actual file paths
|
||||||
|
- NO timestamps after paths
|
||||||
|
|
||||||
|
2. Hunks:
|
||||||
|
\`\`\`
|
||||||
|
@@ -lineStart,lineCount +lineStart,lineCount @@
|
||||||
|
-removed line
|
||||||
|
+added line
|
||||||
|
\`\`\`
|
||||||
|
- Each hunk starts with @@ showing line numbers for changes
|
||||||
|
- Format: @@ -originalStart,originalCount +newStart,newCount @@
|
||||||
|
- Use - for removed/changed lines
|
||||||
|
- Use + for new/modified lines
|
||||||
|
- Indentation must match exactly
|
||||||
|
|
||||||
|
Complete Example:
|
||||||
|
|
||||||
|
Original file (with line numbers):
|
||||||
|
\`\`\`
|
||||||
|
1 | import { Logger } from '../logger';
|
||||||
|
2 |
|
||||||
|
3 | function calculateTotal(items: number[]): number {
|
||||||
|
4 | return items.reduce((sum, item) => {
|
||||||
|
5 | return sum + item;
|
||||||
|
6 | }, 0);
|
||||||
|
7 | }
|
||||||
|
8 |
|
||||||
|
9 | export { calculateTotal };
|
||||||
|
\`\`\`
|
||||||
|
|
||||||
|
After applying the diff, the file would look like:
|
||||||
|
\`\`\`
|
||||||
|
1 | import { Logger } from '../logger';
|
||||||
|
2 |
|
||||||
|
3 | function calculateTotal(items: number[]): number {
|
||||||
|
4 | const total = items.reduce((sum, item) => {
|
||||||
|
5 | return sum + item * 1.1; // Add 10% markup
|
||||||
|
6 | }, 0);
|
||||||
|
7 | return Math.round(total * 100) / 100; // Round to 2 decimal places
|
||||||
|
8 | }
|
||||||
|
9 |
|
||||||
|
10 | export { calculateTotal };
|
||||||
|
\`\`\`
|
||||||
|
|
||||||
|
Diff to modify the file:
|
||||||
|
\`\`\`
|
||||||
|
--- src/utils/helper.ts
|
||||||
|
+++ src/utils/helper.ts
|
||||||
|
@@ -1,9 +1,10 @@
|
||||||
|
import { Logger } from '../logger';
|
||||||
|
|
||||||
|
function calculateTotal(items: number[]): number {
|
||||||
|
- return items.reduce((sum, item) => {
|
||||||
|
- return sum + item;
|
||||||
|
+ const total = items.reduce((sum, item) => {
|
||||||
|
+ return sum + item * 1.1; // Add 10% markup
|
||||||
|
}, 0);
|
||||||
|
+ return Math.round(total * 100) / 100; // Round to 2 decimal places
|
||||||
|
}
|
||||||
|
|
||||||
|
export { calculateTotal };
|
||||||
|
\`\`\`
|
||||||
|
|
||||||
|
Common Pitfalls:
|
||||||
|
1. Missing or incorrect header lines
|
||||||
|
2. Incorrect line numbers in @@ lines
|
||||||
|
3. Wrong indentation in changed lines
|
||||||
|
4. Incomplete context (missing lines that need changing)
|
||||||
|
5. Not marking all modified lines with - and +
|
||||||
|
|
||||||
|
Best Practices:
|
||||||
|
1. Replace entire code blocks:
|
||||||
|
- Remove complete old version with - lines
|
||||||
|
- Add complete new version with + lines
|
||||||
|
- Include correct line numbers
|
||||||
|
2. Moving code requires two hunks:
|
||||||
|
- First hunk: Remove from old location
|
||||||
|
- Second hunk: Add to new location
|
||||||
|
3. One hunk per logical change
|
||||||
|
4. Verify line numbers match the line numbers you have in the file
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
<apply_diff>
|
||||||
|
<path>File path here</path>
|
||||||
|
<diff>
|
||||||
|
Your diff here
|
||||||
|
</diff>
|
||||||
|
</apply_diff>`
|
||||||
|
}
|
||||||
|
|
||||||
|
async applyDiff(originalContent: string, diffContent: string): Promise<DiffResult> {
|
||||||
|
try {
|
||||||
|
const result = applyPatch(originalContent, diffContent)
|
||||||
|
if (result === false) {
|
||||||
|
return {
|
||||||
|
success: false,
|
||||||
|
error: "Failed to apply unified diff - patch rejected",
|
||||||
|
details: {
|
||||||
|
searchContent: diffContent,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
success: true,
|
||||||
|
content: result,
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
return {
|
||||||
|
success: false,
|
||||||
|
error: `Error applying unified diff: ${error.message}`,
|
||||||
|
details: {
|
||||||
|
searchContent: diffContent,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
36
src/core/diff/types.ts
Normal file
36
src/core/diff/types.ts
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
/**
|
||||||
|
* Interface for implementing different diff strategies
|
||||||
|
*/
|
||||||
|
|
||||||
|
export type DiffResult =
|
||||||
|
| { success: true; content: string }
|
||||||
|
| {
|
||||||
|
success: false
|
||||||
|
error: string
|
||||||
|
details?: {
|
||||||
|
similarity?: number
|
||||||
|
threshold?: number
|
||||||
|
matchedRange?: { start: number; end: number }
|
||||||
|
searchContent?: string
|
||||||
|
bestMatch?: string
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface DiffStrategy {
|
||||||
|
/**
|
||||||
|
* Get the tool description for this diff strategy
|
||||||
|
* @param args The tool arguments including cwd and toolOptions
|
||||||
|
* @returns The complete tool description including format requirements and examples
|
||||||
|
*/
|
||||||
|
getToolDescription(args: { cwd: string; toolOptions?: { [key: string]: string } }): string
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Apply a diff to the original content
|
||||||
|
* @param originalContent The original file content
|
||||||
|
* @param diffContent The diff content in the strategy's format
|
||||||
|
* @param startLine Optional line number where the search block starts. If not provided, searches the entire file.
|
||||||
|
* @param endLine Optional line number where the search block ends. If not provided, searches the entire file.
|
||||||
|
* @returns A DiffResult object containing either the successful result or error details
|
||||||
|
*/
|
||||||
|
applyDiff(originalContent: string, diffContent: string, startLine?: number, endLine?: number): Promise<DiffResult>
|
||||||
|
}
|
||||||
746
src/core/mcp/McpHub.ts
Normal file
746
src/core/mcp/McpHub.ts
Normal file
@ -0,0 +1,746 @@
|
|||||||
|
import { Client } from "@modelcontextprotocol/sdk/client/index.js"
|
||||||
|
import { StdioClientTransport, StdioServerParameters } from "@modelcontextprotocol/sdk/client/stdio.js"
|
||||||
|
import {
|
||||||
|
CallToolResultSchema,
|
||||||
|
ListResourcesResultSchema,
|
||||||
|
ListResourceTemplatesResultSchema,
|
||||||
|
ListToolsResultSchema,
|
||||||
|
ReadResourceResultSchema,
|
||||||
|
} from "@modelcontextprotocol/sdk/types.js"
|
||||||
|
import chokidar, { FSWatcher } from "chokidar"
|
||||||
|
import delay from "delay"
|
||||||
|
import deepEqual from "fast-deep-equal"
|
||||||
|
import * as fs from "fs/promises"
|
||||||
|
import * as path from "path"
|
||||||
|
import * as vscode from "vscode"
|
||||||
|
import { z } from "zod"
|
||||||
|
|
||||||
|
import { ClineProvider } from "../../core/webview/ClineProvider"
|
||||||
|
import { GlobalFileNames } from "../../shared/globalFileNames"
|
||||||
|
import {
|
||||||
|
McpResource,
|
||||||
|
McpResourceResponse,
|
||||||
|
McpResourceTemplate,
|
||||||
|
McpServer,
|
||||||
|
McpTool,
|
||||||
|
McpToolCallResponse,
|
||||||
|
} from "../../shared/mcp"
|
||||||
|
import { fileExistsAtPath } from "../../utils/fs"
|
||||||
|
import { arePathsEqual } from "../../utils/path"
|
||||||
|
|
||||||
|
export type McpConnection = {
|
||||||
|
server: McpServer
|
||||||
|
client: Client
|
||||||
|
transport: StdioClientTransport
|
||||||
|
}
|
||||||
|
|
||||||
|
// StdioServerParameters
|
||||||
|
const AlwaysAllowSchema = z.array(z.string()).default([])
|
||||||
|
|
||||||
|
export const StdioConfigSchema = z.object({
|
||||||
|
command: z.string(),
|
||||||
|
args: z.array(z.string()).optional(),
|
||||||
|
env: z.record(z.string()).optional(),
|
||||||
|
alwaysAllow: AlwaysAllowSchema.optional(),
|
||||||
|
disabled: z.boolean().optional(),
|
||||||
|
timeout: z.number().min(1).max(3600).optional().default(60),
|
||||||
|
})
|
||||||
|
|
||||||
|
const McpSettingsSchema = z.object({
|
||||||
|
mcpServers: z.record(StdioConfigSchema),
|
||||||
|
})
|
||||||
|
|
||||||
|
export class McpHub {
|
||||||
|
private providerRef: WeakRef<ClineProvider>
|
||||||
|
private disposables: vscode.Disposable[] = []
|
||||||
|
private settingsWatcher?: vscode.FileSystemWatcher
|
||||||
|
private fileWatchers: Map<string, FSWatcher> = new Map()
|
||||||
|
connections: McpConnection[] = []
|
||||||
|
isConnecting: boolean = false
|
||||||
|
|
||||||
|
constructor(provider: ClineProvider) {
|
||||||
|
this.providerRef = new WeakRef(provider)
|
||||||
|
this.watchMcpSettingsFile()
|
||||||
|
this.initializeMcpServers()
|
||||||
|
}
|
||||||
|
|
||||||
|
getServers(): McpServer[] {
|
||||||
|
// Only return enabled servers
|
||||||
|
return this.connections.filter((conn) => !conn.server.disabled).map((conn) => conn.server)
|
||||||
|
}
|
||||||
|
|
||||||
|
getAllServers(): McpServer[] {
|
||||||
|
// Return all servers regardless of state
|
||||||
|
return this.connections.map((conn) => conn.server)
|
||||||
|
}
|
||||||
|
|
||||||
|
async getMcpServersPath(): Promise<string> {
|
||||||
|
const provider = this.providerRef.deref()
|
||||||
|
if (!provider) {
|
||||||
|
throw new Error("Provider not available")
|
||||||
|
}
|
||||||
|
const mcpServersPath = await provider.ensureMcpServersDirectoryExists()
|
||||||
|
return mcpServersPath
|
||||||
|
}
|
||||||
|
|
||||||
|
async getMcpSettingsFilePath(): Promise<string> {
|
||||||
|
const provider = this.providerRef.deref()
|
||||||
|
if (!provider) {
|
||||||
|
throw new Error("Provider not available")
|
||||||
|
}
|
||||||
|
const mcpSettingsFilePath = path.join(
|
||||||
|
await provider.ensureSettingsDirectoryExists(),
|
||||||
|
GlobalFileNames.mcpSettings,
|
||||||
|
)
|
||||||
|
const fileExists = await fileExistsAtPath(mcpSettingsFilePath)
|
||||||
|
if (!fileExists) {
|
||||||
|
await fs.writeFile(
|
||||||
|
mcpSettingsFilePath,
|
||||||
|
`{
|
||||||
|
"mcpServers": {
|
||||||
|
|
||||||
|
}
|
||||||
|
}`,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
return mcpSettingsFilePath
|
||||||
|
}
|
||||||
|
|
||||||
|
private async watchMcpSettingsFile(): Promise<void> {
|
||||||
|
const settingsPath = await this.getMcpSettingsFilePath()
|
||||||
|
this.disposables.push(
|
||||||
|
vscode.workspace.onDidSaveTextDocument(async (document) => {
|
||||||
|
if (arePathsEqual(document.uri.fsPath, settingsPath)) {
|
||||||
|
const content = await fs.readFile(settingsPath, "utf-8")
|
||||||
|
const errorMessage =
|
||||||
|
"Invalid MCP settings format. Please ensure your settings follow the correct JSON format."
|
||||||
|
let config: any
|
||||||
|
try {
|
||||||
|
config = JSON.parse(content)
|
||||||
|
} catch (error) {
|
||||||
|
vscode.window.showErrorMessage(errorMessage)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
const result = McpSettingsSchema.safeParse(config)
|
||||||
|
if (!result.success) {
|
||||||
|
vscode.window.showErrorMessage(errorMessage)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
await this.updateServerConnections(result.data.mcpServers || {})
|
||||||
|
} catch (error) {
|
||||||
|
console.error("Failed to process MCP settings change:", error)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
private async initializeMcpServers(): Promise<void> {
|
||||||
|
try {
|
||||||
|
const settingsPath = await this.getMcpSettingsFilePath()
|
||||||
|
const content = await fs.readFile(settingsPath, "utf-8")
|
||||||
|
const config = JSON.parse(content)
|
||||||
|
await this.updateServerConnections(config.mcpServers || {})
|
||||||
|
} catch (error) {
|
||||||
|
console.error("Failed to initialize MCP servers:", error)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private async connectToServer(name: string, config: StdioServerParameters): Promise<void> {
|
||||||
|
// Remove existing connection if it exists (should never happen, the connection should be deleted beforehand)
|
||||||
|
this.connections = this.connections.filter((conn) => conn.server.name !== name)
|
||||||
|
|
||||||
|
try {
|
||||||
|
// Each MCP server requires its own transport connection and has unique capabilities, configurations, and error handling. Having separate clients also allows proper scoping of resources/tools and independent server management like reconnection.
|
||||||
|
const client = new Client(
|
||||||
|
{
|
||||||
|
name: "Roo Code",
|
||||||
|
version: this.providerRef.deref()?.context.extension?.packageJSON?.version ?? "1.0.0",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
capabilities: {},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
const transport = new StdioClientTransport({
|
||||||
|
command: config.command,
|
||||||
|
args: config.args,
|
||||||
|
env: {
|
||||||
|
...config.env,
|
||||||
|
...(process.env.PATH ? { PATH: process.env.PATH } : {}),
|
||||||
|
// ...(process.env.NODE_PATH ? { NODE_PATH: process.env.NODE_PATH } : {}),
|
||||||
|
},
|
||||||
|
stderr: "pipe", // necessary for stderr to be available
|
||||||
|
})
|
||||||
|
|
||||||
|
transport.onerror = async (error) => {
|
||||||
|
console.error(`Transport error for "${name}":`, error)
|
||||||
|
const connection = this.connections.find((conn) => conn.server.name === name)
|
||||||
|
if (connection) {
|
||||||
|
connection.server.status = "disconnected"
|
||||||
|
this.appendErrorMessage(connection, error.message)
|
||||||
|
}
|
||||||
|
await this.notifyWebviewOfServerChanges()
|
||||||
|
}
|
||||||
|
|
||||||
|
transport.onclose = async () => {
|
||||||
|
const connection = this.connections.find((conn) => conn.server.name === name)
|
||||||
|
if (connection) {
|
||||||
|
connection.server.status = "disconnected"
|
||||||
|
}
|
||||||
|
await this.notifyWebviewOfServerChanges()
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the config is invalid, show an error
|
||||||
|
if (!StdioConfigSchema.safeParse(config).success) {
|
||||||
|
console.error(`Invalid config for "${name}": missing or invalid parameters`)
|
||||||
|
const connection: McpConnection = {
|
||||||
|
server: {
|
||||||
|
name,
|
||||||
|
config: JSON.stringify(config),
|
||||||
|
status: "disconnected",
|
||||||
|
error: "Invalid config: missing or invalid parameters",
|
||||||
|
},
|
||||||
|
client,
|
||||||
|
transport,
|
||||||
|
}
|
||||||
|
this.connections.push(connection)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// valid schema
|
||||||
|
const parsedConfig = StdioConfigSchema.parse(config)
|
||||||
|
const connection: McpConnection = {
|
||||||
|
server: {
|
||||||
|
name,
|
||||||
|
config: JSON.stringify(config),
|
||||||
|
status: "connecting",
|
||||||
|
disabled: parsedConfig.disabled,
|
||||||
|
},
|
||||||
|
client,
|
||||||
|
transport,
|
||||||
|
}
|
||||||
|
this.connections.push(connection)
|
||||||
|
|
||||||
|
// transport.stderr is only available after the process has been started. However we can't start it separately from the .connect() call because it also starts the transport. And we can't place this after the connect call since we need to capture the stderr stream before the connection is established, in order to capture errors during the connection process.
|
||||||
|
// As a workaround, we start the transport ourselves, and then monkey-patch the start method to no-op so that .connect() doesn't try to start it again.
|
||||||
|
await transport.start()
|
||||||
|
const stderrStream = transport.stderr
|
||||||
|
if (stderrStream) {
|
||||||
|
stderrStream.on("data", async (data: Buffer) => {
|
||||||
|
const errorOutput = data.toString()
|
||||||
|
console.error(`Server "${name}" stderr:`, errorOutput)
|
||||||
|
const connection = this.connections.find((conn) => conn.server.name === name)
|
||||||
|
if (connection) {
|
||||||
|
// NOTE: we do not set server status to "disconnected" because stderr logs do not necessarily mean the server crashed or disconnected, it could just be informational. In fact when the server first starts up, it immediately logs "<name> server running on stdio" to stderr.
|
||||||
|
this.appendErrorMessage(connection, errorOutput)
|
||||||
|
// Only need to update webview right away if it's already disconnected
|
||||||
|
if (connection.server.status === "disconnected") {
|
||||||
|
await this.notifyWebviewOfServerChanges()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
console.error(`No stderr stream for ${name}`)
|
||||||
|
}
|
||||||
|
transport.start = async () => {} // No-op now, .connect() won't fail
|
||||||
|
|
||||||
|
// Connect
|
||||||
|
await client.connect(transport)
|
||||||
|
connection.server.status = "connected"
|
||||||
|
connection.server.error = ""
|
||||||
|
|
||||||
|
// Initial fetch of tools and resources
|
||||||
|
connection.server.tools = await this.fetchToolsList(name)
|
||||||
|
connection.server.resources = await this.fetchResourcesList(name)
|
||||||
|
connection.server.resourceTemplates = await this.fetchResourceTemplatesList(name)
|
||||||
|
} catch (error) {
|
||||||
|
// Update status with error
|
||||||
|
const connection = this.connections.find((conn) => conn.server.name === name)
|
||||||
|
if (connection) {
|
||||||
|
connection.server.status = "disconnected"
|
||||||
|
this.appendErrorMessage(connection, error instanceof Error ? error.message : String(error))
|
||||||
|
}
|
||||||
|
throw error
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private appendErrorMessage(connection: McpConnection, error: string) {
|
||||||
|
const newError = connection.server.error ? `${connection.server.error}\n${error}` : error
|
||||||
|
connection.server.error = newError //.slice(0, 800)
|
||||||
|
}
|
||||||
|
|
||||||
|
private async fetchToolsList(serverName: string): Promise<McpTool[]> {
|
||||||
|
try {
|
||||||
|
const response = await this.connections
|
||||||
|
.find((conn) => conn.server.name === serverName)
|
||||||
|
?.client.request({ method: "tools/list" }, ListToolsResultSchema)
|
||||||
|
|
||||||
|
// Get always allow settings
|
||||||
|
const settingsPath = await this.getMcpSettingsFilePath()
|
||||||
|
const content = await fs.readFile(settingsPath, "utf-8")
|
||||||
|
const config = JSON.parse(content)
|
||||||
|
const alwaysAllowConfig = config.mcpServers[serverName]?.alwaysAllow || []
|
||||||
|
|
||||||
|
// Mark tools as always allowed based on settings
|
||||||
|
const tools = (response?.tools || []).map((tool) => ({
|
||||||
|
...tool,
|
||||||
|
alwaysAllow: alwaysAllowConfig.includes(tool.name),
|
||||||
|
}))
|
||||||
|
|
||||||
|
console.log(`[MCP] Fetched tools for ${serverName}:`, tools)
|
||||||
|
return tools
|
||||||
|
} catch (error) {
|
||||||
|
// console.error(`Failed to fetch tools for ${serverName}:`, error)
|
||||||
|
return []
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private async fetchResourcesList(serverName: string): Promise<McpResource[]> {
|
||||||
|
try {
|
||||||
|
const response = await this.connections
|
||||||
|
.find((conn) => conn.server.name === serverName)
|
||||||
|
?.client.request({ method: "resources/list" }, ListResourcesResultSchema)
|
||||||
|
return response?.resources || []
|
||||||
|
} catch (error) {
|
||||||
|
// console.error(`Failed to fetch resources for ${serverName}:`, error)
|
||||||
|
return []
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private async fetchResourceTemplatesList(serverName: string): Promise<McpResourceTemplate[]> {
|
||||||
|
try {
|
||||||
|
const response = await this.connections
|
||||||
|
.find((conn) => conn.server.name === serverName)
|
||||||
|
?.client.request({ method: "resources/templates/list" }, ListResourceTemplatesResultSchema)
|
||||||
|
return response?.resourceTemplates || []
|
||||||
|
} catch (error) {
|
||||||
|
// console.error(`Failed to fetch resource templates for ${serverName}:`, error)
|
||||||
|
return []
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async deleteConnection(name: string): Promise<void> {
|
||||||
|
const connection = this.connections.find((conn) => conn.server.name === name)
|
||||||
|
if (connection) {
|
||||||
|
try {
|
||||||
|
await connection.transport.close()
|
||||||
|
await connection.client.close()
|
||||||
|
} catch (error) {
|
||||||
|
console.error(`Failed to close transport for ${name}:`, error)
|
||||||
|
}
|
||||||
|
this.connections = this.connections.filter((conn) => conn.server.name !== name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async updateServerConnections(newServers: Record<string, any>): Promise<void> {
|
||||||
|
this.isConnecting = true
|
||||||
|
this.removeAllFileWatchers()
|
||||||
|
const currentNames = new Set(this.connections.map((conn) => conn.server.name))
|
||||||
|
const newNames = new Set(Object.keys(newServers))
|
||||||
|
|
||||||
|
// Delete removed servers
|
||||||
|
for (const name of currentNames) {
|
||||||
|
if (!newNames.has(name)) {
|
||||||
|
await this.deleteConnection(name)
|
||||||
|
console.log(`Deleted MCP server: ${name}`)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update or add servers
|
||||||
|
for (const [name, config] of Object.entries(newServers)) {
|
||||||
|
const currentConnection = this.connections.find((conn) => conn.server.name === name)
|
||||||
|
|
||||||
|
if (!currentConnection) {
|
||||||
|
// New server
|
||||||
|
try {
|
||||||
|
this.setupFileWatcher(name, config)
|
||||||
|
await this.connectToServer(name, config)
|
||||||
|
} catch (error) {
|
||||||
|
console.error(`Failed to connect to new MCP server ${name}:`, error)
|
||||||
|
}
|
||||||
|
} else if (!deepEqual(JSON.parse(currentConnection.server.config), config)) {
|
||||||
|
// Existing server with changed config
|
||||||
|
try {
|
||||||
|
this.setupFileWatcher(name, config)
|
||||||
|
await this.deleteConnection(name)
|
||||||
|
await this.connectToServer(name, config)
|
||||||
|
console.log(`Reconnected MCP server with updated config: ${name}`)
|
||||||
|
} catch (error) {
|
||||||
|
console.error(`Failed to reconnect MCP server ${name}:`, error)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// If server exists with same config, do nothing
|
||||||
|
}
|
||||||
|
await this.notifyWebviewOfServerChanges()
|
||||||
|
this.isConnecting = false
|
||||||
|
}
|
||||||
|
|
||||||
|
private setupFileWatcher(name: string, config: any) {
|
||||||
|
const filePath = config.args?.find((arg: string) => arg.includes("build/index.js"))
|
||||||
|
if (filePath) {
|
||||||
|
// we use chokidar instead of onDidSaveTextDocument because it doesn't require the file to be open in the editor. The settings config is better suited for onDidSave since that will be manually updated by the user or Cline (and we want to detect save events, not every file change)
|
||||||
|
const watcher = chokidar.watch(filePath, {
|
||||||
|
// persistent: true,
|
||||||
|
// ignoreInitial: true,
|
||||||
|
// awaitWriteFinish: true, // This helps with atomic writes
|
||||||
|
})
|
||||||
|
|
||||||
|
watcher.on("change", () => {
|
||||||
|
console.log(`Detected change in ${filePath}. Restarting server ${name}...`)
|
||||||
|
this.restartConnection(name)
|
||||||
|
})
|
||||||
|
|
||||||
|
this.fileWatchers.set(name, watcher)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private removeAllFileWatchers() {
|
||||||
|
this.fileWatchers.forEach((watcher) => watcher.close())
|
||||||
|
this.fileWatchers.clear()
|
||||||
|
}
|
||||||
|
|
||||||
|
async restartConnection(serverName: string): Promise<void> {
|
||||||
|
this.isConnecting = true
|
||||||
|
const provider = this.providerRef.deref()
|
||||||
|
if (!provider) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get existing connection and update its status
|
||||||
|
const connection = this.connections.find((conn) => conn.server.name === serverName)
|
||||||
|
const config = connection?.server.config
|
||||||
|
if (config) {
|
||||||
|
vscode.window.showInformationMessage(`Restarting ${serverName} MCP server...`)
|
||||||
|
connection.server.status = "connecting"
|
||||||
|
connection.server.error = ""
|
||||||
|
await this.notifyWebviewOfServerChanges()
|
||||||
|
await delay(500) // artificial delay to show user that server is restarting
|
||||||
|
try {
|
||||||
|
await this.deleteConnection(serverName)
|
||||||
|
// Try to connect again using existing config
|
||||||
|
await this.connectToServer(serverName, JSON.parse(config))
|
||||||
|
vscode.window.showInformationMessage(`${serverName} MCP server connected`)
|
||||||
|
} catch (error) {
|
||||||
|
console.error(`Failed to restart connection for ${serverName}:`, error)
|
||||||
|
vscode.window.showErrorMessage(`Failed to connect to ${serverName} MCP server`)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
await this.notifyWebviewOfServerChanges()
|
||||||
|
this.isConnecting = false
|
||||||
|
}
|
||||||
|
|
||||||
|
private async notifyWebviewOfServerChanges(): Promise<void> {
|
||||||
|
// servers should always be sorted in the order they are defined in the settings file
|
||||||
|
const settingsPath = await this.getMcpSettingsFilePath()
|
||||||
|
const content = await fs.readFile(settingsPath, "utf-8")
|
||||||
|
const config = JSON.parse(content)
|
||||||
|
const serverOrder = Object.keys(config.mcpServers || {})
|
||||||
|
await this.providerRef.deref()?.postMessageToWebview({
|
||||||
|
type: "mcpServers",
|
||||||
|
mcpServers: [...this.connections]
|
||||||
|
.sort((a, b) => {
|
||||||
|
const indexA = serverOrder.indexOf(a.server.name)
|
||||||
|
const indexB = serverOrder.indexOf(b.server.name)
|
||||||
|
return indexA - indexB
|
||||||
|
})
|
||||||
|
.map((connection) => connection.server),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
public async toggleServerDisabled(serverName: string, disabled: boolean): Promise<void> {
|
||||||
|
let settingsPath: string
|
||||||
|
try {
|
||||||
|
settingsPath = await this.getMcpSettingsFilePath()
|
||||||
|
|
||||||
|
// Ensure the settings file exists and is accessible
|
||||||
|
try {
|
||||||
|
await fs.access(settingsPath)
|
||||||
|
} catch (error) {
|
||||||
|
console.error("Settings file not accessible:", error)
|
||||||
|
throw new Error("Settings file not accessible")
|
||||||
|
}
|
||||||
|
const content = await fs.readFile(settingsPath, "utf-8")
|
||||||
|
const config = JSON.parse(content)
|
||||||
|
|
||||||
|
// Validate the config structure
|
||||||
|
if (!config || typeof config !== "object") {
|
||||||
|
throw new Error("Invalid config structure")
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!config.mcpServers || typeof config.mcpServers !== "object") {
|
||||||
|
config.mcpServers = {}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (config.mcpServers[serverName]) {
|
||||||
|
// Create a new server config object to ensure clean structure
|
||||||
|
const serverConfig = {
|
||||||
|
...config.mcpServers[serverName],
|
||||||
|
disabled,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure required fields exist
|
||||||
|
if (!serverConfig.alwaysAllow) {
|
||||||
|
serverConfig.alwaysAllow = []
|
||||||
|
}
|
||||||
|
|
||||||
|
config.mcpServers[serverName] = serverConfig
|
||||||
|
|
||||||
|
// Write the entire config back
|
||||||
|
const updatedConfig = {
|
||||||
|
mcpServers: config.mcpServers,
|
||||||
|
}
|
||||||
|
|
||||||
|
await fs.writeFile(settingsPath, JSON.stringify(updatedConfig, null, 2))
|
||||||
|
|
||||||
|
const connection = this.connections.find((conn) => conn.server.name === serverName)
|
||||||
|
if (connection) {
|
||||||
|
try {
|
||||||
|
connection.server.disabled = disabled
|
||||||
|
|
||||||
|
// Only refresh capabilities if connected
|
||||||
|
if (connection.server.status === "connected") {
|
||||||
|
connection.server.tools = await this.fetchToolsList(serverName)
|
||||||
|
connection.server.resources = await this.fetchResourcesList(serverName)
|
||||||
|
connection.server.resourceTemplates = await this.fetchResourceTemplatesList(serverName)
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error(`Failed to refresh capabilities for ${serverName}:`, error)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
await this.notifyWebviewOfServerChanges()
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error("Failed to update server disabled state:", error)
|
||||||
|
if (error instanceof Error) {
|
||||||
|
console.error("Error details:", error.message, error.stack)
|
||||||
|
}
|
||||||
|
vscode.window.showErrorMessage(
|
||||||
|
`Failed to update server state: ${error instanceof Error ? error.message : String(error)}`,
|
||||||
|
)
|
||||||
|
throw error
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public async updateServerTimeout(serverName: string, timeout: number): Promise<void> {
|
||||||
|
let settingsPath: string
|
||||||
|
try {
|
||||||
|
settingsPath = await this.getMcpSettingsFilePath()
|
||||||
|
|
||||||
|
// Ensure the settings file exists and is accessible
|
||||||
|
try {
|
||||||
|
await fs.access(settingsPath)
|
||||||
|
} catch (error) {
|
||||||
|
console.error("Settings file not accessible:", error)
|
||||||
|
throw new Error("Settings file not accessible")
|
||||||
|
}
|
||||||
|
const content = await fs.readFile(settingsPath, "utf-8")
|
||||||
|
const config = JSON.parse(content)
|
||||||
|
|
||||||
|
// Validate the config structure
|
||||||
|
if (!config || typeof config !== "object") {
|
||||||
|
throw new Error("Invalid config structure")
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!config.mcpServers || typeof config.mcpServers !== "object") {
|
||||||
|
config.mcpServers = {}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (config.mcpServers[serverName]) {
|
||||||
|
// Create a new server config object to ensure clean structure
|
||||||
|
const serverConfig = {
|
||||||
|
...config.mcpServers[serverName],
|
||||||
|
timeout,
|
||||||
|
}
|
||||||
|
|
||||||
|
config.mcpServers[serverName] = serverConfig
|
||||||
|
|
||||||
|
// Write the entire config back
|
||||||
|
const updatedConfig = {
|
||||||
|
mcpServers: config.mcpServers,
|
||||||
|
}
|
||||||
|
|
||||||
|
await fs.writeFile(settingsPath, JSON.stringify(updatedConfig, null, 2))
|
||||||
|
await this.notifyWebviewOfServerChanges()
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error("Failed to update server timeout:", error)
|
||||||
|
if (error instanceof Error) {
|
||||||
|
console.error("Error details:", error.message, error.stack)
|
||||||
|
}
|
||||||
|
vscode.window.showErrorMessage(
|
||||||
|
`Failed to update server timeout: ${error instanceof Error ? error.message : String(error)}`,
|
||||||
|
)
|
||||||
|
throw error
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public async deleteServer(serverName: string): Promise<void> {
|
||||||
|
try {
|
||||||
|
const settingsPath = await this.getMcpSettingsFilePath()
|
||||||
|
|
||||||
|
// Ensure the settings file exists and is accessible
|
||||||
|
try {
|
||||||
|
await fs.access(settingsPath)
|
||||||
|
} catch (error) {
|
||||||
|
throw new Error("Settings file not accessible")
|
||||||
|
}
|
||||||
|
|
||||||
|
const content = await fs.readFile(settingsPath, "utf-8")
|
||||||
|
const config = JSON.parse(content)
|
||||||
|
|
||||||
|
// Validate the config structure
|
||||||
|
if (!config || typeof config !== "object") {
|
||||||
|
throw new Error("Invalid config structure")
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!config.mcpServers || typeof config.mcpServers !== "object") {
|
||||||
|
config.mcpServers = {}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove the server from the settings
|
||||||
|
if (config.mcpServers[serverName]) {
|
||||||
|
delete config.mcpServers[serverName]
|
||||||
|
|
||||||
|
// Write the entire config back
|
||||||
|
const updatedConfig = {
|
||||||
|
mcpServers: config.mcpServers,
|
||||||
|
}
|
||||||
|
|
||||||
|
await fs.writeFile(settingsPath, JSON.stringify(updatedConfig, null, 2))
|
||||||
|
|
||||||
|
// Update server connections
|
||||||
|
await this.updateServerConnections(config.mcpServers)
|
||||||
|
|
||||||
|
vscode.window.showInformationMessage(`Deleted MCP server: ${serverName}`)
|
||||||
|
} else {
|
||||||
|
vscode.window.showWarningMessage(`Server "${serverName}" not found in configuration`)
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error("Failed to delete MCP server:", error)
|
||||||
|
if (error instanceof Error) {
|
||||||
|
console.error("Error details:", error.message, error.stack)
|
||||||
|
}
|
||||||
|
vscode.window.showErrorMessage(
|
||||||
|
`Failed to delete MCP server: ${error instanceof Error ? error.message : String(error)}`,
|
||||||
|
)
|
||||||
|
throw error
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async readResource(serverName: string, uri: string): Promise<McpResourceResponse> {
|
||||||
|
const connection = this.connections.find((conn) => conn.server.name === serverName)
|
||||||
|
if (!connection) {
|
||||||
|
throw new Error(`No connection found for server: ${serverName}`)
|
||||||
|
}
|
||||||
|
if (connection.server.disabled) {
|
||||||
|
throw new Error(`Server "${serverName}" is disabled`)
|
||||||
|
}
|
||||||
|
return await connection.client.request(
|
||||||
|
{
|
||||||
|
method: "resources/read",
|
||||||
|
params: {
|
||||||
|
uri,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
ReadResourceResultSchema,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
async callTool(
|
||||||
|
serverName: string,
|
||||||
|
toolName: string,
|
||||||
|
toolArguments?: Record<string, unknown>,
|
||||||
|
): Promise<McpToolCallResponse> {
|
||||||
|
const connection = this.connections.find((conn) => conn.server.name === serverName)
|
||||||
|
if (!connection) {
|
||||||
|
throw new Error(
|
||||||
|
`No connection found for server: ${serverName}. Please make sure to use MCP servers available under 'Connected MCP Servers'.`,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if (connection.server.disabled) {
|
||||||
|
throw new Error(`Server "${serverName}" is disabled and cannot be used`)
|
||||||
|
}
|
||||||
|
|
||||||
|
let timeout: number
|
||||||
|
try {
|
||||||
|
const parsedConfig = StdioConfigSchema.parse(JSON.parse(connection.server.config))
|
||||||
|
timeout = (parsedConfig.timeout ?? 60) * 1000
|
||||||
|
} catch (error) {
|
||||||
|
console.error("Failed to parse server config for timeout:", error)
|
||||||
|
// Default to 60 seconds if parsing fails
|
||||||
|
timeout = 60 * 1000
|
||||||
|
}
|
||||||
|
|
||||||
|
return await connection.client.request(
|
||||||
|
{
|
||||||
|
method: "tools/call",
|
||||||
|
params: {
|
||||||
|
name: toolName,
|
||||||
|
arguments: toolArguments,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
CallToolResultSchema,
|
||||||
|
{
|
||||||
|
timeout,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
async toggleToolAlwaysAllow(serverName: string, toolName: string, shouldAllow: boolean): Promise<void> {
|
||||||
|
try {
|
||||||
|
const settingsPath = await this.getMcpSettingsFilePath()
|
||||||
|
const content = await fs.readFile(settingsPath, "utf-8")
|
||||||
|
const config = JSON.parse(content)
|
||||||
|
|
||||||
|
// Initialize alwaysAllow if it doesn't exist
|
||||||
|
if (!config.mcpServers[serverName].alwaysAllow) {
|
||||||
|
config.mcpServers[serverName].alwaysAllow = []
|
||||||
|
}
|
||||||
|
|
||||||
|
const alwaysAllow = config.mcpServers[serverName].alwaysAllow
|
||||||
|
const toolIndex = alwaysAllow.indexOf(toolName)
|
||||||
|
|
||||||
|
if (shouldAllow && toolIndex === -1) {
|
||||||
|
// Add tool to always allow list
|
||||||
|
alwaysAllow.push(toolName)
|
||||||
|
} else if (!shouldAllow && toolIndex !== -1) {
|
||||||
|
// Remove tool from always allow list
|
||||||
|
alwaysAllow.splice(toolIndex, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write updated config back to file
|
||||||
|
await fs.writeFile(settingsPath, JSON.stringify(config, null, 2))
|
||||||
|
|
||||||
|
// Update the tools list to reflect the change
|
||||||
|
const connection = this.connections.find((conn) => conn.server.name === serverName)
|
||||||
|
if (connection) {
|
||||||
|
connection.server.tools = await this.fetchToolsList(serverName)
|
||||||
|
await this.notifyWebviewOfServerChanges()
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error("Failed to update always allow settings:", error)
|
||||||
|
vscode.window.showErrorMessage("Failed to update always allow settings")
|
||||||
|
throw error // Re-throw to ensure the error is properly handled
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async dispose(): Promise<void> {
|
||||||
|
this.removeAllFileWatchers()
|
||||||
|
for (const connection of this.connections) {
|
||||||
|
try {
|
||||||
|
await this.deleteConnection(connection.server.name)
|
||||||
|
} catch (error) {
|
||||||
|
console.error(`Failed to close connection for ${connection.server.name}:`, error)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
this.connections = []
|
||||||
|
if (this.settingsWatcher) {
|
||||||
|
this.settingsWatcher.dispose()
|
||||||
|
}
|
||||||
|
this.disposables.forEach((d) => d.dispose())
|
||||||
|
}
|
||||||
|
}
|
||||||
83
src/core/mcp/McpServerManager.ts
Normal file
83
src/core/mcp/McpServerManager.ts
Normal file
@ -0,0 +1,83 @@
|
|||||||
|
import * as vscode from "vscode"
|
||||||
|
import { ClineProvider } from "../../core/webview/ClineProvider"
|
||||||
|
import { McpHub } from "./McpHub"
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Singleton manager for MCP server instances.
|
||||||
|
* Ensures only one set of MCP servers runs across all webviews.
|
||||||
|
*/
|
||||||
|
export class McpServerManager {
|
||||||
|
private static instance: McpHub | null = null
|
||||||
|
private static readonly GLOBAL_STATE_KEY = "mcpHubInstanceId"
|
||||||
|
private static providers: Set<ClineProvider> = new Set()
|
||||||
|
private static initializationPromise: Promise<McpHub> | null = null
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the singleton McpHub instance.
|
||||||
|
* Creates a new instance if one doesn't exist.
|
||||||
|
* Thread-safe implementation using a promise-based lock.
|
||||||
|
*/
|
||||||
|
static async getInstance(context: vscode.ExtensionContext, provider: ClineProvider): Promise<McpHub> {
|
||||||
|
// Register the provider
|
||||||
|
this.providers.add(provider)
|
||||||
|
|
||||||
|
// If we already have an instance, return it
|
||||||
|
if (this.instance) {
|
||||||
|
return this.instance
|
||||||
|
}
|
||||||
|
|
||||||
|
// If initialization is in progress, wait for it
|
||||||
|
if (this.initializationPromise) {
|
||||||
|
return this.initializationPromise
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a new initialization promise
|
||||||
|
this.initializationPromise = (async () => {
|
||||||
|
try {
|
||||||
|
// Double-check instance in case it was created while we were waiting
|
||||||
|
if (!this.instance) {
|
||||||
|
this.instance = new McpHub(provider)
|
||||||
|
// Store a unique identifier in global state to track the primary instance
|
||||||
|
await context.globalState.update(this.GLOBAL_STATE_KEY, Date.now().toString())
|
||||||
|
}
|
||||||
|
return this.instance
|
||||||
|
} finally {
|
||||||
|
// Clear the initialization promise after completion or error
|
||||||
|
this.initializationPromise = null
|
||||||
|
}
|
||||||
|
})()
|
||||||
|
|
||||||
|
return this.initializationPromise
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Remove a provider from the tracked set.
|
||||||
|
* This is called when a webview is disposed.
|
||||||
|
*/
|
||||||
|
static unregisterProvider(provider: ClineProvider): void {
|
||||||
|
this.providers.delete(provider)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Notify all registered providers of server state changes.
|
||||||
|
*/
|
||||||
|
static notifyProviders(message: any): void {
|
||||||
|
this.providers.forEach((provider) => {
|
||||||
|
provider.postMessageToWebview(message).catch((error) => {
|
||||||
|
console.error("Failed to notify provider:", error)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Clean up the singleton instance and all its resources.
|
||||||
|
*/
|
||||||
|
static async cleanup(context: vscode.ExtensionContext): Promise<void> {
|
||||||
|
if (this.instance) {
|
||||||
|
await this.instance.dispose()
|
||||||
|
this.instance = null
|
||||||
|
await context.globalState.update(this.GLOBAL_STATE_KEY, undefined)
|
||||||
|
}
|
||||||
|
this.providers.clear()
|
||||||
|
}
|
||||||
|
}
|
||||||
206
src/core/services/ripgrep/index.ts
Normal file
206
src/core/services/ripgrep/index.ts
Normal file
@ -0,0 +1,206 @@
|
|||||||
|
// import * as vscode from "vscode"
|
||||||
|
import * as childProcess from "child_process"
|
||||||
|
import * as fs from "fs"
|
||||||
|
import * as path from "path"
|
||||||
|
import * as readline from "readline"
|
||||||
|
|
||||||
|
const isWindows = /^win/.test(process.platform)
|
||||||
|
const binName = isWindows ? "rg.exe" : "rg"
|
||||||
|
|
||||||
|
interface SearchResult {
|
||||||
|
file: string
|
||||||
|
line: number
|
||||||
|
column: number
|
||||||
|
match: string
|
||||||
|
beforeContext: string[]
|
||||||
|
afterContext: string[]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Constants
|
||||||
|
const MAX_RESULTS = 300
|
||||||
|
const MAX_LINE_LENGTH = 500
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Truncates a line if it exceeds the maximum length
|
||||||
|
* @param line The line to truncate
|
||||||
|
* @param maxLength The maximum allowed length (defaults to MAX_LINE_LENGTH)
|
||||||
|
* @returns The truncated line, or the original line if it's shorter than maxLength
|
||||||
|
*/
|
||||||
|
export function truncateLine(line: string, maxLength: number = MAX_LINE_LENGTH): string {
|
||||||
|
return line.length > maxLength ? line.substring(0, maxLength) + " [truncated...]" : line
|
||||||
|
}
|
||||||
|
|
||||||
|
async function getBinPath(): Promise<string | undefined> {
|
||||||
|
const binPath = path.join("/opt/homebrew/bin/", binName)
|
||||||
|
return (await pathExists(binPath)) ? binPath : undefined
|
||||||
|
}
|
||||||
|
|
||||||
|
async function pathExists(path: string): Promise<boolean> {
|
||||||
|
return new Promise((resolve) => {
|
||||||
|
fs.access(path, (err) => {
|
||||||
|
resolve(err === null)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
async function execRipgrep(bin: string, args: string[]): Promise<string> {
|
||||||
|
return new Promise((resolve, reject) => {
|
||||||
|
const rgProcess = childProcess.spawn(bin, args)
|
||||||
|
// cross-platform alternative to head, which is ripgrep author's recommendation for limiting output.
|
||||||
|
const rl = readline.createInterface({
|
||||||
|
input: rgProcess.stdout,
|
||||||
|
crlfDelay: Infinity, // treat \r\n as a single line break even if it's split across chunks. This ensures consistent behavior across different operating systems.
|
||||||
|
})
|
||||||
|
|
||||||
|
let output = ""
|
||||||
|
let lineCount = 0
|
||||||
|
const maxLines = MAX_RESULTS * 5 // limiting ripgrep output with max lines since there's no other way to limit results. it's okay that we're outputting as json, since we're parsing it line by line and ignore anything that's not part of a match. This assumes each result is at most 5 lines.
|
||||||
|
|
||||||
|
rl.on("line", (line) => {
|
||||||
|
if (lineCount < maxLines) {
|
||||||
|
output += line + "\n"
|
||||||
|
lineCount++
|
||||||
|
} else {
|
||||||
|
rl.close()
|
||||||
|
rgProcess.kill()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
let errorOutput = ""
|
||||||
|
rgProcess.stderr.on("data", (data) => {
|
||||||
|
errorOutput += data.toString()
|
||||||
|
})
|
||||||
|
rl.on("close", () => {
|
||||||
|
if (errorOutput) {
|
||||||
|
reject(new Error(`ripgrep process error: ${errorOutput}`))
|
||||||
|
} else {
|
||||||
|
resolve(output)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
rgProcess.on("error", (error) => {
|
||||||
|
reject(new Error(`ripgrep process error: ${error.message}`))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function regexSearchFiles(
|
||||||
|
directoryPath: string,
|
||||||
|
regex: string,
|
||||||
|
): Promise<string> {
|
||||||
|
const rgPath = await getBinPath()
|
||||||
|
|
||||||
|
if (!rgPath) {
|
||||||
|
throw new Error("Could not find ripgrep binary")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 使用--glob参数排除.obsidian目录
|
||||||
|
const args = [
|
||||||
|
"--json",
|
||||||
|
"-e",
|
||||||
|
regex,
|
||||||
|
"--glob",
|
||||||
|
"!.obsidian/**", // 排除.obsidian目录及其所有子目录
|
||||||
|
"--glob",
|
||||||
|
"!.git/**",
|
||||||
|
"--context",
|
||||||
|
"1",
|
||||||
|
directoryPath
|
||||||
|
]
|
||||||
|
|
||||||
|
let output: string
|
||||||
|
try {
|
||||||
|
output = await execRipgrep(rgPath, args)
|
||||||
|
console.log("output", output)
|
||||||
|
} catch (error) {
|
||||||
|
console.error("Error executing ripgrep:", error)
|
||||||
|
return "No results found"
|
||||||
|
}
|
||||||
|
const results: SearchResult[] = []
|
||||||
|
let currentResult: Partial<SearchResult> | null = null
|
||||||
|
|
||||||
|
output.split("\n").forEach((line) => {
|
||||||
|
if (line) {
|
||||||
|
try {
|
||||||
|
const parsed = JSON.parse(line)
|
||||||
|
if (parsed.type === "match") {
|
||||||
|
if (currentResult) {
|
||||||
|
results.push(currentResult as SearchResult)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Safety check: truncate extremely long lines to prevent excessive output
|
||||||
|
const matchText = parsed.data.lines.text
|
||||||
|
const truncatedMatch = truncateLine(matchText)
|
||||||
|
|
||||||
|
currentResult = {
|
||||||
|
file: parsed.data.path.text,
|
||||||
|
line: parsed.data.line_number,
|
||||||
|
column: parsed.data.submatches[0].start,
|
||||||
|
match: truncatedMatch,
|
||||||
|
beforeContext: [],
|
||||||
|
afterContext: [],
|
||||||
|
}
|
||||||
|
} else if (parsed.type === "context" && currentResult) {
|
||||||
|
// Apply the same truncation logic to context lines
|
||||||
|
const contextText = parsed.data.lines.text
|
||||||
|
const truncatedContext = truncateLine(contextText)
|
||||||
|
|
||||||
|
if (parsed.data.line_number < currentResult.line!) {
|
||||||
|
currentResult.beforeContext!.push(truncatedContext)
|
||||||
|
} else {
|
||||||
|
currentResult.afterContext!.push(truncatedContext)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error("Error parsing ripgrep output:", error)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
if (currentResult) {
|
||||||
|
results.push(currentResult as SearchResult)
|
||||||
|
}
|
||||||
|
|
||||||
|
console.log("results", results)
|
||||||
|
console.log("currentResult", currentResult)
|
||||||
|
|
||||||
|
return formatResults(results, directoryPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
function formatResults(results: SearchResult[], cwd: string): string {
|
||||||
|
const groupedResults: { [key: string]: SearchResult[] } = {}
|
||||||
|
|
||||||
|
let output = ""
|
||||||
|
if (results.length >= MAX_RESULTS) {
|
||||||
|
output += `Showing first ${MAX_RESULTS} of ${MAX_RESULTS}+ results. Use a more specific search if necessary.\n\n`
|
||||||
|
} else {
|
||||||
|
output += `Found ${results.length === 1 ? "1 result" : `${results.length.toLocaleString()} results`}.\n\n`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Group results by file name
|
||||||
|
results.slice(0, MAX_RESULTS).forEach((result) => {
|
||||||
|
const relativeFilePath = path.relative(cwd, result.file)
|
||||||
|
if (!groupedResults[relativeFilePath]) {
|
||||||
|
groupedResults[relativeFilePath] = []
|
||||||
|
}
|
||||||
|
groupedResults[relativeFilePath].push(result)
|
||||||
|
})
|
||||||
|
|
||||||
|
for (const [filePath, fileResults] of Object.entries(groupedResults)) {
|
||||||
|
output += `${filePath.toPosix()}\n│----\n`
|
||||||
|
|
||||||
|
fileResults.forEach((result, index) => {
|
||||||
|
const allLines = [...result.beforeContext, result.match, ...result.afterContext]
|
||||||
|
allLines.forEach((line) => {
|
||||||
|
output += `│${line?.trimEnd() ?? ""}\n`
|
||||||
|
})
|
||||||
|
|
||||||
|
if (index < fileResults.length - 1) {
|
||||||
|
output += "│----\n"
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
output += "│----\n\n"
|
||||||
|
}
|
||||||
|
|
||||||
|
return output.trim()
|
||||||
|
}
|
||||||
0
src/core/services/semantic/index.ts
Normal file
0
src/core/services/semantic/index.ts
Normal file
@ -94,6 +94,7 @@ export class DBManager {
|
|||||||
// return drizzle(this.pgClient)
|
// return drizzle(this.pgClient)
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error('Error loading database:', error)
|
console.error('Error loading database:', error)
|
||||||
|
console.log(this.dbPath)
|
||||||
return null
|
return null
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
import { SerializedLexicalNode } from 'lexical'
|
import { SerializedLexicalNode } from 'lexical'
|
||||||
|
|
||||||
import { SUPPORT_EMBEDDING_SIMENTION } from '../constants'
|
import { SUPPORT_EMBEDDING_SIMENTION } from '../constants'
|
||||||
|
import { ApplyStatus } from '../types/apply'
|
||||||
// import { EmbeddingModelId } from '../types/embedding'
|
// import { EmbeddingModelId } from '../types/embedding'
|
||||||
|
|
||||||
// PostgreSQL column types
|
// PostgreSQL column types
|
||||||
@ -123,6 +124,7 @@ export type Conversation = {
|
|||||||
export type Message = {
|
export type Message = {
|
||||||
id: string // uuid
|
id: string // uuid
|
||||||
conversationId: string // uuid
|
conversationId: string // uuid
|
||||||
|
applyStatus: number
|
||||||
role: 'user' | 'assistant'
|
role: 'user' | 'assistant'
|
||||||
content: string | null
|
content: string | null
|
||||||
reasoningContent?: string | null
|
reasoningContent?: string | null
|
||||||
@ -151,6 +153,7 @@ export type InsertMessage = {
|
|||||||
id: string
|
id: string
|
||||||
conversationId: string
|
conversationId: string
|
||||||
role: 'user' | 'assistant'
|
role: 'user' | 'assistant'
|
||||||
|
apply_status: number
|
||||||
content: string | null
|
content: string | null
|
||||||
reasoningContent?: string | null
|
reasoningContent?: string | null
|
||||||
promptContent?: string | null
|
promptContent?: string | null
|
||||||
@ -163,6 +166,7 @@ export type InsertMessage = {
|
|||||||
export type SelectMessage = {
|
export type SelectMessage = {
|
||||||
id: string // uuid
|
id: string // uuid
|
||||||
conversation_id: string // uuid
|
conversation_id: string // uuid
|
||||||
|
apply_status: number
|
||||||
role: 'user' | 'assistant'
|
role: 'user' | 'assistant'
|
||||||
content: string | null
|
content: string | null
|
||||||
reasoning_content?: string | null
|
reasoning_content?: string | null
|
||||||
|
|||||||
@ -102,20 +102,10 @@ export const migrations: Record<string, SqlMigration> = {
|
|||||||
"updated_at" timestamp DEFAULT now() NOT NULL
|
"updated_at" timestamp DEFAULT now() NOT NULL
|
||||||
);
|
);
|
||||||
|
|
||||||
DO $$
|
|
||||||
BEGIN
|
|
||||||
IF NOT EXISTS (
|
|
||||||
SELECT 1 FROM information_schema.columns
|
|
||||||
WHERE table_name = 'messages'
|
|
||||||
AND column_name = 'reasoning_content'
|
|
||||||
) THEN
|
|
||||||
ALTER TABLE "messages" ADD COLUMN "reasoning_content" text;
|
|
||||||
END IF;
|
|
||||||
END $$;
|
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS "messages" (
|
CREATE TABLE IF NOT EXISTS "messages" (
|
||||||
"id" uuid PRIMARY KEY NOT NULL,
|
"id" uuid PRIMARY KEY NOT NULL,
|
||||||
"conversation_id" uuid NOT NULL REFERENCES "conversations"("id") ON DELETE CASCADE,
|
"conversation_id" uuid NOT NULL REFERENCES "conversations"("id") ON DELETE CASCADE,
|
||||||
|
"apply_status" integer NOT NULL DEFAULT 0,
|
||||||
"role" text NOT NULL,
|
"role" text NOT NULL,
|
||||||
"content" text,
|
"content" text,
|
||||||
"reasoning_content" text,
|
"reasoning_content" text,
|
||||||
|
|||||||
@ -25,6 +25,7 @@ import {
|
|||||||
InfioSettings,
|
InfioSettings,
|
||||||
parseInfioSettings,
|
parseInfioSettings,
|
||||||
} from './types/settings'
|
} from './types/settings'
|
||||||
|
import './utils/path'
|
||||||
import { getMentionableBlockData } from './utils/obsidian'
|
import { getMentionableBlockData } from './utils/obsidian'
|
||||||
|
|
||||||
// Remember to rename these classes and interfaces!
|
// Remember to rename these classes and interfaces!
|
||||||
|
|||||||
66
src/types/apply.ts
Normal file
66
src/types/apply.ts
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
/**
|
||||||
|
* 用于指定插入内容的工具参数
|
||||||
|
*/
|
||||||
|
|
||||||
|
export enum ApplyStatus {
|
||||||
|
Idle = 0,
|
||||||
|
Applied = 1,
|
||||||
|
Failed = 2,
|
||||||
|
Rejected = 3,
|
||||||
|
}
|
||||||
|
|
||||||
|
export type ReadFileToolArgs = {
|
||||||
|
type: 'read_file';
|
||||||
|
filepath?: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export type ListFilesToolArgs = {
|
||||||
|
type: 'list_files';
|
||||||
|
filepath?: string;
|
||||||
|
recursive?: boolean;
|
||||||
|
}
|
||||||
|
|
||||||
|
export type RegexSearchFilesToolArgs = {
|
||||||
|
type: 'regex_search_files';
|
||||||
|
filepath?: string;
|
||||||
|
regex?: string;
|
||||||
|
file_pattern?: string;
|
||||||
|
finish?: boolean;
|
||||||
|
}
|
||||||
|
|
||||||
|
export type SemanticSearchFilesToolArgs = {
|
||||||
|
type: 'semantic_search_files';
|
||||||
|
filepath?: string;
|
||||||
|
query?: string;
|
||||||
|
finish?: boolean;
|
||||||
|
}
|
||||||
|
export type WriteToFileToolArgs = {
|
||||||
|
type: 'write_to_file';
|
||||||
|
filepath?: string;
|
||||||
|
content?: string;
|
||||||
|
startLine?: number;
|
||||||
|
endLine?: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
export type InsertContentToolArgs = {
|
||||||
|
type: 'insert_content';
|
||||||
|
filepath?: string;
|
||||||
|
content?: string;
|
||||||
|
startLine?: number;
|
||||||
|
endLine?: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
export type SearchAndReplaceToolArgs = {
|
||||||
|
type: 'search_and_replace';
|
||||||
|
filepath: string;
|
||||||
|
operations: {
|
||||||
|
search: string;
|
||||||
|
replace: string;
|
||||||
|
startLine?: number;
|
||||||
|
endLine?: number;
|
||||||
|
useRegex?: boolean;
|
||||||
|
ignoreCase?: boolean;
|
||||||
|
regexFlags?: string;
|
||||||
|
}[];
|
||||||
|
}
|
||||||
|
export type ToolArgs = ReadFileToolArgs | WriteToFileToolArgs | InsertContentToolArgs | SearchAndReplaceToolArgs | ListFilesToolArgs | RegexSearchFilesToolArgs | SemanticSearchFilesToolArgs;
|
||||||
@ -2,6 +2,7 @@ import { SerializedEditorState } from 'lexical'
|
|||||||
|
|
||||||
import { SelectVector } from '../database/schema'
|
import { SelectVector } from '../database/schema'
|
||||||
|
|
||||||
|
import { ApplyStatus } from './apply'
|
||||||
import { LLMModel } from './llm/model'
|
import { LLMModel } from './llm/model'
|
||||||
import { ContentPart } from './llm/request'
|
import { ContentPart } from './llm/request'
|
||||||
import { ResponseUsage } from './llm/response'
|
import { ResponseUsage } from './llm/response'
|
||||||
@ -9,6 +10,7 @@ import { Mentionable, SerializedMentionable } from './mentionable'
|
|||||||
|
|
||||||
export type ChatUserMessage = {
|
export type ChatUserMessage = {
|
||||||
role: 'user'
|
role: 'user'
|
||||||
|
applyStatus: ApplyStatus
|
||||||
content: SerializedEditorState | null
|
content: SerializedEditorState | null
|
||||||
promptContent: string | ContentPart[] | null
|
promptContent: string | ContentPart[] | null
|
||||||
id: string
|
id: string
|
||||||
@ -20,6 +22,7 @@ export type ChatUserMessage = {
|
|||||||
|
|
||||||
export type ChatAssistantMessage = {
|
export type ChatAssistantMessage = {
|
||||||
role: 'assistant'
|
role: 'assistant'
|
||||||
|
applyStatus: ApplyStatus
|
||||||
content: string
|
content: string
|
||||||
reasoningContent: string
|
reasoningContent: string
|
||||||
id: string
|
id: string
|
||||||
@ -33,6 +36,7 @@ export type ChatMessage = ChatUserMessage | ChatAssistantMessage
|
|||||||
|
|
||||||
export type SerializedChatUserMessage = {
|
export type SerializedChatUserMessage = {
|
||||||
role: 'user'
|
role: 'user'
|
||||||
|
applyStatus: ApplyStatus
|
||||||
content: SerializedEditorState | null
|
content: SerializedEditorState | null
|
||||||
promptContent: string | ContentPart[] | null
|
promptContent: string | ContentPart[] | null
|
||||||
id: string
|
id: string
|
||||||
@ -44,6 +48,7 @@ export type SerializedChatUserMessage = {
|
|||||||
|
|
||||||
export type SerializedChatAssistantMessage = {
|
export type SerializedChatAssistantMessage = {
|
||||||
role: 'assistant'
|
role: 'assistant'
|
||||||
|
applyStatus: ApplyStatus
|
||||||
content: string
|
content: string
|
||||||
reasoningContent: string
|
reasoningContent: string
|
||||||
id: string
|
id: string
|
||||||
|
|||||||
@ -1,4 +1,6 @@
|
|||||||
import { TFile } from 'obsidian'
|
import { TFile } from 'obsidian';
|
||||||
|
|
||||||
|
import { SearchAndReplaceToolArgs } from '../types/apply';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Applies changes to a file by replacing content within specified line range
|
* Applies changes to a file by replacing content within specified line range
|
||||||
@ -9,36 +11,95 @@ import { TFile } from 'obsidian'
|
|||||||
* @param endLine - Ending line number (1-based indexing, optional)
|
* @param endLine - Ending line number (1-based indexing, optional)
|
||||||
* @returns Promise resolving to the modified content or null if operation fails
|
* @returns Promise resolving to the modified content or null if operation fails
|
||||||
*/
|
*/
|
||||||
export const manualApplyChangesToFile = async (
|
export const ApplyEditToFile = async (
|
||||||
content: string,
|
currentFile: TFile,
|
||||||
currentFile: TFile,
|
currentFileContent: string,
|
||||||
currentFileContent: string,
|
content: string,
|
||||||
startLine?: number,
|
startLine?: number,
|
||||||
endLine?: number,
|
endLine?: number,
|
||||||
): Promise<string | null> => {
|
): Promise<string | null> => {
|
||||||
try {
|
try {
|
||||||
// Input validation
|
// 如果文件为空,直接返回新内容
|
||||||
if (!content || !currentFileContent) {
|
if (!currentFileContent || currentFileContent.trim() === '') {
|
||||||
throw new Error('Content cannot be empty')
|
return content;
|
||||||
}
|
}
|
||||||
|
|
||||||
const lines = currentFileContent.split('\n')
|
// 如果要清空文件,直接返回空字符串
|
||||||
const effectiveStartLine = Math.max(1, startLine ?? 1)
|
if (content === '') {
|
||||||
const effectiveEndLine = Math.min(endLine ?? lines.length, lines.length)
|
return '';
|
||||||
|
}
|
||||||
|
|
||||||
// Validate line numbers
|
const lines = currentFileContent.split('\n')
|
||||||
if (effectiveStartLine > effectiveEndLine) {
|
const effectiveStartLine = Math.max(1, startLine ?? 1)
|
||||||
throw new Error('Start line cannot be greater than end line')
|
const effectiveEndLine = Math.min(endLine ?? lines.length, lines.length)
|
||||||
}
|
|
||||||
|
|
||||||
// Construct new content
|
// Validate line numbers
|
||||||
return [
|
if (effectiveStartLine > effectiveEndLine) {
|
||||||
...lines.slice(0, effectiveStartLine - 1),
|
throw new Error('Start line cannot be greater than end line')
|
||||||
content,
|
}
|
||||||
...lines.slice(effectiveEndLine)
|
|
||||||
].join('\n')
|
// Construct new content
|
||||||
} catch (error) {
|
return [
|
||||||
console.error('Error applying changes:', error instanceof Error ? error.message : 'Unknown error')
|
...lines.slice(0, effectiveStartLine - 1),
|
||||||
return null
|
content,
|
||||||
}
|
...lines.slice(effectiveEndLine)
|
||||||
|
].join('\n')
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Error applying changes:', error instanceof Error ? error.message : 'Unknown error')
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
function escapeRegExp(string: string): string {
|
||||||
|
return string.replace(/[.*+?^${}()|[\]\\]/g, "\\$&")
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 搜索和替换文件内容
|
||||||
|
* @param currentFile - 当前文件
|
||||||
|
* @param currentFileContent - 当前文件内容
|
||||||
|
* @param search - 搜索内容
|
||||||
|
* @param replace - 替换内容
|
||||||
|
*/
|
||||||
|
export const SearchAndReplace = async (
|
||||||
|
currentFile: TFile,
|
||||||
|
currentFileContent: string,
|
||||||
|
operations: SearchAndReplaceToolArgs['operations']
|
||||||
|
) => {
|
||||||
|
let lines = currentFileContent.split("\n")
|
||||||
|
|
||||||
|
for (const op of operations) {
|
||||||
|
const flags = op.regexFlags ?? (op.ignoreCase ? "gi" : "g")
|
||||||
|
const multilineFlags = flags.includes("m") ? flags : flags + "m"
|
||||||
|
|
||||||
|
const searchPattern = op.useRegex
|
||||||
|
? new RegExp(op.search, multilineFlags)
|
||||||
|
: new RegExp(escapeRegExp(op.search), multilineFlags)
|
||||||
|
|
||||||
|
if (op.startLine || op.endLine) {
|
||||||
|
const startLine = Math.max((op.startLine ?? 1) - 1, 0)
|
||||||
|
const endLine = Math.min((op.endLine ?? lines.length) - 1, lines.length - 1)
|
||||||
|
|
||||||
|
// Get the content before and after the target section
|
||||||
|
const beforeLines = lines.slice(0, startLine)
|
||||||
|
const afterLines = lines.slice(endLine + 1)
|
||||||
|
|
||||||
|
// Get the target section and perform replacement
|
||||||
|
const targetContent = lines.slice(startLine, endLine + 1).join("\n")
|
||||||
|
const modifiedContent = targetContent.replace(searchPattern, op.replace)
|
||||||
|
const modifiedLines = modifiedContent.split("\n")
|
||||||
|
|
||||||
|
// Reconstruct the full content with the modified section
|
||||||
|
lines = [...beforeLines, ...modifiedLines, ...afterLines]
|
||||||
|
} else {
|
||||||
|
// Global replacement
|
||||||
|
const fullContent = lines.join("\n")
|
||||||
|
const modifiedContent = fullContent.replace(searchPattern, op.replace)
|
||||||
|
lines = modifiedContent.split("\n")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const newContent = lines.join("\n")
|
||||||
|
return newContent;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
import { minimatch } from 'minimatch'
|
import { minimatch } from 'minimatch'
|
||||||
import { Vault } from 'obsidian'
|
import { TFile, TFolder, Vault } from 'obsidian'
|
||||||
|
|
||||||
export const findFilesMatchingPatterns = async (
|
export const findFilesMatchingPatterns = async (
|
||||||
patterns: string[],
|
patterns: string[],
|
||||||
@ -10,3 +10,24 @@ export const findFilesMatchingPatterns = async (
|
|||||||
return patterns.some((pattern) => minimatch(file.path, pattern))
|
return patterns.some((pattern) => minimatch(file.path, pattern))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export const listFilesAndFolders = async (vault: Vault, path: string) => {
|
||||||
|
const folder = vault.getAbstractFileByPath(path)
|
||||||
|
const childrenFiles: string[] = []
|
||||||
|
const childrenFolders: string[] = []
|
||||||
|
if (folder instanceof TFolder) {
|
||||||
|
folder.children.forEach((child) => {
|
||||||
|
if (child instanceof TFile) {
|
||||||
|
childrenFiles.push(child.path)
|
||||||
|
} else if (child instanceof TFolder) {
|
||||||
|
childrenFolders.push(child.path + "/")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return [...childrenFolders, ...childrenFiles]
|
||||||
|
}
|
||||||
|
return []
|
||||||
|
}
|
||||||
|
|
||||||
|
export const regexSearchFiles = async (vault: Vault, path: string, regex: string, file_pattern: string) => {
|
||||||
|
|
||||||
|
}
|
||||||
|
|||||||
312
src/utils/modes.ts
Normal file
312
src/utils/modes.ts
Normal file
@ -0,0 +1,312 @@
|
|||||||
|
import { addCustomInstructions } from "../core/prompts/sections/custom-instructions"
|
||||||
|
|
||||||
|
import { ALWAYS_AVAILABLE_TOOLS, TOOL_GROUPS, ToolGroup } from "./tool-groups"
|
||||||
|
|
||||||
|
// Mode types
|
||||||
|
export type Mode = string
|
||||||
|
|
||||||
|
// Group options type
|
||||||
|
export type GroupOptions = {
|
||||||
|
fileRegex?: string // Regular expression pattern
|
||||||
|
description?: string // Human-readable description of the pattern
|
||||||
|
}
|
||||||
|
|
||||||
|
// Group entry can be either a string or tuple with options
|
||||||
|
export type GroupEntry = ToolGroup | readonly [ToolGroup, GroupOptions]
|
||||||
|
|
||||||
|
// Mode configuration type
|
||||||
|
export type ModeConfig = {
|
||||||
|
slug: string
|
||||||
|
name: string
|
||||||
|
roleDefinition: string
|
||||||
|
customInstructions?: string
|
||||||
|
groups: readonly GroupEntry[] // Now supports both simple strings and tuples with options
|
||||||
|
source?: "global" | "project" // Where this mode was loaded from
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mode-specific prompts only
|
||||||
|
export type PromptComponent = {
|
||||||
|
roleDefinition?: string
|
||||||
|
customInstructions?: string
|
||||||
|
}
|
||||||
|
|
||||||
|
export type CustomModePrompts = {
|
||||||
|
[key: string]: PromptComponent | undefined
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper to extract group name regardless of format
|
||||||
|
export function getGroupName(group: GroupEntry): ToolGroup {
|
||||||
|
if (typeof group === "string") {
|
||||||
|
return group
|
||||||
|
}
|
||||||
|
|
||||||
|
return group[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper to get group options if they exist
|
||||||
|
function getGroupOptions(group: GroupEntry): GroupOptions | undefined {
|
||||||
|
return Array.isArray(group) ? group[1] : undefined
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper to check if a file path matches a regex pattern
|
||||||
|
export function doesFileMatchRegex(filePath: string, pattern: string): boolean {
|
||||||
|
try {
|
||||||
|
const regex = new RegExp(pattern)
|
||||||
|
return regex.test(filePath)
|
||||||
|
} catch (error) {
|
||||||
|
console.error(`Invalid regex pattern: ${pattern}`, error)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper to get all tools for a mode
|
||||||
|
export function getToolsForMode(groups: readonly GroupEntry[]): string[] {
|
||||||
|
const tools = new Set<string>()
|
||||||
|
|
||||||
|
// Add tools from each group
|
||||||
|
groups.forEach((group) => {
|
||||||
|
const groupName = getGroupName(group)
|
||||||
|
const groupConfig = TOOL_GROUPS[groupName]
|
||||||
|
groupConfig.tools.forEach((tool: string) => tools.add(tool))
|
||||||
|
})
|
||||||
|
|
||||||
|
// Always add required tools
|
||||||
|
ALWAYS_AVAILABLE_TOOLS.forEach((tool) => tools.add(tool))
|
||||||
|
|
||||||
|
return Array.from(tools)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Main modes configuration as an ordered array
|
||||||
|
export const modes: readonly ModeConfig[] = [
|
||||||
|
{
|
||||||
|
slug: "write",
|
||||||
|
name: "Write",
|
||||||
|
roleDefinition:
|
||||||
|
"You are Infio, a versatile content creator skilled in composing, editing, and organizing various text-based documents. You excel at structuring information clearly, creating well-formatted content, and helping users express their ideas effectively.",
|
||||||
|
groups: ["read", "edit"],
|
||||||
|
customInstructions:
|
||||||
|
"You can create and modify any text-based files, with particular expertise in Markdown formatting. Help users organize their thoughts, create documentation, take notes, or draft any written content they need. When appropriate, suggest structural improvements and formatting enhancements that make content more readable and accessible. Consider the purpose and audience of each document to provide the most relevant assistance."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
slug: "ask",
|
||||||
|
name: "Ask",
|
||||||
|
roleDefinition:
|
||||||
|
"You are Infio, a versatile assistant dedicated to providing informative responses, thoughtful explanations, and practical guidance on virtually any topic or challenge you face.",
|
||||||
|
groups: ["read"],
|
||||||
|
customInstructions:
|
||||||
|
"You can analyze information, explain concepts across various domains, and access external resources when helpful. Make sure to address the user's questions thoroughly with thoughtful explanations and practical guidance. Use visual aids like Mermaid diagrams when they help make complex topics clearer. Offer solutions to challenges from diverse fields, not just technical ones, and provide context that helps users better understand the subject matter.",
|
||||||
|
},
|
||||||
|
] as const
|
||||||
|
|
||||||
|
// Export the default mode slug
|
||||||
|
export const defaultModeSlug = modes[0].slug
|
||||||
|
|
||||||
|
// Helper functions
|
||||||
|
export function getModeBySlug(slug: string, customModes?: ModeConfig[]): ModeConfig | undefined {
|
||||||
|
// Check custom modes first
|
||||||
|
const customMode = customModes?.find((mode) => mode.slug === slug)
|
||||||
|
if (customMode) {
|
||||||
|
return customMode
|
||||||
|
}
|
||||||
|
// Then check built-in modes
|
||||||
|
return modes.find((mode) => mode.slug === slug)
|
||||||
|
}
|
||||||
|
|
||||||
|
export function getModeConfig(slug: string, customModes?: ModeConfig[]): ModeConfig {
|
||||||
|
const mode = getModeBySlug(slug, customModes)
|
||||||
|
if (!mode) {
|
||||||
|
throw new Error(`No mode found for slug: ${slug}`)
|
||||||
|
}
|
||||||
|
return mode
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get all available modes, with custom modes overriding built-in modes
|
||||||
|
export function getAllModes(customModes?: ModeConfig[]): ModeConfig[] {
|
||||||
|
if (!customModes?.length) {
|
||||||
|
return [...modes]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start with built-in modes
|
||||||
|
const allModes = [...modes]
|
||||||
|
|
||||||
|
// Process custom modes
|
||||||
|
customModes.forEach((customMode) => {
|
||||||
|
const index = allModes.findIndex((mode) => mode.slug === customMode.slug)
|
||||||
|
if (index !== -1) {
|
||||||
|
// Override existing mode
|
||||||
|
allModes[index] = customMode
|
||||||
|
} else {
|
||||||
|
// Add new mode
|
||||||
|
allModes.push(customMode)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
return allModes
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if a mode is custom or an override
|
||||||
|
export function isCustomMode(slug: string, customModes?: ModeConfig[]): boolean {
|
||||||
|
return !!customModes?.some((mode) => mode.slug === slug)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Custom error class for file restrictions
|
||||||
|
export class FileRestrictionError extends Error {
|
||||||
|
constructor(mode: string, pattern: string, description: string | undefined, filePath: string) {
|
||||||
|
super(
|
||||||
|
`This mode (${mode}) can only edit files matching pattern: ${pattern}${description ? ` (${description})` : ""}. Got: ${filePath}`,
|
||||||
|
)
|
||||||
|
this.name = "FileRestrictionError"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export function isToolAllowedForMode(
|
||||||
|
tool: string,
|
||||||
|
modeSlug: string,
|
||||||
|
customModes: ModeConfig[],
|
||||||
|
toolRequirements?: Record<string, boolean>,
|
||||||
|
toolParams?: Record<string, any>, // All tool parameters
|
||||||
|
experiments?: Record<string, boolean>,
|
||||||
|
): boolean {
|
||||||
|
// Always allow these tools
|
||||||
|
if (ALWAYS_AVAILABLE_TOOLS.includes(tool as any)) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
if (experiments && tool in experiments) {
|
||||||
|
if (!experiments[tool]) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check tool requirements if any exist
|
||||||
|
if (toolRequirements && tool in toolRequirements) {
|
||||||
|
if (!toolRequirements[tool]) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const mode = getModeBySlug(modeSlug, customModes)
|
||||||
|
if (!mode) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if tool is in any of the mode's groups and respects any group options
|
||||||
|
for (const group of mode.groups) {
|
||||||
|
const groupName = getGroupName(group)
|
||||||
|
const options = getGroupOptions(group)
|
||||||
|
|
||||||
|
const groupConfig = TOOL_GROUPS[groupName]
|
||||||
|
|
||||||
|
// If the tool isn't in this group's tools, continue to next group
|
||||||
|
if (!groupConfig.tools.includes(tool)) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// If there are no options, allow the tool
|
||||||
|
if (!options) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// For the edit group, check file regex if specified
|
||||||
|
if (groupName === "edit" && options.fileRegex) {
|
||||||
|
const filePath = toolParams?.path
|
||||||
|
if (
|
||||||
|
filePath &&
|
||||||
|
(toolParams.diff || toolParams.content || toolParams.operations) &&
|
||||||
|
!doesFileMatchRegex(filePath, options.fileRegex)
|
||||||
|
) {
|
||||||
|
throw new FileRestrictionError(mode.name, options.fileRegex, options.description, filePath)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the mode-specific default prompts
|
||||||
|
export const defaultPrompts: Readonly<CustomModePrompts> = Object.freeze(
|
||||||
|
Object.fromEntries(
|
||||||
|
modes.map((mode) => [
|
||||||
|
mode.slug,
|
||||||
|
{
|
||||||
|
roleDefinition: mode.roleDefinition,
|
||||||
|
customInstructions: mode.customInstructions,
|
||||||
|
},
|
||||||
|
]),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
// Helper function to get all modes with their prompt overrides from extension state
|
||||||
|
export async function getAllModesWithPrompts(context: vscode.ExtensionContext): Promise<ModeConfig[]> {
|
||||||
|
const customModes = (await context.globalState.get<ModeConfig[]>("customModes")) || []
|
||||||
|
const customModePrompts = (await context.globalState.get<CustomModePrompts>("customModePrompts")) || {}
|
||||||
|
|
||||||
|
const allModes = getAllModes(customModes)
|
||||||
|
return allModes.map((mode) => ({
|
||||||
|
...mode,
|
||||||
|
roleDefinition: customModePrompts[mode.slug]?.roleDefinition ?? mode.roleDefinition,
|
||||||
|
customInstructions: customModePrompts[mode.slug]?.customInstructions ?? mode.customInstructions,
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function to get complete mode details with all overrides
|
||||||
|
export async function getFullModeDetails(
|
||||||
|
modeSlug: string,
|
||||||
|
customModes?: ModeConfig[],
|
||||||
|
customModePrompts?: CustomModePrompts,
|
||||||
|
options?: {
|
||||||
|
cwd?: string
|
||||||
|
globalCustomInstructions?: string
|
||||||
|
preferredLanguage?: string
|
||||||
|
},
|
||||||
|
): Promise<ModeConfig> {
|
||||||
|
// First get the base mode config from custom modes or built-in modes
|
||||||
|
const baseMode = getModeBySlug(modeSlug, customModes) || modes.find((m) => m.slug === modeSlug) || modes[0]
|
||||||
|
|
||||||
|
// Check for any prompt component overrides
|
||||||
|
const promptComponent = customModePrompts?.[modeSlug]
|
||||||
|
|
||||||
|
// Get the base custom instructions
|
||||||
|
const baseCustomInstructions = promptComponent?.customInstructions || baseMode.customInstructions || ""
|
||||||
|
|
||||||
|
// If we have cwd, load and combine all custom instructions
|
||||||
|
let fullCustomInstructions = baseCustomInstructions
|
||||||
|
if (options?.cwd) {
|
||||||
|
fullCustomInstructions = await addCustomInstructions(
|
||||||
|
baseCustomInstructions,
|
||||||
|
options.globalCustomInstructions || "",
|
||||||
|
options.cwd,
|
||||||
|
modeSlug,
|
||||||
|
{ preferredLanguage: options.preferredLanguage },
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return mode with any overrides applied
|
||||||
|
return {
|
||||||
|
...baseMode,
|
||||||
|
roleDefinition: promptComponent?.roleDefinition || baseMode.roleDefinition,
|
||||||
|
customInstructions: fullCustomInstructions,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function to safely get role definition
|
||||||
|
export function getRoleDefinition(modeSlug: string, customModes?: ModeConfig[]): string {
|
||||||
|
const mode = getModeBySlug(modeSlug, customModes)
|
||||||
|
if (!mode) {
|
||||||
|
console.warn(`No mode found for slug: ${modeSlug}`)
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return mode.roleDefinition
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function to safely get custom instructions
|
||||||
|
export function getCustomInstructions(modeSlug: string, customModes?: ModeConfig[]): string {
|
||||||
|
const mode = getModeBySlug(modeSlug, customModes)
|
||||||
|
if (!mode) {
|
||||||
|
console.warn(`No mode found for slug: ${modeSlug}`)
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return mode.customInstructions ?? ""
|
||||||
|
}
|
||||||
@ -1,4 +1,4 @@
|
|||||||
import { InfioBlockAction, ParsedInfioBlock, parseinfioBlocks } from './parse-infio-block'
|
import { InfioBlockAction, ParsedMsgBlock, parseMsgBlocks } from './parse-infio-block'
|
||||||
|
|
||||||
describe('parseinfioBlocks', () => {
|
describe('parseinfioBlocks', () => {
|
||||||
it('should parse a string with infio_block elements', () => {
|
it('should parse a string with infio_block elements', () => {
|
||||||
@ -22,7 +22,7 @@ print("Hello, world!")
|
|||||||
</infio_block>
|
</infio_block>
|
||||||
Some text after`
|
Some text after`
|
||||||
|
|
||||||
const expected: ParsedInfioBlock[] = [
|
const expected: ParsedMsgBlock[] = [
|
||||||
{ type: 'string', content: 'Some text before\n' },
|
{ type: 'string', content: 'Some text before\n' },
|
||||||
{
|
{
|
||||||
type: 'infio_block',
|
type: 'infio_block',
|
||||||
@ -49,7 +49,7 @@ print("Hello, world!")
|
|||||||
{ type: 'string', content: '\nSome text after' },
|
{ type: 'string', content: '\nSome text after' },
|
||||||
]
|
]
|
||||||
|
|
||||||
const result = parseinfioBlocks(input)
|
const result = parseMsgBlocks(input)
|
||||||
expect(result).toEqual(expected)
|
expect(result).toEqual(expected)
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -58,7 +58,7 @@ print("Hello, world!")
|
|||||||
<infio_block language="python"></infio_block>
|
<infio_block language="python"></infio_block>
|
||||||
`
|
`
|
||||||
|
|
||||||
const expected: ParsedInfioBlock[] = [
|
const expected: ParsedMsgBlock[] = [
|
||||||
{ type: 'string', content: '\n ' },
|
{ type: 'string', content: '\n ' },
|
||||||
{
|
{
|
||||||
type: 'infio_block',
|
type: 'infio_block',
|
||||||
@ -69,16 +69,16 @@ print("Hello, world!")
|
|||||||
{ type: 'string', content: '\n ' },
|
{ type: 'string', content: '\n ' },
|
||||||
]
|
]
|
||||||
|
|
||||||
const result = parseinfioBlocks(input)
|
const result = parseMsgBlocks(input)
|
||||||
expect(result).toEqual(expected)
|
expect(result).toEqual(expected)
|
||||||
})
|
})
|
||||||
|
|
||||||
it('should handle input without infio_block elements', () => {
|
it('should handle input without infio_block elements', () => {
|
||||||
const input = 'Just a regular string without any infio_block elements.'
|
const input = 'Just a regular string without any infio_block elements.'
|
||||||
|
|
||||||
const expected: ParsedInfioBlock[] = [{ type: 'string', content: input }]
|
const expected: ParsedMsgBlock[] = [{ type: 'string', content: input }]
|
||||||
|
|
||||||
const result = parseinfioBlocks(input)
|
const result = parseMsgBlocks(input)
|
||||||
expect(result).toEqual(expected)
|
expect(result).toEqual(expected)
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -100,7 +100,7 @@ print("Hello, world!")
|
|||||||
</infio_block>
|
</infio_block>
|
||||||
End`
|
End`
|
||||||
|
|
||||||
const expected: ParsedInfioBlock[] = [
|
const expected: ParsedMsgBlock[] = [
|
||||||
{ type: 'string', content: 'Start\n' },
|
{ type: 'string', content: 'Start\n' },
|
||||||
{
|
{
|
||||||
type: 'infio_block',
|
type: 'infio_block',
|
||||||
@ -129,7 +129,7 @@ print("Hello, world!")
|
|||||||
{ type: 'string', content: '\nEnd' },
|
{ type: 'string', content: '\nEnd' },
|
||||||
]
|
]
|
||||||
|
|
||||||
const result = parseinfioBlocks(input)
|
const result = parseMsgBlocks(input)
|
||||||
expect(result).toEqual(expected)
|
expect(result).toEqual(expected)
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -139,7 +139,7 @@ print("Hello, world!")
|
|||||||
# Unfinished infio_block
|
# Unfinished infio_block
|
||||||
|
|
||||||
Some text after without closing tag`
|
Some text after without closing tag`
|
||||||
const expected: ParsedInfioBlock[] = [
|
const expected: ParsedMsgBlock[] = [
|
||||||
{ type: 'string', content: 'Start\n' },
|
{ type: 'string', content: 'Start\n' },
|
||||||
{
|
{
|
||||||
type: 'infio_block',
|
type: 'infio_block',
|
||||||
@ -152,13 +152,13 @@ Some text after without closing tag`,
|
|||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
const result = parseinfioBlocks(input)
|
const result = parseMsgBlocks(input)
|
||||||
expect(result).toEqual(expected)
|
expect(result).toEqual(expected)
|
||||||
})
|
})
|
||||||
|
|
||||||
it('should handle infio_block with startline and endline attributes', () => {
|
it('should handle infio_block with startline and endline attributes', () => {
|
||||||
const input = `<infio_block language="markdown" startline="2" endline="5"></infio_block>`
|
const input = `<infio_block language="markdown" startline="2" endline="5"></infio_block>`
|
||||||
const expected: ParsedInfioBlock[] = [
|
const expected: ParsedMsgBlock[] = [
|
||||||
{
|
{
|
||||||
type: 'infio_block',
|
type: 'infio_block',
|
||||||
content: '',
|
content: '',
|
||||||
@ -168,13 +168,13 @@ Some text after without closing tag`,
|
|||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
const result = parseinfioBlocks(input)
|
const result = parseMsgBlocks(input)
|
||||||
expect(result).toEqual(expected)
|
expect(result).toEqual(expected)
|
||||||
})
|
})
|
||||||
|
|
||||||
it('should parse infio_block with action attribute', () => {
|
it('should parse infio_block with action attribute', () => {
|
||||||
const input = `<infio_block type="edit"></infio_block>`
|
const input = `<infio_block type="edit"></infio_block>`
|
||||||
const expected: ParsedInfioBlock[] = [
|
const expected: ParsedMsgBlock[] = [
|
||||||
{
|
{
|
||||||
type: 'infio_block',
|
type: 'infio_block',
|
||||||
content: '',
|
content: '',
|
||||||
@ -182,13 +182,13 @@ Some text after without closing tag`,
|
|||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
const result = parseinfioBlocks(input)
|
const result = parseMsgBlocks(input)
|
||||||
expect(result).toEqual(expected)
|
expect(result).toEqual(expected)
|
||||||
})
|
})
|
||||||
|
|
||||||
it('should handle invalid action attribute', () => {
|
it('should handle invalid action attribute', () => {
|
||||||
const input = `<infio_block type="invalid"></infio_block>`
|
const input = `<infio_block type="invalid"></infio_block>`
|
||||||
const expected: ParsedInfioBlock[] = [
|
const expected: ParsedMsgBlock[] = [
|
||||||
{
|
{
|
||||||
type: 'infio_block',
|
type: 'infio_block',
|
||||||
content: '',
|
content: '',
|
||||||
@ -196,7 +196,7 @@ Some text after without closing tag`,
|
|||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
const result = parseinfioBlocks(input)
|
const result = parseMsgBlocks(input)
|
||||||
expect(result).toEqual(expected)
|
expect(result).toEqual(expected)
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -208,7 +208,7 @@ It might contain multiple lines of text.
|
|||||||
</think>
|
</think>
|
||||||
Some text after`
|
Some text after`
|
||||||
|
|
||||||
const expected: ParsedInfioBlock[] = [
|
const expected: ParsedMsgBlock[] = [
|
||||||
{ type: 'string', content: 'Some text before\n' },
|
{ type: 'string', content: 'Some text before\n' },
|
||||||
{
|
{
|
||||||
type: 'think',
|
type: 'think',
|
||||||
@ -220,7 +220,7 @@ It might contain multiple lines of text.
|
|||||||
{ type: 'string', content: '\nSome text after' },
|
{ type: 'string', content: '\nSome text after' },
|
||||||
]
|
]
|
||||||
|
|
||||||
const result = parseinfioBlocks(input)
|
const result = parseMsgBlocks(input)
|
||||||
expect(result).toEqual(expected)
|
expect(result).toEqual(expected)
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -229,7 +229,7 @@ It might contain multiple lines of text.
|
|||||||
<think></think>
|
<think></think>
|
||||||
`
|
`
|
||||||
|
|
||||||
const expected: ParsedInfioBlock[] = [
|
const expected: ParsedMsgBlock[] = [
|
||||||
{ type: 'string', content: '\n ' },
|
{ type: 'string', content: '\n ' },
|
||||||
{
|
{
|
||||||
type: 'think',
|
type: 'think',
|
||||||
@ -238,7 +238,7 @@ It might contain multiple lines of text.
|
|||||||
{ type: 'string', content: '\n ' },
|
{ type: 'string', content: '\n ' },
|
||||||
]
|
]
|
||||||
|
|
||||||
const result = parseinfioBlocks(input)
|
const result = parseMsgBlocks(input)
|
||||||
expect(result).toEqual(expected)
|
expect(result).toEqual(expected)
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -255,7 +255,7 @@ I need to consider several approaches.
|
|||||||
</think>
|
</think>
|
||||||
End`
|
End`
|
||||||
|
|
||||||
const expected: ParsedInfioBlock[] = [
|
const expected: ParsedMsgBlock[] = [
|
||||||
{ type: 'string', content: 'Start\n' },
|
{ type: 'string', content: 'Start\n' },
|
||||||
{
|
{
|
||||||
type: 'infio_block',
|
type: 'infio_block',
|
||||||
@ -277,7 +277,7 @@ I need to consider several approaches.
|
|||||||
{ type: 'string', content: '\nEnd' },
|
{ type: 'string', content: '\nEnd' },
|
||||||
]
|
]
|
||||||
|
|
||||||
const result = parseinfioBlocks(input)
|
const result = parseMsgBlocks(input)
|
||||||
expect(result).toEqual(expected)
|
expect(result).toEqual(expected)
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -286,7 +286,7 @@ I need to consider several approaches.
|
|||||||
<think>
|
<think>
|
||||||
Some unfinished thought
|
Some unfinished thought
|
||||||
without closing tag`
|
without closing tag`
|
||||||
const expected: ParsedInfioBlock[] = [
|
const expected: ParsedMsgBlock[] = [
|
||||||
{ type: 'string', content: 'Start\n' },
|
{ type: 'string', content: 'Start\n' },
|
||||||
{
|
{
|
||||||
type: 'think',
|
type: 'think',
|
||||||
@ -296,7 +296,7 @@ without closing tag`,
|
|||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
const result = parseinfioBlocks(input)
|
const result = parseMsgBlocks(input)
|
||||||
expect(result).toEqual(expected)
|
expect(result).toEqual(expected)
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|||||||
@ -1,128 +1,433 @@
|
|||||||
|
import JSON5 from 'json5'
|
||||||
import { parseFragment } from 'parse5'
|
import { parseFragment } from 'parse5'
|
||||||
|
|
||||||
export enum InfioBlockAction {
|
export type ParsedMsgBlock =
|
||||||
Edit = 'edit',
|
|
||||||
New = 'new',
|
|
||||||
Reference = 'reference'
|
|
||||||
}
|
|
||||||
|
|
||||||
export type ParsedInfioBlock =
|
|
||||||
| { type: 'string'; content: string }
|
|
||||||
| {
|
| {
|
||||||
type: 'infio_block'
|
type: 'string'
|
||||||
content: string
|
content: string
|
||||||
language?: string
|
|
||||||
filename?: string
|
|
||||||
startLine?: number
|
|
||||||
endLine?: number
|
|
||||||
action?: InfioBlockAction
|
|
||||||
}
|
}
|
||||||
| { type: 'think'; content: string }
|
| {
|
||||||
|
type: 'think'
|
||||||
function isInfioBlockAction(value: string): value is InfioBlockAction {
|
content: string
|
||||||
return Object.values<string>(InfioBlockAction).includes(value)
|
} | {
|
||||||
}
|
type: 'thinking'
|
||||||
|
content: string
|
||||||
export function parseinfioBlocks(input: string): ParsedInfioBlock[] {
|
} | {
|
||||||
const parsedResult: ParsedInfioBlock[] = []
|
type: 'write_to_file'
|
||||||
const fragment = parseFragment(input, {
|
path: string
|
||||||
sourceCodeLocationInfo: true,
|
content: string
|
||||||
})
|
lineCount?: number
|
||||||
let lastEndOffset = 0
|
} | {
|
||||||
for (const node of fragment.childNodes) {
|
type: 'insert_content'
|
||||||
if (node.nodeName === 'infio_block') {
|
path: string
|
||||||
if (!node.sourceCodeLocation) {
|
startLine: number
|
||||||
throw new Error('sourceCodeLocation is undefined')
|
content: string
|
||||||
}
|
} | {
|
||||||
const startOffset = node.sourceCodeLocation.startOffset
|
type: 'read_file'
|
||||||
const endOffset = node.sourceCodeLocation.endOffset
|
path: string
|
||||||
if (startOffset > lastEndOffset) {
|
finish: boolean
|
||||||
parsedResult.push({
|
} | {
|
||||||
type: 'string',
|
type: 'attempt_completion'
|
||||||
content: input.slice(lastEndOffset, startOffset),
|
result: string
|
||||||
})
|
} | {
|
||||||
}
|
type: 'search_and_replace'
|
||||||
|
path: string
|
||||||
const language = node.attrs.find((attr) => attr.name === 'language')?.value
|
operations: {
|
||||||
const filename = node.attrs.find((attr) => attr.name === 'filename')?.value
|
search: string
|
||||||
const startLine = node.attrs.find((attr) => attr.name === 'startline')?.value
|
replace: string
|
||||||
const endLine = node.attrs.find((attr) => attr.name === 'endline')?.value
|
start_line?: number
|
||||||
const actionValue = node.attrs.find((attr) => attr.name === 'type')?.value
|
end_line?: number
|
||||||
const action = actionValue && isInfioBlockAction(actionValue)
|
use_regex?: boolean
|
||||||
? actionValue
|
ignore_case?: boolean
|
||||||
: undefined
|
regex_flags?: string
|
||||||
|
}[]
|
||||||
|
finish: boolean
|
||||||
const children = node.childNodes
|
} | {
|
||||||
if (children.length === 0) {
|
type: 'ask_followup_question'
|
||||||
parsedResult.push({
|
question: string
|
||||||
type: 'infio_block',
|
} | {
|
||||||
content: '',
|
type: 'list_files'
|
||||||
language,
|
path: string
|
||||||
filename,
|
recursive?: boolean
|
||||||
startLine: startLine ? parseInt(startLine) : undefined,
|
finish: boolean
|
||||||
endLine: endLine ? parseInt(endLine) : undefined,
|
} | {
|
||||||
action: action,
|
type: 'regex_search_files'
|
||||||
})
|
path: string
|
||||||
} else {
|
regex: string
|
||||||
const innerContentStartOffset =
|
finish: boolean
|
||||||
children[0].sourceCodeLocation?.startOffset
|
} | {
|
||||||
const innerContentEndOffset =
|
type: 'semantic_search_files'
|
||||||
children[children.length - 1].sourceCodeLocation?.endOffset
|
path: string
|
||||||
if (!innerContentStartOffset || !innerContentEndOffset) {
|
query: string
|
||||||
throw new Error('sourceCodeLocation is undefined')
|
finish: boolean
|
||||||
}
|
|
||||||
parsedResult.push({
|
|
||||||
type: 'infio_block',
|
|
||||||
content: input.slice(innerContentStartOffset, innerContentEndOffset),
|
|
||||||
language,
|
|
||||||
filename,
|
|
||||||
startLine: startLine ? parseInt(startLine) : undefined,
|
|
||||||
endLine: endLine ? parseInt(endLine) : undefined,
|
|
||||||
action: action,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
lastEndOffset = endOffset
|
|
||||||
} else if (node.nodeName === 'think') {
|
|
||||||
if (!node.sourceCodeLocation) {
|
|
||||||
throw new Error('sourceCodeLocation is undefined')
|
|
||||||
}
|
|
||||||
const startOffset = node.sourceCodeLocation.startOffset
|
|
||||||
const endOffset = node.sourceCodeLocation.endOffset
|
|
||||||
if (startOffset > lastEndOffset) {
|
|
||||||
parsedResult.push({
|
|
||||||
type: 'string',
|
|
||||||
content: input.slice(lastEndOffset, startOffset),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
const children = node.childNodes
|
|
||||||
if (children.length === 0) {
|
|
||||||
parsedResult.push({
|
|
||||||
type: 'think',
|
|
||||||
content: '',
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
const innerContentStartOffset =
|
|
||||||
children[0].sourceCodeLocation?.startOffset
|
|
||||||
const innerContentEndOffset =
|
|
||||||
children[children.length - 1].sourceCodeLocation?.endOffset
|
|
||||||
if (!innerContentStartOffset || !innerContentEndOffset) {
|
|
||||||
throw new Error('sourceCodeLocation is undefined')
|
|
||||||
}
|
|
||||||
parsedResult.push({
|
|
||||||
type: 'think',
|
|
||||||
content: input.slice(innerContentStartOffset, innerContentEndOffset),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
lastEndOffset = endOffset
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if (lastEndOffset < input.length) {
|
|
||||||
parsedResult.push({
|
export function parseMsgBlocks(
|
||||||
type: 'string',
|
input: string,
|
||||||
content: input.slice(lastEndOffset),
|
): ParsedMsgBlock[] {
|
||||||
|
try {
|
||||||
|
const parsedResult: ParsedMsgBlock[] = []
|
||||||
|
const fragment = parseFragment(input, {
|
||||||
|
sourceCodeLocationInfo: true,
|
||||||
})
|
})
|
||||||
|
let lastEndOffset = 0
|
||||||
|
for (const node of fragment.childNodes) {
|
||||||
|
if (node.nodeName === 'thinking') {
|
||||||
|
if (!node.sourceCodeLocation) {
|
||||||
|
throw new Error('sourceCodeLocation is undefined')
|
||||||
|
}
|
||||||
|
const startOffset = node.sourceCodeLocation.startOffset
|
||||||
|
const endOffset = node.sourceCodeLocation.endOffset
|
||||||
|
if (startOffset > lastEndOffset) {
|
||||||
|
parsedResult.push({
|
||||||
|
type: 'string',
|
||||||
|
content: input.slice(lastEndOffset, startOffset),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
const children = node.childNodes
|
||||||
|
if (children.length === 0) {
|
||||||
|
parsedResult.push({
|
||||||
|
type: 'thinking',
|
||||||
|
content: '',
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
const innerContentStartOffset =
|
||||||
|
children[0].sourceCodeLocation?.startOffset
|
||||||
|
const innerContentEndOffset =
|
||||||
|
children[children.length - 1].sourceCodeLocation?.endOffset
|
||||||
|
if (!innerContentStartOffset || !innerContentEndOffset) {
|
||||||
|
throw new Error('sourceCodeLocation is undefined')
|
||||||
|
}
|
||||||
|
parsedResult.push({
|
||||||
|
type: 'thinking',
|
||||||
|
content: input.slice(innerContentStartOffset, innerContentEndOffset),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
lastEndOffset = endOffset
|
||||||
|
} else if (node.nodeName === 'think') {
|
||||||
|
if (!node.sourceCodeLocation) {
|
||||||
|
throw new Error('sourceCodeLocation is undefined')
|
||||||
|
}
|
||||||
|
const startOffset = node.sourceCodeLocation.startOffset
|
||||||
|
const endOffset = node.sourceCodeLocation.endOffset
|
||||||
|
if (startOffset > lastEndOffset) {
|
||||||
|
parsedResult.push({
|
||||||
|
type: 'string',
|
||||||
|
content: input.slice(lastEndOffset, startOffset),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
const children = node.childNodes
|
||||||
|
if (children.length === 0) {
|
||||||
|
parsedResult.push({
|
||||||
|
type: 'think',
|
||||||
|
content: '',
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
const innerContentStartOffset =
|
||||||
|
children[0].sourceCodeLocation?.startOffset
|
||||||
|
const innerContentEndOffset =
|
||||||
|
children[children.length - 1].sourceCodeLocation?.endOffset
|
||||||
|
if (!innerContentStartOffset || !innerContentEndOffset) {
|
||||||
|
throw new Error('sourceCodeLocation is undefined')
|
||||||
|
}
|
||||||
|
parsedResult.push({
|
||||||
|
type: 'think',
|
||||||
|
content: input.slice(innerContentStartOffset, innerContentEndOffset),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
lastEndOffset = endOffset
|
||||||
|
} else if (node.nodeName === 'list_files') {
|
||||||
|
if (!node.sourceCodeLocation) {
|
||||||
|
throw new Error('sourceCodeLocation is undefined')
|
||||||
|
}
|
||||||
|
const startOffset = node.sourceCodeLocation.startOffset
|
||||||
|
const endOffset = node.sourceCodeLocation.endOffset
|
||||||
|
if (startOffset > lastEndOffset) {
|
||||||
|
parsedResult.push({
|
||||||
|
type: 'string',
|
||||||
|
content: input.slice(lastEndOffset, startOffset),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
let path: string | undefined
|
||||||
|
let recursive: boolean | undefined
|
||||||
|
|
||||||
|
for (const childNode of node.childNodes) {
|
||||||
|
if (childNode.nodeName === 'path' && childNode.childNodes.length > 0) {
|
||||||
|
path = childNode.childNodes[0].value
|
||||||
|
} else if (childNode.nodeName === 'recursive' && childNode.childNodes.length > 0) {
|
||||||
|
const recursiveValue = childNode.childNodes[0].value
|
||||||
|
recursive = recursiveValue ? recursiveValue.toLowerCase() === 'true' : false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
parsedResult.push({
|
||||||
|
type: 'list_files',
|
||||||
|
path: path || '/',
|
||||||
|
recursive,
|
||||||
|
finish: node.sourceCodeLocation.endTag !== undefined
|
||||||
|
})
|
||||||
|
lastEndOffset = endOffset
|
||||||
|
} else if (node.nodeName === 'read_file') {
|
||||||
|
if (!node.sourceCodeLocation) {
|
||||||
|
throw new Error('sourceCodeLocation is undefined')
|
||||||
|
}
|
||||||
|
const startOffset = node.sourceCodeLocation.startOffset
|
||||||
|
const endOffset = node.sourceCodeLocation.endOffset
|
||||||
|
if (startOffset > lastEndOffset) {
|
||||||
|
parsedResult.push({
|
||||||
|
type: 'string',
|
||||||
|
content: input.slice(lastEndOffset, startOffset),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
let path: string | undefined
|
||||||
|
for (const childNode of node.childNodes) {
|
||||||
|
if (childNode.nodeName === 'path' && childNode.childNodes.length > 0) {
|
||||||
|
path = childNode.childNodes[0].value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
parsedResult.push({
|
||||||
|
type: 'read_file',
|
||||||
|
path,
|
||||||
|
// Check if the tag is completely parsed with proper closing tag
|
||||||
|
// In parse5, when a tag is properly closed, its sourceCodeLocation will include endTag
|
||||||
|
finish: node.sourceCodeLocation.endTag !== undefined
|
||||||
|
})
|
||||||
|
lastEndOffset = endOffset
|
||||||
|
} else if (node.nodeName === 'regex_search_files') {
|
||||||
|
if (!node.sourceCodeLocation) {
|
||||||
|
throw new Error('sourceCodeLocation is undefined')
|
||||||
|
}
|
||||||
|
const startOffset = node.sourceCodeLocation.startOffset
|
||||||
|
const endOffset = node.sourceCodeLocation.endOffset
|
||||||
|
if (startOffset > lastEndOffset) {
|
||||||
|
parsedResult.push({
|
||||||
|
type: 'string',
|
||||||
|
content: input.slice(lastEndOffset, startOffset),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
let path: string | undefined
|
||||||
|
let regex: string | undefined
|
||||||
|
|
||||||
|
for (const childNode of node.childNodes) {
|
||||||
|
if (childNode.nodeName === 'path' && childNode.childNodes.length > 0) {
|
||||||
|
path = childNode.childNodes[0].value
|
||||||
|
} else if (childNode.nodeName === 'regex' && childNode.childNodes.length > 0) {
|
||||||
|
regex = childNode.childNodes[0].value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
parsedResult.push({
|
||||||
|
type: 'regex_search_files',
|
||||||
|
path: path,
|
||||||
|
regex: regex,
|
||||||
|
finish: node.sourceCodeLocation.endTag !== undefined
|
||||||
|
})
|
||||||
|
lastEndOffset = endOffset
|
||||||
|
} else if (node.nodeName === 'semantic_search_files') {
|
||||||
|
if (!node.sourceCodeLocation) {
|
||||||
|
throw new Error('sourceCodeLocation is undefined')
|
||||||
|
}
|
||||||
|
const startOffset = node.sourceCodeLocation.startOffset
|
||||||
|
const endOffset = node.sourceCodeLocation.endOffset
|
||||||
|
if (startOffset > lastEndOffset) {
|
||||||
|
parsedResult.push({
|
||||||
|
type: 'string',
|
||||||
|
content: input.slice(lastEndOffset, startOffset),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
let path: string | undefined
|
||||||
|
let query: string | undefined
|
||||||
|
|
||||||
|
for (const childNode of node.childNodes) {
|
||||||
|
if (childNode.nodeName === 'path' && childNode.childNodes.length > 0) {
|
||||||
|
path = childNode.childNodes[0].value
|
||||||
|
} else if (childNode.nodeName === 'query' && childNode.childNodes.length > 0) {
|
||||||
|
query = childNode.childNodes[0].value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
parsedResult.push({
|
||||||
|
type: 'semantic_search_files',
|
||||||
|
path: path,
|
||||||
|
query: query,
|
||||||
|
finish: node.sourceCodeLocation.endTag !== undefined
|
||||||
|
})
|
||||||
|
lastEndOffset = endOffset
|
||||||
|
} else if (node.nodeName === 'write_to_file') {
|
||||||
|
if (!node.sourceCodeLocation) {
|
||||||
|
throw new Error('sourceCodeLocation is undefined')
|
||||||
|
}
|
||||||
|
const startOffset = node.sourceCodeLocation.startOffset
|
||||||
|
const endOffset = node.sourceCodeLocation.endOffset
|
||||||
|
if (startOffset > lastEndOffset) {
|
||||||
|
parsedResult.push({
|
||||||
|
type: 'string',
|
||||||
|
content: input.slice(lastEndOffset, startOffset),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
let path: string | undefined
|
||||||
|
let content: string = ''
|
||||||
|
let lineCount: number | undefined
|
||||||
|
// 处理子标签
|
||||||
|
for (const childNode of node.childNodes) {
|
||||||
|
if (childNode.nodeName === 'path' && childNode.childNodes.length > 0) {
|
||||||
|
path = childNode.childNodes[0].value
|
||||||
|
} else if (childNode.nodeName === 'content' && childNode.childNodes.length > 0) {
|
||||||
|
// 如果内容有多个子节点,需要合并它们
|
||||||
|
content = childNode.childNodes.map(n => n.value || '').join('')
|
||||||
|
} else if (childNode.nodeName === 'line_count' && childNode.childNodes.length > 0) {
|
||||||
|
const lineCountStr = childNode.childNodes[0].value
|
||||||
|
lineCount = lineCountStr ? parseInt(lineCountStr) : undefined
|
||||||
|
}
|
||||||
|
}
|
||||||
|
parsedResult.push({
|
||||||
|
type: 'write_to_file',
|
||||||
|
content,
|
||||||
|
path,
|
||||||
|
lineCount
|
||||||
|
})
|
||||||
|
lastEndOffset = endOffset
|
||||||
|
|
||||||
|
} else if (node.nodeName === 'insert_content') {
|
||||||
|
if (!node.sourceCodeLocation) {
|
||||||
|
throw new Error('sourceCodeLocation is undefined')
|
||||||
|
}
|
||||||
|
const startOffset = node.sourceCodeLocation.startOffset
|
||||||
|
const endOffset = node.sourceCodeLocation.endOffset
|
||||||
|
if (startOffset > lastEndOffset) {
|
||||||
|
parsedResult.push({
|
||||||
|
type: 'string',
|
||||||
|
content: input.slice(lastEndOffset, startOffset),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
let path: string | undefined
|
||||||
|
let content: string = ''
|
||||||
|
let startLine: number = 0
|
||||||
|
|
||||||
|
// 处理子标签
|
||||||
|
for (const childNode of node.childNodes) {
|
||||||
|
if (childNode.nodeName === 'path' && childNode.childNodes.length > 0) {
|
||||||
|
path = childNode.childNodes[0].value
|
||||||
|
} else if (childNode.nodeName === 'operations' && childNode.childNodes.length > 0) {
|
||||||
|
try {
|
||||||
|
const operationsJson = childNode.childNodes[0].value
|
||||||
|
const operations = JSON5.parse(operationsJson)
|
||||||
|
if (Array.isArray(operations) && operations.length > 0) {
|
||||||
|
const operation = operations[0]
|
||||||
|
startLine = operation.start_line || 1
|
||||||
|
content = operation.content || ''
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Failed to parse operations JSON', error)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
parsedResult.push({
|
||||||
|
type: 'insert_content',
|
||||||
|
path,
|
||||||
|
startLine,
|
||||||
|
content
|
||||||
|
})
|
||||||
|
lastEndOffset = endOffset
|
||||||
|
} else if (node.nodeName === 'search_and_replace') {
|
||||||
|
if (!node.sourceCodeLocation) {
|
||||||
|
throw new Error('sourceCodeLocation is undefined')
|
||||||
|
}
|
||||||
|
const startOffset = node.sourceCodeLocation.startOffset
|
||||||
|
const endOffset = node.sourceCodeLocation.endOffset
|
||||||
|
if (startOffset > lastEndOffset) {
|
||||||
|
parsedResult.push({
|
||||||
|
type: 'string',
|
||||||
|
content: input.slice(lastEndOffset, startOffset),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
let path: string | undefined
|
||||||
|
let operations = []
|
||||||
|
|
||||||
|
// 处理子标签
|
||||||
|
for (const childNode of node.childNodes) {
|
||||||
|
if (childNode.nodeName === 'path' && childNode.childNodes.length > 0) {
|
||||||
|
path = childNode.childNodes[0].value
|
||||||
|
} else if (childNode.nodeName === 'operations' && childNode.childNodes.length > 0) {
|
||||||
|
try {
|
||||||
|
const operationsJson = childNode.childNodes[0].value
|
||||||
|
operations = JSON5.parse(operationsJson)
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Failed to parse operations JSON', error)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
parsedResult.push({
|
||||||
|
type: 'search_and_replace',
|
||||||
|
path,
|
||||||
|
operations,
|
||||||
|
finish: node.sourceCodeLocation.endTag !== undefined
|
||||||
|
})
|
||||||
|
lastEndOffset = endOffset
|
||||||
|
} else if (node.nodeName === 'attempt_completion') {
|
||||||
|
if (!node.sourceCodeLocation) {
|
||||||
|
throw new Error('sourceCodeLocation is undefined')
|
||||||
|
}
|
||||||
|
const startOffset = node.sourceCodeLocation.startOffset
|
||||||
|
const endOffset = node.sourceCodeLocation.endOffset
|
||||||
|
if (startOffset > lastEndOffset) {
|
||||||
|
parsedResult.push({
|
||||||
|
type: 'string',
|
||||||
|
content: input.slice(lastEndOffset, startOffset),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
let result: string | undefined
|
||||||
|
for (const childNode of node.childNodes) {
|
||||||
|
if (childNode.nodeName === 'result' && childNode.childNodes.length > 0) {
|
||||||
|
result = childNode.childNodes[0].value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
parsedResult.push({
|
||||||
|
type: 'attempt_completion',
|
||||||
|
result,
|
||||||
|
})
|
||||||
|
lastEndOffset = endOffset
|
||||||
|
|
||||||
|
} else if (node.nodeName === 'ask_followup_question') {
|
||||||
|
if (!node.sourceCodeLocation) {
|
||||||
|
throw new Error('sourceCodeLocation is undefined')
|
||||||
|
}
|
||||||
|
const startOffset = node.sourceCodeLocation.startOffset
|
||||||
|
const endOffset = node.sourceCodeLocation.endOffset
|
||||||
|
if (startOffset > lastEndOffset) {
|
||||||
|
parsedResult.push({
|
||||||
|
type: 'string',
|
||||||
|
content: input.slice(lastEndOffset, startOffset),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
let question: string | undefined
|
||||||
|
for (const childNode of node.childNodes) {
|
||||||
|
if (childNode.nodeName === 'question' && childNode.childNodes.length > 0) {
|
||||||
|
question = childNode.childNodes[0].value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
parsedResult.push({
|
||||||
|
type: 'ask_followup_question',
|
||||||
|
question,
|
||||||
|
})
|
||||||
|
lastEndOffset = endOffset
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handle the last part of the input
|
||||||
|
if (lastEndOffset < input.length) {
|
||||||
|
parsedResult.push({
|
||||||
|
type: 'string',
|
||||||
|
content: input.slice(lastEndOffset),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return parsedResult
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Failed to parse infio block', error)
|
||||||
|
throw error
|
||||||
}
|
}
|
||||||
return parsedResult
|
|
||||||
}
|
}
|
||||||
|
|||||||
107
src/utils/path.ts
Normal file
107
src/utils/path.ts
Normal file
@ -0,0 +1,107 @@
|
|||||||
|
import os from "os"
|
||||||
|
import * as path from "path"
|
||||||
|
|
||||||
|
|
||||||
|
/*
|
||||||
|
The Node.js 'path' module resolves and normalizes paths differently depending on the platform:
|
||||||
|
- On Windows, it uses backslashes (\) as the default path separator.
|
||||||
|
- On POSIX-compliant systems (Linux, macOS), it uses forward slashes (/) as the default path separator.
|
||||||
|
|
||||||
|
While modules like 'upath' can be used to normalize paths to use forward slashes consistently,
|
||||||
|
this can create inconsistencies when interfacing with other modules (like vscode.fs) that use
|
||||||
|
backslashes on Windows.
|
||||||
|
|
||||||
|
Our approach:
|
||||||
|
1. We present paths with forward slashes to the AI and user for consistency.
|
||||||
|
2. We use the 'arePathsEqual' function for safe path comparisons.
|
||||||
|
3. Internally, Node.js gracefully handles both backslashes and forward slashes.
|
||||||
|
|
||||||
|
This strategy ensures consistent path presentation while leveraging Node.js's built-in
|
||||||
|
path handling capabilities across different platforms.
|
||||||
|
|
||||||
|
Note: When interacting with the file system or VS Code APIs, we still use the native path module
|
||||||
|
to ensure correct behavior on all platforms. The toPosixPath and arePathsEqual functions are
|
||||||
|
primarily used for presentation and comparison purposes, not for actual file system operations.
|
||||||
|
|
||||||
|
Observations:
|
||||||
|
- Macos isn't so flexible with mixed separators, whereas windows can handle both. ("Node.js does automatically handle path separators on Windows, converting forward slashes to backslashes as needed. However, on macOS and other Unix-like systems, the path separator is always a forward slash (/), and backslashes are treated as regular characters.")
|
||||||
|
*/
|
||||||
|
|
||||||
|
function toPosixPath(p: string) {
|
||||||
|
// Extended-Length Paths in Windows start with "\\?\" to allow longer paths and bypass usual parsing. If detected, we return the path unmodified to maintain functionality, as altering these paths could break their special syntax.
|
||||||
|
const isExtendedLengthPath = p.startsWith("\\\\?\\")
|
||||||
|
|
||||||
|
if (isExtendedLengthPath) {
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
return p.replace(/\\/g, "/")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Declaration merging allows us to add a new method to the String type
|
||||||
|
// You must import this file in your entry point (extension.ts) to have access at runtime
|
||||||
|
declare global {
|
||||||
|
interface String {
|
||||||
|
toPosix(): string
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
String.prototype.toPosix = function (this: string): string {
|
||||||
|
return toPosixPath(this)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Safe path comparison that works across different platforms
|
||||||
|
export function arePathsEqual(path1?: string, path2?: string): boolean {
|
||||||
|
if (!path1 && !path2) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if (!path1 || !path2) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
path1 = normalizePath(path1)
|
||||||
|
path2 = normalizePath(path2)
|
||||||
|
|
||||||
|
if (process.platform === "win32") {
|
||||||
|
return path1.toLowerCase() === path2.toLowerCase()
|
||||||
|
}
|
||||||
|
return path1 === path2
|
||||||
|
}
|
||||||
|
|
||||||
|
function normalizePath(p: string): string {
|
||||||
|
// normalize resolve ./.. segments, removes duplicate slashes, and standardizes path separators
|
||||||
|
let normalized = path.normalize(p)
|
||||||
|
// however it doesn't remove trailing slashes
|
||||||
|
// remove trailing slash, except for root paths
|
||||||
|
if (normalized.length > 1 && (normalized.endsWith("/") || normalized.endsWith("\\"))) {
|
||||||
|
normalized = normalized.slice(0, -1)
|
||||||
|
}
|
||||||
|
return normalized
|
||||||
|
}
|
||||||
|
|
||||||
|
export function getReadablePath(cwd: string, relPath?: string): string {
|
||||||
|
relPath = relPath || ""
|
||||||
|
// path.resolve is flexible in that it will resolve relative paths like '../../' to the cwd and even ignore the cwd if the relPath is actually an absolute path
|
||||||
|
const absolutePath = path.resolve(cwd, relPath)
|
||||||
|
if (arePathsEqual(cwd, path.join(os.homedir(), "Desktop"))) {
|
||||||
|
// User opened vscode without a workspace, so cwd is the Desktop. Show the full absolute path to keep the user aware of where files are being created
|
||||||
|
return absolutePath.toPosix()
|
||||||
|
}
|
||||||
|
if (arePathsEqual(path.normalize(absolutePath), path.normalize(cwd))) {
|
||||||
|
return path.basename(absolutePath).toPosix()
|
||||||
|
} else {
|
||||||
|
// show the relative path to the cwd
|
||||||
|
const normalizedRelPath = path.relative(cwd, absolutePath)
|
||||||
|
if (absolutePath.includes(cwd)) {
|
||||||
|
return normalizedRelPath.toPosix()
|
||||||
|
} else {
|
||||||
|
// we are outside the cwd, so show the absolute path (useful for when cline passes in '../../' for example)
|
||||||
|
return absolutePath.toPosix()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export const toRelativePath = (filePath: string, cwd: string) => {
|
||||||
|
const relativePath = path.relative(cwd, filePath).toPosix()
|
||||||
|
return filePath.endsWith("/") ? relativePath + "/" : relativePath
|
||||||
|
}
|
||||||
@ -1,28 +1,107 @@
|
|||||||
import { App, TFile, htmlToMarkdown, requestUrl } from 'obsidian'
|
import { App, MarkdownView, TAbstractFile, TFile, TFolder, Vault, htmlToMarkdown, requestUrl } from 'obsidian'
|
||||||
|
|
||||||
import { editorStateToPlainText } from '../components/chat-view/chat-input/utils/editor-state-to-plain-text'
|
import { editorStateToPlainText } from '../components/chat-view/chat-input/utils/editor-state-to-plain-text'
|
||||||
import { QueryProgressState } from '../components/chat-view/QueryProgress'
|
import { QueryProgressState } from '../components/chat-view/QueryProgress'
|
||||||
|
import { SYSTEM_PROMPT } from '../core/prompts/system'
|
||||||
import { RAGEngine } from '../core/rag/rag-engine'
|
import { RAGEngine } from '../core/rag/rag-engine'
|
||||||
import { SelectVector } from '../database/schema'
|
import { SelectVector } from '../database/schema'
|
||||||
import { ChatMessage, ChatUserMessage } from '../types/chat'
|
import { ChatMessage, ChatUserMessage } from '../types/chat'
|
||||||
import { ContentPart, RequestMessage } from '../types/llm/request'
|
import { ContentPart, RequestMessage } from '../types/llm/request'
|
||||||
import {
|
import {
|
||||||
MentionableBlock, MentionableCurrentFile, MentionableFile,
|
MentionableBlock,
|
||||||
|
MentionableFile,
|
||||||
MentionableFolder,
|
MentionableFolder,
|
||||||
MentionableImage,
|
MentionableImage,
|
||||||
MentionableUrl,
|
MentionableUrl,
|
||||||
MentionableVault
|
MentionableVault
|
||||||
} from '../types/mentionable'
|
} from '../types/mentionable'
|
||||||
import { InfioSettings } from '../types/settings'
|
import { InfioSettings } from '../types/settings'
|
||||||
|
import { defaultModeSlug, getFullModeDetails } from "../utils/modes"
|
||||||
|
|
||||||
|
import { listFilesAndFolders } from './glob-utils'
|
||||||
import {
|
import {
|
||||||
getNestedFiles,
|
readTFileContent
|
||||||
readMultipleTFiles,
|
|
||||||
readTFileContent,
|
|
||||||
} from './obsidian'
|
} from './obsidian'
|
||||||
import { tokenCount } from './token'
|
import { tokenCount } from './token'
|
||||||
import { YoutubeTranscript, isYoutubeUrl } from './youtube-transcript'
|
import { YoutubeTranscript, isYoutubeUrl } from './youtube-transcript'
|
||||||
|
|
||||||
|
export function addLineNumbers(content: string, startLine: number = 1): string {
|
||||||
|
const lines = content.split("\n")
|
||||||
|
const maxLineNumberWidth = String(startLine + lines.length - 1).length
|
||||||
|
return lines
|
||||||
|
.map((line, index) => {
|
||||||
|
const lineNumber = String(startLine + index).padStart(maxLineNumberWidth, " ")
|
||||||
|
return `${lineNumber} | ${line}`
|
||||||
|
})
|
||||||
|
.join("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
async function getFolderTreeContent(path: TFolder): Promise<string> {
|
||||||
|
try {
|
||||||
|
const entries = path.children
|
||||||
|
let folderContent = ""
|
||||||
|
entries.forEach((entry, index) => {
|
||||||
|
const isLast = index === entries.length - 1
|
||||||
|
const linePrefix = isLast ? "└── " : "├── "
|
||||||
|
if (entry instanceof TFile) {
|
||||||
|
folderContent += `${linePrefix}${entry.name}\n`
|
||||||
|
} else if (entry instanceof TFolder) {
|
||||||
|
folderContent += `${linePrefix}${entry.name}/\n`
|
||||||
|
} else {
|
||||||
|
folderContent += `${linePrefix}${entry.name}\n`
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return folderContent
|
||||||
|
} catch (error) {
|
||||||
|
throw new Error(`Failed to access path "${path.path}": ${error.message}`)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async function getFileOrFolderContent(path: TAbstractFile, vault: Vault): Promise<string> {
|
||||||
|
try {
|
||||||
|
if (path instanceof TFile) {
|
||||||
|
if (path.extension != 'md') {
|
||||||
|
return "(Binary file, unable to display content)"
|
||||||
|
}
|
||||||
|
return addLineNumbers(await readTFileContent(path, vault))
|
||||||
|
} else if (path instanceof TFolder) {
|
||||||
|
const entries = path.children
|
||||||
|
let folderContent = ""
|
||||||
|
const fileContentPromises: Promise<string | undefined>[] = []
|
||||||
|
entries.forEach((entry, index) => {
|
||||||
|
const isLast = index === entries.length - 1
|
||||||
|
const linePrefix = isLast ? "└── " : "├── "
|
||||||
|
if (entry instanceof TFile) {
|
||||||
|
folderContent += `${linePrefix}${entry.name}\n`
|
||||||
|
fileContentPromises.push(
|
||||||
|
(async () => {
|
||||||
|
try {
|
||||||
|
if (entry.extension != 'md') {
|
||||||
|
return undefined
|
||||||
|
}
|
||||||
|
const content = addLineNumbers(await readTFileContent(entry, vault))
|
||||||
|
return `<file_content path="${entry.path}">\n${content}\n</file_content>`
|
||||||
|
} catch (error) {
|
||||||
|
return undefined
|
||||||
|
}
|
||||||
|
})(),
|
||||||
|
)
|
||||||
|
} else if (entry instanceof TFolder) {
|
||||||
|
folderContent += `${linePrefix}${entry.name}/\n`
|
||||||
|
} else {
|
||||||
|
folderContent += `${linePrefix}${entry.name}\n`
|
||||||
|
}
|
||||||
|
})
|
||||||
|
const fileContents = (await Promise.all(fileContentPromises)).filter((content) => content)
|
||||||
|
return `${folderContent}\n${fileContents.join("\n\n")}`.trim()
|
||||||
|
} else {
|
||||||
|
return `(Failed to read contents of ${path.path})`
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
throw new Error(`Failed to access path "${path.path}": ${error.message}`)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
export class PromptGenerator {
|
export class PromptGenerator {
|
||||||
private getRagEngine: () => Promise<RAGEngine>
|
private getRagEngine: () => Promise<RAGEngine>
|
||||||
private app: App
|
private app: App
|
||||||
@ -47,12 +126,10 @@ export class PromptGenerator {
|
|||||||
messages,
|
messages,
|
||||||
useVaultSearch,
|
useVaultSearch,
|
||||||
onQueryProgressChange,
|
onQueryProgressChange,
|
||||||
type,
|
|
||||||
}: {
|
}: {
|
||||||
messages: ChatMessage[]
|
messages: ChatMessage[]
|
||||||
useVaultSearch?: boolean
|
useVaultSearch?: boolean
|
||||||
onQueryProgressChange?: (queryProgress: QueryProgressState) => void
|
onQueryProgressChange?: (queryProgress: QueryProgressState) => void
|
||||||
type?: string
|
|
||||||
}): Promise<{
|
}): Promise<{
|
||||||
requestMessages: RequestMessage[]
|
requestMessages: RequestMessage[]
|
||||||
compiledMessages: ChatMessage[]
|
compiledMessages: ChatMessage[]
|
||||||
@ -64,14 +141,16 @@ export class PromptGenerator {
|
|||||||
if (lastUserMessage.role !== 'user') {
|
if (lastUserMessage.role !== 'user') {
|
||||||
throw new Error('Last message is not a user message')
|
throw new Error('Last message is not a user message')
|
||||||
}
|
}
|
||||||
|
const isNewChat = messages.filter(message => message.role === 'user').length === 1
|
||||||
|
|
||||||
const { promptContent, shouldUseRAG, similaritySearchResults } =
|
const { promptContent, similaritySearchResults } =
|
||||||
await this.compileUserMessagePrompt({
|
await this.compileUserMessagePrompt({
|
||||||
|
isNewChat,
|
||||||
message: lastUserMessage,
|
message: lastUserMessage,
|
||||||
useVaultSearch,
|
useVaultSearch,
|
||||||
onQueryProgressChange,
|
onQueryProgressChange,
|
||||||
})
|
})
|
||||||
let compiledMessages = [
|
const compiledMessages = [
|
||||||
...messages.slice(0, -1),
|
...messages.slice(0, -1),
|
||||||
{
|
{
|
||||||
...lastUserMessage,
|
...lastUserMessage,
|
||||||
@ -80,39 +159,10 @@ export class PromptGenerator {
|
|||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
// Safeguard: ensure all user messages have parsed content
|
const systemMessage = await this.getSystemMessageNew()
|
||||||
compiledMessages = await Promise.all(
|
|
||||||
compiledMessages.map(async (message) => {
|
|
||||||
if (message.role === 'user' && !message.promptContent) {
|
|
||||||
const { promptContent, similaritySearchResults } =
|
|
||||||
await this.compileUserMessagePrompt({
|
|
||||||
message,
|
|
||||||
})
|
|
||||||
return {
|
|
||||||
...message,
|
|
||||||
promptContent,
|
|
||||||
similaritySearchResults,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return message
|
|
||||||
}),
|
|
||||||
)
|
|
||||||
|
|
||||||
const systemMessage = this.getSystemMessage(shouldUseRAG, type)
|
|
||||||
|
|
||||||
const customInstructionMessage = this.getCustomInstructionMessage()
|
|
||||||
|
|
||||||
const currentFile = lastUserMessage.mentionables.find(
|
|
||||||
(m): m is MentionableCurrentFile => m.type === 'current-file',
|
|
||||||
)?.file
|
|
||||||
const currentFileMessage = currentFile
|
|
||||||
? await this.getCurrentFileMessage(currentFile)
|
|
||||||
: undefined
|
|
||||||
|
|
||||||
const requestMessages: RequestMessage[] = [
|
const requestMessages: RequestMessage[] = [
|
||||||
systemMessage,
|
systemMessage,
|
||||||
...(customInstructionMessage ? [customInstructionMessage, PromptGenerator.EMPTY_ASSISTANT_MESSAGE] : []),
|
|
||||||
...(currentFileMessage ? [currentFileMessage, PromptGenerator.EMPTY_ASSISTANT_MESSAGE] : []),
|
|
||||||
...compiledMessages.slice(-19).map((message): RequestMessage => {
|
...compiledMessages.slice(-19).map((message): RequestMessage => {
|
||||||
if (message.role === 'user') {
|
if (message.role === 'user') {
|
||||||
return {
|
return {
|
||||||
@ -126,7 +176,6 @@ export class PromptGenerator {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}),
|
}),
|
||||||
...(shouldUseRAG ? [this.getRagInstructionMessage()] : []),
|
|
||||||
]
|
]
|
||||||
|
|
||||||
return {
|
return {
|
||||||
@ -135,27 +184,101 @@ export class PromptGenerator {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private async getEnvironmentDetails() {
|
||||||
|
let details = ""
|
||||||
|
// Obsidian Current File
|
||||||
|
details += "\n\n# Obsidian Current File"
|
||||||
|
const currentFile = this.app.workspace.getActiveFile()
|
||||||
|
if (currentFile) {
|
||||||
|
details += `\n${currentFile?.path}`
|
||||||
|
} else {
|
||||||
|
details += "\n(No current file)"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Obsidian Open Tabs
|
||||||
|
details += "\n\n# Obsidian Open Tabs"
|
||||||
|
const openTabs: string[] = [];
|
||||||
|
this.app.workspace.iterateAllLeaves(leaf => {
|
||||||
|
if (leaf.view instanceof MarkdownView && leaf.view.file) {
|
||||||
|
openTabs.push(leaf.view.file?.path);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
if (openTabs.length === 0) {
|
||||||
|
details += "\n(No open tabs)"
|
||||||
|
} else {
|
||||||
|
details += `\n${openTabs.join("\n")}`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add current time information with timezone
|
||||||
|
const now = new Date()
|
||||||
|
const formatter = new Intl.DateTimeFormat(undefined, {
|
||||||
|
year: "numeric",
|
||||||
|
month: "numeric",
|
||||||
|
day: "numeric",
|
||||||
|
hour: "numeric",
|
||||||
|
minute: "numeric",
|
||||||
|
second: "numeric",
|
||||||
|
hour12: true,
|
||||||
|
})
|
||||||
|
const timeZone = formatter.resolvedOptions().timeZone
|
||||||
|
const timeZoneOffset = -now.getTimezoneOffset() / 60 // Convert to hours and invert sign to match conventional notation
|
||||||
|
const timeZoneOffsetStr = `${timeZoneOffset >= 0 ? "+" : ""}${timeZoneOffset}:00`
|
||||||
|
details += `\n\n# Current Time\n${formatter.format(now)} (${timeZone}, UTC${timeZoneOffsetStr})`
|
||||||
|
|
||||||
|
// Add current mode details
|
||||||
|
const currentMode = defaultModeSlug
|
||||||
|
const modeDetails = await getFullModeDetails(currentMode)
|
||||||
|
details += `\n\n# Current Mode\n`
|
||||||
|
details += `<slug>${currentMode}</slug>\n`
|
||||||
|
details += `<name>${modeDetails.name}</name>\n`
|
||||||
|
|
||||||
|
// // Obsidian Current Folder
|
||||||
|
// const currentFolder = this.app.workspace.getActiveFile() ? this.app.workspace.getActiveFile()?.parent?.path : "/"
|
||||||
|
// // Obsidian Vault Files and Folders
|
||||||
|
// if (currentFolder) {
|
||||||
|
// details += `\n\n# Obsidian Current Folder (${currentFolder}) Files`
|
||||||
|
// const filesAndFolders = await listFilesAndFolders(this.app.vault, currentFolder)
|
||||||
|
// if (filesAndFolders.length > 0) {
|
||||||
|
// details += `\n${filesAndFolders.filter(Boolean).join("\n")}`
|
||||||
|
// } else {
|
||||||
|
// details += "\n(No Markdown files in current folder)"
|
||||||
|
// }
|
||||||
|
// } else {
|
||||||
|
// details += "\n(No current folder)"
|
||||||
|
// }
|
||||||
|
|
||||||
|
return `<environment_details>\n${details.trim()}\n</environment_details>`
|
||||||
|
}
|
||||||
|
|
||||||
private async compileUserMessagePrompt({
|
private async compileUserMessagePrompt({
|
||||||
|
isNewChat,
|
||||||
message,
|
message,
|
||||||
useVaultSearch,
|
useVaultSearch,
|
||||||
onQueryProgressChange,
|
onQueryProgressChange,
|
||||||
}: {
|
}: {
|
||||||
|
isNewChat: boolean
|
||||||
message: ChatUserMessage
|
message: ChatUserMessage
|
||||||
useVaultSearch?: boolean
|
useVaultSearch?: boolean
|
||||||
onQueryProgressChange?: (queryProgress: QueryProgressState) => void
|
onQueryProgressChange?: (queryProgress: QueryProgressState) => void
|
||||||
}): Promise<{
|
}): Promise<{
|
||||||
promptContent: ChatUserMessage['promptContent']
|
promptContent: ChatUserMessage['promptContent']
|
||||||
shouldUseRAG: boolean
|
|
||||||
similaritySearchResults?: (Omit<SelectVector, 'embedding'> & {
|
similaritySearchResults?: (Omit<SelectVector, 'embedding'> & {
|
||||||
similarity: number
|
similarity: number
|
||||||
})[]
|
})[]
|
||||||
}> {
|
}> {
|
||||||
if (!message.content) {
|
// Add environment details
|
||||||
|
const environmentDetails = isNewChat
|
||||||
|
? await this.getEnvironmentDetails()
|
||||||
|
: undefined
|
||||||
|
|
||||||
|
// if isToolCallReturn, add read_file_content to promptContent
|
||||||
|
if (message.content === null) {
|
||||||
return {
|
return {
|
||||||
promptContent: '',
|
promptContent: message.promptContent,
|
||||||
shouldUseRAG: false,
|
similaritySearchResults: undefined,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const query = editorStateToPlainText(message.content)
|
const query = editorStateToPlainText(message.content)
|
||||||
let similaritySearchResults = undefined
|
let similaritySearchResults = undefined
|
||||||
|
|
||||||
@ -169,33 +292,94 @@ export class PromptGenerator {
|
|||||||
onQueryProgressChange?.({
|
onQueryProgressChange?.({
|
||||||
type: 'reading-mentionables',
|
type: 'reading-mentionables',
|
||||||
})
|
})
|
||||||
|
|
||||||
|
const taskPrompt = isNewChat ? `<task>\n${query}\n</task>` : `<feedback>\n${query}\n</feedback>`
|
||||||
|
|
||||||
|
// user mention files
|
||||||
const files = message.mentionables
|
const files = message.mentionables
|
||||||
.filter((m): m is MentionableFile => m.type === 'file')
|
.filter((m): m is MentionableFile => m.type === 'file')
|
||||||
.map((m) => m.file)
|
.map((m) => m.file)
|
||||||
|
let fileContentsPrompts = files.length > 0
|
||||||
|
? (await Promise.all(files.map(async (file) => {
|
||||||
|
const content = await getFileOrFolderContent(file, this.app.vault)
|
||||||
|
return `<file_content path="${file.path}">\n${content}\n</file_content>`
|
||||||
|
}))).join('\n')
|
||||||
|
: undefined
|
||||||
|
|
||||||
|
// user mention folders
|
||||||
const folders = message.mentionables
|
const folders = message.mentionables
|
||||||
.filter((m): m is MentionableFolder => m.type === 'folder')
|
.filter((m): m is MentionableFolder => m.type === 'folder')
|
||||||
.map((m) => m.folder)
|
.map((m) => m.folder)
|
||||||
const nestedFiles = folders.flatMap((folder) =>
|
let folderContentsPrompts = folders.length > 0
|
||||||
getNestedFiles(folder, this.app.vault),
|
? (await Promise.all(folders.map(async (folder) => {
|
||||||
|
const content = await getFileOrFolderContent(folder, this.app.vault)
|
||||||
|
return `<folder_content path="${folder.path}">\n${content}\n</folder_content>`
|
||||||
|
}))).join('\n')
|
||||||
|
: undefined
|
||||||
|
|
||||||
|
// user mention blocks
|
||||||
|
const blocks = message.mentionables.filter(
|
||||||
|
(m): m is MentionableBlock => m.type === 'block',
|
||||||
)
|
)
|
||||||
const allFiles = [...files, ...nestedFiles]
|
const blockContentsPrompt = blocks.length > 0
|
||||||
const fileContents = await readMultipleTFiles(allFiles, this.app.vault)
|
? blocks
|
||||||
|
.map(({ file, content, startLine, endLine }) => {
|
||||||
|
const content_with_line_numbers = addLineNumbers(content, startLine)
|
||||||
|
return `<file_block_content location="${file.path}#L${startLine}-${endLine}">\n${content_with_line_numbers}\n</file_block_content>`
|
||||||
|
})
|
||||||
|
.join('\n')
|
||||||
|
: undefined
|
||||||
|
|
||||||
// Count tokens incrementally to avoid long processing times on large content sets
|
// user mention urls
|
||||||
const exceedsTokenThreshold = async () => {
|
const urls = message.mentionables.filter(
|
||||||
let accTokenCount = 0
|
(m): m is MentionableUrl => m.type === 'url',
|
||||||
for (const content of fileContents) {
|
)
|
||||||
const count = await tokenCount(content)
|
const urlContents = await Promise.all(
|
||||||
accTokenCount += count
|
urls.map(async ({ url }) => ({
|
||||||
if (accTokenCount > this.settings.ragOptions.thresholdTokens) {
|
url,
|
||||||
return true
|
content: await this.getWebsiteContent(url)
|
||||||
}
|
}))
|
||||||
|
)
|
||||||
|
const urlContentsPrompt = urlContents.length > 0
|
||||||
|
? urlContents
|
||||||
|
.map(({ url, content }) => (
|
||||||
|
`<url_content url="${url}">\n${content}\n</url_content>`
|
||||||
|
))
|
||||||
|
.join('\n') : undefined
|
||||||
|
|
||||||
|
const currentFile = message.mentionables
|
||||||
|
.filter((m): m is MentionableFile => m.type === 'current-file')
|
||||||
|
.first()
|
||||||
|
const currentFileContent = currentFile && currentFile.file != null
|
||||||
|
? await getFileOrFolderContent(currentFile.file, this.app.vault)
|
||||||
|
: undefined
|
||||||
|
|
||||||
|
const currentFileContentPrompt = isNewChat && currentFileContent
|
||||||
|
? `<current_file_content path="${currentFile.file.path}">\n${currentFileContent}\n</current_file_content>`
|
||||||
|
: undefined
|
||||||
|
|
||||||
|
// Count file and folder tokens
|
||||||
|
let accTokenCount = 0
|
||||||
|
let isOverThreshold = false
|
||||||
|
for (const content of [fileContentsPrompts, folderContentsPrompts].filter(Boolean)) {
|
||||||
|
const count = await tokenCount(content)
|
||||||
|
accTokenCount += count
|
||||||
|
if (accTokenCount > this.settings.ragOptions.thresholdTokens) {
|
||||||
|
isOverThreshold = true
|
||||||
}
|
}
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
const shouldUseRAG = useVaultSearch || (await exceedsTokenThreshold())
|
if (isOverThreshold) {
|
||||||
|
fileContentsPrompts = files.map((file) => {
|
||||||
|
return `<file_content path="${file.path}">\n(Content omitted due to token limit. Relevant sections will be provided by semantic search below.)\n</file_content>`
|
||||||
|
}).join('\n')
|
||||||
|
folderContentsPrompts = folders.map(async (folder) => {
|
||||||
|
const tree_content = await getFolderTreeContent(folder)
|
||||||
|
return `<folder_content path="${folder.path}">\n${tree_content}\n(Content omitted due to token limit. Relevant sections will be provided by semantic search below.)\n</folder_content>`
|
||||||
|
}).join('\n')
|
||||||
|
}
|
||||||
|
|
||||||
let filePrompt: string
|
const shouldUseRAG = useVaultSearch || isOverThreshold
|
||||||
|
let similaritySearchContents
|
||||||
if (shouldUseRAG) {
|
if (shouldUseRAG) {
|
||||||
similaritySearchResults = useVaultSearch
|
similaritySearchResults = useVaultSearch
|
||||||
? await (
|
? await (
|
||||||
@ -203,7 +387,7 @@ export class PromptGenerator {
|
|||||||
).processQuery({
|
).processQuery({
|
||||||
query,
|
query,
|
||||||
onQueryProgressChange: onQueryProgressChange,
|
onQueryProgressChange: onQueryProgressChange,
|
||||||
}) // TODO: Add similarity boosting for mentioned files or folders
|
})
|
||||||
: await (
|
: await (
|
||||||
await this.getRagEngine()
|
await this.getRagEngine()
|
||||||
).processQuery({
|
).processQuery({
|
||||||
@ -214,60 +398,42 @@ export class PromptGenerator {
|
|||||||
},
|
},
|
||||||
onQueryProgressChange: onQueryProgressChange,
|
onQueryProgressChange: onQueryProgressChange,
|
||||||
})
|
})
|
||||||
filePrompt = `## Potentially relevant snippets from the current vault
|
const snippets = similaritySearchResults.map(({ path, content, metadata }) => {
|
||||||
${similaritySearchResults
|
const contentWithLineNumbers = this.addLineNumbersToContent({
|
||||||
.map(({ path, content, metadata }) => {
|
content,
|
||||||
const contentWithLineNumbers = this.addLineNumbersToContent({
|
startLine: metadata.startLine,
|
||||||
content,
|
|
||||||
startLine: metadata.startLine,
|
|
||||||
})
|
|
||||||
return `\`\`\`${path}\n${contentWithLineNumbers}\n\`\`\`\n`
|
|
||||||
})
|
|
||||||
.join('')}\n`
|
|
||||||
} else {
|
|
||||||
filePrompt = allFiles
|
|
||||||
.map((file, index) => {
|
|
||||||
return `\`\`\`${file.path}\n${fileContents[index]}\n\`\`\`\n`
|
|
||||||
})
|
})
|
||||||
.join('')
|
return `<file_block_content location="${path}#L${metadata.startLine}-${metadata.endLine}">\n${contentWithLineNumbers}\n</file_block_content>`
|
||||||
|
}).join('\n')
|
||||||
|
similaritySearchContents = snippets.length > 0
|
||||||
|
? `<similarity_search_results>\n${snippets}\n</similarity_search_results>`
|
||||||
|
: '<similarity_search_results>\n(No relevant results found)\n</similarity_search_results>'
|
||||||
|
} else {
|
||||||
|
similaritySearchContents = undefined
|
||||||
}
|
}
|
||||||
|
|
||||||
const blocks = message.mentionables.filter(
|
const parsedText = [
|
||||||
(m): m is MentionableBlock => m.type === 'block',
|
taskPrompt,
|
||||||
)
|
blockContentsPrompt,
|
||||||
const blockPrompt = blocks
|
fileContentsPrompts,
|
||||||
.map(({ file, content, startLine, endLine }) => {
|
folderContentsPrompts,
|
||||||
return `\`\`\`${file.path}#L${startLine}-${endLine}\n${content}\n\`\`\`\n`
|
urlContentsPrompt,
|
||||||
})
|
similaritySearchContents,
|
||||||
.join('')
|
currentFileContentPrompt,
|
||||||
|
environmentDetails,
|
||||||
const urls = message.mentionables.filter(
|
].filter(Boolean).join('\n\n')
|
||||||
(m): m is MentionableUrl => m.type === 'url',
|
|
||||||
)
|
|
||||||
|
|
||||||
const urlPrompt =
|
|
||||||
urls.length > 0
|
|
||||||
? `## Potentially relevant web search results
|
|
||||||
${(
|
|
||||||
await Promise.all(
|
|
||||||
urls.map(
|
|
||||||
async ({ url }) => `\`\`\`
|
|
||||||
Website URL: ${url}
|
|
||||||
Website Content:
|
|
||||||
${await this.getWebsiteContent(url)}
|
|
||||||
\`\`\``,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
).join('\n')}
|
|
||||||
`
|
|
||||||
: ''
|
|
||||||
|
|
||||||
|
// user mention images
|
||||||
const imageDataUrls = message.mentionables
|
const imageDataUrls = message.mentionables
|
||||||
.filter((m): m is MentionableImage => m.type === 'image')
|
.filter((m): m is MentionableImage => m.type === 'image')
|
||||||
.map(({ data }) => data)
|
.map(({ data }) => data)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
promptContent: [
|
promptContent: [
|
||||||
|
{
|
||||||
|
type: 'text',
|
||||||
|
text: parsedText,
|
||||||
|
},
|
||||||
...imageDataUrls.map(
|
...imageDataUrls.map(
|
||||||
(data): ContentPart => ({
|
(data): ContentPart => ({
|
||||||
type: 'image_url',
|
type: 'image_url',
|
||||||
@ -275,14 +441,18 @@ ${await this.getWebsiteContent(url)}
|
|||||||
url: data,
|
url: data,
|
||||||
},
|
},
|
||||||
}),
|
}),
|
||||||
),
|
)
|
||||||
{
|
|
||||||
type: 'text',
|
|
||||||
text: `${filePrompt}${blockPrompt}${urlPrompt}\n\n${query}\n\n`,
|
|
||||||
},
|
|
||||||
],
|
],
|
||||||
shouldUseRAG,
|
similaritySearchResults,
|
||||||
similaritySearchResults: similaritySearchResults,
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private async getSystemMessageNew(): Promise<RequestMessage> {
|
||||||
|
const systemPrompt = await SYSTEM_PROMPT(this.app.vault.getRoot().path, false)
|
||||||
|
|
||||||
|
return {
|
||||||
|
role: 'system',
|
||||||
|
content: systemPrompt,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -392,7 +562,7 @@ ${customInstruction}
|
|||||||
return {
|
return {
|
||||||
role: 'user',
|
role: 'user',
|
||||||
content: `# Inputs
|
content: `# Inputs
|
||||||
## Current file
|
## Current Open File
|
||||||
Here is the file I'm looking at.
|
Here is the file I'm looking at.
|
||||||
\`\`\`${currentFile.path}
|
\`\`\`${currentFile.path}
|
||||||
${fileContent}
|
${fileContent}
|
||||||
|
|||||||
77
src/utils/tool-groups.ts
Normal file
77
src/utils/tool-groups.ts
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
// Define tool group configuration
|
||||||
|
export type ToolGroupConfig = {
|
||||||
|
tools: readonly string[]
|
||||||
|
alwaysAvailable?: boolean // Whether this group is always available and shouldn't show in prompts view
|
||||||
|
}
|
||||||
|
|
||||||
|
// Map of tool slugs to their display names
|
||||||
|
export const TOOL_DISPLAY_NAMES = {
|
||||||
|
execute_command: "run commands",
|
||||||
|
read_file: "read files",
|
||||||
|
write_to_file: "write files",
|
||||||
|
apply_diff: "apply changes",
|
||||||
|
search_files: "search files",
|
||||||
|
list_files: "list files",
|
||||||
|
// list_code_definition_names: "list definitions",
|
||||||
|
browser_action: "use a browser",
|
||||||
|
use_mcp_tool: "use mcp tools",
|
||||||
|
access_mcp_resource: "access mcp resources",
|
||||||
|
ask_followup_question: "ask questions",
|
||||||
|
attempt_completion: "complete tasks",
|
||||||
|
switch_mode: "switch modes",
|
||||||
|
new_task: "create new task",
|
||||||
|
} as const
|
||||||
|
|
||||||
|
// Define available tool groups
|
||||||
|
export const TOOL_GROUPS: Record<string, ToolGroupConfig> = {
|
||||||
|
read: {
|
||||||
|
tools: ["read_file", "list_files", "search_files"],
|
||||||
|
},
|
||||||
|
edit: {
|
||||||
|
tools: ["apply_diff", "write_to_file", "insert_content", "search_and_replace"],
|
||||||
|
},
|
||||||
|
// browser: {
|
||||||
|
// tools: ["browser_action"],
|
||||||
|
// },
|
||||||
|
// command: {
|
||||||
|
// tools: ["execute_command"],
|
||||||
|
// },
|
||||||
|
mcp: {
|
||||||
|
tools: ["use_mcp_tool", "access_mcp_resource"],
|
||||||
|
},
|
||||||
|
modes: {
|
||||||
|
tools: ["switch_mode",],
|
||||||
|
alwaysAvailable: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
export type ToolGroup = keyof typeof TOOL_GROUPS
|
||||||
|
|
||||||
|
// Tools that are always available to all modes
|
||||||
|
export const ALWAYS_AVAILABLE_TOOLS = [
|
||||||
|
"ask_followup_question",
|
||||||
|
"attempt_completion",
|
||||||
|
"switch_mode",
|
||||||
|
"new_task",
|
||||||
|
] as const
|
||||||
|
|
||||||
|
// Tool name types for type safety
|
||||||
|
export type ToolName = keyof typeof TOOL_DISPLAY_NAMES
|
||||||
|
|
||||||
|
// Tool helper functions
|
||||||
|
export function getToolName(toolConfig: string | readonly [ToolName, ...any[]]): ToolName {
|
||||||
|
return typeof toolConfig === "string" ? (toolConfig as ToolName) : toolConfig[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
export function getToolOptions(toolConfig: string | readonly [ToolName, ...any[]]): any {
|
||||||
|
return typeof toolConfig === "string" ? undefined : toolConfig[1]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Display names for groups in UI
|
||||||
|
export const GROUP_DISPLAY_NAMES: Record<ToolGroup, string> = {
|
||||||
|
read: "Read Files",
|
||||||
|
edit: "Edit Files",
|
||||||
|
browser: "Use Browser",
|
||||||
|
command: "Run Commands",
|
||||||
|
mcp: "Use MCP",
|
||||||
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user