73 lines
2.9 KiB
Python
73 lines
2.9 KiB
Python
|
# app/core/config.py
|
|||
|
|
|||
|
from pydantic import Field, validator
|
|||
|
from pydantic_settings import BaseSettings
|
|||
|
from typing import Any
|
|||
|
from transformers import AutoTokenizer
|
|||
|
import os
|
|||
|
|
|||
|
class Settings(BaseSettings):
|
|||
|
MODEL_CHECKPOINT: str = Field(
|
|||
|
'cointegrated/rubert-tiny-toxicity',
|
|||
|
description="Контрольная точка модели для оценки токсичности"
|
|||
|
)
|
|||
|
USE_CUDA: bool = Field(
|
|||
|
True,
|
|||
|
description="Использовать ли CUDA для ускорения вычислений"
|
|||
|
)
|
|||
|
REDIS_URL: str = Field(
|
|||
|
"redis://localhost:6379/0",
|
|||
|
description="URL для подключения к Redis"
|
|||
|
)
|
|||
|
CELERY_BROKER_URL: str = Field(
|
|||
|
"redis://localhost:6379/0",
|
|||
|
description="URL брокера для Celery"
|
|||
|
)
|
|||
|
CELERY_RESULT_BACKEND: str = Field(
|
|||
|
"redis://localhost:6379/0",
|
|||
|
description="URL для хранения результатов Celery"
|
|||
|
)
|
|||
|
|
|||
|
class Config:
|
|||
|
env_file = ".env"
|
|||
|
env_file_encoding = 'utf-8'
|
|||
|
case_sensitive = False
|
|||
|
|
|||
|
@validator('MODEL_CHECKPOINT')
|
|||
|
def validate_model_checkpoint(cls, v: str) -> str:
|
|||
|
if not v or not isinstance(v, str):
|
|||
|
raise ValueError('MODEL_CHECKPOINT должен быть непустой строкой')
|
|||
|
|
|||
|
if os.path.exists(v):
|
|||
|
# Проверяем, что это директория модели с config.json
|
|||
|
config_path = os.path.join(v, 'config.json')
|
|||
|
if not os.path.exists(config_path):
|
|||
|
raise ValueError(f'В локальной модели по пути "{v}" отсутствует файл config.json')
|
|||
|
elif v.startswith("http"):
|
|||
|
pass # Предполагаем, что это URL, проверим ниже
|
|||
|
else:
|
|||
|
# Предполагаем, что это название модели в HuggingFace
|
|||
|
pass # Проверим ниже
|
|||
|
|
|||
|
try:
|
|||
|
# Попытка загрузить конфигурацию модели
|
|||
|
# Это не загрузит модель полностью, но проверит доступность модели
|
|||
|
AutoTokenizer.from_pretrained(v, cache_dir=None, force_download=False)
|
|||
|
except Exception as e:
|
|||
|
raise ValueError(f'Невозможно загрузить конфигурацию модели из transformers для "{v}": {e}')
|
|||
|
|
|||
|
return v
|
|||
|
|
|||
|
@validator('USE_CUDA', pre=True)
|
|||
|
def validate_use_cuda(cls, v: Any) -> bool:
|
|||
|
if isinstance(v, bool):
|
|||
|
return v
|
|||
|
if isinstance(v, str):
|
|||
|
if v.lower() in {'true', '1', 'yes', 'y'}:
|
|||
|
return True
|
|||
|
elif v.lower() in {'false', '0', 'no', 'n'}:
|
|||
|
return False
|
|||
|
raise ValueError('USE_CUDA должен быть булевым значением (True/False)')
|
|||
|
|
|||
|
settings = Settings()
|