# app/model_loader.py
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import PeftModel

class ModelLoader:
    def __init__(self, base_model_name: str, adapter_dir: str, hf_token: str, use_4bit: bool = True):
        self.base_model_name = base_model_name
        self.adapter_dir = adapter_dir
        self.hf_token = hf_token
        self.use_4bit = use_4bit

        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model = None
        self.tokenizer = None

    def load_model(self):
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=self.use_4bit,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.float16,
        )

        base_model = AutoModelForCausalLM.from_pretrained(
            self.base_model_name,
            torch_dtype=torch.float16,
            device_map="auto",
            attn_implementation="eager",
            quantization_config=bnb_config,
            token=self.hf_token
        )

        self.model = PeftModel.from_pretrained(
            base_model,
            self.adapter_dir
        )
        self.model.eval()
        self.model.to(self.device)

    def load_tokenizer(self):
        self.tokenizer = AutoTokenizer.from_pretrained(self.adapter_dir)
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

    def get_model(self):
        if self.model is None:
            self.load_model()
        return self.model

    def get_tokenizer(self):
        if self.tokenizer is None:
            self.load_tokenizer()
        return self.tokenizer

    def get_device(self):
        return self.device