File size: 3,501 Bytes
6445da3
 
 
 
 
 
 
 
 
 
 
 
bdc550c
6445da3
bdc550c
a2de0b4
bdc550c
 
 
 
 
 
 
 
 
 
6445da3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6acfeaf
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from typing import Dict, List
from langchain_core.documents.base import Document
from config.settings import settings
import torch

class ResearchAgent:
    def __init__(self):
        """
        Initialize the research agent with local Ollama LLM.
        """

        print("Initializing RelevanceChecker with lightweight Hugging Face model...")

        # Use a smaller, CPU-friendly model by default
        model_name = getattr(settings, "HF_MODEL_NAME")
        
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        
        # Use float32 on CPU (fp16 only works on GPU)
        torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
        
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch_dtype).to(self.device)
        
        print(f"Model '{model_name}' loaded on {self.device} with dtype={torch_dtype}.")

        
    def sanitize_response(self, response_text: str) -> str:
        """
        Sanitize the LLM's response by stripping unnecessary whitespace.
        """
        return response_text.strip()

    def generate_prompt(self, question: str, context: str) -> str:
        """
        Generate a structured prompt for the LLM to generate a precise and factual answer.
        """
        prompt = f"""
        You are an AI assistant designed to provide precise and factual answers based on the given context.

        **Instructions:**
        - Answer the following question using only the provided context.
        - Be clear, concise, and factual.
        - Return as much information as you can get from the context.
        
        **Question:** {question}
        **Context:**
        {context}

        **Provide your answer below:**
        """
        return prompt

    def generate(self, question: str, documents: List[Document]) -> Dict:
        """
        Generate an initial answer using the provided documents.
        """
        print(f"ResearchAgent.generate called with question='{question}' and {len(documents)} documents.")

        # Combine the top document contents into one string
        context = "\n\n".join([doc.page_content for doc in documents])
        print(f"Combined context length: {len(context)} characters.")

        # Create a prompt for the LLM
        prompt = self.generate_prompt(question, context)
        print("Prompt created for the LLM.")

        # Call the LLM to generate the answer
        try:
            print("Running inference with Transformers...")
            inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024).to(self.device)
            outputs = self.model.generate(**inputs, max_new_tokens=300, temperature=0.3)
            llm_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
            print("Model response received.")
        except Exception as e:
            print(f"Error during model inference: {e}")
            raise RuntimeError("Failed to generate answer due to a model error.") from e

        # Sanitize the response
        draft_answer = self.sanitize_response(llm_response) if llm_response else "I cannot answer this question based on the provided documents."

        print(f"Generated answer: {draft_answer}")

        return {
            "draft_answer": draft_answer,
            "context_used": context
        }