Fix batch

This commit is contained in:
itqop 2025-03-03 05:22:31 +03:00
parent ed203a9fca
commit 7f1a2e3d2d
4 changed files with 100 additions and 26 deletions

View File

@ -11,5 +11,7 @@ def load_config():
"AUDIO_TASK_CHANNEL": os.getenv("AUDIO_TASK_CHANNEL", "audio_tasks"), "AUDIO_TASK_CHANNEL": os.getenv("AUDIO_TASK_CHANNEL", "audio_tasks"),
"TEXT_RESULT_CHANNEL": os.getenv("TEXT_RESULT_CHANNEL", "text_result_channel"), "TEXT_RESULT_CHANNEL": os.getenv("TEXT_RESULT_CHANNEL", "text_result_channel"),
"TEXT_TASK_CHANNEL": os.getenv("TEXT_TASK_CHANNEL", "text_task_channel"), "TEXT_TASK_CHANNEL": os.getenv("TEXT_TASK_CHANNEL", "text_task_channel"),
"BATCH_SIZE": int(os.getenv("BATCH_SIZE", "4")),
"WAIT_TIMEOUT": int(os.getenv("WAIT_TIMEOUT", "1")),
"OLLAMA_URL": os.getenv("OLLAMA_URL", "http://ollama:11434/api/generate/"), "OLLAMA_URL": os.getenv("OLLAMA_URL", "http://ollama:11434/api/generate/"),
} }

View File

@ -1,31 +1,53 @@
import asyncio import asyncio
import logging import logging
from config import load_config from config import load_config
from models import AudioTask from models import AudioTask, TextTask
from redis_client import RedisClient from redis_client import RedisClient
from transcriber import WhisperTranscriber from transcriber import WhisperTranscriber
from typing import List, Dict
async def process_audio_task(redis_client: RedisClient, transcriber: WhisperTranscriber, task_data: dict): async def process_audio_tasks_batch(redis_client: RedisClient, transcriber: WhisperTranscriber, tasks_data: List[dict]):
try: audio_tasks = []
task = AudioTask(**task_data) for task_data in tasks_data:
except Exception as e: try:
logging.error(f"Error creating AudioTask from data: {e}") task = AudioTask(**task_data)
audio_tasks.append(task)
except Exception as e:
logging.error(f"Error creating AudioTask from data: {e}")
if not audio_tasks:
return return
logging.info(f"Processing batch of {len(audio_tasks)} audio tasks...")
transcription_tasks = []
for task in audio_tasks:
transcription_tasks.append(transcribe_audio(transcriber, task))
text_tasks = await asyncio.gather(*transcription_tasks)
text_tasks = [task for task in text_tasks if task is not None]
if text_tasks:
await redis_client.send_texts_batch(text_tasks)
logging.info(f"Sent {len(text_tasks)} texts to summarize service")
logging.info(f"Processing task {task.uuid} ...") async def transcribe_audio(transcriber: WhisperTranscriber, task: AudioTask) -> TextTask:
loop = asyncio.get_running_loop() try:
text = await loop.run_in_executor(None, transcriber.transcribe, task.file_path) logging.info(f"Transcribing audio for task {task.uuid}...")
logging.info(f"Transcription completed for task {task.uuid}, text length: {len(text)}") loop = asyncio.get_running_loop()
text = await loop.run_in_executor(None, transcriber.transcribe, task.file_path)
summarize_task = { logging.info(f"Transcription completed for task {task.uuid}, text length: {len(text)}")
"chat_id": task.chat_id,
"user_id": task.user_id, return TextTask(
"message_id": task.message_id, chat_id=task.chat_id,
"text": text user_id=task.user_id,
} message_id=task.message_id,
await redis_client.send_to_summarize(summarize_task) text=text
logging.info(f"Sent text to summarize service for task {task.uuid}") )
except Exception as e:
logging.error(f"Error transcribing audio for task {task.uuid}: {e}")
return None
async def main(): async def main():
logging.basicConfig( logging.basicConfig(
@ -41,18 +63,20 @@ async def main():
port=config["REDIS_PORT"], port=config["REDIS_PORT"],
task_channel=config["AUDIO_TASK_CHANNEL"], task_channel=config["AUDIO_TASK_CHANNEL"],
result_channel=config["TEXT_RESULT_CHANNEL"], result_channel=config["TEXT_RESULT_CHANNEL"],
text_task_channel="text_task_channel" text_task_channel=config["TEXT_TASK_CHANNEL"]
) )
transcriber = WhisperTranscriber(config["WHISPER_MODEL"], config["DEVICE"]) transcriber = WhisperTranscriber(config["WHISPER_MODEL"], config["DEVICE"])
logging.info(f"Initialized transcriber with model {config['WHISPER_MODEL']} on {config['DEVICE']}") logging.info(f"Initialized transcriber with model {config['WHISPER_MODEL']} on {config['DEVICE']}")
logging.info(f"Waiting for audio tasks in channel {config['AUDIO_TASK_CHANNEL']}...") batch_size = config["BATCH_SIZE"]
wait_timeout = config["WAIT_TIMEOUT"]
logging.info(f"Waiting for audio tasks in channel {config['AUDIO_TASK_CHANNEL']} with batch size {batch_size}...")
while True: while True:
task_data = await redis_client.get_task(timeout=1) tasks_data = await redis_client.get_tasks_batch(batch_size, wait_timeout)
if task_data: if tasks_data:
logging.info(f"Received task: {task_data.get('uuid', 'unknown')}") logging.info(f"Received {len(tasks_data)} tasks")
asyncio.create_task(process_audio_task(redis_client, transcriber, task_data)) asyncio.create_task(process_audio_tasks_batch(redis_client, transcriber, tasks_data))
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,4 +1,5 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import List
@dataclass @dataclass
class AudioTask: class AudioTask:
@ -7,3 +8,10 @@ class AudioTask:
user_id: int user_id: int
chat_id: int chat_id: int
message_id: int message_id: int
@dataclass
class TextTask:
user_id: int
chat_id: int
message_id: int
text: str

View File

@ -1,5 +1,7 @@
import json import json
import redis.asyncio as redis import redis.asyncio as redis
from typing import List, Optional
from models import TextTask
class RedisClient: class RedisClient:
def __init__(self, host: str, port: int, task_channel: str, result_channel: str, text_task_channel: str = "text_task_channel"): def __init__(self, host: str, port: int, task_channel: str, result_channel: str, text_task_channel: str = "text_task_channel"):
@ -19,6 +21,31 @@ class RedisClient:
print(f"Error parsing task message: {e}") print(f"Error parsing task message: {e}")
return None return None
async def get_tasks_batch(self, batch_size: int, timeout: int = 1) -> List[dict]:
"""Получает батч задач из очереди"""
tasks = []
result = await self.client.blpop(self.task_channel, timeout=timeout)
if result:
_, task_json = result
try:
task = json.loads(task_json)
tasks.append(task)
except Exception as e:
print(f"Error parsing task message: {e}")
if tasks:
for _ in range(batch_size - 1):
task_json = await self.client.lpop(self.task_channel)
if not task_json:
break
try:
task = json.loads(task_json)
tasks.append(task)
except Exception as e:
print(f"Error parsing task message: {e}")
return tasks
async def publish_result(self, result: dict): async def publish_result(self, result: dict):
"""Отправляет результат в очередь результатов""" """Отправляет результат в очередь результатов"""
await self.client.rpush(self.result_channel, json.dumps(result)) await self.client.rpush(self.result_channel, json.dumps(result))
@ -26,3 +53,16 @@ class RedisClient:
async def send_to_summarize(self, task_data: dict): async def send_to_summarize(self, task_data: dict):
"""Отправляет текст в сервис суммаризации""" """Отправляет текст в сервис суммаризации"""
await self.client.rpush(self.text_task_channel, json.dumps(task_data)) await self.client.rpush(self.text_task_channel, json.dumps(task_data))
async def send_texts_batch(self, tasks: List[TextTask]):
"""Отправляет батч текстов в сервис суммаризации"""
pipeline = self.client.pipeline()
for task in tasks:
task_data = {
"chat_id": task.chat_id,
"user_id": task.user_id,
"message_id": task.message_id,
"text": task.text
}
pipeline.rpush(self.text_task_channel, json.dumps(task_data))
await pipeline.execute()