Fix batch
This commit is contained in:
parent
ed203a9fca
commit
7f1a2e3d2d
|
@ -11,5 +11,7 @@ def load_config():
|
|||
"AUDIO_TASK_CHANNEL": os.getenv("AUDIO_TASK_CHANNEL", "audio_tasks"),
|
||||
"TEXT_RESULT_CHANNEL": os.getenv("TEXT_RESULT_CHANNEL", "text_result_channel"),
|
||||
"TEXT_TASK_CHANNEL": os.getenv("TEXT_TASK_CHANNEL", "text_task_channel"),
|
||||
"BATCH_SIZE": int(os.getenv("BATCH_SIZE", "4")),
|
||||
"WAIT_TIMEOUT": int(os.getenv("WAIT_TIMEOUT", "1")),
|
||||
"OLLAMA_URL": os.getenv("OLLAMA_URL", "http://ollama:11434/api/generate/"),
|
||||
}
|
||||
|
|
|
@ -1,31 +1,53 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from config import load_config
|
||||
from models import AudioTask
|
||||
from models import AudioTask, TextTask
|
||||
from redis_client import RedisClient
|
||||
from transcriber import WhisperTranscriber
|
||||
from typing import List, Dict
|
||||
|
||||
async def process_audio_task(redis_client: RedisClient, transcriber: WhisperTranscriber, task_data: dict):
|
||||
async def process_audio_tasks_batch(redis_client: RedisClient, transcriber: WhisperTranscriber, tasks_data: List[dict]):
|
||||
audio_tasks = []
|
||||
for task_data in tasks_data:
|
||||
try:
|
||||
task = AudioTask(**task_data)
|
||||
audio_tasks.append(task)
|
||||
except Exception as e:
|
||||
logging.error(f"Error creating AudioTask from data: {e}")
|
||||
|
||||
if not audio_tasks:
|
||||
return
|
||||
|
||||
logging.info(f"Processing task {task.uuid} ...")
|
||||
logging.info(f"Processing batch of {len(audio_tasks)} audio tasks...")
|
||||
|
||||
transcription_tasks = []
|
||||
for task in audio_tasks:
|
||||
transcription_tasks.append(transcribe_audio(transcriber, task))
|
||||
|
||||
text_tasks = await asyncio.gather(*transcription_tasks)
|
||||
|
||||
text_tasks = [task for task in text_tasks if task is not None]
|
||||
|
||||
if text_tasks:
|
||||
await redis_client.send_texts_batch(text_tasks)
|
||||
logging.info(f"Sent {len(text_tasks)} texts to summarize service")
|
||||
|
||||
async def transcribe_audio(transcriber: WhisperTranscriber, task: AudioTask) -> TextTask:
|
||||
try:
|
||||
logging.info(f"Transcribing audio for task {task.uuid}...")
|
||||
loop = asyncio.get_running_loop()
|
||||
text = await loop.run_in_executor(None, transcriber.transcribe, task.file_path)
|
||||
logging.info(f"Transcription completed for task {task.uuid}, text length: {len(text)}")
|
||||
|
||||
summarize_task = {
|
||||
"chat_id": task.chat_id,
|
||||
"user_id": task.user_id,
|
||||
"message_id": task.message_id,
|
||||
"text": text
|
||||
}
|
||||
await redis_client.send_to_summarize(summarize_task)
|
||||
logging.info(f"Sent text to summarize service for task {task.uuid}")
|
||||
|
||||
return TextTask(
|
||||
chat_id=task.chat_id,
|
||||
user_id=task.user_id,
|
||||
message_id=task.message_id,
|
||||
text=text
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"Error transcribing audio for task {task.uuid}: {e}")
|
||||
return None
|
||||
|
||||
async def main():
|
||||
logging.basicConfig(
|
||||
|
@ -41,18 +63,20 @@ async def main():
|
|||
port=config["REDIS_PORT"],
|
||||
task_channel=config["AUDIO_TASK_CHANNEL"],
|
||||
result_channel=config["TEXT_RESULT_CHANNEL"],
|
||||
text_task_channel="text_task_channel"
|
||||
text_task_channel=config["TEXT_TASK_CHANNEL"]
|
||||
)
|
||||
transcriber = WhisperTranscriber(config["WHISPER_MODEL"], config["DEVICE"])
|
||||
logging.info(f"Initialized transcriber with model {config['WHISPER_MODEL']} on {config['DEVICE']}")
|
||||
|
||||
logging.info(f"Waiting for audio tasks in channel {config['AUDIO_TASK_CHANNEL']}...")
|
||||
batch_size = config["BATCH_SIZE"]
|
||||
wait_timeout = config["WAIT_TIMEOUT"]
|
||||
logging.info(f"Waiting for audio tasks in channel {config['AUDIO_TASK_CHANNEL']} with batch size {batch_size}...")
|
||||
|
||||
while True:
|
||||
task_data = await redis_client.get_task(timeout=1)
|
||||
if task_data:
|
||||
logging.info(f"Received task: {task_data.get('uuid', 'unknown')}")
|
||||
asyncio.create_task(process_audio_task(redis_client, transcriber, task_data))
|
||||
tasks_data = await redis_client.get_tasks_batch(batch_size, wait_timeout)
|
||||
if tasks_data:
|
||||
logging.info(f"Received {len(tasks_data)} tasks")
|
||||
asyncio.create_task(process_audio_tasks_batch(redis_client, transcriber, tasks_data))
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
|
||||
@dataclass
|
||||
class AudioTask:
|
||||
|
@ -7,3 +8,10 @@ class AudioTask:
|
|||
user_id: int
|
||||
chat_id: int
|
||||
message_id: int
|
||||
|
||||
@dataclass
|
||||
class TextTask:
|
||||
user_id: int
|
||||
chat_id: int
|
||||
message_id: int
|
||||
text: str
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
import json
|
||||
import redis.asyncio as redis
|
||||
from typing import List, Optional
|
||||
from models import TextTask
|
||||
|
||||
class RedisClient:
|
||||
def __init__(self, host: str, port: int, task_channel: str, result_channel: str, text_task_channel: str = "text_task_channel"):
|
||||
|
@ -19,6 +21,31 @@ class RedisClient:
|
|||
print(f"Error parsing task message: {e}")
|
||||
return None
|
||||
|
||||
async def get_tasks_batch(self, batch_size: int, timeout: int = 1) -> List[dict]:
|
||||
"""Получает батч задач из очереди"""
|
||||
tasks = []
|
||||
result = await self.client.blpop(self.task_channel, timeout=timeout)
|
||||
if result:
|
||||
_, task_json = result
|
||||
try:
|
||||
task = json.loads(task_json)
|
||||
tasks.append(task)
|
||||
except Exception as e:
|
||||
print(f"Error parsing task message: {e}")
|
||||
|
||||
if tasks:
|
||||
for _ in range(batch_size - 1):
|
||||
task_json = await self.client.lpop(self.task_channel)
|
||||
if not task_json:
|
||||
break
|
||||
try:
|
||||
task = json.loads(task_json)
|
||||
tasks.append(task)
|
||||
except Exception as e:
|
||||
print(f"Error parsing task message: {e}")
|
||||
|
||||
return tasks
|
||||
|
||||
async def publish_result(self, result: dict):
|
||||
"""Отправляет результат в очередь результатов"""
|
||||
await self.client.rpush(self.result_channel, json.dumps(result))
|
||||
|
@ -26,3 +53,16 @@ class RedisClient:
|
|||
async def send_to_summarize(self, task_data: dict):
|
||||
"""Отправляет текст в сервис суммаризации"""
|
||||
await self.client.rpush(self.text_task_channel, json.dumps(task_data))
|
||||
|
||||
async def send_texts_batch(self, tasks: List[TextTask]):
|
||||
"""Отправляет батч текстов в сервис суммаризации"""
|
||||
pipeline = self.client.pipeline()
|
||||
for task in tasks:
|
||||
task_data = {
|
||||
"chat_id": task.chat_id,
|
||||
"user_id": task.user_id,
|
||||
"message_id": task.message_id,
|
||||
"text": task.text
|
||||
}
|
||||
pipeline.rpush(self.text_task_channel, json.dumps(task_data))
|
||||
await pipeline.execute()
|
||||
|
|
Loading…
Reference in New Issue