From 053bed6e3c14d5f7d36d787aabe025f840b529cd Mon Sep 17 00:00:00 2001 From: itqop Date: Mon, 24 Feb 2025 00:22:12 +0300 Subject: [PATCH] test llm --- speech_service/config.py | 1 + speech_service/main.py | 43 ++++++++++++++++++++++++-- speech_service/requirements.txt | 1 + telegram_bot/handlers/audio_handler.py | 3 +- 4 files changed, 45 insertions(+), 3 deletions(-) diff --git a/speech_service/config.py b/speech_service/config.py index b95ee30..e2aa109 100644 --- a/speech_service/config.py +++ b/speech_service/config.py @@ -10,4 +10,5 @@ def load_config(): "DEVICE": os.getenv("DEVICE", "cuda"), "AUDIO_TASK_CHANNEL": os.getenv("AUDIO_TASK_CHANNEL", "audio_tasks"), "TEXT_RESULT_CHANNEL": os.getenv("TEXT_RESULT_CHANNEL", "texts"), + "OLLAMA_URL": os.getenv("OLLAMA_URL", "http://ollama:11434/api/generate/"), } diff --git a/speech_service/main.py b/speech_service/main.py index 405b896..37a42aa 100644 --- a/speech_service/main.py +++ b/speech_service/main.py @@ -4,8 +4,43 @@ from config import load_config from models import AudioTask from redis_client import RedisClient from transcriber import WhisperTranscriber +import httpx +import json -async def process_audio_task(redis_client: RedisClient, transcriber: WhisperTranscriber, task_data: dict): +async def request_summary(text: str, ollama_url: str) -> str: + """ + Делает запрос к Ollama API для суммаризации текста. + Использует модель gemma2:2b, системное сообщение и температуру 0.6. + Запрос делается без стриминга. + """ + payload = { + "model": "gemma2:2b", + "prompt": text, + "system": ( + "Ты — помощник для суммаризации текста. Твоя задача: выделить ключевые моменты " + "из высказывания человека, переформулируя их кратко и сохраняя оригинальный смысл. " + "Очень важно: не отвечай на вопросы, не рассуждай, не комментируй, не добавляй ничего от себя, " + "выполняй только суммаризацию." + ), + "stream": False, + "options": { + "temperature": 0.6 + } + } + async with httpx.AsyncClient() as client: + try: + response = await client.post(ollama_url, json=payload, timeout=60.0) + response.raise_for_status() + except Exception as e: + print(f"LLM API request failed: {e}") + return text + data = response.json() + summary = data.get("response") + if summary: + return summary.strip() + return text + +async def process_audio_task(redis_client: RedisClient, transcriber: WhisperTranscriber, task_data: dict, ollama_url: str): try: task = AudioTask(**task_data) except Exception as e: @@ -16,6 +51,9 @@ async def process_audio_task(redis_client: RedisClient, transcriber: WhisperTran loop = asyncio.get_running_loop() text = await loop.run_in_executor(None, transcriber.transcribe, task.file_path) + if task_data.get("sum") == 1: + text = await request_summary(text, ollama_url) + result = { "chat_id": task.chat_id, "user_id": task.user_id, @@ -25,6 +63,7 @@ async def process_audio_task(redis_client: RedisClient, transcriber: WhisperTran await redis_client.publish_result(result) print(f"Published result for task {task.uuid}") + async def main(): config = load_config() redis_client = RedisClient( @@ -46,7 +85,7 @@ async def main(): except Exception as e: print(f"Error parsing task message: {e}") continue - asyncio.create_task(process_audio_task(redis_client, transcriber, task_data)) + asyncio.create_task(process_audio_task(redis_client, transcriber, task_data, config["OLLAMA_URL"])) await asyncio.sleep(0.1) if __name__ == "__main__": diff --git a/speech_service/requirements.txt b/speech_service/requirements.txt index c9f0cd2..07e0882 100644 --- a/speech_service/requirements.txt +++ b/speech_service/requirements.txt @@ -9,3 +9,4 @@ torchaudio==2.5.1 transformers redis>=4.2.0 python-dotenv +httpx \ No newline at end of file diff --git a/telegram_bot/handlers/audio_handler.py b/telegram_bot/handlers/audio_handler.py index b883e59..bd626d8 100644 --- a/telegram_bot/handlers/audio_handler.py +++ b/telegram_bot/handlers/audio_handler.py @@ -39,7 +39,8 @@ async def handle_voice_and_video(message: types.Message, redis_service, storage_ "file_path": wav_destination, "user_id": message.from_user.id, "chat_id": message.chat.id, - "message_id": message.message_id + "message_id": message.message_id, + "sum": 1 } await redis_service.publish_task(task_data)