toxic-detector/app/handlers/toxicity_handler.py

35 lines
1.4 KiB
Python
Raw Permalink Normal View History

2024-10-22 22:51:03 +02:00
# app/handlers/toxicity_handler.py
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from typing import Optional, List
from app.core.config import settings
class ToxicityHandler:
def __init__(self, model_checkpoint: str = settings.MODEL_CHECKPOINT, use_cuda: Optional[bool] = settings.USE_CUDA):
self.device = torch.device('cuda' if torch.cuda.is_available() and use_cuda else 'cpu')
self.tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
self.model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint)
self.model.to(self.device)
2024-10-23 00:10:51 +02:00
self.model.eval()
2024-10-22 22:51:03 +02:00
def text_to_toxicity(self, text: str, aggregate: bool = True) -> float:
"""Вычисляет токсичность текста.
Args:
text (str): Входной текст.
aggregate (bool): Если True, возвращает агрегированную оценку токсичности.
Returns:
float: Оценка токсичности.
"""
with torch.no_grad():
inputs = self.tokenizer(text, return_tensors='pt', truncation=True, padding=True).to(self.device)
logits = self.model(**inputs).logits
proba = torch.sigmoid(logits)
proba = proba.cpu().numpy()[0]
if aggregate:
return float(1 - proba[0] * (1 - proba[-1]))
return proba.tolist()