2025-01-05 21:14:35 +08:00

301 lines
7.5 KiB
TypeScript

import {
Content,
EnhancedGenerateContentResponse,
GenerateContentResult,
GenerateContentStreamResult,
GoogleGenerativeAI,
Part,
} from '@google/generative-ai'
import { CustomLLMModel } from '../../types/llm/model'
import {
LLMOptions,
LLMRequestNonStreaming,
LLMRequestStreaming,
RequestMessage,
} from '../../types/llm/request'
import {
LLMResponseNonStreaming,
LLMResponseStreaming,
} from '../../types/llm/response'
import { parseImageDataUrl } from '../../utils/image'
import { BaseLLMProvider } from './base'
import {
LLMAPIKeyInvalidException,
LLMAPIKeyNotSetException,
} from './exception'
/**
* Note on OpenAI Compatibility API:
* Gemini provides an OpenAI-compatible endpoint (https://ai.google.dev/gemini-api/docs/openai)
* which allows using the OpenAI SDK with Gemini models. However, there are currently CORS issues
* preventing its use in Obsidian. Consider switching to this endpoint in the future once these
* issues are resolved.
*/
export class GeminiProvider implements BaseLLMProvider {
private client: GoogleGenerativeAI
private apiKey: string
constructor(apiKey: string) {
this.apiKey = apiKey
this.client = new GoogleGenerativeAI(apiKey)
}
async generateResponse(
model: CustomLLMModel,
request: LLMRequestNonStreaming,
options?: LLMOptions,
): Promise<LLMResponseNonStreaming> {
if (!this.apiKey) {
if (!model.apiKey) {
throw new LLMAPIKeyNotSetException(
`Gemini API key is missing. Please set it in settings menu.`,
)
}
this.apiKey = model.apiKey
this.client = new GoogleGenerativeAI(model.apiKey)
}
const systemMessages = request.messages.filter((m) => m.role === 'system')
const systemInstruction: string | undefined =
systemMessages.length > 0
? systemMessages.map((m) => m.content).join('\n')
: undefined
try {
const model = this.client.getGenerativeModel({
model: request.model,
generationConfig: {
maxOutputTokens: request.max_tokens,
temperature: request.temperature,
topP: request.top_p,
presencePenalty: request.presence_penalty,
frequencyPenalty: request.frequency_penalty,
},
systemInstruction: systemInstruction,
})
const result = await model.generateContent(
{
systemInstruction: systemInstruction,
contents: request.messages
.map((message) => GeminiProvider.parseRequestMessage(message))
.filter((m): m is Content => m !== null),
},
{
signal: options?.signal,
},
)
const messageId = crypto.randomUUID() // Gemini does not return a message id
return GeminiProvider.parseNonStreamingResponse(
result,
request.model,
messageId,
)
} catch (error) {
const isInvalidApiKey =
error.message?.includes('API_KEY_INVALID') ||
error.message?.includes('API key not valid')
if (isInvalidApiKey) {
throw new LLMAPIKeyInvalidException(
`Gemini API key is invalid. Please update it in settings menu.`,
)
}
throw error
}
}
async streamResponse(
model: CustomLLMModel,
request: LLMRequestStreaming,
options?: LLMOptions,
): Promise<AsyncIterable<LLMResponseStreaming>> {
if (!this.apiKey) {
if (!model.apiKey) {
throw new LLMAPIKeyNotSetException(
`Gemini API key is missing. Please set it in settings menu.`,
)
}
this.apiKey = model.apiKey
this.client = new GoogleGenerativeAI(model.apiKey)
}
const systemMessages = request.messages.filter((m) => m.role === 'system')
const systemInstruction: string | undefined =
systemMessages.length > 0
? systemMessages.map((m) => m.content).join('\n')
: undefined
try {
const model = this.client.getGenerativeModel({
model: request.model,
generationConfig: {
maxOutputTokens: request.max_tokens,
temperature: request.temperature,
topP: request.top_p,
presencePenalty: request.presence_penalty,
frequencyPenalty: request.frequency_penalty,
},
systemInstruction: systemInstruction,
})
const stream = await model.generateContentStream(
{
systemInstruction: systemInstruction,
contents: request.messages
.map((message) => GeminiProvider.parseRequestMessage(message))
.filter((m): m is Content => m !== null),
},
{
signal: options?.signal,
},
)
const messageId = crypto.randomUUID() // Gemini does not return a message id
return this.streamResponseGenerator(stream, request.model, messageId)
} catch (error) {
const isInvalidApiKey =
error.message?.includes('API_KEY_INVALID') ||
error.message?.includes('API key not valid')
if (isInvalidApiKey) {
throw new LLMAPIKeyInvalidException(
`Gemini API key is invalid. Please update it in settings menu.`,
)
}
throw error
}
}
private async *streamResponseGenerator(
stream: GenerateContentStreamResult,
model: string,
messageId: string,
): AsyncIterable<LLMResponseStreaming> {
for await (const chunk of stream.stream) {
yield GeminiProvider.parseStreamingResponseChunk(chunk, model, messageId)
}
}
static parseRequestMessage(message: RequestMessage): Content | null {
if (message.role === 'system') {
return null
}
if (Array.isArray(message.content)) {
return {
role: message.role === 'user' ? 'user' : 'model',
parts: message.content.map((part) => {
switch (part.type) {
case 'text':
return { text: part.text }
case 'image_url': {
const { mimeType, base64Data } = parseImageDataUrl(
part.image_url.url,
)
GeminiProvider.validateImageType(mimeType)
return {
inlineData: {
data: base64Data,
mimeType,
},
}
}
}
}) as Part[],
}
}
return {
role: message.role === 'user' ? 'user' : 'model',
parts: [
{
text: message.content,
},
],
}
}
static parseNonStreamingResponse(
response: GenerateContentResult,
model: string,
messageId: string,
): LLMResponseNonStreaming {
return {
id: messageId,
choices: [
{
finish_reason:
response.response.candidates?.[0]?.finishReason ?? null,
message: {
content: response.response.text(),
role: 'assistant',
},
},
],
created: Date.now(),
model: model,
object: 'chat.completion',
usage: response.response.usageMetadata
? {
prompt_tokens: response.response.usageMetadata.promptTokenCount,
completion_tokens:
response.response.usageMetadata.candidatesTokenCount,
total_tokens: response.response.usageMetadata.totalTokenCount,
}
: undefined,
}
}
static parseStreamingResponseChunk(
chunk: EnhancedGenerateContentResponse,
model: string,
messageId: string,
): LLMResponseStreaming {
return {
id: messageId,
choices: [
{
finish_reason: chunk.candidates?.[0]?.finishReason ?? null,
delta: {
content: chunk.text(),
},
},
],
created: Date.now(),
model: model,
object: 'chat.completion.chunk',
usage: chunk.usageMetadata
? {
prompt_tokens: chunk.usageMetadata.promptTokenCount,
completion_tokens: chunk.usageMetadata.candidatesTokenCount,
total_tokens: chunk.usageMetadata.totalTokenCount,
}
: undefined,
}
}
private static validateImageType(mimeType: string) {
const SUPPORTED_IMAGE_TYPES = [
'image/png',
'image/jpeg',
'image/webp',
'image/heic',
'image/heif',
]
if (!SUPPORTED_IMAGE_TYPES.includes(mimeType)) {
throw new Error(
`Gemini does not support image type ${mimeType}. Supported types: ${SUPPORTED_IMAGE_TYPES.join(
', ',
)}`,
)
}
}
}