from transformers import AutoTokenizer, AutoModelForSeq2SeqLM from config.settings import settings import torch import logging logger = logging.getLogger(__name__) class RelevanceChecker: def __init__(self): # Initialize the Hugging Face LLM print("Initializing RelevanceChecker with lightweight Hugging Face model...") # Use a smaller, CPU-friendly model by default model_name = getattr(settings, "HF_MODEL_NAME") self.device = "cuda" if torch.cuda.is_available() else "cpu" # Use float32 on CPU (fp16 only works on GPU) torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch_dtype).to(self.device) print(f"Model '{model_name}' loaded on {self.device} with dtype={torch_dtype}.") def check(self, question: str, retriever, k=3) -> str: """ 1. Retrieve the top-k document chunks from the global retriever. 2. Combine them into a single text string. 3. Pass that text + question to the LLM for classification. Returns: "CAN_ANSWER", "PARTIAL", or "NO_MATCH". """ logger.debug(f"RelevanceChecker.check called with question='{question}' and k={k}") # Retrieve doc chunks from the ensemble retriever top_docs = retriever.invoke(question) if not top_docs: logger.debug("No documents returned from retriever.invoke(). Classifying as NO_MATCH.") return "NO_MATCH" # Combine the top k chunk texts into one string document_content = "\n\n".join(doc.page_content for doc in top_docs[:k]) # Create a prompt for the LLM to classify relevance prompt = f""" You are an AI relevance checker between a user's question and provided document content. **Instructions:** - Classify how well the document content addresses the user's question. - Respond with only one of the following labels: CAN_ANSWER, PARTIAL, NO_MATCH. - Do not include any additional text or explanation. **Labels:** 1) "CAN_ANSWER": The passages contain enough explicit information to fully answer the question. 2) "PARTIAL": The passages mention or discuss the question's topic but do not provide all the details needed for a complete answer. 3) "NO_MATCH": The passages do not discuss or mention the question's topic at all. **Important:** If the passages mention or reference the topic or timeframe of the question in any way, even if incomplete, respond with "PARTIAL" instead of "NO_MATCH". **Question:** {question} **Passages:** {document_content} **Respond ONLY with one of the following labels: CAN_ANSWER, PARTIAL, NO_MATCH** """ # Call the LLM try: inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024).to(self.device) outputs = self.model.generate(**inputs, max_new_tokens=10) llm_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True).strip().upper() except Exception as e: logger.error(f"Error during model inference: {e}") return "NO_MATCH" logger.debug(f"LLM response: {llm_response}") # Validate the response valid_labels = {"CAN_ANSWER", "PARTIAL", "NO_MATCH"} if llm_response not in valid_labels: logger.debug("LLM did not respond with a valid label. Forcing 'NO_MATCH'.") classification = "NO_MATCH" else: logger.debug(f"Classification recognized as '{llm_response}'.") classification = llm_response print(f"Checker response: {classification}") return classification