diff --git a/summarize_service/Dockerfile b/summarize_service/Dockerfile new file mode 100644 index 0000000..92e34a5 --- /dev/null +++ b/summarize_service/Dockerfile @@ -0,0 +1,22 @@ +FROM nvidia/cuda:12.2.0-runtime-ubuntu22.04 + +RUN apt update && apt install -y \ + python3.10 \ + python3-pip \ + curl \ + pciutils \ + lshw \ + ffmpeg \ + && rm -rf /var/lib/apt/lists/* + +WORKDIR /app + +COPY requirements.txt ./ +RUN pip install --no-cache-dir -r requirements.txt + +COPY . . + +ENV NVIDIA_VISIBLE_DEVICES=all +ENV NVIDIA_DRIVER_CAPABILITIES=compute,utility + +CMD ["python3", "app/worker.py"] diff --git a/summarize_service/app/config.py b/summarize_service/app/config.py new file mode 100644 index 0000000..415f647 --- /dev/null +++ b/summarize_service/app/config.py @@ -0,0 +1,15 @@ +import os + +MAX_INPUT_LENGTH = int(os.environ.get("MAX_INPUT_LENGTH", "1024")) +MAX_OUTPUT_LENGTH = int(os.environ.get("MAX_OUTPUT_LENGTH", "1024")) +BATCH_SIZE = int(os.environ.get("BATCH_SIZE", "4")) +WAIT_TIMEOUT = int(os.environ.get("WAIT_TIMEOUT", "1")) + +REDIS_HOST = os.environ.get("REDIS_HOST", "redis") +REDIS_PORT = int(os.environ.get("REDIS_PORT", "6379")) +TEXT_TASK_CHANNEL = os.environ.get("TEXT_TASK_CHANNEL", "text_task_channel") +TEXT_RESULT_CHANNEL = os.environ.get("TEXT_RESULT_CHANNEL", "text_result_channel") + +BASE_MODEL = os.environ.get("BASE_MODEL", "google/gemma-2-2b") +ADAPTER_DIR = os.environ.get("ADAPTER_DIR", "./gemma-2-2b_lora") +HF_TOKEN = os.environ.get("HF_TOKEN", "") \ No newline at end of file diff --git a/summarize_service/app/inference_service.py b/summarize_service/app/inference_service.py new file mode 100644 index 0000000..800c2e0 --- /dev/null +++ b/summarize_service/app/inference_service.py @@ -0,0 +1,50 @@ +from app.config import MAX_INPUT_LENGTH, MAX_OUTPUT_LENGTH + +class InferenceService: + def __init__(self, model_loader: "ModelLoader"): + self.model_loader = model_loader + self.model = self.model_loader.get_model() + self.tokenizer = self.model_loader.get_tokenizer() + self.device = self.model_loader.get_device() + self.max_input_length = MAX_INPUT_LENGTH + self.max_output_length = MAX_OUTPUT_LENGTH + + def generate_response(self, prompt: str) -> str: + full_prompt = f"Запрос: {prompt}\nОжидаемый ответ:" + inputs = self.tokenizer( + full_prompt, + return_tensors="pt", + truncation=True, + max_length=self.max_input_length, + padding="max_length" + ).to(self.device) + + outputs = self.model.generate( + inputs.input_ids, + max_length=self.max_output_length, + num_beams=5, + early_stopping=True, + no_repeat_ngram_size=2 + ) + result = self.tokenizer.decode(outputs[0], skip_special_tokens=True) + return result + + def generate_batch(self, prompts: list) -> list: + full_prompts = [f"Запрос: {p}\nОжидаемый ответ:" for p in prompts] + inputs = self.tokenizer( + full_prompts, + return_tensors="pt", + truncation=True, + max_length=self.max_input_length, + padding="longest" + ).to(self.device) + + outputs = self.model.generate( + inputs.input_ids, + max_length=self.max_output_length, + num_beams=5, + early_stopping=True, + no_repeat_ngram_size=2 + ) + responses = [self.tokenizer.decode(out, skip_special_tokens=True) for out in outputs] + return responses diff --git a/summarize_service/app/model_loader.py b/summarize_service/app/model_loader.py new file mode 100644 index 0000000..896ceea --- /dev/null +++ b/summarize_service/app/model_loader.py @@ -0,0 +1,57 @@ +# app/model_loader.py +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig +from peft import PeftModel + +class ModelLoader: + def __init__(self, base_model_name: str, adapter_dir: str, hf_token: str, use_4bit: bool = True): + self.base_model_name = base_model_name + self.adapter_dir = adapter_dir + self.hf_token = hf_token + self.use_4bit = use_4bit + + self.device = "cuda" if torch.cuda.is_available() else "cpu" + self.model = None + self.tokenizer = None + + def load_model(self): + bnb_config = BitsAndBytesConfig( + load_in_4bit=self.use_4bit, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.float16, + ) + + base_model = AutoModelForCausalLM.from_pretrained( + self.base_model_name, + torch_dtype=torch.float16, + device_map="auto", + attn_implementation="eager", + quantization_config=bnb_config, + token=self.hf_token + ) + + self.model = PeftModel.from_pretrained( + base_model, + self.adapter_dir, + local_files_only=True + ) + self.model.eval() + self.model.to(self.device) + + def load_tokenizer(self): + self.tokenizer = AutoTokenizer.from_pretrained(self.base_model_name, local_files_only=True) + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + + def get_model(self): + if self.model is None: + self.load_model() + return self.model + + def get_tokenizer(self): + if self.tokenizer is None: + self.load_tokenizer() + return self.tokenizer + + def get_device(self): + return self.device diff --git a/summarize_service/app/redis_client.py b/summarize_service/app/redis_client.py new file mode 100644 index 0000000..7c15f1c --- /dev/null +++ b/summarize_service/app/redis_client.py @@ -0,0 +1,47 @@ +# app/redis_client.py +import os +import json +import time +import redis +from typing import List +from pydantic import BaseModel + +class Task(BaseModel): + chat_id: int + user_id: int + message_id: int + text: str + +class RedisClient: + def __init__(self, host: str, port: int, task_channel: str, result_channel: str): + self.host = host + self.port = port + self.task_channel = task_channel + self.result_channel = result_channel + self.client = redis.Redis(host=self.host, port=self.port, decode_responses=True) + + def get_tasks(self, batch_size: int, wait_timeout: int = 5) -> List[Task]: + tasks = [] + res = self.client.blpop(self.task_channel, timeout=wait_timeout) + if res: + _, task_json = res + try: + task = Task.parse_raw(task_json) + tasks.append(task) + except Exception as e: + print("Ошибка парсинга задачи:", e) + + while len(tasks) < batch_size: + task_json = self.client.lpop(self.task_channel) + if task_json is None: + break + try: + task = Task.parse_raw(task_json) + tasks.append(task) + except Exception as e: + print("Ошибка парсинга задачи:", e) + return tasks + + def publish_result(self, result: dict): + result_json = json.dumps(result) + self.client.rpush(self.result_channel, result_json) diff --git a/summarize_service/app/worker.py b/summarize_service/app/worker.py new file mode 100644 index 0000000..159a159 --- /dev/null +++ b/summarize_service/app/worker.py @@ -0,0 +1,42 @@ +# app/worker.py +import time +from app.model_loader import ModelLoader +from app.inference_service import InferenceService +from app.redis_client import RedisClient +from config import BASE_MODEL, ADAPTER_DIR, HF_TOKEN, REDIS_HOST, REDIS_PORT, TEXT_RESULT_CHANNEL, TEXT_TASK_CHANNEL, BATCH_SIZE, WAIT_TIMEOUT + +def main(): + model_loader = ModelLoader(BASE_MODEL, ADAPTER_DIR, HF_TOKEN) + model_loader.load_model() + model_loader.load_tokenizer() + inference_service = InferenceService(model_loader) + + redis_client = RedisClient( + host=REDIS_HOST, + port=REDIS_PORT, + task_channel=TEXT_TASK_CHANNEL, + result_channel=TEXT_RESULT_CHANNEL + ) + + print("Worker запущен, ожидаем задачи...") + + while True: + tasks = redis_client.get_tasks(BATCH_SIZE, wait_timeout=WAIT_TIMEOUT) + if not tasks: + time.sleep(0.5) + continue + + texts = [task.text for task in tasks] + responses = inference_service.generate_batch(texts) + for task, response in zip(tasks, responses): + result = { + "chat_id": task.chat_id, + "user_id": task.user_id, + "message_id": task.message_id, + "text": response + } + redis_client.publish_result(result) + print(f"Обработана задача {task.message_id}") + +if __name__ == "__main__": + main() diff --git a/summarize_service/requirements.txt b/summarize_service/requirements.txt new file mode 100644 index 0000000..dfb9bae --- /dev/null +++ b/summarize_service/requirements.txt @@ -0,0 +1,14 @@ +--index-url https://download.pytorch.org/whl/cu121 + +torch==2.5.1 + +--index-url https://pypi.org/simple + +transformers +redis>=4.2.0 +python-dotenv +redis +pydantic +peft +bitsandbytes +flash-attention \ No newline at end of file