GeekBot / agents /relevance_checker.py
abrar-adnan's picture
Update agents/relevance_checker.py
48451c4 verified
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from config.settings import settings
import torch
import logging
logger = logging.getLogger(__name__)
class RelevanceChecker:
def __init__(self):
# Initialize the Hugging Face LLM
print("Initializing RelevanceChecker with lightweight Hugging Face model...")
# Use a smaller, CPU-friendly model by default
model_name = getattr(settings, "HF_MODEL_NAME")
self.device = "cuda" if torch.cuda.is_available() else "cpu"
# Use float32 on CPU (fp16 only works on GPU)
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch_dtype).to(self.device)
print(f"Model '{model_name}' loaded on {self.device} with dtype={torch_dtype}.")
def check(self, question: str, retriever, k=3) -> str:
"""
1. Retrieve the top-k document chunks from the global retriever.
2. Combine them into a single text string.
3. Pass that text + question to the LLM for classification.
Returns: "CAN_ANSWER", "PARTIAL", or "NO_MATCH".
"""
logger.debug(f"RelevanceChecker.check called with question='{question}' and k={k}")
# Retrieve doc chunks from the ensemble retriever
top_docs = retriever.invoke(question)
if not top_docs:
logger.debug("No documents returned from retriever.invoke(). Classifying as NO_MATCH.")
return "NO_MATCH"
# Combine the top k chunk texts into one string
document_content = "\n\n".join(doc.page_content for doc in top_docs[:k])
# Create a prompt for the LLM to classify relevance
prompt = f"""
You are an AI relevance checker between a user's question and provided document content.
**Instructions:**
- Classify how well the document content addresses the user's question.
- Respond with only one of the following labels: CAN_ANSWER, PARTIAL, NO_MATCH.
- Do not include any additional text or explanation.
**Labels:**
1) "CAN_ANSWER": The passages contain enough explicit information to fully answer the question.
2) "PARTIAL": The passages mention or discuss the question's topic but do not provide all the details needed for a complete answer.
3) "NO_MATCH": The passages do not discuss or mention the question's topic at all.
**Important:** If the passages mention or reference the topic or timeframe of the question in any way, even if incomplete, respond with "PARTIAL" instead of "NO_MATCH".
**Question:** {question}
**Passages:** {document_content}
**Respond ONLY with one of the following labels: CAN_ANSWER, PARTIAL, NO_MATCH**
"""
# Call the LLM
try:
inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024).to(self.device)
outputs = self.model.generate(**inputs, max_new_tokens=10)
llm_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True).strip().upper()
except Exception as e:
logger.error(f"Error during model inference: {e}")
return "NO_MATCH"
logger.debug(f"LLM response: {llm_response}")
# Validate the response
valid_labels = {"CAN_ANSWER", "PARTIAL", "NO_MATCH"}
if llm_response not in valid_labels:
logger.debug("LLM did not respond with a valid label. Forcing 'NO_MATCH'.")
classification = "NO_MATCH"
else:
logger.debug(f"Classification recognized as '{llm_response}'.")
classification = llm_response
print(f"Checker response: {classification}")
return classification