test llm
This commit is contained in:
parent
ac7596c183
commit
053bed6e3c
|
@ -10,4 +10,5 @@ def load_config():
|
||||||
"DEVICE": os.getenv("DEVICE", "cuda"),
|
"DEVICE": os.getenv("DEVICE", "cuda"),
|
||||||
"AUDIO_TASK_CHANNEL": os.getenv("AUDIO_TASK_CHANNEL", "audio_tasks"),
|
"AUDIO_TASK_CHANNEL": os.getenv("AUDIO_TASK_CHANNEL", "audio_tasks"),
|
||||||
"TEXT_RESULT_CHANNEL": os.getenv("TEXT_RESULT_CHANNEL", "texts"),
|
"TEXT_RESULT_CHANNEL": os.getenv("TEXT_RESULT_CHANNEL", "texts"),
|
||||||
|
"OLLAMA_URL": os.getenv("OLLAMA_URL", "http://ollama:11434/api/generate/"),
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,8 +4,43 @@ from config import load_config
|
||||||
from models import AudioTask
|
from models import AudioTask
|
||||||
from redis_client import RedisClient
|
from redis_client import RedisClient
|
||||||
from transcriber import WhisperTranscriber
|
from transcriber import WhisperTranscriber
|
||||||
|
import httpx
|
||||||
|
import json
|
||||||
|
|
||||||
async def process_audio_task(redis_client: RedisClient, transcriber: WhisperTranscriber, task_data: dict):
|
async def request_summary(text: str, ollama_url: str) -> str:
|
||||||
|
"""
|
||||||
|
Делает запрос к Ollama API для суммаризации текста.
|
||||||
|
Использует модель gemma2:2b, системное сообщение и температуру 0.6.
|
||||||
|
Запрос делается без стриминга.
|
||||||
|
"""
|
||||||
|
payload = {
|
||||||
|
"model": "gemma2:2b",
|
||||||
|
"prompt": text,
|
||||||
|
"system": (
|
||||||
|
"Ты — помощник для суммаризации текста. Твоя задача: выделить ключевые моменты "
|
||||||
|
"из высказывания человека, переформулируя их кратко и сохраняя оригинальный смысл. "
|
||||||
|
"Очень важно: не отвечай на вопросы, не рассуждай, не комментируй, не добавляй ничего от себя, "
|
||||||
|
"выполняй только суммаризацию."
|
||||||
|
),
|
||||||
|
"stream": False,
|
||||||
|
"options": {
|
||||||
|
"temperature": 0.6
|
||||||
|
}
|
||||||
|
}
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
try:
|
||||||
|
response = await client.post(ollama_url, json=payload, timeout=60.0)
|
||||||
|
response.raise_for_status()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"LLM API request failed: {e}")
|
||||||
|
return text
|
||||||
|
data = response.json()
|
||||||
|
summary = data.get("response")
|
||||||
|
if summary:
|
||||||
|
return summary.strip()
|
||||||
|
return text
|
||||||
|
|
||||||
|
async def process_audio_task(redis_client: RedisClient, transcriber: WhisperTranscriber, task_data: dict, ollama_url: str):
|
||||||
try:
|
try:
|
||||||
task = AudioTask(**task_data)
|
task = AudioTask(**task_data)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -16,6 +51,9 @@ async def process_audio_task(redis_client: RedisClient, transcriber: WhisperTran
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
text = await loop.run_in_executor(None, transcriber.transcribe, task.file_path)
|
text = await loop.run_in_executor(None, transcriber.transcribe, task.file_path)
|
||||||
|
|
||||||
|
if task_data.get("sum") == 1:
|
||||||
|
text = await request_summary(text, ollama_url)
|
||||||
|
|
||||||
result = {
|
result = {
|
||||||
"chat_id": task.chat_id,
|
"chat_id": task.chat_id,
|
||||||
"user_id": task.user_id,
|
"user_id": task.user_id,
|
||||||
|
@ -25,6 +63,7 @@ async def process_audio_task(redis_client: RedisClient, transcriber: WhisperTran
|
||||||
await redis_client.publish_result(result)
|
await redis_client.publish_result(result)
|
||||||
print(f"Published result for task {task.uuid}")
|
print(f"Published result for task {task.uuid}")
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
config = load_config()
|
config = load_config()
|
||||||
redis_client = RedisClient(
|
redis_client = RedisClient(
|
||||||
|
@ -46,7 +85,7 @@ async def main():
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error parsing task message: {e}")
|
print(f"Error parsing task message: {e}")
|
||||||
continue
|
continue
|
||||||
asyncio.create_task(process_audio_task(redis_client, transcriber, task_data))
|
asyncio.create_task(process_audio_task(redis_client, transcriber, task_data, config["OLLAMA_URL"]))
|
||||||
await asyncio.sleep(0.1)
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -9,3 +9,4 @@ torchaudio==2.5.1
|
||||||
transformers
|
transformers
|
||||||
redis>=4.2.0
|
redis>=4.2.0
|
||||||
python-dotenv
|
python-dotenv
|
||||||
|
httpx
|
|
@ -39,7 +39,8 @@ async def handle_voice_and_video(message: types.Message, redis_service, storage_
|
||||||
"file_path": wav_destination,
|
"file_path": wav_destination,
|
||||||
"user_id": message.from_user.id,
|
"user_id": message.from_user.id,
|
||||||
"chat_id": message.chat.id,
|
"chat_id": message.chat.id,
|
||||||
"message_id": message.message_id
|
"message_id": message.message_id,
|
||||||
|
"sum": 1
|
||||||
}
|
}
|
||||||
|
|
||||||
await redis_service.publish_task(task_data)
|
await redis_service.publish_task(task_data)
|
||||||
|
|
Loading…
Reference in New Issue