qopscribe/summarize_service/app/inference_service.py

51 lines
1.8 KiB
Python

from config import MAX_INPUT_LENGTH, MAX_OUTPUT_LENGTH
class InferenceService:
def __init__(self, model_loader: "ModelLoader"):
self.model_loader = model_loader
self.model = self.model_loader.get_model()
self.tokenizer = self.model_loader.get_tokenizer()
self.device = self.model_loader.get_device()
self.max_input_length = MAX_INPUT_LENGTH
self.max_output_length = MAX_OUTPUT_LENGTH
def generate_response(self, prompt: str) -> str:
full_prompt = f"Запрос: {prompt}\nОжидаемый ответ:"
inputs = self.tokenizer(
full_prompt,
return_tensors="pt",
truncation=True,
max_length=self.max_input_length,
padding="max_length"
).to(self.device)
outputs = self.model.generate(
inputs.input_ids,
max_length=self.max_output_length,
num_beams=5,
early_stopping=True,
no_repeat_ngram_size=2
)
result = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
return result
def generate_batch(self, prompts: list) -> list:
full_prompts = [f"Запрос: {p}\nОжидаемый ответ:" for p in prompts]
inputs = self.tokenizer(
full_prompts,
return_tensors="pt",
truncation=True,
max_length=self.max_input_length,
padding="longest"
).to(self.device)
outputs = self.model.generate(
inputs.input_ids,
max_length=self.max_output_length,
num_beams=5,
early_stopping=True,
no_repeat_ngram_size=2
)
responses = [self.tokenizer.decode(out, skip_special_tokens=True).replace("Ожидаемый ответ:", "").strip() for out in outputs]
return responses