|
|
import re
|
|
|
import xml.etree.ElementTree as ET
|
|
|
import math
|
|
|
from typing import Dict, Tuple
|
|
|
|
|
|
|
|
|
class SmoothXMLRewardEvaluator:
|
|
|
"""
|
|
|
Smooth, differentiable XML reward function.
|
|
|
Returns continuous scores between 0.0 and 1.0 for all components.
|
|
|
"""
|
|
|
|
|
|
def __init__(self):
|
|
|
self.composite_weights = {
|
|
|
'structure': 0.30,
|
|
|
'xml_valid': 0.25,
|
|
|
'order': 0.25,
|
|
|
'confidence': 0.18,
|
|
|
'distribution': 0.02
|
|
|
}
|
|
|
|
|
|
def evaluate_structure(self, output: str) -> float:
|
|
|
"""
|
|
|
Check structure (tags present).
|
|
|
"""
|
|
|
required_tags = [
|
|
|
r"<query_analysis>.*</query_analysis>",
|
|
|
r"<domain ambiguous=\"(true|false)\">.*</domain>",
|
|
|
r"<intent ambiguous=\"(true|false)\">.*</intent>",
|
|
|
r"<candidate confidence=\"(?:0\.\d|1\.0)\">.*?</candidate>",
|
|
|
r"<insufficient_context>(true|false)</insufficient_context>",
|
|
|
r"<rephrased>(true|false)</rephrased>",
|
|
|
r"<rephrased_query>.*?</rephrased_query>"
|
|
|
]
|
|
|
hits = sum(bool(re.search(tag, output, re.DOTALL)) for tag in required_tags)
|
|
|
return hits / len(required_tags)
|
|
|
|
|
|
def evaluate_xml_validity(self, output: str) -> float:
|
|
|
"""
|
|
|
Check XML validity.
|
|
|
"""
|
|
|
try:
|
|
|
ET.fromstring(output.strip())
|
|
|
return 1.0
|
|
|
except ET.ParseError:
|
|
|
return 0.0
|
|
|
|
|
|
def evaluate_order(self, output: str) -> float:
|
|
|
"""
|
|
|
Check order of elements (fraction of correct sequence for continuity).
|
|
|
"""
|
|
|
sequence = ["<query_analysis>", "<domain", "<intent", "<insufficient_context>", "<rephrased>", "<rephrased_query>"]
|
|
|
last_index = -1
|
|
|
correct_count = 0
|
|
|
for tag in sequence:
|
|
|
match = re.search(tag, output)
|
|
|
if match:
|
|
|
idx = match.start()
|
|
|
if idx > last_index:
|
|
|
correct_count += 1
|
|
|
last_index = idx
|
|
|
return correct_count / len(sequence)
|
|
|
|
|
|
def evaluate_confidence(self, output: str) -> float:
|
|
|
"""
|
|
|
Check confidence correctness.
|
|
|
"""
|
|
|
score = 0.0
|
|
|
blocks = ['domain', 'intent']
|
|
|
for block_name in blocks:
|
|
|
block_match = re.search(f"<{block_name} ambiguous=\"(true|false)\">.*?</{block_name}>", output, re.DOTALL)
|
|
|
if not block_match:
|
|
|
continue
|
|
|
try:
|
|
|
is_ambiguous = block_match.group(1) == "true"
|
|
|
confidences = [float(c) for c in re.findall(r"<candidate confidence=\"(0\.\d|1\.0)\">", block_match.group(0))]
|
|
|
if not confidences:
|
|
|
continue
|
|
|
if is_ambiguous:
|
|
|
target_sum = 1.0
|
|
|
actual_sum = sum(confidences)
|
|
|
|
|
|
score += max(0, 1 - abs(actual_sum - target_sum))
|
|
|
else:
|
|
|
if len(confidences) == 1:
|
|
|
score += 1.0 - abs(confidences[0] - 1.0)
|
|
|
except (ValueError, AttributeError):
|
|
|
continue
|
|
|
return score / len(blocks) if blocks else 0.0
|
|
|
|
|
|
def evaluate_distribution(self, output: str) -> float:
|
|
|
"""
|
|
|
Check confidence distribution (entropy-based, normalized).
|
|
|
"""
|
|
|
total_score = 0.0
|
|
|
blocks_evaluated = 0
|
|
|
for block_name in ['domain', 'intent']:
|
|
|
block_match = re.search(f'<{block_name} ambiguous="(true|false)".*?</{block_name}>', output, re.DOTALL)
|
|
|
if not block_match:
|
|
|
continue
|
|
|
is_ambiguous = block_match.group(1) == "true"
|
|
|
confidences = [float(c) for c in re.findall(r'confidence="([01]\.\d)"', block_match.group(0))]
|
|
|
if not confidences:
|
|
|
continue
|
|
|
blocks_evaluated += 1
|
|
|
if is_ambiguous and len(confidences) > 1:
|
|
|
|
|
|
entropy = -sum(p * math.log(p + 1e-8) for p in confidences)
|
|
|
max_entropy = math.log(len(confidences))
|
|
|
total_score += entropy / max_entropy if max_entropy > 0 else 0
|
|
|
else:
|
|
|
total_score += 1.0
|
|
|
return total_score / blocks_evaluated if blocks_evaluated > 0 else 0.0
|
|
|
|
|
|
def structural_penalty(self, output: str) -> float:
|
|
|
"""
|
|
|
Compute structural penalty (soft, subtractive).
|
|
|
"""
|
|
|
penalty = 0.0
|
|
|
try:
|
|
|
root = ET.fromstring(output.strip())
|
|
|
except ET.ParseError:
|
|
|
return 1.0
|
|
|
for tag in root.findall(".//domain") + root.findall(".//intent"):
|
|
|
if "ambiguous" not in tag.attrib:
|
|
|
penalty += 0.1
|
|
|
for cand in root.findall(".//candidate"):
|
|
|
if "confidence" not in cand.attrib:
|
|
|
penalty += 0.05
|
|
|
return min(1.0, penalty)
|
|
|
|
|
|
def answering_penalty(self, output: str) -> float:
|
|
|
"""
|
|
|
Compute answering penalty.
|
|
|
"""
|
|
|
stripped = output.strip()
|
|
|
if stripped.startswith('<query_analysis>') and stripped.endswith('</query_analysis>'):
|
|
|
return 0.0
|
|
|
return 1.0
|
|
|
|
|
|
def evaluate(self, output: str) -> Tuple[float, Dict[str, float]]:
|
|
|
"""
|
|
|
Return a composite evaluation, using intermediary checks.
|
|
|
"""
|
|
|
ap = self.answering_penalty(output)
|
|
|
if ap > 0:
|
|
|
return 0.0, {"answering_penalty": ap}
|
|
|
|
|
|
components = {
|
|
|
"structure": self.evaluate_structure(output),
|
|
|
"xml_valid": self.evaluate_xml_validity(output),
|
|
|
"order": self.evaluate_order(output),
|
|
|
"confidence": self.evaluate_confidence(output),
|
|
|
"distribution": self.evaluate_distribution(output)
|
|
|
}
|
|
|
|
|
|
sp = self.structural_penalty(output)
|
|
|
components["structural_penalty"] = sp
|
|
|
|
|
|
|
|
|
reward = sum(self.composite_weights[k] * v for k, v in components.items() if k != "structural_penalty")
|
|
|
final_score = max(0.0, reward - 0.5 * sp)
|
|
|
|
|
|
components["final_score"] = final_score
|
|
|
return final_score, components
|
|
|
|
|
|
def get_detailed_analysis(self, output: str) -> Dict[str, any]:
|
|
|
"""
|
|
|
Provides detailed analysis of the XML output including all component scores.
|
|
|
"""
|
|
|
final_score, component_scores = self.evaluate(output)
|
|
|
|
|
|
return {
|
|
|
'final_score': final_score,
|
|
|
'component_scores': component_scores,
|
|
|
'weights_used': self.composite_weights,
|
|
|
'recommendations': self._get_recommendations(component_scores)
|
|
|
}
|
|
|
|
|
|
def _get_recommendations(self, component_scores: Dict[str, float]) -> list:
|
|
|
"""
|
|
|
Generate improvement recommendations based on component scores.
|
|
|
"""
|
|
|
recommendations = []
|
|
|
|
|
|
if component_scores.get('structure', 1.0) < 0.85:
|
|
|
recommendations.append("Improve XML structure - ensure all required tags are present")
|
|
|
|
|
|
if component_scores.get('xml_valid', 1.0) < 1.0:
|
|
|
recommendations.append("Fix XML syntax errors - ensure proper tag closing and nesting")
|
|
|
|
|
|
if component_scores.get('order', 1.0) < 0.9:
|
|
|
recommendations.append("Reorder XML elements to match expected structure")
|
|
|
|
|
|
if component_scores.get('confidence', 1.0) < 0.8:
|
|
|
recommendations.append("Fix confidence values - ensure they sum to 1.0 for ambiguous cases")
|
|
|
|
|
|
if component_scores.get('distribution', 1.0) < 0.8:
|
|
|
recommendations.append("Improve confidence distribution balance for ambiguous classifications")
|
|
|
|
|
|
if component_scores.get('structural_penalty', 0.0) > 0.2:
|
|
|
recommendations.append("Address structural issues - missing attributes or malformed tags")
|
|
|
|
|
|
return recommendations
|
|
|
|