diff --git a/src/components/chat-view/chat-input/ModelSelect.tsx b/src/components/chat-view/chat-input/ModelSelect.tsx index 86e150f..9449b6b 100644 --- a/src/components/chat-view/chat-input/ModelSelect.tsx +++ b/src/components/chat-view/chat-input/ModelSelect.tsx @@ -1,33 +1,192 @@ import * as DropdownMenu from '@radix-ui/react-dropdown-menu' +import Fuse, { FuseResult } from 'fuse.js' import { ChevronDown, ChevronUp } from 'lucide-react' -import { useEffect, useState } from 'react' +import { useEffect, useMemo, useState } from 'react' import { useSettings } from '../../../contexts/SettingsContext' -import { GetProviderModelIds } from "../../../utils/api" +import { ApiProvider } from '../../../types/llm/model' +import { GetAllProviders, GetProviderModelIds } from "../../../utils/api" + +type TextSegment = { + text: string; + isHighlighted: boolean; +}; + +type SearchableItem = { + id: string; + html: string | TextSegment[]; +}; + +type HighlightedItem = { + id: string; + html: TextSegment[]; +}; + +// Reuse highlight function from ProviderModelsPicker +const highlight = (fuseSearchResult: Array>): HighlightedItem[] => { + const set = (obj: Record, path: string, value: TextSegment[]): void => { + const pathValue = path.split(".") + let i: number + let current = obj as Record + + for (i = 0; i < pathValue.length - 1; i++) { + const nextValue = current[pathValue[i]] + if (typeof nextValue === 'object' && nextValue !== null) { + current = nextValue as Record + } else { + throw new Error(`Invalid path: ${path}`) + } + } + + current[pathValue[i]] = value + } + + const mergeRegions = (regions: [number, number][]): [number, number][] => { + if (regions.length === 0) return regions + regions.sort((a, b) => a[0] - b[0]) + const merged: [number, number][] = [regions[0]] + for (let i = 1; i < regions.length; i++) { + const last = merged[merged.length - 1] + const current = regions[i] + if (current[0] <= last[1] + 1) { + last[1] = Math.max(last[1], current[1]) + } else { + merged.push(current) + } + } + return merged + } + + const generateHighlightedSegments = (inputText: string, regions: [number, number][] = []): TextSegment[] => { + if (regions.length === 0) { + return [{ text: inputText, isHighlighted: false }]; + } + + const mergedRegions = mergeRegions(regions); + const segments: TextSegment[] = []; + let nextUnhighlightedRegionStartingIndex = 0; + + mergedRegions.forEach((region) => { + const start = region[0]; + const end = region[1]; + const lastRegionNextIndex = end + 1; + + if (nextUnhighlightedRegionStartingIndex < start) { + segments.push({ + text: inputText.substring(nextUnhighlightedRegionStartingIndex, start), + isHighlighted: false, + }); + } + + segments.push({ + text: inputText.substring(start, lastRegionNextIndex), + isHighlighted: true, + }); + + nextUnhighlightedRegionStartingIndex = lastRegionNextIndex; + }); + + if (nextUnhighlightedRegionStartingIndex < inputText.length) { + segments.push({ + text: inputText.substring(nextUnhighlightedRegionStartingIndex), + isHighlighted: false, + }); + } + + return segments; + } + + return fuseSearchResult + .filter(({ matches }) => matches && matches.length) + .map(({ item, matches }): HighlightedItem => { + const highlightedItem: HighlightedItem = { + id: item.id, + html: typeof item.html === 'string' ? [{ text: item.html, isHighlighted: false }] : [...item.html] + } + + matches?.forEach((match) => { + if (match.key && typeof match.value === "string" && match.indices) { + const mergedIndices = mergeRegions([...match.indices]) + set(highlightedItem, match.key, generateHighlightedSegments(match.value, mergedIndices)) + } + }) + + return highlightedItem + }) +} + +const HighlightedText: React.FC<{ segments: TextSegment[] }> = ({ segments }) => { + return ( + <> + {segments.map((segment, index) => ( + segment.isHighlighted ? ( + {segment.text} + ) : ( + {segment.text} + ) + ))} + + ); +}; export function ModelSelect() { const { settings, setSettings } = useSettings() const [isOpen, setIsOpen] = useState(false) + const [modelProvider, setModelProvider] = useState(settings.chatModelProvider) const [chatModelId, setChatModelId] = useState(settings.chatModelId) - const [providerModels, setProviderModels] = useState([]) + const [modelIds, setModelIds] = useState([]) const [isLoading, setIsLoading] = useState(true) + const [searchTerm, setSearchTerm] = useState("") + const [selectedIndex, setSelectedIndex] = useState(0) + + const providers = GetAllProviders() useEffect(() => { const fetchModels = async () => { setIsLoading(true) try { - const models = await GetProviderModelIds(settings.chatModelProvider) - setProviderModels(models) + const models = await GetProviderModelIds(modelProvider) + setModelIds(models) setChatModelId(settings.chatModelId) } catch (error) { console.error('Failed to fetch provider models:', error) + setModelIds([]) } finally { setIsLoading(false) } } fetchModels() - }, [settings]) + }, [modelProvider, settings.chatModelId]) + + const searchableItems = useMemo(() => { + return modelIds.map((id) => ({ + id, + html: id, + })) + }, [modelIds]) + + const fuse = useMemo(() => { + return new Fuse(searchableItems, { + keys: ["html"], + threshold: 0.6, + shouldSort: true, + isCaseSensitive: false, + ignoreLocation: false, + includeMatches: true, + minMatchCharLength: 1, + }) + }, [searchableItems]) + + const filteredOptions = useMemo(() => { + const results: HighlightedItem[] = searchTerm + ? highlight(fuse.search(searchTerm)) + : searchableItems.map(item => ({ + ...item, + html: typeof item.html === 'string' ? [{ text: item.html, isHighlighted: false }] : item.html + })) + return results + }, [searchableItems, searchTerm, fuse]) return ( @@ -36,30 +195,126 @@ export function ModelSelect() { {isOpen ? : }
- {chatModelId} + [{modelProvider}] {chatModelId}
- + +
+ + {modelIds.length > 0 ? ( + { + setSearchTerm(e.target.value) + setSelectedIndex(0) + }} + onKeyDown={(e) => { + switch (e.key) { + case "ArrowDown": + e.preventDefault() + setSelectedIndex((prev) => + Math.min(prev + 1, filteredOptions.length - 1) + ) + break + case "ArrowUp": + e.preventDefault() + setSelectedIndex((prev) => Math.max(prev - 1, 0)) + break + case "Enter": { + e.preventDefault() + const selectedOption = filteredOptions[selectedIndex] + if (selectedOption) { + setSettings({ + ...settings, + chatModelProvider: modelProvider, + chatModelId: selectedOption.id, + }) + setChatModelId(selectedOption.id) + setSearchTerm("") + setIsOpen(false) + } + break + } + case "Escape": + e.preventDefault() + setIsOpen(false) + setSearchTerm("") + break + } + }} + /> + ) : ( + { + setSearchTerm(e.target.value) + }} + onKeyDown={(e) => { + if (e.key === "Enter") { + e.preventDefault() + setSettings({ + ...settings, + chatModelProvider: modelProvider, + chatModelId: searchTerm, + }) + setChatModelId(searchTerm) + setIsOpen(false) + } + }} + /> + )} +
    {isLoading ? (
  • Loading...
  • ) : ( - providerModels.map((modelId) => ( + filteredOptions.map((option, index) => ( { - setChatModelId(modelId) setSettings({ ...settings, - chatModelId: modelId, + chatModelProvider: modelProvider, + chatModelId: option.id, }) + setChatModelId(option.id) + setSearchTerm("") + setIsOpen(false) }} + className={`infio-llm-setting-combobox-option ${index === selectedIndex ? 'is-selected' : ''}`} + onMouseEnter={() => setSelectedIndex(index)} asChild > -
  • {modelId}
  • +
  • + +
  • )) )} diff --git a/src/styles.css b/src/styles.css index 645ab3d..513e4e2 100644 --- a/src/styles.css +++ b/src/styles.css @@ -478,6 +478,10 @@ button:not(.clickable-icon).infio-chat-list-dropdown { max-width: 240px; } +.infio-popover.infio-llm-setting-combobox-dropdown { + max-width: 340px; +} + .infio-popover ul { padding: 0; list-style: none;