This commit is contained in:
itqop 2024-10-23 01:10:51 +03:00
parent be286b87f8
commit 623be62342
5 changed files with 5 additions and 17 deletions

View File

@ -31,7 +31,6 @@ async def assess_toxicity(input: TextInput):
cache_key = get_cache_key(preprocessed_text) cache_key = get_cache_key(preprocessed_text)
# Попытка получить результат из кеша
cached_result = await cache.get(cache_key) cached_result = await cache.get(cache_key)
if cached_result: if cached_result:
try: try:
@ -42,12 +41,10 @@ async def assess_toxicity(input: TextInput):
logger.warning(f"Кеш для ключа {cache_key} повреждён. Переходим к обработке.") logger.warning(f"Кеш для ключа {cache_key} повреждён. Переходим к обработке.")
try: try:
# Отправляем задачу в очередь Celery
result = assess_toxicity_task.delay(preprocessed_text) result = assess_toxicity_task.delay(preprocessed_text)
logger.info(f"Задача отправлена в очередь Celery для текста: {preprocessed_text}") logger.info(f"Задача отправлена в очередь Celery для текста: {preprocessed_text}")
toxicity_score = result.get(timeout=10) # Ждем результат до 10 секунд toxicity_score = result.get(timeout=5)
# Сохраняем результат в кеш
await cache.set(cache_key, json.dumps(toxicity_score)) await cache.set(cache_key, json.dumps(toxicity_score))
logger.info(f"Результат сохранён в кеш для ключа {cache_key}: {toxicity_score}") logger.info(f"Результат сохранён в кеш для ключа {cache_key}: {toxicity_score}")

View File

@ -43,14 +43,11 @@ class Settings(BaseSettings):
if not os.path.exists(config_path): if not os.path.exists(config_path):
raise ValueError(f'В локальной модели по пути "{v}" отсутствует файл config.json') raise ValueError(f'В локальной модели по пути "{v}" отсутствует файл config.json')
elif v.startswith("http"): elif v.startswith("http"):
pass # Предполагаем, что это URL, проверим ниже pass
else: else:
# Предполагаем, что это название модели в HuggingFace pass
pass # Проверим ниже
try: try:
# Попытка загрузить конфигурацию модели
# Это не загрузит модель полностью, но проверит доступность модели
AutoTokenizer.from_pretrained(v, cache_dir=None, force_download=False) AutoTokenizer.from_pretrained(v, cache_dir=None, force_download=False)
except Exception as e: except Exception as e:
raise ValueError(f'Невозможно загрузить конфигурацию модели из transformers для "{v}": {e}') raise ValueError(f'Невозможно загрузить конфигурацию модели из transformers для "{v}": {e}')

View File

@ -11,7 +11,7 @@ class ToxicityHandler:
self.tokenizer = AutoTokenizer.from_pretrained(model_checkpoint) self.tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
self.model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint) self.model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint)
self.model.to(self.device) self.model.to(self.device)
self.model.eval() # Перевод модели в режим оценки self.model.eval()
def text_to_toxicity(self, text: str, aggregate: bool = True) -> float: def text_to_toxicity(self, text: str, aggregate: bool = True) -> float:
"""Вычисляет токсичность текста. """Вычисляет токсичность текста.
@ -30,6 +30,5 @@ class ToxicityHandler:
proba = proba.cpu().numpy()[0] proba = proba.cpu().numpy()[0]
if aggregate: if aggregate:
# Пример агрегирования, можно настроить по необходимости
return float(1 - proba[0] * (1 - proba[-1])) return float(1 - proba[0] * (1 - proba[-1]))
return proba.tolist() return proba.tolist()

View File

@ -18,16 +18,14 @@ app = FastAPI(
version="1.0" version="1.0"
) )
# Разрешение CORS (опционально, настройте по необходимости)
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
allow_origins=["*"], # Замените на конкретные домены при необходимости allow_origins=["*"],
allow_credentials=True, allow_credentials=True,
allow_methods=["*"], allow_methods=["*"],
allow_headers=["*"], allow_headers=["*"],
) )
# Включение маршрутов API
app.include_router(routes.router) app.include_router(routes.router)
@app.on_event("startup") @app.on_event("startup")

View File

@ -2,14 +2,12 @@ from celery import Celery
from app.core.config import settings from app.core.config import settings
from app.handlers.toxicity_handler import ToxicityHandler from app.handlers.toxicity_handler import ToxicityHandler
# Инициализация Celery
celery_app = Celery( celery_app = Celery(
'tasks', 'tasks',
broker=settings.CELERY_BROKER_URL, broker=settings.CELERY_BROKER_URL,
backend=settings.CELERY_RESULT_BACKEND backend=settings.CELERY_RESULT_BACKEND
) )
# Конфигурация Celery
celery_app.conf.update( celery_app.conf.update(
task_serializer='json', task_serializer='json',
result_serializer='json', result_serializer='json',
@ -18,7 +16,6 @@ celery_app.conf.update(
enable_utc=True, enable_utc=True,
) )
# Инициализация обработчика токсичности
toxicity_handler = ToxicityHandler() toxicity_handler = ToxicityHandler()
@celery_app.task @celery_app.task