2024-10-22 22:51:03 +02:00
|
|
|
# app/handlers/toxicity_handler.py
|
|
|
|
|
|
|
|
import torch
|
|
|
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
|
|
|
from typing import Optional, List
|
|
|
|
from app.core.config import settings
|
|
|
|
|
|
|
|
class ToxicityHandler:
|
|
|
|
def __init__(self, model_checkpoint: str = settings.MODEL_CHECKPOINT, use_cuda: Optional[bool] = settings.USE_CUDA):
|
|
|
|
self.device = torch.device('cuda' if torch.cuda.is_available() and use_cuda else 'cpu')
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
|
|
|
|
self.model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint)
|
|
|
|
self.model.to(self.device)
|
2024-10-23 00:10:51 +02:00
|
|
|
self.model.eval()
|
2024-10-22 22:51:03 +02:00
|
|
|
|
|
|
|
def text_to_toxicity(self, text: str, aggregate: bool = True) -> float:
|
|
|
|
"""Вычисляет токсичность текста.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
text (str): Входной текст.
|
|
|
|
aggregate (bool): Если True, возвращает агрегированную оценку токсичности.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
float: Оценка токсичности.
|
|
|
|
"""
|
|
|
|
with torch.no_grad():
|
|
|
|
inputs = self.tokenizer(text, return_tensors='pt', truncation=True, padding=True).to(self.device)
|
|
|
|
logits = self.model(**inputs).logits
|
|
|
|
proba = torch.sigmoid(logits)
|
|
|
|
|
|
|
|
proba = proba.cpu().numpy()[0]
|
|
|
|
if aggregate:
|
|
|
|
return float(1 - proba[0] * (1 - proba[-1]))
|
|
|
|
return proba.tolist()
|