Add summarize service v0.1
This commit is contained in:
parent
053bed6e3c
commit
8e5e9562f5
|
@ -0,0 +1,22 @@
|
|||
FROM nvidia/cuda:12.2.0-runtime-ubuntu22.04
|
||||
|
||||
RUN apt update && apt install -y \
|
||||
python3.10 \
|
||||
python3-pip \
|
||||
curl \
|
||||
pciutils \
|
||||
lshw \
|
||||
ffmpeg \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY requirements.txt ./
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
COPY . .
|
||||
|
||||
ENV NVIDIA_VISIBLE_DEVICES=all
|
||||
ENV NVIDIA_DRIVER_CAPABILITIES=compute,utility
|
||||
|
||||
CMD ["python3", "app/worker.py"]
|
|
@ -0,0 +1,15 @@
|
|||
import os
|
||||
|
||||
MAX_INPUT_LENGTH = int(os.environ.get("MAX_INPUT_LENGTH", "1024"))
|
||||
MAX_OUTPUT_LENGTH = int(os.environ.get("MAX_OUTPUT_LENGTH", "1024"))
|
||||
BATCH_SIZE = int(os.environ.get("BATCH_SIZE", "4"))
|
||||
WAIT_TIMEOUT = int(os.environ.get("WAIT_TIMEOUT", "1"))
|
||||
|
||||
REDIS_HOST = os.environ.get("REDIS_HOST", "redis")
|
||||
REDIS_PORT = int(os.environ.get("REDIS_PORT", "6379"))
|
||||
TEXT_TASK_CHANNEL = os.environ.get("TEXT_TASK_CHANNEL", "text_task_channel")
|
||||
TEXT_RESULT_CHANNEL = os.environ.get("TEXT_RESULT_CHANNEL", "text_result_channel")
|
||||
|
||||
BASE_MODEL = os.environ.get("BASE_MODEL", "google/gemma-2-2b")
|
||||
ADAPTER_DIR = os.environ.get("ADAPTER_DIR", "./gemma-2-2b_lora")
|
||||
HF_TOKEN = os.environ.get("HF_TOKEN", "")
|
|
@ -0,0 +1,50 @@
|
|||
from app.config import MAX_INPUT_LENGTH, MAX_OUTPUT_LENGTH
|
||||
|
||||
class InferenceService:
|
||||
def __init__(self, model_loader: "ModelLoader"):
|
||||
self.model_loader = model_loader
|
||||
self.model = self.model_loader.get_model()
|
||||
self.tokenizer = self.model_loader.get_tokenizer()
|
||||
self.device = self.model_loader.get_device()
|
||||
self.max_input_length = MAX_INPUT_LENGTH
|
||||
self.max_output_length = MAX_OUTPUT_LENGTH
|
||||
|
||||
def generate_response(self, prompt: str) -> str:
|
||||
full_prompt = f"Запрос: {prompt}\nОжидаемый ответ:"
|
||||
inputs = self.tokenizer(
|
||||
full_prompt,
|
||||
return_tensors="pt",
|
||||
truncation=True,
|
||||
max_length=self.max_input_length,
|
||||
padding="max_length"
|
||||
).to(self.device)
|
||||
|
||||
outputs = self.model.generate(
|
||||
inputs.input_ids,
|
||||
max_length=self.max_output_length,
|
||||
num_beams=5,
|
||||
early_stopping=True,
|
||||
no_repeat_ngram_size=2
|
||||
)
|
||||
result = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||||
return result
|
||||
|
||||
def generate_batch(self, prompts: list) -> list:
|
||||
full_prompts = [f"Запрос: {p}\nОжидаемый ответ:" for p in prompts]
|
||||
inputs = self.tokenizer(
|
||||
full_prompts,
|
||||
return_tensors="pt",
|
||||
truncation=True,
|
||||
max_length=self.max_input_length,
|
||||
padding="longest"
|
||||
).to(self.device)
|
||||
|
||||
outputs = self.model.generate(
|
||||
inputs.input_ids,
|
||||
max_length=self.max_output_length,
|
||||
num_beams=5,
|
||||
early_stopping=True,
|
||||
no_repeat_ngram_size=2
|
||||
)
|
||||
responses = [self.tokenizer.decode(out, skip_special_tokens=True) for out in outputs]
|
||||
return responses
|
|
@ -0,0 +1,57 @@
|
|||
# app/model_loader.py
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
||||
from peft import PeftModel
|
||||
|
||||
class ModelLoader:
|
||||
def __init__(self, base_model_name: str, adapter_dir: str, hf_token: str, use_4bit: bool = True):
|
||||
self.base_model_name = base_model_name
|
||||
self.adapter_dir = adapter_dir
|
||||
self.hf_token = hf_token
|
||||
self.use_4bit = use_4bit
|
||||
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
self.model = None
|
||||
self.tokenizer = None
|
||||
|
||||
def load_model(self):
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
load_in_4bit=self.use_4bit,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_compute_dtype=torch.float16,
|
||||
)
|
||||
|
||||
base_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.base_model_name,
|
||||
torch_dtype=torch.float16,
|
||||
device_map="auto",
|
||||
attn_implementation="eager",
|
||||
quantization_config=bnb_config,
|
||||
token=self.hf_token
|
||||
)
|
||||
|
||||
self.model = PeftModel.from_pretrained(
|
||||
base_model,
|
||||
self.adapter_dir,
|
||||
local_files_only=True
|
||||
)
|
||||
self.model.eval()
|
||||
self.model.to(self.device)
|
||||
|
||||
def load_tokenizer(self):
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.base_model_name, local_files_only=True)
|
||||
if self.tokenizer.pad_token is None:
|
||||
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||
|
||||
def get_model(self):
|
||||
if self.model is None:
|
||||
self.load_model()
|
||||
return self.model
|
||||
|
||||
def get_tokenizer(self):
|
||||
if self.tokenizer is None:
|
||||
self.load_tokenizer()
|
||||
return self.tokenizer
|
||||
|
||||
def get_device(self):
|
||||
return self.device
|
|
@ -0,0 +1,47 @@
|
|||
# app/redis_client.py
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
import redis
|
||||
from typing import List
|
||||
from pydantic import BaseModel
|
||||
|
||||
class Task(BaseModel):
|
||||
chat_id: int
|
||||
user_id: int
|
||||
message_id: int
|
||||
text: str
|
||||
|
||||
class RedisClient:
|
||||
def __init__(self, host: str, port: int, task_channel: str, result_channel: str):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.task_channel = task_channel
|
||||
self.result_channel = result_channel
|
||||
self.client = redis.Redis(host=self.host, port=self.port, decode_responses=True)
|
||||
|
||||
def get_tasks(self, batch_size: int, wait_timeout: int = 5) -> List[Task]:
|
||||
tasks = []
|
||||
res = self.client.blpop(self.task_channel, timeout=wait_timeout)
|
||||
if res:
|
||||
_, task_json = res
|
||||
try:
|
||||
task = Task.parse_raw(task_json)
|
||||
tasks.append(task)
|
||||
except Exception as e:
|
||||
print("Ошибка парсинга задачи:", e)
|
||||
|
||||
while len(tasks) < batch_size:
|
||||
task_json = self.client.lpop(self.task_channel)
|
||||
if task_json is None:
|
||||
break
|
||||
try:
|
||||
task = Task.parse_raw(task_json)
|
||||
tasks.append(task)
|
||||
except Exception as e:
|
||||
print("Ошибка парсинга задачи:", e)
|
||||
return tasks
|
||||
|
||||
def publish_result(self, result: dict):
|
||||
result_json = json.dumps(result)
|
||||
self.client.rpush(self.result_channel, result_json)
|
|
@ -0,0 +1,42 @@
|
|||
# app/worker.py
|
||||
import time
|
||||
from app.model_loader import ModelLoader
|
||||
from app.inference_service import InferenceService
|
||||
from app.redis_client import RedisClient
|
||||
from config import BASE_MODEL, ADAPTER_DIR, HF_TOKEN, REDIS_HOST, REDIS_PORT, TEXT_RESULT_CHANNEL, TEXT_TASK_CHANNEL, BATCH_SIZE, WAIT_TIMEOUT
|
||||
|
||||
def main():
|
||||
model_loader = ModelLoader(BASE_MODEL, ADAPTER_DIR, HF_TOKEN)
|
||||
model_loader.load_model()
|
||||
model_loader.load_tokenizer()
|
||||
inference_service = InferenceService(model_loader)
|
||||
|
||||
redis_client = RedisClient(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
task_channel=TEXT_TASK_CHANNEL,
|
||||
result_channel=TEXT_RESULT_CHANNEL
|
||||
)
|
||||
|
||||
print("Worker запущен, ожидаем задачи...")
|
||||
|
||||
while True:
|
||||
tasks = redis_client.get_tasks(BATCH_SIZE, wait_timeout=WAIT_TIMEOUT)
|
||||
if not tasks:
|
||||
time.sleep(0.5)
|
||||
continue
|
||||
|
||||
texts = [task.text for task in tasks]
|
||||
responses = inference_service.generate_batch(texts)
|
||||
for task, response in zip(tasks, responses):
|
||||
result = {
|
||||
"chat_id": task.chat_id,
|
||||
"user_id": task.user_id,
|
||||
"message_id": task.message_id,
|
||||
"text": response
|
||||
}
|
||||
redis_client.publish_result(result)
|
||||
print(f"Обработана задача {task.message_id}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,14 @@
|
|||
--index-url https://download.pytorch.org/whl/cu121
|
||||
|
||||
torch==2.5.1
|
||||
|
||||
--index-url https://pypi.org/simple
|
||||
|
||||
transformers
|
||||
redis>=4.2.0
|
||||
python-dotenv
|
||||
redis
|
||||
pydantic
|
||||
peft
|
||||
bitsandbytes
|
||||
flash-attention
|
Loading…
Reference in New Issue