Spaces:
Sleeping
Sleeping
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| from config.settings import settings | |
| import torch | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| class RelevanceChecker: | |
| def __init__(self): | |
| # Initialize the Hugging Face LLM | |
| print("Initializing RelevanceChecker with lightweight Hugging Face model...") | |
| # Use a smaller, CPU-friendly model by default | |
| model_name = getattr(settings, "HF_MODEL_NAME") | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Use float32 on CPU (fp16 only works on GPU) | |
| torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch_dtype).to(self.device) | |
| print(f"Model '{model_name}' loaded on {self.device} with dtype={torch_dtype}.") | |
| def check(self, question: str, retriever, k=3) -> str: | |
| """ | |
| 1. Retrieve the top-k document chunks from the global retriever. | |
| 2. Combine them into a single text string. | |
| 3. Pass that text + question to the LLM for classification. | |
| Returns: "CAN_ANSWER", "PARTIAL", or "NO_MATCH". | |
| """ | |
| logger.debug(f"RelevanceChecker.check called with question='{question}' and k={k}") | |
| # Retrieve doc chunks from the ensemble retriever | |
| top_docs = retriever.invoke(question) | |
| if not top_docs: | |
| logger.debug("No documents returned from retriever.invoke(). Classifying as NO_MATCH.") | |
| return "NO_MATCH" | |
| # Combine the top k chunk texts into one string | |
| document_content = "\n\n".join(doc.page_content for doc in top_docs[:k]) | |
| # Create a prompt for the LLM to classify relevance | |
| prompt = f""" | |
| You are an AI relevance checker between a user's question and provided document content. | |
| **Instructions:** | |
| - Classify how well the document content addresses the user's question. | |
| - Respond with only one of the following labels: CAN_ANSWER, PARTIAL, NO_MATCH. | |
| - Do not include any additional text or explanation. | |
| **Labels:** | |
| 1) "CAN_ANSWER": The passages contain enough explicit information to fully answer the question. | |
| 2) "PARTIAL": The passages mention or discuss the question's topic but do not provide all the details needed for a complete answer. | |
| 3) "NO_MATCH": The passages do not discuss or mention the question's topic at all. | |
| **Important:** If the passages mention or reference the topic or timeframe of the question in any way, even if incomplete, respond with "PARTIAL" instead of "NO_MATCH". | |
| **Question:** {question} | |
| **Passages:** {document_content} | |
| **Respond ONLY with one of the following labels: CAN_ANSWER, PARTIAL, NO_MATCH** | |
| """ | |
| # Call the LLM | |
| try: | |
| inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024).to(self.device) | |
| outputs = self.model.generate(**inputs, max_new_tokens=10) | |
| llm_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True).strip().upper() | |
| except Exception as e: | |
| logger.error(f"Error during model inference: {e}") | |
| return "NO_MATCH" | |
| logger.debug(f"LLM response: {llm_response}") | |
| # Validate the response | |
| valid_labels = {"CAN_ANSWER", "PARTIAL", "NO_MATCH"} | |
| if llm_response not in valid_labels: | |
| logger.debug("LLM did not respond with a valid label. Forcing 'NO_MATCH'.") | |
| classification = "NO_MATCH" | |
| else: | |
| logger.debug(f"Classification recognized as '{llm_response}'.") | |
| classification = llm_response | |
| print(f"Checker response: {classification}") | |
| return classification |