52 lines
1.9 KiB
Python
52 lines
1.9 KiB
Python
# app/worker.py
|
|
import time
|
|
import logging
|
|
from model_loader import ModelLoader
|
|
from inference_service import InferenceService
|
|
from redis_client import RedisClient
|
|
from config import BASE_MODEL, ADAPTER_DIR, HF_TOKEN, REDIS_HOST, REDIS_PORT, TEXT_RESULT_CHANNEL, TEXT_TASK_CHANNEL, BATCH_SIZE, WAIT_TIMEOUT
|
|
|
|
def main():
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
)
|
|
|
|
logging.info("Инициализация модели...")
|
|
model_loader = ModelLoader(BASE_MODEL, ADAPTER_DIR, HF_TOKEN)
|
|
model_loader.load_model()
|
|
model_loader.load_tokenizer()
|
|
inference_service = InferenceService(model_loader)
|
|
|
|
logging.info(f"Подключение к Redis: {REDIS_HOST}:{REDIS_PORT}")
|
|
redis_client = RedisClient(
|
|
host=REDIS_HOST,
|
|
port=REDIS_PORT,
|
|
task_channel=TEXT_TASK_CHANNEL,
|
|
result_channel=TEXT_RESULT_CHANNEL
|
|
)
|
|
|
|
logging.info(f"Worker запущен, ожидаем задачи в канале {TEXT_TASK_CHANNEL}...")
|
|
|
|
while True:
|
|
tasks = redis_client.get_tasks(BATCH_SIZE, wait_timeout=WAIT_TIMEOUT)
|
|
if not tasks:
|
|
time.sleep(0.5)
|
|
continue
|
|
|
|
logging.info(f"Получено {len(tasks)} задач для обработки")
|
|
texts = [task.text for task in tasks]
|
|
responses = inference_service.generate_batch(texts)
|
|
for task, response in zip(tasks, responses):
|
|
result = {
|
|
"chat_id": task.chat_id,
|
|
"user_id": task.user_id,
|
|
"message_id": task.message_id,
|
|
"text": response
|
|
}
|
|
redis_client.publish_result(result)
|
|
logging.info(f"Обработана задача {task.message_id}, результат отправлен в канал {TEXT_RESULT_CHANNEL}")
|
|
|
|
if __name__ == "__main__":
|
|
main()
|