from config import MAX_INPUT_LENGTH, MAX_OUTPUT_LENGTH

class InferenceService:
    def __init__(self, model_loader: "ModelLoader"):
        self.model_loader = model_loader
        self.model = self.model_loader.get_model()
        self.tokenizer = self.model_loader.get_tokenizer()
        self.device = self.model_loader.get_device()
        self.max_input_length = MAX_INPUT_LENGTH
        self.max_output_length = MAX_OUTPUT_LENGTH

    def generate_response(self, prompt: str) -> str:
        full_prompt = f"Запрос: {prompt}\nОжидаемый ответ:"
        inputs = self.tokenizer(
            full_prompt,
            return_tensors="pt",
            truncation=True,
            max_length=self.max_input_length,
            padding="max_length"
        ).to(self.device)

        outputs = self.model.generate(
            inputs.input_ids,
            max_length=self.max_output_length,
            num_beams=5,
            early_stopping=True,
            no_repeat_ngram_size=2
        )
        result = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        return result

    def generate_batch(self, prompts: list) -> list:
        full_prompts = [f"Запрос: {p}\nОжидаемый ответ:" for p in prompts]
        inputs = self.tokenizer(
            full_prompts,
            return_tensors="pt",
            truncation=True,
            max_length=self.max_input_length,
            padding="longest"
        ).to(self.device)

        outputs = self.model.generate(
            inputs.input_ids,
            max_length=self.max_output_length,
            num_beams=5,
            early_stopping=True,
            no_repeat_ngram_size=2
        )
        responses = [self.tokenizer.decode(out, skip_special_tokens=True).replace("Ожидаемый ответ:", "").replace("Запрос:", "Исходный текст:").strip() for out in outputs]
        return responses