69 lines
2.5 KiB
Python
69 lines
2.5 KiB
Python
from pydantic import Field, validator
|
||
from pydantic_settings import BaseSettings
|
||
from typing import Any
|
||
from transformers import AutoTokenizer
|
||
import os
|
||
|
||
|
||
class Settings(BaseSettings):
|
||
MODEL_CHECKPOINT: str = Field(
|
||
'cointegrated/rubert-tiny-toxicity',
|
||
description="Контрольная точка модели для оценки токсичности"
|
||
)
|
||
USE_CUDA: bool = Field(
|
||
True,
|
||
description="Использовать ли CUDA для ускорения вычислений"
|
||
)
|
||
REDIS_URL: str = Field(
|
||
"redis://localhost:6379/0",
|
||
description="URL для подключения к Redis"
|
||
)
|
||
CELERY_BROKER_URL: str = Field(
|
||
"redis://localhost:6379/0",
|
||
description="URL брокера для Celery"
|
||
)
|
||
CELERY_RESULT_BACKEND: str = Field(
|
||
"redis://localhost:6379/0",
|
||
description="URL для хранения результатов Celery"
|
||
)
|
||
|
||
class Config:
|
||
env_file = ".env"
|
||
env_file_encoding = 'utf-8'
|
||
case_sensitive = False
|
||
|
||
@validator('MODEL_CHECKPOINT')
|
||
def validate_model_checkpoint(cls, v: str) -> str:
|
||
if not v or not isinstance(v, str):
|
||
raise ValueError('MODEL_CHECKPOINT должен быть непустой строкой')
|
||
|
||
if os.path.exists(v):
|
||
# Проверяем, что это директория модели с config.json
|
||
config_path = os.path.join(v, 'config.json')
|
||
if not os.path.exists(config_path):
|
||
raise ValueError(f'В локальной модели по пути "{v}" отсутствует файл config.json')
|
||
elif v.startswith("http"):
|
||
pass
|
||
else:
|
||
pass
|
||
|
||
try:
|
||
AutoTokenizer.from_pretrained(v, cache_dir=None, force_download=False)
|
||
except Exception as e:
|
||
raise ValueError(f'Невозможно загрузить конфигурацию модели из transformers для "{v}": {e}')
|
||
|
||
return v
|
||
|
||
@validator('USE_CUDA', pre=True)
|
||
def validate_use_cuda(cls, v: Any) -> bool:
|
||
if isinstance(v, bool):
|
||
return v
|
||
if isinstance(v, str):
|
||
if v.lower() in {'true', '1', 'yes', 'y'}:
|
||
return True
|
||
elif v.lower() in {'false', '0', 'no', 'n'}:
|
||
return False
|
||
raise ValueError('USE_CUDA должен быть булевым значением (True/False)')
|
||
|
||
settings = Settings()
|