Init commit
This commit is contained in:
parent
0e52f6a84f
commit
a2fe4ad302
|
@ -0,0 +1,6 @@
|
||||||
|
__pycache__
|
||||||
|
*.pyc
|
||||||
|
*.pyo
|
||||||
|
*.pyd
|
||||||
|
*.env
|
||||||
|
.env
|
|
@ -0,0 +1,30 @@
|
||||||
|
# Dockerfile
|
||||||
|
|
||||||
|
FROM python:3.11-alpine
|
||||||
|
|
||||||
|
# Установка системных зависимостей
|
||||||
|
RUN apk update && apk add --no-cache \
|
||||||
|
build-base \
|
||||||
|
libffi-dev \
|
||||||
|
openssl-dev \
|
||||||
|
&& rm -rf /var/cache/apk/*
|
||||||
|
|
||||||
|
# Установка зависимостей Python
|
||||||
|
COPY app/requirements.txt /app/requirements.txt
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
# Установка зависимостей
|
||||||
|
RUN pip install --upgrade pip
|
||||||
|
RUN pip install --no-cache-dir -r requirements.txt
|
||||||
|
|
||||||
|
# Копирование исходного кода
|
||||||
|
COPY app /app
|
||||||
|
|
||||||
|
# Экспонирование порта
|
||||||
|
EXPOSE 8000
|
||||||
|
|
||||||
|
# Установка Gunicorn
|
||||||
|
RUN pip install gunicorn
|
||||||
|
|
||||||
|
# Команда запуска приложения с использованием Gunicorn и Uvicorn workers
|
||||||
|
CMD ["gunicorn", "main:app", "--worker-class", "uvicorn.workers.UvicornWorker", "--bind", "0.0.0.0:8000", "--workers", "4"]
|
|
@ -0,0 +1,51 @@
|
||||||
|
# app/api/routes.py
|
||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException, Depends
|
||||||
|
from app.models.schemas import TextInput, ToxicityOutput
|
||||||
|
from app.core.cache import cache
|
||||||
|
import json
|
||||||
|
import hashlib
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from app.tasks import assess_toxicity_task
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
def get_cache_key(text: str) -> str:
|
||||||
|
"""Генерация уникального ключа для кеша на основе текста."""
|
||||||
|
return f"toxicity:{hashlib.sha256(text.encode()).hexdigest()}"
|
||||||
|
|
||||||
|
@router.post("/toxicity", response_model=ToxicityOutput, summary="Оценка токсичности текста", response_description="Оценка токсичности")
|
||||||
|
async def assess_toxicity(input: TextInput):
|
||||||
|
"""
|
||||||
|
Принимает текст и возвращает оценку его токсичности.
|
||||||
|
|
||||||
|
- **text**: Текст для оценки
|
||||||
|
"""
|
||||||
|
cache_key = get_cache_key(input.text)
|
||||||
|
|
||||||
|
# Попытка получить результат из кеша
|
||||||
|
cached_result = await cache.get(cache_key)
|
||||||
|
if cached_result:
|
||||||
|
try:
|
||||||
|
toxicity_score = json.loads(cached_result)
|
||||||
|
logger.info(f"Получен результат из кеша для ключа {cache_key}: {toxicity_score}")
|
||||||
|
return {"toxicity_score": toxicity_score}
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.warning(f"Кеш для ключа {cache_key} повреждён. Переходим к обработке.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Отправляем задачу в очередь Celery
|
||||||
|
result = assess_toxicity_task.delay(input.text)
|
||||||
|
logger.info(f"Задача отправлена в очередь Celery для текста: {input.text}")
|
||||||
|
toxicity_score = result.get(timeout=10) # Ждем результат до 10 секунд
|
||||||
|
|
||||||
|
# Сохраняем результат в кеш
|
||||||
|
await cache.set(cache_key, json.dumps(toxicity_score))
|
||||||
|
logger.info(f"Результат сохранён в кеш для ключа {cache_key}: {toxicity_score}")
|
||||||
|
|
||||||
|
return {"toxicity_score": toxicity_score}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Ошибка при обработке текста: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
|
@ -0,0 +1,26 @@
|
||||||
|
import aioredis
|
||||||
|
from app.core.config import settings
|
||||||
|
|
||||||
|
|
||||||
|
class RedisCache:
|
||||||
|
def __init__(self):
|
||||||
|
self.redis = None
|
||||||
|
|
||||||
|
async def connect(self):
|
||||||
|
self.redis = await aioredis.from_url(settings.REDIS_URL, encoding="utf-8", decode_responses=True)
|
||||||
|
|
||||||
|
async def disconnect(self):
|
||||||
|
if self.redis:
|
||||||
|
await self.redis.close()
|
||||||
|
|
||||||
|
async def get(self, key: str):
|
||||||
|
if not self.redis:
|
||||||
|
await self.connect()
|
||||||
|
return await self.redis.get(key)
|
||||||
|
|
||||||
|
async def set(self, key: str, value: str, expire: int = 3600):
|
||||||
|
if not self.redis:
|
||||||
|
await self.connect()
|
||||||
|
await self.redis.set(key, value, ex=expire)
|
||||||
|
|
||||||
|
cache = RedisCache()
|
|
@ -0,0 +1,72 @@
|
||||||
|
# app/core/config.py
|
||||||
|
|
||||||
|
from pydantic import Field, validator
|
||||||
|
from pydantic_settings import BaseSettings
|
||||||
|
from typing import Any
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
import os
|
||||||
|
|
||||||
|
class Settings(BaseSettings):
|
||||||
|
MODEL_CHECKPOINT: str = Field(
|
||||||
|
'cointegrated/rubert-tiny-toxicity',
|
||||||
|
description="Контрольная точка модели для оценки токсичности"
|
||||||
|
)
|
||||||
|
USE_CUDA: bool = Field(
|
||||||
|
True,
|
||||||
|
description="Использовать ли CUDA для ускорения вычислений"
|
||||||
|
)
|
||||||
|
REDIS_URL: str = Field(
|
||||||
|
"redis://localhost:6379/0",
|
||||||
|
description="URL для подключения к Redis"
|
||||||
|
)
|
||||||
|
CELERY_BROKER_URL: str = Field(
|
||||||
|
"redis://localhost:6379/0",
|
||||||
|
description="URL брокера для Celery"
|
||||||
|
)
|
||||||
|
CELERY_RESULT_BACKEND: str = Field(
|
||||||
|
"redis://localhost:6379/0",
|
||||||
|
description="URL для хранения результатов Celery"
|
||||||
|
)
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
env_file = ".env"
|
||||||
|
env_file_encoding = 'utf-8'
|
||||||
|
case_sensitive = False
|
||||||
|
|
||||||
|
@validator('MODEL_CHECKPOINT')
|
||||||
|
def validate_model_checkpoint(cls, v: str) -> str:
|
||||||
|
if not v or not isinstance(v, str):
|
||||||
|
raise ValueError('MODEL_CHECKPOINT должен быть непустой строкой')
|
||||||
|
|
||||||
|
if os.path.exists(v):
|
||||||
|
# Проверяем, что это директория модели с config.json
|
||||||
|
config_path = os.path.join(v, 'config.json')
|
||||||
|
if not os.path.exists(config_path):
|
||||||
|
raise ValueError(f'В локальной модели по пути "{v}" отсутствует файл config.json')
|
||||||
|
elif v.startswith("http"):
|
||||||
|
pass # Предполагаем, что это URL, проверим ниже
|
||||||
|
else:
|
||||||
|
# Предполагаем, что это название модели в HuggingFace
|
||||||
|
pass # Проверим ниже
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Попытка загрузить конфигурацию модели
|
||||||
|
# Это не загрузит модель полностью, но проверит доступность модели
|
||||||
|
AutoTokenizer.from_pretrained(v, cache_dir=None, force_download=False)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f'Невозможно загрузить конфигурацию модели из transformers для "{v}": {e}')
|
||||||
|
|
||||||
|
return v
|
||||||
|
|
||||||
|
@validator('USE_CUDA', pre=True)
|
||||||
|
def validate_use_cuda(cls, v: Any) -> bool:
|
||||||
|
if isinstance(v, bool):
|
||||||
|
return v
|
||||||
|
if isinstance(v, str):
|
||||||
|
if v.lower() in {'true', '1', 'yes', 'y'}:
|
||||||
|
return True
|
||||||
|
elif v.lower() in {'false', '0', 'no', 'n'}:
|
||||||
|
return False
|
||||||
|
raise ValueError('USE_CUDA должен быть булевым значением (True/False)')
|
||||||
|
|
||||||
|
settings = Settings()
|
|
@ -0,0 +1,35 @@
|
||||||
|
# app/handlers/toxicity_handler.py
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
||||||
|
from typing import Optional, List
|
||||||
|
from app.core.config import settings
|
||||||
|
|
||||||
|
class ToxicityHandler:
|
||||||
|
def __init__(self, model_checkpoint: str = settings.MODEL_CHECKPOINT, use_cuda: Optional[bool] = settings.USE_CUDA):
|
||||||
|
self.device = torch.device('cuda' if torch.cuda.is_available() and use_cuda else 'cpu')
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
|
||||||
|
self.model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint)
|
||||||
|
self.model.to(self.device)
|
||||||
|
self.model.eval() # Перевод модели в режим оценки
|
||||||
|
|
||||||
|
def text_to_toxicity(self, text: str, aggregate: bool = True) -> float:
|
||||||
|
"""Вычисляет токсичность текста.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text (str): Входной текст.
|
||||||
|
aggregate (bool): Если True, возвращает агрегированную оценку токсичности.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float: Оценка токсичности.
|
||||||
|
"""
|
||||||
|
with torch.no_grad():
|
||||||
|
inputs = self.tokenizer(text, return_tensors='pt', truncation=True, padding=True).to(self.device)
|
||||||
|
logits = self.model(**inputs).logits
|
||||||
|
proba = torch.sigmoid(logits)
|
||||||
|
|
||||||
|
proba = proba.cpu().numpy()[0]
|
||||||
|
if aggregate:
|
||||||
|
# Пример агрегирования, можно настроить по необходимости
|
||||||
|
return float(1 - proba[0] * (1 - proba[-1]))
|
||||||
|
return proba.tolist()
|
|
@ -0,0 +1,50 @@
|
||||||
|
# app/main.py
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from fastapi import FastAPI, Request
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from app.api import routes
|
||||||
|
from app.core.config import settings
|
||||||
|
from app.core.cache import cache
|
||||||
|
|
||||||
|
# Настройка логирования
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
app = FastAPI(
|
||||||
|
title="Toxicity Assessment API",
|
||||||
|
description="API для оценки токсичности текста",
|
||||||
|
version="1.0"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Разрешение CORS (опционально, настройте по необходимости)
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"], # Замените на конкретные домены при необходимости
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Включение маршрутов API
|
||||||
|
app.include_router(routes.router)
|
||||||
|
|
||||||
|
@app.on_event("startup")
|
||||||
|
async def startup_event():
|
||||||
|
logger.info("Toxicity Assessment API запущен и готов к работе.")
|
||||||
|
logger.info(f"MODEL_CHECKPOINT: {settings.MODEL_CHECKPOINT}")
|
||||||
|
logger.info(f"USE_CUDA: {settings.USE_CUDA}")
|
||||||
|
await cache.connect()
|
||||||
|
|
||||||
|
@app.on_event("shutdown")
|
||||||
|
async def shutdown_event():
|
||||||
|
await cache.disconnect()
|
||||||
|
|
||||||
|
@app.exception_handler(Exception)
|
||||||
|
async def global_exception_handler(request: Request, exc: Exception):
|
||||||
|
logger.error(f"Ошибка: {exc}")
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=500,
|
||||||
|
content={"detail": "Внутренняя ошибка сервера."},
|
||||||
|
)
|
|
@ -0,0 +1,9 @@
|
||||||
|
# app/models/schemas.py
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
class TextInput(BaseModel):
|
||||||
|
text: str = Field(..., example="Это просто пиииииииииииииииииииздец")
|
||||||
|
|
||||||
|
class ToxicityOutput(BaseModel):
|
||||||
|
toxicity_score: float = Field(..., example=0.85)
|
|
@ -0,0 +1,8 @@
|
||||||
|
fastapi
|
||||||
|
uvicorn[standard]
|
||||||
|
transformers
|
||||||
|
torch
|
||||||
|
gunicorn
|
||||||
|
aioredis==2.0.1
|
||||||
|
celery==5.3.0
|
||||||
|
redis==4.5.5
|
|
@ -0,0 +1,35 @@
|
||||||
|
from celery import Celery
|
||||||
|
from app.core.config import settings
|
||||||
|
from app.handlers.toxicity_handler import ToxicityHandler
|
||||||
|
|
||||||
|
# Инициализация Celery
|
||||||
|
celery_app = Celery(
|
||||||
|
'tasks',
|
||||||
|
broker=settings.CELERY_BROKER_URL,
|
||||||
|
backend=settings.CELERY_RESULT_BACKEND
|
||||||
|
)
|
||||||
|
|
||||||
|
# Конфигурация Celery
|
||||||
|
celery_app.conf.update(
|
||||||
|
task_serializer='json',
|
||||||
|
result_serializer='json',
|
||||||
|
accept_content=['json'],
|
||||||
|
timezone='UTC',
|
||||||
|
enable_utc=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Инициализация обработчика токсичности
|
||||||
|
toxicity_handler = ToxicityHandler()
|
||||||
|
|
||||||
|
@celery_app.task
|
||||||
|
def assess_toxicity_task(text: str) -> float:
|
||||||
|
"""
|
||||||
|
Задача для оценки токсичности текста.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text (str): Входной текст.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float: Оценка токсичности.
|
||||||
|
"""
|
||||||
|
return toxicity_handler.text_to_toxicity(text)
|
|
@ -0,0 +1,36 @@
|
||||||
|
services:
|
||||||
|
toxicity_api:
|
||||||
|
build: .
|
||||||
|
container_name: toxicity_assessment_api
|
||||||
|
ports:
|
||||||
|
- "8000:8000"
|
||||||
|
environment:
|
||||||
|
- MODEL_CHECKPOINT=cointegrated/rubert-tiny-toxicity
|
||||||
|
- USE_CUDA=False # Установите True, если используете GPU и это поддерживается
|
||||||
|
- REDIS_URL=redis://redis:6379/0
|
||||||
|
- CELERY_BROKER_URL=redis://redis:6379/0
|
||||||
|
- CELERY_RESULT_BACKEND=redis://redis:6379/0
|
||||||
|
depends_on:
|
||||||
|
- redis
|
||||||
|
restart: unless-stopped
|
||||||
|
|
||||||
|
celery_worker:
|
||||||
|
build: .
|
||||||
|
container_name: celery_worker
|
||||||
|
command: celery -A app.worker worker --loglevel=info
|
||||||
|
environment:
|
||||||
|
- MODEL_CHECKPOINT=cointegrated/rubert-tiny-toxicity
|
||||||
|
- USE_CUDA=False
|
||||||
|
- REDIS_URL=redis://redis:6379/0
|
||||||
|
- CELERY_BROKER_URL=redis://redis:6379/0
|
||||||
|
- CELERY_RESULT_BACKEND=redis://redis:6379/0
|
||||||
|
depends_on:
|
||||||
|
- redis
|
||||||
|
restart: unless-stopped
|
||||||
|
|
||||||
|
redis:
|
||||||
|
image: redis:7-alpine
|
||||||
|
container_name: redis
|
||||||
|
ports:
|
||||||
|
- "6379:6379"
|
||||||
|
restart: unless-stopped
|
Loading…
Reference in New Issue