Spaces:
Runtime error
Runtime error
| import difflib | |
| import logging | |
| from typing import List | |
| import torch | |
| from app.services.base import load_hf_pipeline | |
| from app.core.config import settings | |
| from app.core.exceptions import ServiceError | |
| logger = logging.getLogger(f"{settings.APP_NAME}.services.grammar") | |
| class GrammarCorrector: | |
| def __init__(self): | |
| self._pipeline = None | |
| def _get_pipeline(self): | |
| if self._pipeline is None: | |
| logger.info("Loading grammar correction pipeline...") | |
| self._pipeline = load_hf_pipeline( | |
| model_id=settings.GRAMMAR_MODEL_ID, | |
| task="text2text-generation", | |
| feature_name="Grammar Correction" | |
| ) | |
| return self._pipeline | |
| async def correct(self, text: str) -> dict: | |
| text = text.strip() | |
| if not text: | |
| raise ServiceError(status_code=400, detail="Input text is empty for grammar correction.") | |
| try: | |
| pipeline = self._get_pipeline() | |
| result = pipeline(text, max_length=512, num_beams=4, early_stopping=True) | |
| corrected = result[0]["generated_text"].strip() | |
| if not corrected: | |
| raise ServiceError(status_code=500, detail="Failed to decode grammar correction output.") | |
| issues = self.get_diff_issues(text, corrected) | |
| return { | |
| "original_text": text, | |
| "corrected_text_suggestion": corrected, | |
| "issues": issues | |
| } | |
| except Exception as e: | |
| logger.error(f"Grammar correction error for input: '{text[:50]}...'", exc_info=True) | |
| raise ServiceError(status_code=500, detail="An internal error occurred during grammar correction.") from e | |
| def get_diff_issues(self, original: str, corrected: str) -> List[dict]: | |
| def safe_slice(s: str, start: int, end: int) -> str: | |
| return s[max(0, start):min(len(s), end)] | |
| matcher = difflib.SequenceMatcher(None, original, corrected) | |
| issues = [] | |
| for tag, i1, i2, j1, j2 in matcher.get_opcodes(): | |
| if tag == "equal": | |
| continue | |
| issues.append({ | |
| "offset": i1, | |
| "length": i2 - i1, | |
| "original_segment": original[i1:i2], | |
| "suggested_segment": corrected[j1:j2], | |
| "context_before": safe_slice(original, i1 - 15, i1), | |
| "context_after": safe_slice(original, i2, i2 + 15), | |
| "message": "Grammar correction", | |
| "line": original[:i1].count("\n") + 1, | |
| "column": (i1 - original[:i1].rfind("\n") - 1) if "\n" in original[:i1] else i1 + 1 | |
| }) | |
| return issues | |