2024-10-22 22:51:03 +02:00
|
|
|
|
from pydantic import Field, validator
|
|
|
|
|
from pydantic_settings import BaseSettings
|
|
|
|
|
from typing import Any
|
|
|
|
|
from transformers import AutoTokenizer
|
|
|
|
|
import os
|
|
|
|
|
|
2024-10-23 00:07:56 +02:00
|
|
|
|
|
2024-10-22 22:51:03 +02:00
|
|
|
|
class Settings(BaseSettings):
|
|
|
|
|
MODEL_CHECKPOINT: str = Field(
|
|
|
|
|
'cointegrated/rubert-tiny-toxicity',
|
|
|
|
|
description="Контрольная точка модели для оценки токсичности"
|
|
|
|
|
)
|
|
|
|
|
USE_CUDA: bool = Field(
|
|
|
|
|
True,
|
|
|
|
|
description="Использовать ли CUDA для ускорения вычислений"
|
|
|
|
|
)
|
|
|
|
|
REDIS_URL: str = Field(
|
|
|
|
|
"redis://localhost:6379/0",
|
|
|
|
|
description="URL для подключения к Redis"
|
|
|
|
|
)
|
|
|
|
|
CELERY_BROKER_URL: str = Field(
|
|
|
|
|
"redis://localhost:6379/0",
|
|
|
|
|
description="URL брокера для Celery"
|
|
|
|
|
)
|
|
|
|
|
CELERY_RESULT_BACKEND: str = Field(
|
|
|
|
|
"redis://localhost:6379/0",
|
|
|
|
|
description="URL для хранения результатов Celery"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
class Config:
|
|
|
|
|
env_file = ".env"
|
|
|
|
|
env_file_encoding = 'utf-8'
|
|
|
|
|
case_sensitive = False
|
|
|
|
|
|
|
|
|
|
@validator('MODEL_CHECKPOINT')
|
|
|
|
|
def validate_model_checkpoint(cls, v: str) -> str:
|
|
|
|
|
if not v or not isinstance(v, str):
|
|
|
|
|
raise ValueError('MODEL_CHECKPOINT должен быть непустой строкой')
|
|
|
|
|
|
|
|
|
|
if os.path.exists(v):
|
|
|
|
|
# Проверяем, что это директория модели с config.json
|
|
|
|
|
config_path = os.path.join(v, 'config.json')
|
|
|
|
|
if not os.path.exists(config_path):
|
|
|
|
|
raise ValueError(f'В локальной модели по пути "{v}" отсутствует файл config.json')
|
|
|
|
|
elif v.startswith("http"):
|
2024-10-23 00:10:51 +02:00
|
|
|
|
pass
|
2024-10-22 22:51:03 +02:00
|
|
|
|
else:
|
2024-10-23 00:10:51 +02:00
|
|
|
|
pass
|
2024-10-22 22:51:03 +02:00
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
AutoTokenizer.from_pretrained(v, cache_dir=None, force_download=False)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
raise ValueError(f'Невозможно загрузить конфигурацию модели из transformers для "{v}": {e}')
|
|
|
|
|
|
|
|
|
|
return v
|
|
|
|
|
|
|
|
|
|
@validator('USE_CUDA', pre=True)
|
|
|
|
|
def validate_use_cuda(cls, v: Any) -> bool:
|
|
|
|
|
if isinstance(v, bool):
|
|
|
|
|
return v
|
|
|
|
|
if isinstance(v, str):
|
|
|
|
|
if v.lower() in {'true', '1', 'yes', 'y'}:
|
|
|
|
|
return True
|
|
|
|
|
elif v.lower() in {'false', '0', 'no', 'n'}:
|
|
|
|
|
return False
|
|
|
|
|
raise ValueError('USE_CUDA должен быть булевым значением (True/False)')
|
|
|
|
|
|
|
|
|
|
settings = Settings()
|