fix redis
This commit is contained in:
parent
9ed038557c
commit
e7f7081bc6
|
@ -1,23 +1,22 @@
|
|||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from config import load_config
|
||||
from models import AudioTask
|
||||
from redis_client import RedisClient
|
||||
from transcriber import WhisperTranscriber
|
||||
import httpx
|
||||
|
||||
async def process_audio_task(redis_client: RedisClient, transcriber: WhisperTranscriber, task_data: dict):
|
||||
try:
|
||||
task = AudioTask(**task_data)
|
||||
except Exception as e:
|
||||
print(f"Error creating AudioTask from data: {e}")
|
||||
logging.error(f"Error creating AudioTask from data: {e}")
|
||||
return
|
||||
|
||||
print(f"Processing task {task.uuid} ...")
|
||||
logging.info(f"Processing task {task.uuid} ...")
|
||||
loop = asyncio.get_running_loop()
|
||||
text = await loop.run_in_executor(None, transcriber.transcribe, task.file_path)
|
||||
logging.info(f"Transcription completed for task {task.uuid}, text length: {len(text)}")
|
||||
|
||||
# Отправляем текст в сервис суммаризации
|
||||
summarize_task = {
|
||||
"chat_id": task.chat_id,
|
||||
"user_id": task.user_id,
|
||||
|
@ -25,11 +24,18 @@ async def process_audio_task(redis_client: RedisClient, transcriber: WhisperTran
|
|||
"text": text
|
||||
}
|
||||
await redis_client.send_to_summarize(summarize_task)
|
||||
print(f"Sent text to summarize service for task {task.uuid}")
|
||||
logging.info(f"Sent text to summarize service for task {task.uuid}")
|
||||
|
||||
|
||||
async def main():
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
|
||||
config = load_config()
|
||||
logging.info(f"Loaded config: REDIS_HOST={config['REDIS_HOST']}, REDIS_PORT={config['REDIS_PORT']}")
|
||||
|
||||
redis_client = RedisClient(
|
||||
host=config["REDIS_HOST"],
|
||||
port=config["REDIS_PORT"],
|
||||
|
@ -38,12 +44,14 @@ async def main():
|
|||
text_task_channel="text_task_channel"
|
||||
)
|
||||
transcriber = WhisperTranscriber(config["WHISPER_MODEL"], config["DEVICE"])
|
||||
logging.info(f"Initialized transcriber with model {config['WHISPER_MODEL']} on {config['DEVICE']}")
|
||||
|
||||
print("Waiting for audio tasks...")
|
||||
logging.info(f"Waiting for audio tasks in channel {config['AUDIO_TASK_CHANNEL']}...")
|
||||
|
||||
while True:
|
||||
task_data = await redis_client.get_task(timeout=1)
|
||||
if task_data:
|
||||
logging.info(f"Received task: {task_data.get('uuid', 'unknown')}")
|
||||
asyncio.create_task(process_audio_task(redis_client, transcriber, task_data))
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
|
|
|
@ -20,8 +20,8 @@ class RedisClient:
|
|||
return None
|
||||
|
||||
async def publish_result(self, result: dict):
|
||||
"""Публикует результат в канал результатов"""
|
||||
await self.client.publish(self.result_channel, json.dumps(result))
|
||||
"""Отправляет результат в очередь результатов"""
|
||||
await self.client.rpush(self.result_channel, json.dumps(result))
|
||||
|
||||
async def send_to_summarize(self, task_data: dict):
|
||||
"""Отправляет текст в сервис суммаризации"""
|
||||
|
|
|
@ -9,4 +9,3 @@ torchaudio==2.5.1
|
|||
transformers
|
||||
redis>=4.2.0
|
||||
python-dotenv
|
||||
httpx
|
|
@ -41,5 +41,6 @@ class RedisClient:
|
|||
return tasks
|
||||
|
||||
def publish_result(self, result: dict):
|
||||
"""Отправляет результат в очередь результатов"""
|
||||
result_json = json.dumps(result)
|
||||
self.client.publish(self.result_channel, result_json)
|
||||
self.client.rpush(self.result_channel, result_json)
|
||||
|
|
|
@ -1,16 +1,24 @@
|
|||
# app/worker.py
|
||||
import time
|
||||
import logging
|
||||
from model_loader import ModelLoader
|
||||
from inference_service import InferenceService
|
||||
from redis_client import RedisClient
|
||||
from config import BASE_MODEL, ADAPTER_DIR, HF_TOKEN, REDIS_HOST, REDIS_PORT, TEXT_RESULT_CHANNEL, TEXT_TASK_CHANNEL, BATCH_SIZE, WAIT_TIMEOUT
|
||||
|
||||
def main():
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
|
||||
logging.info("Инициализация модели...")
|
||||
model_loader = ModelLoader(BASE_MODEL, ADAPTER_DIR, HF_TOKEN)
|
||||
model_loader.load_model()
|
||||
model_loader.load_tokenizer()
|
||||
inference_service = InferenceService(model_loader)
|
||||
|
||||
logging.info(f"Подключение к Redis: {REDIS_HOST}:{REDIS_PORT}")
|
||||
redis_client = RedisClient(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
|
@ -18,7 +26,7 @@ def main():
|
|||
result_channel=TEXT_RESULT_CHANNEL
|
||||
)
|
||||
|
||||
print("Worker запущен, ожидаем задачи...")
|
||||
logging.info(f"Worker запущен, ожидаем задачи в канале {TEXT_TASK_CHANNEL}...")
|
||||
|
||||
while True:
|
||||
tasks = redis_client.get_tasks(BATCH_SIZE, wait_timeout=WAIT_TIMEOUT)
|
||||
|
@ -26,6 +34,7 @@ def main():
|
|||
time.sleep(0.5)
|
||||
continue
|
||||
|
||||
logging.info(f"Получено {len(tasks)} задач для обработки")
|
||||
texts = [task.text for task in tasks]
|
||||
responses = inference_service.generate_batch(texts)
|
||||
for task, response in zip(tasks, responses):
|
||||
|
@ -36,7 +45,7 @@ def main():
|
|||
"text": response
|
||||
}
|
||||
redis_client.publish_result(result)
|
||||
print(f"Обработана задача {task.message_id}")
|
||||
logging.info(f"Обработана задача {task.message_id}, результат отправлен в канал {TEXT_RESULT_CHANNEL}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import os
|
||||
import subprocess
|
||||
import uuid
|
||||
import logging
|
||||
from functools import partial
|
||||
from aiogram import types, Dispatcher, F
|
||||
|
||||
|
@ -34,7 +35,6 @@ async def handle_voice_and_video(message: types.Message, redis_service, storage_
|
|||
|
||||
os.remove(temp_destination)
|
||||
|
||||
# Отправляем сообщение пользователю о начале обработки
|
||||
processing_msg = await message.reply("Обрабатываю аудио, пожалуйста, подождите...")
|
||||
|
||||
task_data = {
|
||||
|
@ -45,23 +45,25 @@ async def handle_voice_and_video(message: types.Message, redis_service, storage_
|
|||
"message_id": message.message_id
|
||||
}
|
||||
|
||||
logging.info(f"Отправляю задачу в Redis: {task_data}")
|
||||
await redis_service.publish_task(task_data)
|
||||
|
||||
logging.info(f"Ожидаю результат для сообщения {message.message_id}")
|
||||
text = await redis_service.wait_for_text(
|
||||
user_id=message.from_user.id,
|
||||
chat_id=message.chat.id,
|
||||
message_id=message.message_id
|
||||
)
|
||||
|
||||
# Удаляем временный файл
|
||||
os.remove(wav_destination)
|
||||
|
||||
# Удаляем сообщение о обработке
|
||||
await processing_msg.delete()
|
||||
|
||||
if text:
|
||||
logging.info(f"Получен результат для сообщения {message.message_id}, длина текста: {len(text)}")
|
||||
await send_long_message(message, text)
|
||||
else:
|
||||
logging.warning(f"Не получен результат для сообщения {message.message_id}")
|
||||
await message.reply("К сожалению, не удалось получить результат обработки в отведенное время.")
|
||||
|
||||
async def send_long_message(message: types.Message, text: str):
|
||||
|
|
|
@ -7,6 +7,11 @@ from handlers import register_all_handlers
|
|||
from services.redis_service import RedisService
|
||||
|
||||
async def main():
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
|
||||
config = load_config()
|
||||
|
||||
bot = Bot(token=config.TELEGRAM_TOKEN, default=DefaultBotProperties(parse_mode="HTML"))
|
||||
|
@ -14,14 +19,15 @@ async def main():
|
|||
dp = Dispatcher(bot=bot)
|
||||
|
||||
redis_service = RedisService(config.REDIS_HOST, config.REDIS_PORT)
|
||||
logging.info(f"Подключение к Redis: {config.REDIS_HOST}:{config.REDIS_PORT}")
|
||||
|
||||
register_all_handlers(dp, redis_service, config.BOT_STORAGE_PATH)
|
||||
|
||||
try:
|
||||
logging.info("Запуск бота")
|
||||
await dp.start_polling(bot)
|
||||
finally:
|
||||
await bot.session.close()
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
asyncio.run(main())
|
||||
|
|
|
@ -17,6 +17,7 @@ class RedisService:
|
|||
start_time = asyncio.get_event_loop().time()
|
||||
|
||||
while asyncio.get_event_loop().time() - start_time < timeout:
|
||||
try:
|
||||
result = await self.client.blpop(self.result_channel, timeout=1)
|
||||
if not result:
|
||||
continue
|
||||
|
@ -24,12 +25,20 @@ class RedisService:
|
|||
_, data_json = result
|
||||
try:
|
||||
data = json.loads(data_json)
|
||||
except Exception:
|
||||
print(f"Получен результат: {data}")
|
||||
except Exception as e:
|
||||
print(f"Ошибка при разборе JSON: {e}")
|
||||
continue
|
||||
|
||||
if (data.get("user_id") == user_id and
|
||||
data.get("chat_id") == chat_id and
|
||||
data.get("message_id") == message_id):
|
||||
return data.get("text")
|
||||
else:
|
||||
# Если это не наш результат, вернем его обратно в очередь
|
||||
await self.client.rpush(self.result_channel, data_json)
|
||||
except Exception as e:
|
||||
print(f"Ошибка при ожидании результата: {e}")
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
return None
|
||||
|
|
Loading…
Reference in New Issue