krogoldAI's picture
Upload 5 files
decae83 verified
import re
import xml.etree.ElementTree as ET
import math
from typing import Dict, Tuple
class SmoothXMLRewardEvaluator:
"""
Smooth, differentiable XML reward function.
Returns continuous scores between 0.0 and 1.0 for all components.
"""
def __init__(self):
self.composite_weights = {
'structure': 0.30,
'xml_valid': 0.25,
'order': 0.25,
'confidence': 0.18,
'distribution': 0.02
}
def evaluate_structure(self, output: str) -> float:
"""
Check structure (tags present).
"""
required_tags = [
r"<query_analysis>.*</query_analysis>",
r"<domain ambiguous=\"(true|false)\">.*</domain>",
r"<intent ambiguous=\"(true|false)\">.*</intent>",
r"<candidate confidence=\"(?:0\.\d|1\.0)\">.*?</candidate>",
r"<insufficient_context>(true|false)</insufficient_context>",
r"<rephrased>(true|false)</rephrased>",
r"<rephrased_query>.*?</rephrased_query>"
]
hits = sum(bool(re.search(tag, output, re.DOTALL)) for tag in required_tags)
return hits / len(required_tags)
def evaluate_xml_validity(self, output: str) -> float:
"""
Check XML validity.
"""
try:
ET.fromstring(output.strip())
return 1.0
except ET.ParseError:
return 0.0
def evaluate_order(self, output: str) -> float:
"""
Check order of elements (fraction of correct sequence for continuity).
"""
sequence = ["<query_analysis>", "<domain", "<intent", "<insufficient_context>", "<rephrased>", "<rephrased_query>"]
last_index = -1
correct_count = 0
for tag in sequence:
match = re.search(tag, output)
if match:
idx = match.start()
if idx > last_index:
correct_count += 1
last_index = idx
return correct_count / len(sequence)
def evaluate_confidence(self, output: str) -> float:
"""
Check confidence correctness.
"""
score = 0.0
blocks = ['domain', 'intent']
for block_name in blocks:
block_match = re.search(f"<{block_name} ambiguous=\"(true|false)\">.*?</{block_name}>", output, re.DOTALL)
if not block_match:
continue
try:
is_ambiguous = block_match.group(1) == "true"
confidences = [float(c) for c in re.findall(r"<candidate confidence=\"(0\.\d|1\.0)\">", block_match.group(0))]
if not confidences:
continue
if is_ambiguous:
target_sum = 1.0
actual_sum = sum(confidences)
# Continuous score: closer to 1.0 sum → higher reward
score += max(0, 1 - abs(actual_sum - target_sum))
else:
if len(confidences) == 1:
score += 1.0 - abs(confidences[0] - 1.0)
except (ValueError, AttributeError):
continue
return score / len(blocks) if blocks else 0.0
def evaluate_distribution(self, output: str) -> float:
"""
Check confidence distribution (entropy-based, normalized).
"""
total_score = 0.0
blocks_evaluated = 0
for block_name in ['domain', 'intent']:
block_match = re.search(f'<{block_name} ambiguous="(true|false)".*?</{block_name}>', output, re.DOTALL)
if not block_match:
continue
is_ambiguous = block_match.group(1) == "true"
confidences = [float(c) for c in re.findall(r'confidence="([01]\.\d)"', block_match.group(0))]
if not confidences:
continue
blocks_evaluated += 1
if is_ambiguous and len(confidences) > 1:
# entropy normalized
entropy = -sum(p * math.log(p + 1e-8) for p in confidences)
max_entropy = math.log(len(confidences))
total_score += entropy / max_entropy if max_entropy > 0 else 0
else:
total_score += 1.0
return total_score / blocks_evaluated if blocks_evaluated > 0 else 0.0
def structural_penalty(self, output: str) -> float:
"""
Compute structural penalty (soft, subtractive).
"""
penalty = 0.0
try:
root = ET.fromstring(output.strip())
except ET.ParseError:
return 1.0
for tag in root.findall(".//domain") + root.findall(".//intent"):
if "ambiguous" not in tag.attrib:
penalty += 0.1
for cand in root.findall(".//candidate"):
if "confidence" not in cand.attrib:
penalty += 0.05
return min(1.0, penalty)
def answering_penalty(self, output: str) -> float:
"""
Compute answering penalty.
"""
stripped = output.strip()
if stripped.startswith('<query_analysis>') and stripped.endswith('</query_analysis>'):
return 0.0
return 1.0
def evaluate(self, output: str) -> Tuple[float, Dict[str, float]]:
"""
Return a composite evaluation, using intermediary checks.
"""
ap = self.answering_penalty(output)
if ap > 0:
return 0.0, {"answering_penalty": ap}
components = {
"structure": self.evaluate_structure(output),
"xml_valid": self.evaluate_xml_validity(output),
"order": self.evaluate_order(output),
"confidence": self.evaluate_confidence(output),
"distribution": self.evaluate_distribution(output)
}
sp = self.structural_penalty(output)
components["structural_penalty"] = sp
# Subtractive penalty instead of multiplicative
reward = sum(self.composite_weights[k] * v for k, v in components.items() if k != "structural_penalty")
final_score = max(0.0, reward - 0.5 * sp)
components["final_score"] = final_score
return final_score, components
def get_detailed_analysis(self, output: str) -> Dict[str, any]:
"""
Provides detailed analysis of the XML output including all component scores.
"""
final_score, component_scores = self.evaluate(output)
return {
'final_score': final_score,
'component_scores': component_scores,
'weights_used': self.composite_weights,
'recommendations': self._get_recommendations(component_scores)
}
def _get_recommendations(self, component_scores: Dict[str, float]) -> list:
"""
Generate improvement recommendations based on component scores.
"""
recommendations = []
if component_scores.get('structure', 1.0) < 0.85:
recommendations.append("Improve XML structure - ensure all required tags are present")
if component_scores.get('xml_valid', 1.0) < 1.0:
recommendations.append("Fix XML syntax errors - ensure proper tag closing and nesting")
if component_scores.get('order', 1.0) < 0.9:
recommendations.append("Reorder XML elements to match expected structure")
if component_scores.get('confidence', 1.0) < 0.8:
recommendations.append("Fix confidence values - ensure they sum to 1.0 for ambiguous cases")
if component_scores.get('distribution', 1.0) < 0.8:
recommendations.append("Improve confidence distribution balance for ambiguous classifications")
if component_scores.get('structural_penalty', 0.0) > 0.2:
recommendations.append("Address structural issues - missing attributes or malformed tags")
return recommendations