Fix batch
This commit is contained in:
parent
ed203a9fca
commit
7f1a2e3d2d
|
@ -11,5 +11,7 @@ def load_config():
|
||||||
"AUDIO_TASK_CHANNEL": os.getenv("AUDIO_TASK_CHANNEL", "audio_tasks"),
|
"AUDIO_TASK_CHANNEL": os.getenv("AUDIO_TASK_CHANNEL", "audio_tasks"),
|
||||||
"TEXT_RESULT_CHANNEL": os.getenv("TEXT_RESULT_CHANNEL", "text_result_channel"),
|
"TEXT_RESULT_CHANNEL": os.getenv("TEXT_RESULT_CHANNEL", "text_result_channel"),
|
||||||
"TEXT_TASK_CHANNEL": os.getenv("TEXT_TASK_CHANNEL", "text_task_channel"),
|
"TEXT_TASK_CHANNEL": os.getenv("TEXT_TASK_CHANNEL", "text_task_channel"),
|
||||||
|
"BATCH_SIZE": int(os.getenv("BATCH_SIZE", "4")),
|
||||||
|
"WAIT_TIMEOUT": int(os.getenv("WAIT_TIMEOUT", "1")),
|
||||||
"OLLAMA_URL": os.getenv("OLLAMA_URL", "http://ollama:11434/api/generate/"),
|
"OLLAMA_URL": os.getenv("OLLAMA_URL", "http://ollama:11434/api/generate/"),
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,31 +1,53 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from config import load_config
|
from config import load_config
|
||||||
from models import AudioTask
|
from models import AudioTask, TextTask
|
||||||
from redis_client import RedisClient
|
from redis_client import RedisClient
|
||||||
from transcriber import WhisperTranscriber
|
from transcriber import WhisperTranscriber
|
||||||
|
from typing import List, Dict
|
||||||
|
|
||||||
async def process_audio_task(redis_client: RedisClient, transcriber: WhisperTranscriber, task_data: dict):
|
async def process_audio_tasks_batch(redis_client: RedisClient, transcriber: WhisperTranscriber, tasks_data: List[dict]):
|
||||||
try:
|
audio_tasks = []
|
||||||
task = AudioTask(**task_data)
|
for task_data in tasks_data:
|
||||||
except Exception as e:
|
try:
|
||||||
logging.error(f"Error creating AudioTask from data: {e}")
|
task = AudioTask(**task_data)
|
||||||
|
audio_tasks.append(task)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error creating AudioTask from data: {e}")
|
||||||
|
|
||||||
|
if not audio_tasks:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
logging.info(f"Processing batch of {len(audio_tasks)} audio tasks...")
|
||||||
|
|
||||||
|
transcription_tasks = []
|
||||||
|
for task in audio_tasks:
|
||||||
|
transcription_tasks.append(transcribe_audio(transcriber, task))
|
||||||
|
|
||||||
|
text_tasks = await asyncio.gather(*transcription_tasks)
|
||||||
|
|
||||||
|
text_tasks = [task for task in text_tasks if task is not None]
|
||||||
|
|
||||||
|
if text_tasks:
|
||||||
|
await redis_client.send_texts_batch(text_tasks)
|
||||||
|
logging.info(f"Sent {len(text_tasks)} texts to summarize service")
|
||||||
|
|
||||||
logging.info(f"Processing task {task.uuid} ...")
|
async def transcribe_audio(transcriber: WhisperTranscriber, task: AudioTask) -> TextTask:
|
||||||
loop = asyncio.get_running_loop()
|
try:
|
||||||
text = await loop.run_in_executor(None, transcriber.transcribe, task.file_path)
|
logging.info(f"Transcribing audio for task {task.uuid}...")
|
||||||
logging.info(f"Transcription completed for task {task.uuid}, text length: {len(text)}")
|
loop = asyncio.get_running_loop()
|
||||||
|
text = await loop.run_in_executor(None, transcriber.transcribe, task.file_path)
|
||||||
summarize_task = {
|
logging.info(f"Transcription completed for task {task.uuid}, text length: {len(text)}")
|
||||||
"chat_id": task.chat_id,
|
|
||||||
"user_id": task.user_id,
|
return TextTask(
|
||||||
"message_id": task.message_id,
|
chat_id=task.chat_id,
|
||||||
"text": text
|
user_id=task.user_id,
|
||||||
}
|
message_id=task.message_id,
|
||||||
await redis_client.send_to_summarize(summarize_task)
|
text=text
|
||||||
logging.info(f"Sent text to summarize service for task {task.uuid}")
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error transcribing audio for task {task.uuid}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
|
@ -41,18 +63,20 @@ async def main():
|
||||||
port=config["REDIS_PORT"],
|
port=config["REDIS_PORT"],
|
||||||
task_channel=config["AUDIO_TASK_CHANNEL"],
|
task_channel=config["AUDIO_TASK_CHANNEL"],
|
||||||
result_channel=config["TEXT_RESULT_CHANNEL"],
|
result_channel=config["TEXT_RESULT_CHANNEL"],
|
||||||
text_task_channel="text_task_channel"
|
text_task_channel=config["TEXT_TASK_CHANNEL"]
|
||||||
)
|
)
|
||||||
transcriber = WhisperTranscriber(config["WHISPER_MODEL"], config["DEVICE"])
|
transcriber = WhisperTranscriber(config["WHISPER_MODEL"], config["DEVICE"])
|
||||||
logging.info(f"Initialized transcriber with model {config['WHISPER_MODEL']} on {config['DEVICE']}")
|
logging.info(f"Initialized transcriber with model {config['WHISPER_MODEL']} on {config['DEVICE']}")
|
||||||
|
|
||||||
logging.info(f"Waiting for audio tasks in channel {config['AUDIO_TASK_CHANNEL']}...")
|
batch_size = config["BATCH_SIZE"]
|
||||||
|
wait_timeout = config["WAIT_TIMEOUT"]
|
||||||
|
logging.info(f"Waiting for audio tasks in channel {config['AUDIO_TASK_CHANNEL']} with batch size {batch_size}...")
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
task_data = await redis_client.get_task(timeout=1)
|
tasks_data = await redis_client.get_tasks_batch(batch_size, wait_timeout)
|
||||||
if task_data:
|
if tasks_data:
|
||||||
logging.info(f"Received task: {task_data.get('uuid', 'unknown')}")
|
logging.info(f"Received {len(tasks_data)} tasks")
|
||||||
asyncio.create_task(process_audio_task(redis_client, transcriber, task_data))
|
asyncio.create_task(process_audio_tasks_batch(redis_client, transcriber, tasks_data))
|
||||||
await asyncio.sleep(0.1)
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from typing import List
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AudioTask:
|
class AudioTask:
|
||||||
|
@ -7,3 +8,10 @@ class AudioTask:
|
||||||
user_id: int
|
user_id: int
|
||||||
chat_id: int
|
chat_id: int
|
||||||
message_id: int
|
message_id: int
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TextTask:
|
||||||
|
user_id: int
|
||||||
|
chat_id: int
|
||||||
|
message_id: int
|
||||||
|
text: str
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
import json
|
import json
|
||||||
import redis.asyncio as redis
|
import redis.asyncio as redis
|
||||||
|
from typing import List, Optional
|
||||||
|
from models import TextTask
|
||||||
|
|
||||||
class RedisClient:
|
class RedisClient:
|
||||||
def __init__(self, host: str, port: int, task_channel: str, result_channel: str, text_task_channel: str = "text_task_channel"):
|
def __init__(self, host: str, port: int, task_channel: str, result_channel: str, text_task_channel: str = "text_task_channel"):
|
||||||
|
@ -19,6 +21,31 @@ class RedisClient:
|
||||||
print(f"Error parsing task message: {e}")
|
print(f"Error parsing task message: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
async def get_tasks_batch(self, batch_size: int, timeout: int = 1) -> List[dict]:
|
||||||
|
"""Получает батч задач из очереди"""
|
||||||
|
tasks = []
|
||||||
|
result = await self.client.blpop(self.task_channel, timeout=timeout)
|
||||||
|
if result:
|
||||||
|
_, task_json = result
|
||||||
|
try:
|
||||||
|
task = json.loads(task_json)
|
||||||
|
tasks.append(task)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error parsing task message: {e}")
|
||||||
|
|
||||||
|
if tasks:
|
||||||
|
for _ in range(batch_size - 1):
|
||||||
|
task_json = await self.client.lpop(self.task_channel)
|
||||||
|
if not task_json:
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
task = json.loads(task_json)
|
||||||
|
tasks.append(task)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error parsing task message: {e}")
|
||||||
|
|
||||||
|
return tasks
|
||||||
|
|
||||||
async def publish_result(self, result: dict):
|
async def publish_result(self, result: dict):
|
||||||
"""Отправляет результат в очередь результатов"""
|
"""Отправляет результат в очередь результатов"""
|
||||||
await self.client.rpush(self.result_channel, json.dumps(result))
|
await self.client.rpush(self.result_channel, json.dumps(result))
|
||||||
|
@ -26,3 +53,16 @@ class RedisClient:
|
||||||
async def send_to_summarize(self, task_data: dict):
|
async def send_to_summarize(self, task_data: dict):
|
||||||
"""Отправляет текст в сервис суммаризации"""
|
"""Отправляет текст в сервис суммаризации"""
|
||||||
await self.client.rpush(self.text_task_channel, json.dumps(task_data))
|
await self.client.rpush(self.text_task_channel, json.dumps(task_data))
|
||||||
|
|
||||||
|
async def send_texts_batch(self, tasks: List[TextTask]):
|
||||||
|
"""Отправляет батч текстов в сервис суммаризации"""
|
||||||
|
pipeline = self.client.pipeline()
|
||||||
|
for task in tasks:
|
||||||
|
task_data = {
|
||||||
|
"chat_id": task.chat_id,
|
||||||
|
"user_id": task.user_id,
|
||||||
|
"message_id": task.message_id,
|
||||||
|
"text": task.text
|
||||||
|
}
|
||||||
|
pipeline.rpush(self.text_task_channel, json.dumps(task_data))
|
||||||
|
await pipeline.execute()
|
||||||
|
|
Loading…
Reference in New Issue