Spaces:
Sleeping
Sleeping
File size: 3,501 Bytes
6445da3 bdc550c 6445da3 bdc550c a2de0b4 bdc550c 6445da3 6acfeaf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from typing import Dict, List
from langchain_core.documents.base import Document
from config.settings import settings
import torch
class ResearchAgent:
def __init__(self):
"""
Initialize the research agent with local Ollama LLM.
"""
print("Initializing RelevanceChecker with lightweight Hugging Face model...")
# Use a smaller, CPU-friendly model by default
model_name = getattr(settings, "HF_MODEL_NAME")
self.device = "cuda" if torch.cuda.is_available() else "cpu"
# Use float32 on CPU (fp16 only works on GPU)
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch_dtype).to(self.device)
print(f"Model '{model_name}' loaded on {self.device} with dtype={torch_dtype}.")
def sanitize_response(self, response_text: str) -> str:
"""
Sanitize the LLM's response by stripping unnecessary whitespace.
"""
return response_text.strip()
def generate_prompt(self, question: str, context: str) -> str:
"""
Generate a structured prompt for the LLM to generate a precise and factual answer.
"""
prompt = f"""
You are an AI assistant designed to provide precise and factual answers based on the given context.
**Instructions:**
- Answer the following question using only the provided context.
- Be clear, concise, and factual.
- Return as much information as you can get from the context.
**Question:** {question}
**Context:**
{context}
**Provide your answer below:**
"""
return prompt
def generate(self, question: str, documents: List[Document]) -> Dict:
"""
Generate an initial answer using the provided documents.
"""
print(f"ResearchAgent.generate called with question='{question}' and {len(documents)} documents.")
# Combine the top document contents into one string
context = "\n\n".join([doc.page_content for doc in documents])
print(f"Combined context length: {len(context)} characters.")
# Create a prompt for the LLM
prompt = self.generate_prompt(question, context)
print("Prompt created for the LLM.")
# Call the LLM to generate the answer
try:
print("Running inference with Transformers...")
inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024).to(self.device)
outputs = self.model.generate(**inputs, max_new_tokens=300, temperature=0.3)
llm_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
print("Model response received.")
except Exception as e:
print(f"Error during model inference: {e}")
raise RuntimeError("Failed to generate answer due to a model error.") from e
# Sanitize the response
draft_answer = self.sanitize_response(llm_response) if llm_response else "I cannot answer this question based on the provided documents."
print(f"Generated answer: {draft_answer}")
return {
"draft_answer": draft_answer,
"context_used": context
} |