qopscribe/summarize_service/app/model_loader.py

57 lines
1.7 KiB
Python

# app/model_loader.py
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import PeftModel
class ModelLoader:
def __init__(self, base_model_name: str, adapter_dir: str, hf_token: str, use_4bit: bool = True):
self.base_model_name = base_model_name
self.adapter_dir = adapter_dir
self.hf_token = hf_token
self.use_4bit = use_4bit
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = None
self.tokenizer = None
def load_model(self):
bnb_config = BitsAndBytesConfig(
load_in_4bit=self.use_4bit,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)
base_model = AutoModelForCausalLM.from_pretrained(
self.base_model_name,
torch_dtype=torch.float16,
device_map="auto",
attn_implementation="eager",
quantization_config=bnb_config,
token=self.hf_token
)
self.model = PeftModel.from_pretrained(
base_model,
self.adapter_dir
)
self.model.eval()
self.model.to(self.device)
def load_tokenizer(self):
self.tokenizer = AutoTokenizer.from_pretrained(self.adapter_dir)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
def get_model(self):
if self.model is None:
self.load_model()
return self.model
def get_tokenizer(self):
if self.tokenizer is None:
self.load_tokenizer()
return self.tokenizer
def get_device(self):
return self.device