Init commit
This commit is contained in:
parent
0e52f6a84f
commit
a2fe4ad302
|
@ -0,0 +1,6 @@
|
|||
__pycache__
|
||||
*.pyc
|
||||
*.pyo
|
||||
*.pyd
|
||||
*.env
|
||||
.env
|
|
@ -0,0 +1,30 @@
|
|||
# Dockerfile
|
||||
|
||||
FROM python:3.11-alpine
|
||||
|
||||
# Установка системных зависимостей
|
||||
RUN apk update && apk add --no-cache \
|
||||
build-base \
|
||||
libffi-dev \
|
||||
openssl-dev \
|
||||
&& rm -rf /var/cache/apk/*
|
||||
|
||||
# Установка зависимостей Python
|
||||
COPY app/requirements.txt /app/requirements.txt
|
||||
WORKDIR /app
|
||||
|
||||
# Установка зависимостей
|
||||
RUN pip install --upgrade pip
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# Копирование исходного кода
|
||||
COPY app /app
|
||||
|
||||
# Экспонирование порта
|
||||
EXPOSE 8000
|
||||
|
||||
# Установка Gunicorn
|
||||
RUN pip install gunicorn
|
||||
|
||||
# Команда запуска приложения с использованием Gunicorn и Uvicorn workers
|
||||
CMD ["gunicorn", "main:app", "--worker-class", "uvicorn.workers.UvicornWorker", "--bind", "0.0.0.0:8000", "--workers", "4"]
|
|
@ -0,0 +1,51 @@
|
|||
# app/api/routes.py
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Depends
|
||||
from app.models.schemas import TextInput, ToxicityOutput
|
||||
from app.core.cache import cache
|
||||
import json
|
||||
import hashlib
|
||||
import logging
|
||||
|
||||
from app.tasks import assess_toxicity_task
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def get_cache_key(text: str) -> str:
|
||||
"""Генерация уникального ключа для кеша на основе текста."""
|
||||
return f"toxicity:{hashlib.sha256(text.encode()).hexdigest()}"
|
||||
|
||||
@router.post("/toxicity", response_model=ToxicityOutput, summary="Оценка токсичности текста", response_description="Оценка токсичности")
|
||||
async def assess_toxicity(input: TextInput):
|
||||
"""
|
||||
Принимает текст и возвращает оценку его токсичности.
|
||||
|
||||
- **text**: Текст для оценки
|
||||
"""
|
||||
cache_key = get_cache_key(input.text)
|
||||
|
||||
# Попытка получить результат из кеша
|
||||
cached_result = await cache.get(cache_key)
|
||||
if cached_result:
|
||||
try:
|
||||
toxicity_score = json.loads(cached_result)
|
||||
logger.info(f"Получен результат из кеша для ключа {cache_key}: {toxicity_score}")
|
||||
return {"toxicity_score": toxicity_score}
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Кеш для ключа {cache_key} повреждён. Переходим к обработке.")
|
||||
|
||||
try:
|
||||
# Отправляем задачу в очередь Celery
|
||||
result = assess_toxicity_task.delay(input.text)
|
||||
logger.info(f"Задача отправлена в очередь Celery для текста: {input.text}")
|
||||
toxicity_score = result.get(timeout=10) # Ждем результат до 10 секунд
|
||||
|
||||
# Сохраняем результат в кеш
|
||||
await cache.set(cache_key, json.dumps(toxicity_score))
|
||||
logger.info(f"Результат сохранён в кеш для ключа {cache_key}: {toxicity_score}")
|
||||
|
||||
return {"toxicity_score": toxicity_score}
|
||||
except Exception as e:
|
||||
logger.error(f"Ошибка при обработке текста: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
|
@ -0,0 +1,26 @@
|
|||
import aioredis
|
||||
from app.core.config import settings
|
||||
|
||||
|
||||
class RedisCache:
|
||||
def __init__(self):
|
||||
self.redis = None
|
||||
|
||||
async def connect(self):
|
||||
self.redis = await aioredis.from_url(settings.REDIS_URL, encoding="utf-8", decode_responses=True)
|
||||
|
||||
async def disconnect(self):
|
||||
if self.redis:
|
||||
await self.redis.close()
|
||||
|
||||
async def get(self, key: str):
|
||||
if not self.redis:
|
||||
await self.connect()
|
||||
return await self.redis.get(key)
|
||||
|
||||
async def set(self, key: str, value: str, expire: int = 3600):
|
||||
if not self.redis:
|
||||
await self.connect()
|
||||
await self.redis.set(key, value, ex=expire)
|
||||
|
||||
cache = RedisCache()
|
|
@ -0,0 +1,72 @@
|
|||
# app/core/config.py
|
||||
|
||||
from pydantic import Field, validator
|
||||
from pydantic_settings import BaseSettings
|
||||
from typing import Any
|
||||
from transformers import AutoTokenizer
|
||||
import os
|
||||
|
||||
class Settings(BaseSettings):
|
||||
MODEL_CHECKPOINT: str = Field(
|
||||
'cointegrated/rubert-tiny-toxicity',
|
||||
description="Контрольная точка модели для оценки токсичности"
|
||||
)
|
||||
USE_CUDA: bool = Field(
|
||||
True,
|
||||
description="Использовать ли CUDA для ускорения вычислений"
|
||||
)
|
||||
REDIS_URL: str = Field(
|
||||
"redis://localhost:6379/0",
|
||||
description="URL для подключения к Redis"
|
||||
)
|
||||
CELERY_BROKER_URL: str = Field(
|
||||
"redis://localhost:6379/0",
|
||||
description="URL брокера для Celery"
|
||||
)
|
||||
CELERY_RESULT_BACKEND: str = Field(
|
||||
"redis://localhost:6379/0",
|
||||
description="URL для хранения результатов Celery"
|
||||
)
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
env_file_encoding = 'utf-8'
|
||||
case_sensitive = False
|
||||
|
||||
@validator('MODEL_CHECKPOINT')
|
||||
def validate_model_checkpoint(cls, v: str) -> str:
|
||||
if not v or not isinstance(v, str):
|
||||
raise ValueError('MODEL_CHECKPOINT должен быть непустой строкой')
|
||||
|
||||
if os.path.exists(v):
|
||||
# Проверяем, что это директория модели с config.json
|
||||
config_path = os.path.join(v, 'config.json')
|
||||
if not os.path.exists(config_path):
|
||||
raise ValueError(f'В локальной модели по пути "{v}" отсутствует файл config.json')
|
||||
elif v.startswith("http"):
|
||||
pass # Предполагаем, что это URL, проверим ниже
|
||||
else:
|
||||
# Предполагаем, что это название модели в HuggingFace
|
||||
pass # Проверим ниже
|
||||
|
||||
try:
|
||||
# Попытка загрузить конфигурацию модели
|
||||
# Это не загрузит модель полностью, но проверит доступность модели
|
||||
AutoTokenizer.from_pretrained(v, cache_dir=None, force_download=False)
|
||||
except Exception as e:
|
||||
raise ValueError(f'Невозможно загрузить конфигурацию модели из transformers для "{v}": {e}')
|
||||
|
||||
return v
|
||||
|
||||
@validator('USE_CUDA', pre=True)
|
||||
def validate_use_cuda(cls, v: Any) -> bool:
|
||||
if isinstance(v, bool):
|
||||
return v
|
||||
if isinstance(v, str):
|
||||
if v.lower() in {'true', '1', 'yes', 'y'}:
|
||||
return True
|
||||
elif v.lower() in {'false', '0', 'no', 'n'}:
|
||||
return False
|
||||
raise ValueError('USE_CUDA должен быть булевым значением (True/False)')
|
||||
|
||||
settings = Settings()
|
|
@ -0,0 +1,35 @@
|
|||
# app/handlers/toxicity_handler.py
|
||||
|
||||
import torch
|
||||
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
||||
from typing import Optional, List
|
||||
from app.core.config import settings
|
||||
|
||||
class ToxicityHandler:
|
||||
def __init__(self, model_checkpoint: str = settings.MODEL_CHECKPOINT, use_cuda: Optional[bool] = settings.USE_CUDA):
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() and use_cuda else 'cpu')
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
|
||||
self.model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint)
|
||||
self.model.to(self.device)
|
||||
self.model.eval() # Перевод модели в режим оценки
|
||||
|
||||
def text_to_toxicity(self, text: str, aggregate: bool = True) -> float:
|
||||
"""Вычисляет токсичность текста.
|
||||
|
||||
Args:
|
||||
text (str): Входной текст.
|
||||
aggregate (bool): Если True, возвращает агрегированную оценку токсичности.
|
||||
|
||||
Returns:
|
||||
float: Оценка токсичности.
|
||||
"""
|
||||
with torch.no_grad():
|
||||
inputs = self.tokenizer(text, return_tensors='pt', truncation=True, padding=True).to(self.device)
|
||||
logits = self.model(**inputs).logits
|
||||
proba = torch.sigmoid(logits)
|
||||
|
||||
proba = proba.cpu().numpy()[0]
|
||||
if aggregate:
|
||||
# Пример агрегирования, можно настроить по необходимости
|
||||
return float(1 - proba[0] * (1 - proba[-1]))
|
||||
return proba.tolist()
|
|
@ -0,0 +1,50 @@
|
|||
# app/main.py
|
||||
|
||||
import logging
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from app.api import routes
|
||||
from app.core.config import settings
|
||||
from app.core.cache import cache
|
||||
|
||||
# Настройка логирования
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
app = FastAPI(
|
||||
title="Toxicity Assessment API",
|
||||
description="API для оценки токсичности текста",
|
||||
version="1.0"
|
||||
)
|
||||
|
||||
# Разрешение CORS (опционально, настройте по необходимости)
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"], # Замените на конкретные домены при необходимости
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Включение маршрутов API
|
||||
app.include_router(routes.router)
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
logger.info("Toxicity Assessment API запущен и готов к работе.")
|
||||
logger.info(f"MODEL_CHECKPOINT: {settings.MODEL_CHECKPOINT}")
|
||||
logger.info(f"USE_CUDA: {settings.USE_CUDA}")
|
||||
await cache.connect()
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown_event():
|
||||
await cache.disconnect()
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
async def global_exception_handler(request: Request, exc: Exception):
|
||||
logger.error(f"Ошибка: {exc}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={"detail": "Внутренняя ошибка сервера."},
|
||||
)
|
|
@ -0,0 +1,9 @@
|
|||
# app/models/schemas.py
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
class TextInput(BaseModel):
|
||||
text: str = Field(..., example="Это просто пиииииииииииииииииииздец")
|
||||
|
||||
class ToxicityOutput(BaseModel):
|
||||
toxicity_score: float = Field(..., example=0.85)
|
|
@ -0,0 +1,8 @@
|
|||
fastapi
|
||||
uvicorn[standard]
|
||||
transformers
|
||||
torch
|
||||
gunicorn
|
||||
aioredis==2.0.1
|
||||
celery==5.3.0
|
||||
redis==4.5.5
|
|
@ -0,0 +1,35 @@
|
|||
from celery import Celery
|
||||
from app.core.config import settings
|
||||
from app.handlers.toxicity_handler import ToxicityHandler
|
||||
|
||||
# Инициализация Celery
|
||||
celery_app = Celery(
|
||||
'tasks',
|
||||
broker=settings.CELERY_BROKER_URL,
|
||||
backend=settings.CELERY_RESULT_BACKEND
|
||||
)
|
||||
|
||||
# Конфигурация Celery
|
||||
celery_app.conf.update(
|
||||
task_serializer='json',
|
||||
result_serializer='json',
|
||||
accept_content=['json'],
|
||||
timezone='UTC',
|
||||
enable_utc=True,
|
||||
)
|
||||
|
||||
# Инициализация обработчика токсичности
|
||||
toxicity_handler = ToxicityHandler()
|
||||
|
||||
@celery_app.task
|
||||
def assess_toxicity_task(text: str) -> float:
|
||||
"""
|
||||
Задача для оценки токсичности текста.
|
||||
|
||||
Args:
|
||||
text (str): Входной текст.
|
||||
|
||||
Returns:
|
||||
float: Оценка токсичности.
|
||||
"""
|
||||
return toxicity_handler.text_to_toxicity(text)
|
|
@ -0,0 +1,36 @@
|
|||
services:
|
||||
toxicity_api:
|
||||
build: .
|
||||
container_name: toxicity_assessment_api
|
||||
ports:
|
||||
- "8000:8000"
|
||||
environment:
|
||||
- MODEL_CHECKPOINT=cointegrated/rubert-tiny-toxicity
|
||||
- USE_CUDA=False # Установите True, если используете GPU и это поддерживается
|
||||
- REDIS_URL=redis://redis:6379/0
|
||||
- CELERY_BROKER_URL=redis://redis:6379/0
|
||||
- CELERY_RESULT_BACKEND=redis://redis:6379/0
|
||||
depends_on:
|
||||
- redis
|
||||
restart: unless-stopped
|
||||
|
||||
celery_worker:
|
||||
build: .
|
||||
container_name: celery_worker
|
||||
command: celery -A app.worker worker --loglevel=info
|
||||
environment:
|
||||
- MODEL_CHECKPOINT=cointegrated/rubert-tiny-toxicity
|
||||
- USE_CUDA=False
|
||||
- REDIS_URL=redis://redis:6379/0
|
||||
- CELERY_BROKER_URL=redis://redis:6379/0
|
||||
- CELERY_RESULT_BACKEND=redis://redis:6379/0
|
||||
depends_on:
|
||||
- redis
|
||||
restart: unless-stopped
|
||||
|
||||
redis:
|
||||
image: redis:7-alpine
|
||||
container_name: redis
|
||||
ports:
|
||||
- "6379:6379"
|
||||
restart: unless-stopped
|
Loading…
Reference in New Issue