Add summarize service v0.1

This commit is contained in:
itqop 2025-03-03 03:35:19 +03:00
parent 053bed6e3c
commit 8e5e9562f5
7 changed files with 247 additions and 0 deletions

View File

@ -0,0 +1,22 @@
FROM nvidia/cuda:12.2.0-runtime-ubuntu22.04
RUN apt update && apt install -y \
python3.10 \
python3-pip \
curl \
pciutils \
lshw \
ffmpeg \
&& rm -rf /var/lib/apt/lists/*
WORKDIR /app
COPY requirements.txt ./
RUN pip install --no-cache-dir -r requirements.txt
COPY . .
ENV NVIDIA_VISIBLE_DEVICES=all
ENV NVIDIA_DRIVER_CAPABILITIES=compute,utility
CMD ["python3", "app/worker.py"]

View File

@ -0,0 +1,15 @@
import os
MAX_INPUT_LENGTH = int(os.environ.get("MAX_INPUT_LENGTH", "1024"))
MAX_OUTPUT_LENGTH = int(os.environ.get("MAX_OUTPUT_LENGTH", "1024"))
BATCH_SIZE = int(os.environ.get("BATCH_SIZE", "4"))
WAIT_TIMEOUT = int(os.environ.get("WAIT_TIMEOUT", "1"))
REDIS_HOST = os.environ.get("REDIS_HOST", "redis")
REDIS_PORT = int(os.environ.get("REDIS_PORT", "6379"))
TEXT_TASK_CHANNEL = os.environ.get("TEXT_TASK_CHANNEL", "text_task_channel")
TEXT_RESULT_CHANNEL = os.environ.get("TEXT_RESULT_CHANNEL", "text_result_channel")
BASE_MODEL = os.environ.get("BASE_MODEL", "google/gemma-2-2b")
ADAPTER_DIR = os.environ.get("ADAPTER_DIR", "./gemma-2-2b_lora")
HF_TOKEN = os.environ.get("HF_TOKEN", "")

View File

@ -0,0 +1,50 @@
from app.config import MAX_INPUT_LENGTH, MAX_OUTPUT_LENGTH
class InferenceService:
def __init__(self, model_loader: "ModelLoader"):
self.model_loader = model_loader
self.model = self.model_loader.get_model()
self.tokenizer = self.model_loader.get_tokenizer()
self.device = self.model_loader.get_device()
self.max_input_length = MAX_INPUT_LENGTH
self.max_output_length = MAX_OUTPUT_LENGTH
def generate_response(self, prompt: str) -> str:
full_prompt = f"Запрос: {prompt}\nОжидаемый ответ:"
inputs = self.tokenizer(
full_prompt,
return_tensors="pt",
truncation=True,
max_length=self.max_input_length,
padding="max_length"
).to(self.device)
outputs = self.model.generate(
inputs.input_ids,
max_length=self.max_output_length,
num_beams=5,
early_stopping=True,
no_repeat_ngram_size=2
)
result = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
return result
def generate_batch(self, prompts: list) -> list:
full_prompts = [f"Запрос: {p}\nОжидаемый ответ:" for p in prompts]
inputs = self.tokenizer(
full_prompts,
return_tensors="pt",
truncation=True,
max_length=self.max_input_length,
padding="longest"
).to(self.device)
outputs = self.model.generate(
inputs.input_ids,
max_length=self.max_output_length,
num_beams=5,
early_stopping=True,
no_repeat_ngram_size=2
)
responses = [self.tokenizer.decode(out, skip_special_tokens=True) for out in outputs]
return responses

View File

@ -0,0 +1,57 @@
# app/model_loader.py
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import PeftModel
class ModelLoader:
def __init__(self, base_model_name: str, adapter_dir: str, hf_token: str, use_4bit: bool = True):
self.base_model_name = base_model_name
self.adapter_dir = adapter_dir
self.hf_token = hf_token
self.use_4bit = use_4bit
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = None
self.tokenizer = None
def load_model(self):
bnb_config = BitsAndBytesConfig(
load_in_4bit=self.use_4bit,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)
base_model = AutoModelForCausalLM.from_pretrained(
self.base_model_name,
torch_dtype=torch.float16,
device_map="auto",
attn_implementation="eager",
quantization_config=bnb_config,
token=self.hf_token
)
self.model = PeftModel.from_pretrained(
base_model,
self.adapter_dir,
local_files_only=True
)
self.model.eval()
self.model.to(self.device)
def load_tokenizer(self):
self.tokenizer = AutoTokenizer.from_pretrained(self.base_model_name, local_files_only=True)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
def get_model(self):
if self.model is None:
self.load_model()
return self.model
def get_tokenizer(self):
if self.tokenizer is None:
self.load_tokenizer()
return self.tokenizer
def get_device(self):
return self.device

View File

@ -0,0 +1,47 @@
# app/redis_client.py
import os
import json
import time
import redis
from typing import List
from pydantic import BaseModel
class Task(BaseModel):
chat_id: int
user_id: int
message_id: int
text: str
class RedisClient:
def __init__(self, host: str, port: int, task_channel: str, result_channel: str):
self.host = host
self.port = port
self.task_channel = task_channel
self.result_channel = result_channel
self.client = redis.Redis(host=self.host, port=self.port, decode_responses=True)
def get_tasks(self, batch_size: int, wait_timeout: int = 5) -> List[Task]:
tasks = []
res = self.client.blpop(self.task_channel, timeout=wait_timeout)
if res:
_, task_json = res
try:
task = Task.parse_raw(task_json)
tasks.append(task)
except Exception as e:
print("Ошибка парсинга задачи:", e)
while len(tasks) < batch_size:
task_json = self.client.lpop(self.task_channel)
if task_json is None:
break
try:
task = Task.parse_raw(task_json)
tasks.append(task)
except Exception as e:
print("Ошибка парсинга задачи:", e)
return tasks
def publish_result(self, result: dict):
result_json = json.dumps(result)
self.client.rpush(self.result_channel, result_json)

View File

@ -0,0 +1,42 @@
# app/worker.py
import time
from app.model_loader import ModelLoader
from app.inference_service import InferenceService
from app.redis_client import RedisClient
from config import BASE_MODEL, ADAPTER_DIR, HF_TOKEN, REDIS_HOST, REDIS_PORT, TEXT_RESULT_CHANNEL, TEXT_TASK_CHANNEL, BATCH_SIZE, WAIT_TIMEOUT
def main():
model_loader = ModelLoader(BASE_MODEL, ADAPTER_DIR, HF_TOKEN)
model_loader.load_model()
model_loader.load_tokenizer()
inference_service = InferenceService(model_loader)
redis_client = RedisClient(
host=REDIS_HOST,
port=REDIS_PORT,
task_channel=TEXT_TASK_CHANNEL,
result_channel=TEXT_RESULT_CHANNEL
)
print("Worker запущен, ожидаем задачи...")
while True:
tasks = redis_client.get_tasks(BATCH_SIZE, wait_timeout=WAIT_TIMEOUT)
if not tasks:
time.sleep(0.5)
continue
texts = [task.text for task in tasks]
responses = inference_service.generate_batch(texts)
for task, response in zip(tasks, responses):
result = {
"chat_id": task.chat_id,
"user_id": task.user_id,
"message_id": task.message_id,
"text": response
}
redis_client.publish_result(result)
print(f"Обработана задача {task.message_id}")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,14 @@
--index-url https://download.pytorch.org/whl/cu121
torch==2.5.1
--index-url https://pypi.org/simple
transformers
redis>=4.2.0
python-dotenv
redis
pydantic
peft
bitsandbytes
flash-attention