Spaces:
Sleeping
Sleeping
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| from typing import Dict, List | |
| from langchain_core.documents.base import Document | |
| from config.settings import settings | |
| import torch | |
| class ResearchAgent: | |
| def __init__(self): | |
| """ | |
| Initialize the research agent with local Ollama LLM. | |
| """ | |
| print("Initializing RelevanceChecker with lightweight Hugging Face model...") | |
| # Use a smaller, CPU-friendly model by default | |
| model_name = getattr(settings, "HF_MODEL_NAME") | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Use float32 on CPU (fp16 only works on GPU) | |
| torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch_dtype).to(self.device) | |
| print(f"Model '{model_name}' loaded on {self.device} with dtype={torch_dtype}.") | |
| def sanitize_response(self, response_text: str) -> str: | |
| """ | |
| Sanitize the LLM's response by stripping unnecessary whitespace. | |
| """ | |
| return response_text.strip() | |
| def generate_prompt(self, question: str, context: str) -> str: | |
| """ | |
| Generate a structured prompt for the LLM to generate a precise and factual answer. | |
| """ | |
| prompt = f""" | |
| You are an AI assistant designed to provide precise and factual answers based on the given context. | |
| **Instructions:** | |
| - Answer the following question using only the provided context. | |
| - Be clear, concise, and factual. | |
| - Return as much information as you can get from the context. | |
| **Question:** {question} | |
| **Context:** | |
| {context} | |
| **Provide your answer below:** | |
| """ | |
| return prompt | |
| def generate(self, question: str, documents: List[Document]) -> Dict: | |
| """ | |
| Generate an initial answer using the provided documents. | |
| """ | |
| print(f"ResearchAgent.generate called with question='{question}' and {len(documents)} documents.") | |
| # Combine the top document contents into one string | |
| context = "\n\n".join([doc.page_content for doc in documents]) | |
| print(f"Combined context length: {len(context)} characters.") | |
| # Create a prompt for the LLM | |
| prompt = self.generate_prompt(question, context) | |
| print("Prompt created for the LLM.") | |
| # Call the LLM to generate the answer | |
| try: | |
| print("Running inference with Transformers...") | |
| inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024).to(self.device) | |
| outputs = self.model.generate(**inputs, max_new_tokens=300, temperature=0.3) | |
| llm_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| print("Model response received.") | |
| except Exception as e: | |
| print(f"Error during model inference: {e}") | |
| raise RuntimeError("Failed to generate answer due to a model error.") from e | |
| # Sanitize the response | |
| draft_answer = self.sanitize_response(llm_response) if llm_response else "I cannot answer this question based on the provided documents." | |
| print(f"Generated answer: {draft_answer}") | |
| return { | |
| "draft_answer": draft_answer, | |
| "context_used": context | |
| } |