56 lines
1.7 KiB
Python
56 lines
1.7 KiB
Python
import google.generativeai as genai
|
|
import tiktoken
|
|
|
|
from src.llm_clients.base import LLMClient
|
|
from src.models.email import LLMRawOutput
|
|
from src.models.errors import LLMError
|
|
from src.app.config import settings
|
|
|
|
|
|
class GeminiClient(LLMClient):
|
|
def __init__(self):
|
|
if not settings.gemini_api_key:
|
|
raise ValueError("Gemini API key is required")
|
|
|
|
genai.configure(api_key=settings.gemini_api_key)
|
|
self.model = genai.GenerativeModel("gemini-pro")
|
|
self.encoding = tiktoken.get_encoding("cl100k_base")
|
|
|
|
def generate_completion(
|
|
self,
|
|
system_prompt: str,
|
|
user_prompt: str,
|
|
max_tokens: int = 1024,
|
|
temperature: float = 0.7,
|
|
**kwargs,
|
|
) -> LLMRawOutput:
|
|
try:
|
|
prompt = f"Системная инструкция: {system_prompt}\n\nЗапрос пользователя: {user_prompt}"
|
|
|
|
generation_config = genai.types.GenerationConfig(
|
|
max_output_tokens=max_tokens,
|
|
temperature=temperature,
|
|
)
|
|
|
|
response = self.model.generate_content(
|
|
prompt, generation_config=generation_config
|
|
)
|
|
|
|
content = response.text
|
|
|
|
prompt_tokens = self.count_tokens(prompt)
|
|
completion_tokens = self.count_tokens(content)
|
|
|
|
return LLMRawOutput(
|
|
content=content,
|
|
tokens_prompt=prompt_tokens,
|
|
tokens_completion=completion_tokens,
|
|
model="gemini-pro",
|
|
)
|
|
|
|
except Exception as e:
|
|
raise LLMError(f"Gemini generation error: {str(e)}", "gemini", "gemini-pro")
|
|
|
|
def count_tokens(self, text: str) -> int:
|
|
return len(self.encoding.encode(text))
|