Fix typo
This commit is contained in:
parent
be286b87f8
commit
623be62342
|
@ -31,7 +31,6 @@ async def assess_toxicity(input: TextInput):
|
||||||
|
|
||||||
cache_key = get_cache_key(preprocessed_text)
|
cache_key = get_cache_key(preprocessed_text)
|
||||||
|
|
||||||
# Попытка получить результат из кеша
|
|
||||||
cached_result = await cache.get(cache_key)
|
cached_result = await cache.get(cache_key)
|
||||||
if cached_result:
|
if cached_result:
|
||||||
try:
|
try:
|
||||||
|
@ -42,12 +41,10 @@ async def assess_toxicity(input: TextInput):
|
||||||
logger.warning(f"Кеш для ключа {cache_key} повреждён. Переходим к обработке.")
|
logger.warning(f"Кеш для ключа {cache_key} повреждён. Переходим к обработке.")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Отправляем задачу в очередь Celery
|
|
||||||
result = assess_toxicity_task.delay(preprocessed_text)
|
result = assess_toxicity_task.delay(preprocessed_text)
|
||||||
logger.info(f"Задача отправлена в очередь Celery для текста: {preprocessed_text}")
|
logger.info(f"Задача отправлена в очередь Celery для текста: {preprocessed_text}")
|
||||||
toxicity_score = result.get(timeout=10) # Ждем результат до 10 секунд
|
toxicity_score = result.get(timeout=5)
|
||||||
|
|
||||||
# Сохраняем результат в кеш
|
|
||||||
await cache.set(cache_key, json.dumps(toxicity_score))
|
await cache.set(cache_key, json.dumps(toxicity_score))
|
||||||
logger.info(f"Результат сохранён в кеш для ключа {cache_key}: {toxicity_score}")
|
logger.info(f"Результат сохранён в кеш для ключа {cache_key}: {toxicity_score}")
|
||||||
|
|
||||||
|
|
|
@ -43,14 +43,11 @@ class Settings(BaseSettings):
|
||||||
if not os.path.exists(config_path):
|
if not os.path.exists(config_path):
|
||||||
raise ValueError(f'В локальной модели по пути "{v}" отсутствует файл config.json')
|
raise ValueError(f'В локальной модели по пути "{v}" отсутствует файл config.json')
|
||||||
elif v.startswith("http"):
|
elif v.startswith("http"):
|
||||||
pass # Предполагаем, что это URL, проверим ниже
|
pass
|
||||||
else:
|
else:
|
||||||
# Предполагаем, что это название модели в HuggingFace
|
pass
|
||||||
pass # Проверим ниже
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Попытка загрузить конфигурацию модели
|
|
||||||
# Это не загрузит модель полностью, но проверит доступность модели
|
|
||||||
AutoTokenizer.from_pretrained(v, cache_dir=None, force_download=False)
|
AutoTokenizer.from_pretrained(v, cache_dir=None, force_download=False)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(f'Невозможно загрузить конфигурацию модели из transformers для "{v}": {e}')
|
raise ValueError(f'Невозможно загрузить конфигурацию модели из transformers для "{v}": {e}')
|
||||||
|
|
|
@ -11,7 +11,7 @@ class ToxicityHandler:
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
|
self.tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
|
||||||
self.model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint)
|
self.model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint)
|
||||||
self.model.to(self.device)
|
self.model.to(self.device)
|
||||||
self.model.eval() # Перевод модели в режим оценки
|
self.model.eval()
|
||||||
|
|
||||||
def text_to_toxicity(self, text: str, aggregate: bool = True) -> float:
|
def text_to_toxicity(self, text: str, aggregate: bool = True) -> float:
|
||||||
"""Вычисляет токсичность текста.
|
"""Вычисляет токсичность текста.
|
||||||
|
@ -30,6 +30,5 @@ class ToxicityHandler:
|
||||||
|
|
||||||
proba = proba.cpu().numpy()[0]
|
proba = proba.cpu().numpy()[0]
|
||||||
if aggregate:
|
if aggregate:
|
||||||
# Пример агрегирования, можно настроить по необходимости
|
|
||||||
return float(1 - proba[0] * (1 - proba[-1]))
|
return float(1 - proba[0] * (1 - proba[-1]))
|
||||||
return proba.tolist()
|
return proba.tolist()
|
||||||
|
|
|
@ -18,16 +18,14 @@ app = FastAPI(
|
||||||
version="1.0"
|
version="1.0"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Разрешение CORS (опционально, настройте по необходимости)
|
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
CORSMiddleware,
|
CORSMiddleware,
|
||||||
allow_origins=["*"], # Замените на конкретные домены при необходимости
|
allow_origins=["*"],
|
||||||
allow_credentials=True,
|
allow_credentials=True,
|
||||||
allow_methods=["*"],
|
allow_methods=["*"],
|
||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Включение маршрутов API
|
|
||||||
app.include_router(routes.router)
|
app.include_router(routes.router)
|
||||||
|
|
||||||
@app.on_event("startup")
|
@app.on_event("startup")
|
||||||
|
|
|
@ -2,14 +2,12 @@ from celery import Celery
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.handlers.toxicity_handler import ToxicityHandler
|
from app.handlers.toxicity_handler import ToxicityHandler
|
||||||
|
|
||||||
# Инициализация Celery
|
|
||||||
celery_app = Celery(
|
celery_app = Celery(
|
||||||
'tasks',
|
'tasks',
|
||||||
broker=settings.CELERY_BROKER_URL,
|
broker=settings.CELERY_BROKER_URL,
|
||||||
backend=settings.CELERY_RESULT_BACKEND
|
backend=settings.CELERY_RESULT_BACKEND
|
||||||
)
|
)
|
||||||
|
|
||||||
# Конфигурация Celery
|
|
||||||
celery_app.conf.update(
|
celery_app.conf.update(
|
||||||
task_serializer='json',
|
task_serializer='json',
|
||||||
result_serializer='json',
|
result_serializer='json',
|
||||||
|
@ -18,7 +16,6 @@ celery_app.conf.update(
|
||||||
enable_utc=True,
|
enable_utc=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Инициализация обработчика токсичности
|
|
||||||
toxicity_handler = ToxicityHandler()
|
toxicity_handler = ToxicityHandler()
|
||||||
|
|
||||||
@celery_app.task
|
@celery_app.task
|
||||||
|
|
Loading…
Reference in New Issue