This commit is contained in:
itqop 2024-10-23 01:10:51 +03:00
parent be286b87f8
commit 623be62342
5 changed files with 5 additions and 17 deletions

View File

@ -31,7 +31,6 @@ async def assess_toxicity(input: TextInput):
cache_key = get_cache_key(preprocessed_text)
# Попытка получить результат из кеша
cached_result = await cache.get(cache_key)
if cached_result:
try:
@ -42,12 +41,10 @@ async def assess_toxicity(input: TextInput):
logger.warning(f"Кеш для ключа {cache_key} повреждён. Переходим к обработке.")
try:
# Отправляем задачу в очередь Celery
result = assess_toxicity_task.delay(preprocessed_text)
logger.info(f"Задача отправлена в очередь Celery для текста: {preprocessed_text}")
toxicity_score = result.get(timeout=10) # Ждем результат до 10 секунд
toxicity_score = result.get(timeout=5)
# Сохраняем результат в кеш
await cache.set(cache_key, json.dumps(toxicity_score))
logger.info(f"Результат сохранён в кеш для ключа {cache_key}: {toxicity_score}")

View File

@ -43,14 +43,11 @@ class Settings(BaseSettings):
if not os.path.exists(config_path):
raise ValueError(f'В локальной модели по пути "{v}" отсутствует файл config.json')
elif v.startswith("http"):
pass # Предполагаем, что это URL, проверим ниже
pass
else:
# Предполагаем, что это название модели в HuggingFace
pass # Проверим ниже
pass
try:
# Попытка загрузить конфигурацию модели
# Это не загрузит модель полностью, но проверит доступность модели
AutoTokenizer.from_pretrained(v, cache_dir=None, force_download=False)
except Exception as e:
raise ValueError(f'Невозможно загрузить конфигурацию модели из transformers для "{v}": {e}')

View File

@ -11,7 +11,7 @@ class ToxicityHandler:
self.tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
self.model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint)
self.model.to(self.device)
self.model.eval() # Перевод модели в режим оценки
self.model.eval()
def text_to_toxicity(self, text: str, aggregate: bool = True) -> float:
"""Вычисляет токсичность текста.
@ -30,6 +30,5 @@ class ToxicityHandler:
proba = proba.cpu().numpy()[0]
if aggregate:
# Пример агрегирования, можно настроить по необходимости
return float(1 - proba[0] * (1 - proba[-1]))
return proba.tolist()

View File

@ -18,16 +18,14 @@ app = FastAPI(
version="1.0"
)
# Разрешение CORS (опционально, настройте по необходимости)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Замените на конкретные домены при необходимости
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Включение маршрутов API
app.include_router(routes.router)
@app.on_event("startup")

View File

@ -2,14 +2,12 @@ from celery import Celery
from app.core.config import settings
from app.handlers.toxicity_handler import ToxicityHandler
# Инициализация Celery
celery_app = Celery(
'tasks',
broker=settings.CELERY_BROKER_URL,
backend=settings.CELERY_RESULT_BACKEND
)
# Конфигурация Celery
celery_app.conf.update(
task_serializer='json',
result_serializer='json',
@ -18,7 +16,6 @@ celery_app.conf.update(
enable_utc=True,
)
# Инициализация обработчика токсичности
toxicity_handler = ToxicityHandler()
@celery_app.task