From 7f1a2e3d2d683a9ff4782a59d298b5c64ccc5006 Mon Sep 17 00:00:00 2001 From: itqop Date: Mon, 3 Mar 2025 05:22:31 +0300 Subject: [PATCH] Fix batch --- speech_service/config.py | 2 + speech_service/main.py | 76 ++++++++++++++++++++++------------ speech_service/models.py | 8 ++++ speech_service/redis_client.py | 40 ++++++++++++++++++ 4 files changed, 100 insertions(+), 26 deletions(-) diff --git a/speech_service/config.py b/speech_service/config.py index 85d037d..95fed57 100644 --- a/speech_service/config.py +++ b/speech_service/config.py @@ -11,5 +11,7 @@ def load_config(): "AUDIO_TASK_CHANNEL": os.getenv("AUDIO_TASK_CHANNEL", "audio_tasks"), "TEXT_RESULT_CHANNEL": os.getenv("TEXT_RESULT_CHANNEL", "text_result_channel"), "TEXT_TASK_CHANNEL": os.getenv("TEXT_TASK_CHANNEL", "text_task_channel"), + "BATCH_SIZE": int(os.getenv("BATCH_SIZE", "4")), + "WAIT_TIMEOUT": int(os.getenv("WAIT_TIMEOUT", "1")), "OLLAMA_URL": os.getenv("OLLAMA_URL", "http://ollama:11434/api/generate/"), } diff --git a/speech_service/main.py b/speech_service/main.py index 338ecb3..7a53a7f 100644 --- a/speech_service/main.py +++ b/speech_service/main.py @@ -1,31 +1,53 @@ import asyncio import logging from config import load_config -from models import AudioTask +from models import AudioTask, TextTask from redis_client import RedisClient from transcriber import WhisperTranscriber +from typing import List, Dict -async def process_audio_task(redis_client: RedisClient, transcriber: WhisperTranscriber, task_data: dict): - try: - task = AudioTask(**task_data) - except Exception as e: - logging.error(f"Error creating AudioTask from data: {e}") +async def process_audio_tasks_batch(redis_client: RedisClient, transcriber: WhisperTranscriber, tasks_data: List[dict]): + audio_tasks = [] + for task_data in tasks_data: + try: + task = AudioTask(**task_data) + audio_tasks.append(task) + except Exception as e: + logging.error(f"Error creating AudioTask from data: {e}") + + if not audio_tasks: return + + logging.info(f"Processing batch of {len(audio_tasks)} audio tasks...") + + transcription_tasks = [] + for task in audio_tasks: + transcription_tasks.append(transcribe_audio(transcriber, task)) + + text_tasks = await asyncio.gather(*transcription_tasks) + + text_tasks = [task for task in text_tasks if task is not None] + + if text_tasks: + await redis_client.send_texts_batch(text_tasks) + logging.info(f"Sent {len(text_tasks)} texts to summarize service") - logging.info(f"Processing task {task.uuid} ...") - loop = asyncio.get_running_loop() - text = await loop.run_in_executor(None, transcriber.transcribe, task.file_path) - logging.info(f"Transcription completed for task {task.uuid}, text length: {len(text)}") - - summarize_task = { - "chat_id": task.chat_id, - "user_id": task.user_id, - "message_id": task.message_id, - "text": text - } - await redis_client.send_to_summarize(summarize_task) - logging.info(f"Sent text to summarize service for task {task.uuid}") - +async def transcribe_audio(transcriber: WhisperTranscriber, task: AudioTask) -> TextTask: + try: + logging.info(f"Transcribing audio for task {task.uuid}...") + loop = asyncio.get_running_loop() + text = await loop.run_in_executor(None, transcriber.transcribe, task.file_path) + logging.info(f"Transcription completed for task {task.uuid}, text length: {len(text)}") + + return TextTask( + chat_id=task.chat_id, + user_id=task.user_id, + message_id=task.message_id, + text=text + ) + except Exception as e: + logging.error(f"Error transcribing audio for task {task.uuid}: {e}") + return None async def main(): logging.basicConfig( @@ -41,18 +63,20 @@ async def main(): port=config["REDIS_PORT"], task_channel=config["AUDIO_TASK_CHANNEL"], result_channel=config["TEXT_RESULT_CHANNEL"], - text_task_channel="text_task_channel" + text_task_channel=config["TEXT_TASK_CHANNEL"] ) transcriber = WhisperTranscriber(config["WHISPER_MODEL"], config["DEVICE"]) logging.info(f"Initialized transcriber with model {config['WHISPER_MODEL']} on {config['DEVICE']}") - logging.info(f"Waiting for audio tasks in channel {config['AUDIO_TASK_CHANNEL']}...") + batch_size = config["BATCH_SIZE"] + wait_timeout = config["WAIT_TIMEOUT"] + logging.info(f"Waiting for audio tasks in channel {config['AUDIO_TASK_CHANNEL']} with batch size {batch_size}...") while True: - task_data = await redis_client.get_task(timeout=1) - if task_data: - logging.info(f"Received task: {task_data.get('uuid', 'unknown')}") - asyncio.create_task(process_audio_task(redis_client, transcriber, task_data)) + tasks_data = await redis_client.get_tasks_batch(batch_size, wait_timeout) + if tasks_data: + logging.info(f"Received {len(tasks_data)} tasks") + asyncio.create_task(process_audio_tasks_batch(redis_client, transcriber, tasks_data)) await asyncio.sleep(0.1) if __name__ == "__main__": diff --git a/speech_service/models.py b/speech_service/models.py index 8b7127e..4671d5b 100644 --- a/speech_service/models.py +++ b/speech_service/models.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +from typing import List @dataclass class AudioTask: @@ -7,3 +8,10 @@ class AudioTask: user_id: int chat_id: int message_id: int + +@dataclass +class TextTask: + user_id: int + chat_id: int + message_id: int + text: str diff --git a/speech_service/redis_client.py b/speech_service/redis_client.py index 4d5b4f1..baf07e6 100644 --- a/speech_service/redis_client.py +++ b/speech_service/redis_client.py @@ -1,5 +1,7 @@ import json import redis.asyncio as redis +from typing import List, Optional +from models import TextTask class RedisClient: def __init__(self, host: str, port: int, task_channel: str, result_channel: str, text_task_channel: str = "text_task_channel"): @@ -19,6 +21,31 @@ class RedisClient: print(f"Error parsing task message: {e}") return None + async def get_tasks_batch(self, batch_size: int, timeout: int = 1) -> List[dict]: + """Получает батч задач из очереди""" + tasks = [] + result = await self.client.blpop(self.task_channel, timeout=timeout) + if result: + _, task_json = result + try: + task = json.loads(task_json) + tasks.append(task) + except Exception as e: + print(f"Error parsing task message: {e}") + + if tasks: + for _ in range(batch_size - 1): + task_json = await self.client.lpop(self.task_channel) + if not task_json: + break + try: + task = json.loads(task_json) + tasks.append(task) + except Exception as e: + print(f"Error parsing task message: {e}") + + return tasks + async def publish_result(self, result: dict): """Отправляет результат в очередь результатов""" await self.client.rpush(self.result_channel, json.dumps(result)) @@ -26,3 +53,16 @@ class RedisClient: async def send_to_summarize(self, task_data: dict): """Отправляет текст в сервис суммаризации""" await self.client.rpush(self.text_task_channel, json.dumps(task_data)) + + async def send_texts_batch(self, tasks: List[TextTask]): + """Отправляет батч текстов в сервис суммаризации""" + pipeline = self.client.pipeline() + for task in tasks: + task_data = { + "chat_id": task.chat_id, + "user_id": task.user_id, + "message_id": task.message_id, + "text": task.text + } + pipeline.rpush(self.text_task_channel, json.dumps(task_data)) + await pipeline.execute()