""" FastAPI service for Czech text correction pipeline Combines grammar error correction and punctuation restoration """ from fastapi import FastAPI, HTTPException, Request from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field from typing import Optional, List, Dict import torch from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForTokenClassification, pipeline import time import re import logging import os from contextlib import asynccontextmanager # Configure CPU threads for model inference (default 12 threads for better performance) num_threads = int(os.environ.get("OMP_NUM_THREADS", 12)) torch.set_num_threads(num_threads) torch.set_num_interop_threads(num_threads) # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) logger.info(f"PyTorch configured to use {num_threads} CPU threads") # Global variables for models gec_model = None gec_tokenizer = None punct_pipeline = None device = None # Optimal hyperparameters for production GEC_CONFIG = { "num_beams": 8, "do_sample": False, "repetition_penalty": 1.0, "length_penalty": 1.0, "no_repeat_ngram_size": 0, "early_stopping": True, "max_new_tokens": 100000 } @asynccontextmanager async def lifespan(app: FastAPI): """Load models on startup, cleanup on shutdown""" global gec_model, gec_tokenizer, punct_pipeline, device logger.info("Loading models...") # Setup device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logger.info(f"Using device: {device}") # Load GEC model logger.info("Loading Czech GEC model...") gec_tokenizer = AutoTokenizer.from_pretrained("ufal/byt5-large-geccc-mate") gec_model = AutoModelForSeq2SeqLM.from_pretrained("ufal/byt5-large-geccc-mate") gec_model = gec_model.to(device) logger.info("GEC model loaded successfully") # Load punctuation model logger.info("Loading punctuation model...") punct_tokenizer = AutoTokenizer.from_pretrained("kredor/punctuate-all") punct_model = AutoModelForTokenClassification.from_pretrained("kredor/punctuate-all") punct_pipeline = pipeline( "token-classification", model=punct_model, tokenizer=punct_tokenizer, device=0 if torch.cuda.is_available() else -1 ) logger.info("Punctuation model loaded successfully") logger.info("All models loaded and ready") yield # Cleanup (if needed) logger.info("Shutting down...") # Create FastAPI app with lifespan app = FastAPI( title="Czech Text Correction API", description="API for Czech grammar error correction and punctuation restoration", version="1.0.0", lifespan=lifespan ) # Enable CORS app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Request/Response models class CorrectionRequest(BaseModel): text: str = Field(..., max_length=100000, description="Czech text to correct") options: Optional[Dict] = Field(default={}, description="Optional parameters") class CorrectionResponse(BaseModel): success: bool corrected_text: str processing_time_ms: Optional[float] = None error: Optional[str] = None class BatchCorrectionRequest(BaseModel): texts: List[str] = Field(..., max_items=10, description="List of texts to correct") options: Optional[Dict] = Field(default={}, description="Optional parameters") class BatchCorrectionResponse(BaseModel): success: bool corrected_texts: List[str] processing_time_ms: Optional[float] = None error: Optional[str] = None class HealthResponse(BaseModel): status: str models_loaded: bool gpu_available: bool device: str class InfoResponse(BaseModel): name: str version: str models: Dict[str, str] capabilities: List[str] max_input_length: int def apply_gec_correction(text: str) -> str: """Apply grammar error correction to text""" if not text.strip(): return text # Tokenize inputs = gec_tokenizer( text, return_tensors="pt", max_length=100000, truncation=True ) inputs = {k: v.to(device) for k, v in inputs.items()} # Generate correction with torch.no_grad(): outputs = gec_model.generate( **inputs, **GEC_CONFIG ) # Decode corrected = gec_tokenizer.decode(outputs[0], skip_special_tokens=True) return corrected def apply_gec_correction_batch(texts: List[str]) -> List[str]: """Apply grammar error correction to multiple texts (batched for GPU efficiency)""" if not texts: return [] # Filter empty texts and track indices non_empty_texts = [] non_empty_indices = [] results = [""] * len(texts) for i, text in enumerate(texts): if text.strip(): non_empty_texts.append(text) non_empty_indices.append(i) else: results[i] = text if not non_empty_texts: return results # Tokenize all texts at once inputs = gec_tokenizer( non_empty_texts, return_tensors="pt", max_length=100000, truncation=True, padding=True ) inputs = {k: v.to(device) for k, v in inputs.items()} # Generate corrections in batch with torch.no_grad(): outputs = gec_model.generate( **inputs, **GEC_CONFIG ) # Decode all outputs corrected_texts = gec_tokenizer.batch_decode(outputs, skip_special_tokens=True) # Map back to original positions for i, corrected in zip(non_empty_indices, corrected_texts): results[i] = corrected return results def apply_punctuation(text: str) -> str: """Apply punctuation and capitalization to text""" if not text.strip(): return text # Process with pipeline clean_text = text.lower() results = punct_pipeline(clean_text) # Build punctuation map punct_map = {} current_word = "" current_punct = "" for i, result in enumerate(results): word = result['word'].replace('▁', '').strip() # Map entity labels to punctuation entity = result['entity'] punct_marks = { 'LABEL_0': '', 'LABEL_1': '.', 'LABEL_2': ',', 'LABEL_3': '?', 'LABEL_4': '-', 'LABEL_5': ':' } punct = punct_marks.get(entity, '') # Handle subword tokens if not result['word'].startswith('▁') and i > 0: current_word += word else: if current_word: punct_map[current_word] = current_punct current_word = word current_punct = punct # Add last word if current_word: punct_map[current_word] = current_punct # Reconstruct with punctuation words = clean_text.split() punctuated = [] for word in words: if word in punct_map and punct_map[word]: punctuated.append(word + punct_map[word]) else: punctuated.append(word) # Join and capitalize sentences result = ' '.join(punctuated) # Capitalize first letter and after sentence endings sentences = re.split(r'(?<=[.?!])\s+', result) capitalized = ' '.join(s[0].upper() + s[1:] if s else s for s in sentences) # Clean spacing around punctuation for p in [',', '.', '?', ':', '!', ';']: capitalized = capitalized.replace(f' {p}', p) return capitalized def apply_punctuation_batch(texts: List[str]) -> List[str]: """Apply punctuation and capitalization to multiple texts (batched for GPU efficiency)""" if not texts: return [] results = [] for text in texts: if not text.strip(): results.append(text) continue # Process with pipeline (pipeline handles batching internally) clean_text = text.lower() pipeline_results = punct_pipeline(clean_text) # Build punctuation map punct_map = {} current_word = "" current_punct = "" for i, result in enumerate(pipeline_results): word = result['word'].replace('▁', '').strip() # Map entity labels to punctuation entity = result['entity'] punct_marks = { 'LABEL_0': '', 'LABEL_1': '.', 'LABEL_2': ',', 'LABEL_3': '?', 'LABEL_4': '-', 'LABEL_5': ':' } punct = punct_marks.get(entity, '') # Handle subword tokens if not result['word'].startswith('▁') and i > 0: current_word += word else: if current_word: punct_map[current_word] = current_punct current_word = word current_punct = punct # Add last word if current_word: punct_map[current_word] = current_punct # Reconstruct with punctuation words = clean_text.split() punctuated = [] for word in words: if word in punct_map and punct_map[word]: punctuated.append(word + punct_map[word]) else: punctuated.append(word) # Join and capitalize sentences result_text = ' '.join(punctuated) # Capitalize first letter and after sentence endings sentences = re.split(r'(?<=[.?!])\s+', result_text) capitalized = ' '.join(s[0].upper() + s[1:] if s else s for s in sentences) # Clean spacing around punctuation for p in [',', '.', '?', ':', '!', ';']: capitalized = capitalized.replace(f' {p}', p) results.append(capitalized) return results def process_text(text: str) -> str: """Full pipeline: GEC + punctuation""" # Step 1: Grammar correction gec_corrected = apply_gec_correction(text) # Step 2: Punctuation and capitalization final_text = apply_punctuation(gec_corrected) return final_text @app.post("/api/correct", response_model=CorrectionResponse) async def correct_text(request: CorrectionRequest): """ Correct Czech text (grammar + punctuation) """ try: start_time = time.time() # Validate input if not request.text.strip(): raise HTTPException(status_code=400, detail="Text cannot be empty") if len(request.text) > 100000: raise HTTPException(status_code=400, detail="Text too long (max 100000 characters)") logger.info(f"📝 Single text request received ({len(request.text)} chars)") # Process text corrected = process_text(request.text) # Calculate processing time processing_time = (time.time() - start_time) * 1000 logger.info(f"✅ Completed in {processing_time:.1f}ms") # Include timing if requested response = CorrectionResponse( success=True, corrected_text=corrected ) if request.options.get("include_timing", False): response.processing_time_ms = processing_time return response except Exception as e: logger.error(f"Error processing text: {str(e)}") return CorrectionResponse( success=False, corrected_text="", error=str(e) ) @app.post("/api/correct/batch", response_model=BatchCorrectionResponse) async def correct_batch(request: BatchCorrectionRequest): """ Correct multiple Czech texts (batched for GPU efficiency) """ try: start_time = time.time() # Validate if not request.texts: raise HTTPException(status_code=400, detail="No texts provided") logger.info(f"📦 Batch request received: {len(request.texts)} texts") # Validate text lengths validated_texts = [] for text in request.texts: if len(text) > 100000: validated_texts.append("") # Will be handled as error later else: validated_texts.append(text) # Process all texts in batch (GPU efficient!) # Step 1: Grammar correction (batched) logger.info(f"🔧 Starting GEC batch processing ({len(validated_texts)} texts)...") gec_start = time.time() gec_corrected_texts = apply_gec_correction_batch(validated_texts) gec_time = (time.time() - gec_start) * 1000 logger.info(f"✓ GEC completed in {gec_time:.1f}ms") # Step 2: Punctuation and capitalization (batched) logger.info(f"📝 Starting punctuation batch processing...") punct_start = time.time() final_texts = apply_punctuation_batch(gec_corrected_texts) punct_time = (time.time() - punct_start) * 1000 logger.info(f"✓ Punctuation completed in {punct_time:.1f}ms") # Mark texts that were too long corrected_texts = [] for i, text in enumerate(request.texts): if len(text) > 100000: corrected_texts.append("[Error: Text too long]") else: corrected_texts.append(final_texts[i]) # Calculate processing time processing_time = (time.time() - start_time) * 1000 logger.info(f"✅ Batch completed: {len(corrected_texts)} texts in {processing_time:.1f}ms (avg {processing_time/len(corrected_texts):.1f}ms/text)") response = BatchCorrectionResponse( success=True, corrected_texts=corrected_texts ) if request.options.get("include_timing", False): response.processing_time_ms = processing_time return response except Exception as e: logger.error(f"Error processing batch: {str(e)}") return BatchCorrectionResponse( success=False, corrected_texts=[], error=str(e) ) @app.post("/api/correct/gec-only") async def correct_gec_only(request: CorrectionRequest): """ Apply only grammar error correction (no punctuation) """ try: corrected = apply_gec_correction(request.text) return CorrectionResponse( success=True, corrected_text=corrected ) except Exception as e: return CorrectionResponse( success=False, corrected_text="", error=str(e) ) @app.post("/api/correct/punct-only") async def correct_punct_only(request: CorrectionRequest): """ Apply only punctuation restoration (no grammar correction) """ try: corrected = apply_punctuation(request.text) return CorrectionResponse( success=True, corrected_text=corrected ) except Exception as e: return CorrectionResponse( success=False, corrected_text="", error=str(e) ) @app.get("/api/health", response_model=HealthResponse) async def health_check(): """ Check API health and model status """ models_loaded = (gec_model is not None and punct_pipeline is not None) return HealthResponse( status="healthy" if models_loaded else "loading", models_loaded=models_loaded, gpu_available=torch.cuda.is_available(), device=str(device) if device else "not initialized" ) @app.get("/api/info", response_model=InfoResponse) async def get_info(): """ Get API information and capabilities """ return InfoResponse( name="Czech Text Correction API", version="1.0.0", models={ "gec": "ufal/byt5-large-geccc-mate", "punctuation": "kredor/punctuate-all" }, capabilities=[ "Grammar error correction", "Punctuation restoration", "Capitalization", "Batch processing", "Czech language focus" ], max_input_length=100000 ) @app.get("/") async def root(): """Root endpoint with API documentation link""" return { "message": "Czech Text Correction API", "docs": "/docs", "health": "/api/health", "info": "/api/info" } if __name__ == "__main__": import uvicorn import os port = int(os.environ.get("PORT", 8042)) uvicorn.run(app, host="0.0.0.0", port=port)