from transformers import AutoTokenizer, AutoModelForSeq2SeqLM from typing import Dict, List from langchain_core.documents.base import Document from config.settings import settings import torch class ResearchAgent: def __init__(self): """ Initialize the research agent with local Ollama LLM. """ print("Initializing RelevanceChecker with lightweight Hugging Face model...") # Use a smaller, CPU-friendly model by default model_name = getattr(settings, "HF_MODEL_NAME") self.device = "cuda" if torch.cuda.is_available() else "cpu" # Use float32 on CPU (fp16 only works on GPU) torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch_dtype).to(self.device) print(f"Model '{model_name}' loaded on {self.device} with dtype={torch_dtype}.") def sanitize_response(self, response_text: str) -> str: """ Sanitize the LLM's response by stripping unnecessary whitespace. """ return response_text.strip() def generate_prompt(self, question: str, context: str) -> str: """ Generate a structured prompt for the LLM to generate a precise and factual answer. """ prompt = f""" You are an AI assistant designed to provide precise and factual answers based on the given context. **Instructions:** - Answer the following question using only the provided context. - Be clear, concise, and factual. - Return as much information as you can get from the context. **Question:** {question} **Context:** {context} **Provide your answer below:** """ return prompt def generate(self, question: str, documents: List[Document]) -> Dict: """ Generate an initial answer using the provided documents. """ print(f"ResearchAgent.generate called with question='{question}' and {len(documents)} documents.") # Combine the top document contents into one string context = "\n\n".join([doc.page_content for doc in documents]) print(f"Combined context length: {len(context)} characters.") # Create a prompt for the LLM prompt = self.generate_prompt(question, context) print("Prompt created for the LLM.") # Call the LLM to generate the answer try: print("Running inference with Transformers...") inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024).to(self.device) outputs = self.model.generate(**inputs, max_new_tokens=300, temperature=0.3) llm_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) print("Model response received.") except Exception as e: print(f"Error during model inference: {e}") raise RuntimeError("Failed to generate answer due to a model error.") from e # Sanitize the response draft_answer = self.sanitize_response(llm_response) if llm_response else "I cannot answer this question based on the provided documents." print(f"Generated answer: {draft_answer}") return { "draft_answer": draft_answer, "context_used": context }