from transformers import AutoTokenizer, AutoModelForSeq2SeqLM import torch from typing import Dict, List from langchain_core.documents.base import Document from config.settings import settings class VerificationAgent: def __init__(self): """ Initialize the verification agent with LLM. """ print("Initializing RelevanceChecker with lightweight Hugging Face model...") # Use a smaller, CPU-friendly model by default model_name = getattr(settings, "HF_MODEL_NAME") self.device = "cuda" if torch.cuda.is_available() else "cpu" # Use float32 on CPU (fp16 only works on GPU) torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch_dtype).to(self.device) print(f"Model '{model_name}' loaded on {self.device} with dtype={torch_dtype}.") def sanitize_response(self, response_text: str) -> str: """ Sanitize the LLM's response by stripping unnecessary whitespace. """ return response_text.strip() def generate_prompt(self, answer: str, context: str) -> str: """ Generate a structured prompt for the LLM to verify the answer against the context. """ prompt = f"""You are a strict verification agent. Your task is to verify if an answer is supported by the provided context. CRITICAL RULES: 1. ONLY use information from the context provided below. Do NOT use any external knowledge or assumptions. 2. If a claim in the answer is NOT explicitly or implicitly supported by the context, mark it as unsupported. 3. If the answer contradicts information in the context, mark it as a contradiction. 4. If you cannot verify a claim using ONLY the context, mark it as unsupported. 5. Be strict - do not assume or infer beyond what is clearly stated in the context. 6. Respond EXACTLY in the format specified below - no additional text, explanations, or formatting. **VERIFICATION FORMAT (follow exactly):** Supported: YES Unsupported Claims: [] Contradictions: [] Relevant: YES Additional Details: None OR if unsupported/contradictions found: Supported: NO Unsupported Claims: [list each unsupported claim exactly as it appears in the answer] Contradictions: [list each contradiction exactly as it appears] Relevant: YES or NO Additional Details: [brief explanation of why claims are unsupported or contradicted] **Answer to verify:** {answer} **Context (use ONLY this for verification):** {context} **Your verification (respond ONLY with the format above):** """ return prompt def generate_with_hf(self, prompt: str, max_new_tokens=512) -> str: """ Generate output using the local Hugging Face model. """ inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048).to(self.device) outputs = self.model.generate(**inputs, max_new_tokens=max_new_tokens) return self.tokenizer.decode(outputs[0], skip_special_tokens=True) def parse_verification_response(self, response_text: str) -> Dict: """ Parse the LLM's verification response into a structured dictionary. """ try: # Normalize the response - remove markdown formatting, extra whitespace response_text = response_text.strip() # Remove any markdown code blocks if present if response_text.startswith('```'): lines = response_text.split('\n') response_text = '\n'.join(lines[1:-1]) if len(lines) > 2 else response_text print(f"[DEBUG] Parsing verification response (first 500 chars): {response_text[:500]}") verification = {} lines = response_text.split('\n') for line in lines: line = line.strip() if not line or not ':' in line: continue # Split on first colon only parts = line.split(':', 1) if len(parts) != 2: continue key = parts[0].strip() value = parts[1].strip() # Normalize key names (case-insensitive matching) key_lower = key.lower() if 'supported' in key_lower: # Extract YES/NO, handle variations value_upper = value.upper() print(f"[DEBUG] Found 'Supported' key with value: '{value}' (upper: '{value_upper}')") if 'YES' in value_upper or 'TRUE' in value_upper or 'Y' == value_upper.strip(): verification["Supported"] = "YES" print(f"[DEBUG] Set Supported to YES") elif 'NO' in value_upper or 'FALSE' in value_upper or 'N' == value_upper.strip(): verification["Supported"] = "NO" print(f"[DEBUG] Set Supported to NO") else: # If value is empty or unclear, check if there are unsupported claims/contradictions # If no issues found later, default to YES; otherwise NO print(f"[DEBUG] Supported value unclear: '{value}', will decide based on claims/contradictions") verification["Supported"] = None # Mark as undecided elif 'unsupported' in key_lower: # Handle list parsing items = [] value = value.strip() if value.lower() in ['none', 'n/a', '[]', '']: items = [] elif value.startswith('[') and value.endswith(']'): # Parse list items list_content = value[1:-1].strip() if list_content: items = [item.strip().strip('"').strip("'").strip() for item in list_content.split(',') if item.strip()] else: # Single item or comma-separated without brackets items = [item.strip().strip('"').strip("'") for item in value.split(',') if item.strip() and item.strip().lower() not in ['none', 'n/a']] verification["Unsupported Claims"] = items elif 'contradiction' in key_lower: # Handle list parsing (same logic as unsupported) items = [] value = value.strip() if value.lower() in ['none', 'n/a', '[]', '']: items = [] elif value.startswith('[') and value.endswith(']'): list_content = value[1:-1].strip() if list_content: items = [item.strip().strip('"').strip("'").strip() for item in list_content.split(',') if item.strip()] else: items = [item.strip().strip('"').strip("'") for item in value.split(',') if item.strip() and item.strip().lower() not in ['none', 'n/a']] verification["Contradictions"] = items elif 'relevant' in key_lower: value_upper = value.upper() if 'YES' in value_upper or 'TRUE' in value_upper: verification["Relevant"] = "YES" elif 'NO' in value_upper or 'FALSE' in value_upper: verification["Relevant"] = "NO" else: verification["Relevant"] = "YES" # Default to YES if unclear elif 'additional' in key_lower or 'detail' in key_lower: if value.lower() in ['none', 'n/a', '']: verification["Additional Details"] = "" else: verification["Additional Details"] = value # Ensure all required keys are present with defaults if "Supported" not in verification or verification.get("Supported") is None: # If undecided, check if there are unsupported claims or contradictions unsupported_claims = verification.get("Unsupported Claims", []) contradictions = verification.get("Contradictions", []) if not unsupported_claims and not contradictions: verification["Supported"] = "YES" # No issues found, default to YES print(f"[DEBUG] Supported was missing/undecided, but no claims/contradictions found, defaulting to YES") else: verification["Supported"] = "NO" # Issues found, default to NO print(f"[DEBUG] Supported was missing/undecided, but found {len(unsupported_claims)} unsupported claims and {len(contradictions)} contradictions, defaulting to NO") if "Unsupported Claims" not in verification: verification["Unsupported Claims"] = [] if "Contradictions" not in verification: verification["Contradictions"] = [] if "Relevant" not in verification: verification["Relevant"] = "YES" if "Additional Details" not in verification: verification["Additional Details"] = "" print(f"[DEBUG] Final parsed verification: Supported={verification.get('Supported')}, Unsupported Claims={len(verification.get('Unsupported Claims', []))}, Contradictions={len(verification.get('Contradictions', []))}") return verification except Exception as e: print(f"Error parsing verification response: {e}") print(f"Response text was: {response_text}") # Return a safe default return { "Supported": "NO", "Unsupported Claims": [], "Contradictions": [], "Relevant": "NO", "Additional Details": f"Parsing error: {str(e)}" } def format_verification_report(self, verification: Dict) -> str: """ Format the verification report dictionary into a readable markdown-formatted report. """ supported = verification.get("Supported", "NO") unsupported_claims = verification.get("Unsupported Claims", []) contradictions = verification.get("Contradictions", []) relevant = verification.get("Relevant", "NO") additional_details = verification.get("Additional Details", "") # Use markdown formatting for better display report = f"### Verification Report\n\n" # Add status indicators supported_icon = "✅" if supported == "YES" else "❌" report += f"**Supported:** {supported_icon} {supported}\n\n" if unsupported_claims: report += f"**⚠️ Unsupported Claims:**\n" for claim in unsupported_claims: report += f"- {claim}\n" report += "\n" else: report += f"**Unsupported Claims:** None\n\n" if contradictions: report += f"**🔴 Contradictions:**\n" for contradiction in contradictions: report += f"- {contradiction}\n" report += "\n" else: report += f"**Contradictions:** None\n\n" relevant_icon = "✅" if relevant == "YES" else "❌" report += f"**Relevant:** {relevant_icon} {relevant}\n\n" if additional_details and additional_details.lower() not in ['none', 'n/a', '']: report += f"**Additional Details:**\n{additional_details}\n" else: report += f"**Additional Details:** None\n" return report def generate_out_of_context_report(self) -> str: """ Generate a verification report for questions that are out of context. """ verification = { "Supported": "NO", "Unsupported Claims": ["The question is not related to the provided documents."], "Contradictions": [], "Relevant": "NO", "Additional Details": "The question cannot be answered using the provided documents as it is out of context." } return self.format_verification_report(verification) def check(self, answer: str, documents: List[Document]) -> Dict: """ Verify the answer against the provided documents. """ print(f"VerificationAgent.check called with answer='{answer}' and {len(documents)} documents.") # Combine all document contents into one string # Limit context size to prevent token overflow (keep last 8000 chars if too long) context_parts = [doc.page_content for doc in documents] context = "\n\n".join(context_parts) # Truncate context if too long (keep most recent content which is usually more relevant) MAX_CONTEXT_LENGTH = 10000 # Approximate character limit if len(context) > MAX_CONTEXT_LENGTH: print(f"Context too long ({len(context)} chars), truncating to last {MAX_CONTEXT_LENGTH} chars") context = context[-MAX_CONTEXT_LENGTH:] print(f"Combined context length: {len(context)} characters.") # Create a prompt for the LLM to verify the answer prompt = self.generate_prompt(answer, context) print("Prompt created for the LLM.") try: print("Generating response with local Hugging Face model...") llm_response = self.generate_with_hf(prompt) print("LLM response received.") except Exception as e: print(f"Error during model inference: {e}") raise RuntimeError("Failed to verify answer due to a model error.") from e # Sanitize the response sanitized_response = self.sanitize_response(llm_response) if llm_response else "" if not sanitized_response: print("LLM returned an empty response.") verification_report = { "Supported": "NO", "Unsupported Claims": [], "Contradictions": [], "Relevant": "NO", "Additional Details": "Empty response from the model." } else: # Parse the response into the expected format verification_report = self.parse_verification_response(sanitized_response) if verification_report is None: print("LLM did not respond with the expected format. Using default verification report.") verification_report = { "Supported": "NO", "Unsupported Claims": [], "Contradictions": [], "Relevant": "NO", "Additional Details": "Failed to parse the model's response." } # Format the verification report into a paragraph verification_report_formatted = self.format_verification_report(verification_report) print(f"Verification report:\n{verification_report_formatted}") print(f"Context used: {context}") return { "verification_report": verification_report_formatted, "context_used": context }