Add summarize service v0.1
This commit is contained in:
parent
053bed6e3c
commit
8e5e9562f5
|
@ -0,0 +1,22 @@
|
||||||
|
FROM nvidia/cuda:12.2.0-runtime-ubuntu22.04
|
||||||
|
|
||||||
|
RUN apt update && apt install -y \
|
||||||
|
python3.10 \
|
||||||
|
python3-pip \
|
||||||
|
curl \
|
||||||
|
pciutils \
|
||||||
|
lshw \
|
||||||
|
ffmpeg \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
COPY requirements.txt ./
|
||||||
|
RUN pip install --no-cache-dir -r requirements.txt
|
||||||
|
|
||||||
|
COPY . .
|
||||||
|
|
||||||
|
ENV NVIDIA_VISIBLE_DEVICES=all
|
||||||
|
ENV NVIDIA_DRIVER_CAPABILITIES=compute,utility
|
||||||
|
|
||||||
|
CMD ["python3", "app/worker.py"]
|
|
@ -0,0 +1,15 @@
|
||||||
|
import os
|
||||||
|
|
||||||
|
MAX_INPUT_LENGTH = int(os.environ.get("MAX_INPUT_LENGTH", "1024"))
|
||||||
|
MAX_OUTPUT_LENGTH = int(os.environ.get("MAX_OUTPUT_LENGTH", "1024"))
|
||||||
|
BATCH_SIZE = int(os.environ.get("BATCH_SIZE", "4"))
|
||||||
|
WAIT_TIMEOUT = int(os.environ.get("WAIT_TIMEOUT", "1"))
|
||||||
|
|
||||||
|
REDIS_HOST = os.environ.get("REDIS_HOST", "redis")
|
||||||
|
REDIS_PORT = int(os.environ.get("REDIS_PORT", "6379"))
|
||||||
|
TEXT_TASK_CHANNEL = os.environ.get("TEXT_TASK_CHANNEL", "text_task_channel")
|
||||||
|
TEXT_RESULT_CHANNEL = os.environ.get("TEXT_RESULT_CHANNEL", "text_result_channel")
|
||||||
|
|
||||||
|
BASE_MODEL = os.environ.get("BASE_MODEL", "google/gemma-2-2b")
|
||||||
|
ADAPTER_DIR = os.environ.get("ADAPTER_DIR", "./gemma-2-2b_lora")
|
||||||
|
HF_TOKEN = os.environ.get("HF_TOKEN", "")
|
|
@ -0,0 +1,50 @@
|
||||||
|
from app.config import MAX_INPUT_LENGTH, MAX_OUTPUT_LENGTH
|
||||||
|
|
||||||
|
class InferenceService:
|
||||||
|
def __init__(self, model_loader: "ModelLoader"):
|
||||||
|
self.model_loader = model_loader
|
||||||
|
self.model = self.model_loader.get_model()
|
||||||
|
self.tokenizer = self.model_loader.get_tokenizer()
|
||||||
|
self.device = self.model_loader.get_device()
|
||||||
|
self.max_input_length = MAX_INPUT_LENGTH
|
||||||
|
self.max_output_length = MAX_OUTPUT_LENGTH
|
||||||
|
|
||||||
|
def generate_response(self, prompt: str) -> str:
|
||||||
|
full_prompt = f"Запрос: {prompt}\nОжидаемый ответ:"
|
||||||
|
inputs = self.tokenizer(
|
||||||
|
full_prompt,
|
||||||
|
return_tensors="pt",
|
||||||
|
truncation=True,
|
||||||
|
max_length=self.max_input_length,
|
||||||
|
padding="max_length"
|
||||||
|
).to(self.device)
|
||||||
|
|
||||||
|
outputs = self.model.generate(
|
||||||
|
inputs.input_ids,
|
||||||
|
max_length=self.max_output_length,
|
||||||
|
num_beams=5,
|
||||||
|
early_stopping=True,
|
||||||
|
no_repeat_ngram_size=2
|
||||||
|
)
|
||||||
|
result = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def generate_batch(self, prompts: list) -> list:
|
||||||
|
full_prompts = [f"Запрос: {p}\nОжидаемый ответ:" for p in prompts]
|
||||||
|
inputs = self.tokenizer(
|
||||||
|
full_prompts,
|
||||||
|
return_tensors="pt",
|
||||||
|
truncation=True,
|
||||||
|
max_length=self.max_input_length,
|
||||||
|
padding="longest"
|
||||||
|
).to(self.device)
|
||||||
|
|
||||||
|
outputs = self.model.generate(
|
||||||
|
inputs.input_ids,
|
||||||
|
max_length=self.max_output_length,
|
||||||
|
num_beams=5,
|
||||||
|
early_stopping=True,
|
||||||
|
no_repeat_ngram_size=2
|
||||||
|
)
|
||||||
|
responses = [self.tokenizer.decode(out, skip_special_tokens=True) for out in outputs]
|
||||||
|
return responses
|
|
@ -0,0 +1,57 @@
|
||||||
|
# app/model_loader.py
|
||||||
|
import torch
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
||||||
|
from peft import PeftModel
|
||||||
|
|
||||||
|
class ModelLoader:
|
||||||
|
def __init__(self, base_model_name: str, adapter_dir: str, hf_token: str, use_4bit: bool = True):
|
||||||
|
self.base_model_name = base_model_name
|
||||||
|
self.adapter_dir = adapter_dir
|
||||||
|
self.hf_token = hf_token
|
||||||
|
self.use_4bit = use_4bit
|
||||||
|
|
||||||
|
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
self.model = None
|
||||||
|
self.tokenizer = None
|
||||||
|
|
||||||
|
def load_model(self):
|
||||||
|
bnb_config = BitsAndBytesConfig(
|
||||||
|
load_in_4bit=self.use_4bit,
|
||||||
|
bnb_4bit_quant_type="nf4",
|
||||||
|
bnb_4bit_compute_dtype=torch.float16,
|
||||||
|
)
|
||||||
|
|
||||||
|
base_model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
self.base_model_name,
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
device_map="auto",
|
||||||
|
attn_implementation="eager",
|
||||||
|
quantization_config=bnb_config,
|
||||||
|
token=self.hf_token
|
||||||
|
)
|
||||||
|
|
||||||
|
self.model = PeftModel.from_pretrained(
|
||||||
|
base_model,
|
||||||
|
self.adapter_dir,
|
||||||
|
local_files_only=True
|
||||||
|
)
|
||||||
|
self.model.eval()
|
||||||
|
self.model.to(self.device)
|
||||||
|
|
||||||
|
def load_tokenizer(self):
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(self.base_model_name, local_files_only=True)
|
||||||
|
if self.tokenizer.pad_token is None:
|
||||||
|
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||||
|
|
||||||
|
def get_model(self):
|
||||||
|
if self.model is None:
|
||||||
|
self.load_model()
|
||||||
|
return self.model
|
||||||
|
|
||||||
|
def get_tokenizer(self):
|
||||||
|
if self.tokenizer is None:
|
||||||
|
self.load_tokenizer()
|
||||||
|
return self.tokenizer
|
||||||
|
|
||||||
|
def get_device(self):
|
||||||
|
return self.device
|
|
@ -0,0 +1,47 @@
|
||||||
|
# app/redis_client.py
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import redis
|
||||||
|
from typing import List
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
class Task(BaseModel):
|
||||||
|
chat_id: int
|
||||||
|
user_id: int
|
||||||
|
message_id: int
|
||||||
|
text: str
|
||||||
|
|
||||||
|
class RedisClient:
|
||||||
|
def __init__(self, host: str, port: int, task_channel: str, result_channel: str):
|
||||||
|
self.host = host
|
||||||
|
self.port = port
|
||||||
|
self.task_channel = task_channel
|
||||||
|
self.result_channel = result_channel
|
||||||
|
self.client = redis.Redis(host=self.host, port=self.port, decode_responses=True)
|
||||||
|
|
||||||
|
def get_tasks(self, batch_size: int, wait_timeout: int = 5) -> List[Task]:
|
||||||
|
tasks = []
|
||||||
|
res = self.client.blpop(self.task_channel, timeout=wait_timeout)
|
||||||
|
if res:
|
||||||
|
_, task_json = res
|
||||||
|
try:
|
||||||
|
task = Task.parse_raw(task_json)
|
||||||
|
tasks.append(task)
|
||||||
|
except Exception as e:
|
||||||
|
print("Ошибка парсинга задачи:", e)
|
||||||
|
|
||||||
|
while len(tasks) < batch_size:
|
||||||
|
task_json = self.client.lpop(self.task_channel)
|
||||||
|
if task_json is None:
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
task = Task.parse_raw(task_json)
|
||||||
|
tasks.append(task)
|
||||||
|
except Exception as e:
|
||||||
|
print("Ошибка парсинга задачи:", e)
|
||||||
|
return tasks
|
||||||
|
|
||||||
|
def publish_result(self, result: dict):
|
||||||
|
result_json = json.dumps(result)
|
||||||
|
self.client.rpush(self.result_channel, result_json)
|
|
@ -0,0 +1,42 @@
|
||||||
|
# app/worker.py
|
||||||
|
import time
|
||||||
|
from app.model_loader import ModelLoader
|
||||||
|
from app.inference_service import InferenceService
|
||||||
|
from app.redis_client import RedisClient
|
||||||
|
from config import BASE_MODEL, ADAPTER_DIR, HF_TOKEN, REDIS_HOST, REDIS_PORT, TEXT_RESULT_CHANNEL, TEXT_TASK_CHANNEL, BATCH_SIZE, WAIT_TIMEOUT
|
||||||
|
|
||||||
|
def main():
|
||||||
|
model_loader = ModelLoader(BASE_MODEL, ADAPTER_DIR, HF_TOKEN)
|
||||||
|
model_loader.load_model()
|
||||||
|
model_loader.load_tokenizer()
|
||||||
|
inference_service = InferenceService(model_loader)
|
||||||
|
|
||||||
|
redis_client = RedisClient(
|
||||||
|
host=REDIS_HOST,
|
||||||
|
port=REDIS_PORT,
|
||||||
|
task_channel=TEXT_TASK_CHANNEL,
|
||||||
|
result_channel=TEXT_RESULT_CHANNEL
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Worker запущен, ожидаем задачи...")
|
||||||
|
|
||||||
|
while True:
|
||||||
|
tasks = redis_client.get_tasks(BATCH_SIZE, wait_timeout=WAIT_TIMEOUT)
|
||||||
|
if not tasks:
|
||||||
|
time.sleep(0.5)
|
||||||
|
continue
|
||||||
|
|
||||||
|
texts = [task.text for task in tasks]
|
||||||
|
responses = inference_service.generate_batch(texts)
|
||||||
|
for task, response in zip(tasks, responses):
|
||||||
|
result = {
|
||||||
|
"chat_id": task.chat_id,
|
||||||
|
"user_id": task.user_id,
|
||||||
|
"message_id": task.message_id,
|
||||||
|
"text": response
|
||||||
|
}
|
||||||
|
redis_client.publish_result(result)
|
||||||
|
print(f"Обработана задача {task.message_id}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -0,0 +1,14 @@
|
||||||
|
--index-url https://download.pytorch.org/whl/cu121
|
||||||
|
|
||||||
|
torch==2.5.1
|
||||||
|
|
||||||
|
--index-url https://pypi.org/simple
|
||||||
|
|
||||||
|
transformers
|
||||||
|
redis>=4.2.0
|
||||||
|
python-dotenv
|
||||||
|
redis
|
||||||
|
pydantic
|
||||||
|
peft
|
||||||
|
bitsandbytes
|
||||||
|
flash-attention
|
Loading…
Reference in New Issue