51 lines
1.9 KiB
Python
51 lines
1.9 KiB
Python
from config import MAX_INPUT_LENGTH, MAX_OUTPUT_LENGTH
|
|
|
|
class InferenceService:
|
|
def __init__(self, model_loader: "ModelLoader"):
|
|
self.model_loader = model_loader
|
|
self.model = self.model_loader.get_model()
|
|
self.tokenizer = self.model_loader.get_tokenizer()
|
|
self.device = self.model_loader.get_device()
|
|
self.max_input_length = MAX_INPUT_LENGTH
|
|
self.max_output_length = MAX_OUTPUT_LENGTH
|
|
|
|
def generate_response(self, prompt: str) -> str:
|
|
full_prompt = f"Запрос: {prompt}\nОжидаемый ответ:"
|
|
inputs = self.tokenizer(
|
|
full_prompt,
|
|
return_tensors="pt",
|
|
truncation=True,
|
|
max_length=self.max_input_length,
|
|
padding="max_length"
|
|
).to(self.device)
|
|
|
|
outputs = self.model.generate(
|
|
inputs.input_ids,
|
|
max_length=self.max_output_length,
|
|
num_beams=5,
|
|
early_stopping=True,
|
|
no_repeat_ngram_size=2
|
|
)
|
|
result = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
return result
|
|
|
|
def generate_batch(self, prompts: list) -> list:
|
|
full_prompts = [f"Запрос: {p}\nОжидаемый ответ:" for p in prompts]
|
|
inputs = self.tokenizer(
|
|
full_prompts,
|
|
return_tensors="pt",
|
|
truncation=True,
|
|
max_length=self.max_input_length,
|
|
padding="longest"
|
|
).to(self.device)
|
|
|
|
outputs = self.model.generate(
|
|
inputs.input_ids,
|
|
max_length=self.max_output_length,
|
|
num_beams=5,
|
|
early_stopping=True,
|
|
no_repeat_ngram_size=2
|
|
)
|
|
responses = [self.tokenizer.decode(out, skip_special_tokens=True).replace("Ожидаемый ответ:", "").replace("Запрос:", "Исходный текст:").strip() for out in outputs]
|
|
return responses
|