Upload 2 files
Browse files- api.py +154 -14
- api_client.py +310 -0
api.py
CHANGED
|
@@ -12,11 +12,18 @@ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForToken
|
|
| 12 |
import time
|
| 13 |
import re
|
| 14 |
import logging
|
|
|
|
| 15 |
from contextlib import asynccontextmanager
|
| 16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
# Configure logging
|
| 18 |
logging.basicConfig(level=logging.INFO)
|
| 19 |
logger = logging.getLogger(__name__)
|
|
|
|
| 20 |
|
| 21 |
# Global variables for models
|
| 22 |
gec_model = None
|
|
@@ -32,7 +39,7 @@ GEC_CONFIG = {
|
|
| 32 |
"length_penalty": 1.0,
|
| 33 |
"no_repeat_ngram_size": 0,
|
| 34 |
"early_stopping": True,
|
| 35 |
-
"max_new_tokens":
|
| 36 |
}
|
| 37 |
|
| 38 |
@asynccontextmanager
|
|
@@ -91,7 +98,7 @@ app.add_middleware(
|
|
| 91 |
|
| 92 |
# Request/Response models
|
| 93 |
class CorrectionRequest(BaseModel):
|
| 94 |
-
text: str = Field(..., max_length=
|
| 95 |
options: Optional[Dict] = Field(default={}, description="Optional parameters")
|
| 96 |
|
| 97 |
class CorrectionResponse(BaseModel):
|
|
@@ -132,7 +139,7 @@ def apply_gec_correction(text: str) -> str:
|
|
| 132 |
inputs = gec_tokenizer(
|
| 133 |
text,
|
| 134 |
return_tensors="pt",
|
| 135 |
-
max_length=
|
| 136 |
truncation=True
|
| 137 |
)
|
| 138 |
inputs = {k: v.to(device) for k, v in inputs.items()}
|
|
@@ -148,6 +155,52 @@ def apply_gec_correction(text: str) -> str:
|
|
| 148 |
corrected = gec_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 149 |
return corrected
|
| 150 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
def apply_punctuation(text: str) -> str:
|
| 152 |
"""Apply punctuation and capitalization to text"""
|
| 153 |
if not text.strip():
|
|
@@ -213,6 +266,79 @@ def apply_punctuation(text: str) -> str:
|
|
| 213 |
|
| 214 |
return capitalized
|
| 215 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 216 |
def process_text(text: str) -> str:
|
| 217 |
"""Full pipeline: GEC + punctuation"""
|
| 218 |
# Step 1: Grammar correction
|
|
@@ -235,8 +361,8 @@ async def correct_text(request: CorrectionRequest):
|
|
| 235 |
if not request.text.strip():
|
| 236 |
raise HTTPException(status_code=400, detail="Text cannot be empty")
|
| 237 |
|
| 238 |
-
if len(request.text) >
|
| 239 |
-
raise HTTPException(status_code=400, detail="Text too long (max
|
| 240 |
|
| 241 |
# Process text
|
| 242 |
corrected = process_text(request.text)
|
|
@@ -266,7 +392,7 @@ async def correct_text(request: CorrectionRequest):
|
|
| 266 |
@app.post("/api/correct/batch", response_model=BatchCorrectionResponse)
|
| 267 |
async def correct_batch(request: BatchCorrectionRequest):
|
| 268 |
"""
|
| 269 |
-
Correct multiple Czech texts
|
| 270 |
"""
|
| 271 |
try:
|
| 272 |
start_time = time.time()
|
|
@@ -275,14 +401,28 @@ async def correct_batch(request: BatchCorrectionRequest):
|
|
| 275 |
if not request.texts:
|
| 276 |
raise HTTPException(status_code=400, detail="No texts provided")
|
| 277 |
|
| 278 |
-
#
|
| 279 |
-
|
| 280 |
for text in request.texts:
|
| 281 |
-
if len(text) >
|
| 282 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 283 |
else:
|
| 284 |
-
|
| 285 |
-
corrected_texts.append(corrected)
|
| 286 |
|
| 287 |
# Calculate processing time
|
| 288 |
processing_time = (time.time() - start_time) * 1000
|
|
@@ -374,7 +514,7 @@ async def get_info():
|
|
| 374 |
"Batch processing",
|
| 375 |
"Czech language focus"
|
| 376 |
],
|
| 377 |
-
max_input_length=
|
| 378 |
)
|
| 379 |
|
| 380 |
@app.get("/")
|
|
@@ -390,5 +530,5 @@ async def root():
|
|
| 390 |
if __name__ == "__main__":
|
| 391 |
import uvicorn
|
| 392 |
import os
|
| 393 |
-
port = int(os.environ.get("PORT",
|
| 394 |
uvicorn.run(app, host="0.0.0.0", port=port)
|
|
|
|
| 12 |
import time
|
| 13 |
import re
|
| 14 |
import logging
|
| 15 |
+
import os
|
| 16 |
from contextlib import asynccontextmanager
|
| 17 |
|
| 18 |
+
# Configure CPU threads for model inference (default 12 threads for better performance)
|
| 19 |
+
num_threads = int(os.environ.get("OMP_NUM_THREADS", 12))
|
| 20 |
+
torch.set_num_threads(num_threads)
|
| 21 |
+
torch.set_num_interop_threads(num_threads)
|
| 22 |
+
|
| 23 |
# Configure logging
|
| 24 |
logging.basicConfig(level=logging.INFO)
|
| 25 |
logger = logging.getLogger(__name__)
|
| 26 |
+
logger.info(f"PyTorch configured to use {num_threads} CPU threads")
|
| 27 |
|
| 28 |
# Global variables for models
|
| 29 |
gec_model = None
|
|
|
|
| 39 |
"length_penalty": 1.0,
|
| 40 |
"no_repeat_ngram_size": 0,
|
| 41 |
"early_stopping": True,
|
| 42 |
+
"max_new_tokens": 100000
|
| 43 |
}
|
| 44 |
|
| 45 |
@asynccontextmanager
|
|
|
|
| 98 |
|
| 99 |
# Request/Response models
|
| 100 |
class CorrectionRequest(BaseModel):
|
| 101 |
+
text: str = Field(..., max_length=100000, description="Czech text to correct")
|
| 102 |
options: Optional[Dict] = Field(default={}, description="Optional parameters")
|
| 103 |
|
| 104 |
class CorrectionResponse(BaseModel):
|
|
|
|
| 139 |
inputs = gec_tokenizer(
|
| 140 |
text,
|
| 141 |
return_tensors="pt",
|
| 142 |
+
max_length=100000,
|
| 143 |
truncation=True
|
| 144 |
)
|
| 145 |
inputs = {k: v.to(device) for k, v in inputs.items()}
|
|
|
|
| 155 |
corrected = gec_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 156 |
return corrected
|
| 157 |
|
| 158 |
+
def apply_gec_correction_batch(texts: List[str]) -> List[str]:
|
| 159 |
+
"""Apply grammar error correction to multiple texts (batched for GPU efficiency)"""
|
| 160 |
+
if not texts:
|
| 161 |
+
return []
|
| 162 |
+
|
| 163 |
+
# Filter empty texts and track indices
|
| 164 |
+
non_empty_texts = []
|
| 165 |
+
non_empty_indices = []
|
| 166 |
+
results = [""] * len(texts)
|
| 167 |
+
|
| 168 |
+
for i, text in enumerate(texts):
|
| 169 |
+
if text.strip():
|
| 170 |
+
non_empty_texts.append(text)
|
| 171 |
+
non_empty_indices.append(i)
|
| 172 |
+
else:
|
| 173 |
+
results[i] = text
|
| 174 |
+
|
| 175 |
+
if not non_empty_texts:
|
| 176 |
+
return results
|
| 177 |
+
|
| 178 |
+
# Tokenize all texts at once
|
| 179 |
+
inputs = gec_tokenizer(
|
| 180 |
+
non_empty_texts,
|
| 181 |
+
return_tensors="pt",
|
| 182 |
+
max_length=100000,
|
| 183 |
+
truncation=True,
|
| 184 |
+
padding=True
|
| 185 |
+
)
|
| 186 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
| 187 |
+
|
| 188 |
+
# Generate corrections in batch
|
| 189 |
+
with torch.no_grad():
|
| 190 |
+
outputs = gec_model.generate(
|
| 191 |
+
**inputs,
|
| 192 |
+
**GEC_CONFIG
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
# Decode all outputs
|
| 196 |
+
corrected_texts = gec_tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
| 197 |
+
|
| 198 |
+
# Map back to original positions
|
| 199 |
+
for i, corrected in zip(non_empty_indices, corrected_texts):
|
| 200 |
+
results[i] = corrected
|
| 201 |
+
|
| 202 |
+
return results
|
| 203 |
+
|
| 204 |
def apply_punctuation(text: str) -> str:
|
| 205 |
"""Apply punctuation and capitalization to text"""
|
| 206 |
if not text.strip():
|
|
|
|
| 266 |
|
| 267 |
return capitalized
|
| 268 |
|
| 269 |
+
def apply_punctuation_batch(texts: List[str]) -> List[str]:
|
| 270 |
+
"""Apply punctuation and capitalization to multiple texts (batched for GPU efficiency)"""
|
| 271 |
+
if not texts:
|
| 272 |
+
return []
|
| 273 |
+
|
| 274 |
+
results = []
|
| 275 |
+
for text in texts:
|
| 276 |
+
if not text.strip():
|
| 277 |
+
results.append(text)
|
| 278 |
+
continue
|
| 279 |
+
|
| 280 |
+
# Process with pipeline (pipeline handles batching internally)
|
| 281 |
+
clean_text = text.lower()
|
| 282 |
+
pipeline_results = punct_pipeline(clean_text)
|
| 283 |
+
|
| 284 |
+
# Build punctuation map
|
| 285 |
+
punct_map = {}
|
| 286 |
+
current_word = ""
|
| 287 |
+
current_punct = ""
|
| 288 |
+
|
| 289 |
+
for i, result in enumerate(pipeline_results):
|
| 290 |
+
word = result['word'].replace('▁', '').strip()
|
| 291 |
+
|
| 292 |
+
# Map entity labels to punctuation
|
| 293 |
+
entity = result['entity']
|
| 294 |
+
punct_marks = {
|
| 295 |
+
'LABEL_0': '',
|
| 296 |
+
'LABEL_1': '.',
|
| 297 |
+
'LABEL_2': ',',
|
| 298 |
+
'LABEL_3': '?',
|
| 299 |
+
'LABEL_4': '-',
|
| 300 |
+
'LABEL_5': ':'
|
| 301 |
+
}
|
| 302 |
+
punct = punct_marks.get(entity, '')
|
| 303 |
+
|
| 304 |
+
# Handle subword tokens
|
| 305 |
+
if not result['word'].startswith('▁') and i > 0:
|
| 306 |
+
current_word += word
|
| 307 |
+
else:
|
| 308 |
+
if current_word:
|
| 309 |
+
punct_map[current_word] = current_punct
|
| 310 |
+
current_word = word
|
| 311 |
+
current_punct = punct
|
| 312 |
+
|
| 313 |
+
# Add last word
|
| 314 |
+
if current_word:
|
| 315 |
+
punct_map[current_word] = current_punct
|
| 316 |
+
|
| 317 |
+
# Reconstruct with punctuation
|
| 318 |
+
words = clean_text.split()
|
| 319 |
+
punctuated = []
|
| 320 |
+
|
| 321 |
+
for word in words:
|
| 322 |
+
if word in punct_map and punct_map[word]:
|
| 323 |
+
punctuated.append(word + punct_map[word])
|
| 324 |
+
else:
|
| 325 |
+
punctuated.append(word)
|
| 326 |
+
|
| 327 |
+
# Join and capitalize sentences
|
| 328 |
+
result_text = ' '.join(punctuated)
|
| 329 |
+
|
| 330 |
+
# Capitalize first letter and after sentence endings
|
| 331 |
+
sentences = re.split(r'(?<=[.?!])\s+', result_text)
|
| 332 |
+
capitalized = ' '.join(s[0].upper() + s[1:] if s else s for s in sentences)
|
| 333 |
+
|
| 334 |
+
# Clean spacing around punctuation
|
| 335 |
+
for p in [',', '.', '?', ':', '!', ';']:
|
| 336 |
+
capitalized = capitalized.replace(f' {p}', p)
|
| 337 |
+
|
| 338 |
+
results.append(capitalized)
|
| 339 |
+
|
| 340 |
+
return results
|
| 341 |
+
|
| 342 |
def process_text(text: str) -> str:
|
| 343 |
"""Full pipeline: GEC + punctuation"""
|
| 344 |
# Step 1: Grammar correction
|
|
|
|
| 361 |
if not request.text.strip():
|
| 362 |
raise HTTPException(status_code=400, detail="Text cannot be empty")
|
| 363 |
|
| 364 |
+
if len(request.text) > 100000:
|
| 365 |
+
raise HTTPException(status_code=400, detail="Text too long (max 100000 characters)")
|
| 366 |
|
| 367 |
# Process text
|
| 368 |
corrected = process_text(request.text)
|
|
|
|
| 392 |
@app.post("/api/correct/batch", response_model=BatchCorrectionResponse)
|
| 393 |
async def correct_batch(request: BatchCorrectionRequest):
|
| 394 |
"""
|
| 395 |
+
Correct multiple Czech texts (batched for GPU efficiency)
|
| 396 |
"""
|
| 397 |
try:
|
| 398 |
start_time = time.time()
|
|
|
|
| 401 |
if not request.texts:
|
| 402 |
raise HTTPException(status_code=400, detail="No texts provided")
|
| 403 |
|
| 404 |
+
# Validate text lengths
|
| 405 |
+
validated_texts = []
|
| 406 |
for text in request.texts:
|
| 407 |
+
if len(text) > 100000:
|
| 408 |
+
validated_texts.append("") # Will be handled as error later
|
| 409 |
+
else:
|
| 410 |
+
validated_texts.append(text)
|
| 411 |
+
|
| 412 |
+
# Process all texts in batch (GPU efficient!)
|
| 413 |
+
# Step 1: Grammar correction (batched)
|
| 414 |
+
gec_corrected_texts = apply_gec_correction_batch(validated_texts)
|
| 415 |
+
|
| 416 |
+
# Step 2: Punctuation and capitalization (batched)
|
| 417 |
+
final_texts = apply_punctuation_batch(gec_corrected_texts)
|
| 418 |
+
|
| 419 |
+
# Mark texts that were too long
|
| 420 |
+
corrected_texts = []
|
| 421 |
+
for i, text in enumerate(request.texts):
|
| 422 |
+
if len(text) > 100000:
|
| 423 |
+
corrected_texts.append("[Error: Text too long]")
|
| 424 |
else:
|
| 425 |
+
corrected_texts.append(final_texts[i])
|
|
|
|
| 426 |
|
| 427 |
# Calculate processing time
|
| 428 |
processing_time = (time.time() - start_time) * 1000
|
|
|
|
| 514 |
"Batch processing",
|
| 515 |
"Czech language focus"
|
| 516 |
],
|
| 517 |
+
max_input_length=100000
|
| 518 |
)
|
| 519 |
|
| 520 |
@app.get("/")
|
|
|
|
| 530 |
if __name__ == "__main__":
|
| 531 |
import uvicorn
|
| 532 |
import os
|
| 533 |
+
port = int(os.environ.get("PORT", 8042))
|
| 534 |
uvicorn.run(app, host="0.0.0.0", port=port)
|
api_client.py
ADDED
|
@@ -0,0 +1,310 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Client for Czech text correction API with local server auto-start
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import requests
|
| 6 |
+
import time
|
| 7 |
+
import subprocess
|
| 8 |
+
import os
|
| 9 |
+
import sys
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import Optional, Dict, List, Any
|
| 12 |
+
import logging
|
| 13 |
+
|
| 14 |
+
# Configure logging
|
| 15 |
+
logging.basicConfig(level=logging.INFO)
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
class CzechCorrectionClient:
|
| 19 |
+
"""Client for Czech text correction with automatic local server startup"""
|
| 20 |
+
|
| 21 |
+
# Local API endpoint only
|
| 22 |
+
LOCAL_ENDPOINT = {
|
| 23 |
+
"name": "Local",
|
| 24 |
+
"base_url": "http://localhost:8042",
|
| 25 |
+
"timeout": 3600 # 1 hour for local (grammar correction can be slow)
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
def __init__(self, prefer_local: bool = True):
|
| 29 |
+
"""
|
| 30 |
+
Initialize the client
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
prefer_local: Deprecated, always uses local API now
|
| 34 |
+
"""
|
| 35 |
+
self.endpoint = self.LOCAL_ENDPOINT
|
| 36 |
+
self._working_endpoint = None
|
| 37 |
+
self._last_health_check = 0
|
| 38 |
+
self.health_check_interval = 3600 # Cache endpoint for 1 hour
|
| 39 |
+
self._server_process = None
|
| 40 |
+
|
| 41 |
+
def _check_endpoint_health(self, endpoint: Dict) -> bool:
|
| 42 |
+
"""Check if an endpoint is healthy"""
|
| 43 |
+
try:
|
| 44 |
+
response = requests.get(
|
| 45 |
+
f"{endpoint['base_url']}/api/health",
|
| 46 |
+
timeout=10 # Increased timeout for health check
|
| 47 |
+
)
|
| 48 |
+
if response.status_code == 200:
|
| 49 |
+
data = response.json()
|
| 50 |
+
return data.get('status') == 'healthy'
|
| 51 |
+
except Exception as e:
|
| 52 |
+
logger.debug(f"Health check failed for {endpoint['name']}: {e}")
|
| 53 |
+
return False
|
| 54 |
+
|
| 55 |
+
def _is_port_in_use(self, port: int) -> bool:
|
| 56 |
+
"""Check if a port is already in use"""
|
| 57 |
+
import socket
|
| 58 |
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
| 59 |
+
try:
|
| 60 |
+
s.bind(('localhost', port))
|
| 61 |
+
return False
|
| 62 |
+
except OSError:
|
| 63 |
+
return True
|
| 64 |
+
|
| 65 |
+
def _start_local_server(self) -> bool:
|
| 66 |
+
"""Start the local API server if not already running"""
|
| 67 |
+
try:
|
| 68 |
+
# Check if port 8042 is already in use
|
| 69 |
+
if self._is_port_in_use(8042):
|
| 70 |
+
logger.warning("Port 8042 is already in use - server may already be running")
|
| 71 |
+
# Wait a bit and check health again
|
| 72 |
+
time.sleep(2)
|
| 73 |
+
if self._check_endpoint_health(self.endpoint):
|
| 74 |
+
logger.info("✅ Server is already running on port 8042")
|
| 75 |
+
return True
|
| 76 |
+
else:
|
| 77 |
+
logger.error("Port 8042 is in use but server is not responding to health checks")
|
| 78 |
+
return False
|
| 79 |
+
|
| 80 |
+
# Find the api_service directory
|
| 81 |
+
current_file = Path(__file__).resolve()
|
| 82 |
+
api_service_dir = current_file.parent
|
| 83 |
+
api_script = api_service_dir / "api.py"
|
| 84 |
+
|
| 85 |
+
if not api_script.exists():
|
| 86 |
+
logger.error(f"API script not found at {api_script}")
|
| 87 |
+
return False
|
| 88 |
+
|
| 89 |
+
logger.info("Starting local API server...")
|
| 90 |
+
logger.info("This may take 1-2 minutes to load models...")
|
| 91 |
+
|
| 92 |
+
# Start the server in the background
|
| 93 |
+
env = os.environ.copy()
|
| 94 |
+
env['PORT'] = '8042' # Set port to 8042
|
| 95 |
+
|
| 96 |
+
self._server_process = subprocess.Popen(
|
| 97 |
+
[sys.executable, str(api_script)],
|
| 98 |
+
cwd=str(api_service_dir),
|
| 99 |
+
env=env,
|
| 100 |
+
stdout=subprocess.PIPE,
|
| 101 |
+
stderr=subprocess.PIPE,
|
| 102 |
+
start_new_session=True
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
# Wait for server to be ready (up to 2 minutes)
|
| 106 |
+
max_wait = 120
|
| 107 |
+
start_time = time.time()
|
| 108 |
+
|
| 109 |
+
while time.time() - start_time < max_wait:
|
| 110 |
+
if self._check_endpoint_health(self.endpoint):
|
| 111 |
+
logger.info("✅ Local API server started successfully")
|
| 112 |
+
return True
|
| 113 |
+
time.sleep(2)
|
| 114 |
+
|
| 115 |
+
logger.error("Server failed to start within timeout")
|
| 116 |
+
return False
|
| 117 |
+
|
| 118 |
+
except Exception as e:
|
| 119 |
+
logger.error(f"Failed to start local server: {e}")
|
| 120 |
+
return False
|
| 121 |
+
|
| 122 |
+
def _get_working_endpoint(self) -> Optional[Dict]:
|
| 123 |
+
"""Get working endpoint, starting server if needed"""
|
| 124 |
+
current_time = time.time()
|
| 125 |
+
|
| 126 |
+
# Use cached endpoint if still valid
|
| 127 |
+
if self._working_endpoint and (current_time - self._last_health_check < self.health_check_interval):
|
| 128 |
+
return self._working_endpoint
|
| 129 |
+
|
| 130 |
+
# Check if local server is running
|
| 131 |
+
if self._check_endpoint_health(self.endpoint):
|
| 132 |
+
logger.info(f"Using {self.endpoint['name']} API endpoint")
|
| 133 |
+
self._working_endpoint = self.endpoint
|
| 134 |
+
self._last_health_check = current_time
|
| 135 |
+
return self.endpoint
|
| 136 |
+
|
| 137 |
+
# Try to start the server
|
| 138 |
+
logger.info("Local API server not running, attempting to start...")
|
| 139 |
+
if self._start_local_server():
|
| 140 |
+
self._working_endpoint = self.endpoint
|
| 141 |
+
self._last_health_check = current_time
|
| 142 |
+
return self.endpoint
|
| 143 |
+
|
| 144 |
+
logger.error("Could not start or connect to local API server")
|
| 145 |
+
return None
|
| 146 |
+
|
| 147 |
+
def correct_text(self, text: str, include_timing: bool = False) -> Dict[str, Any]:
|
| 148 |
+
"""
|
| 149 |
+
Correct Czech text (grammar and punctuation)
|
| 150 |
+
|
| 151 |
+
Args:
|
| 152 |
+
text: Text to correct
|
| 153 |
+
include_timing: Whether to include processing time in response
|
| 154 |
+
|
| 155 |
+
Returns:
|
| 156 |
+
Dict with 'success', 'corrected_text', and optionally 'processing_time_ms'
|
| 157 |
+
"""
|
| 158 |
+
if not text or not text.strip():
|
| 159 |
+
return {
|
| 160 |
+
"success": True,
|
| 161 |
+
"corrected_text": text,
|
| 162 |
+
"error": None
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
endpoint = self._get_working_endpoint()
|
| 166 |
+
if not endpoint:
|
| 167 |
+
return {
|
| 168 |
+
"success": False,
|
| 169 |
+
"corrected_text": text,
|
| 170 |
+
"error": "Could not start or connect to local API server"
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
try:
|
| 174 |
+
payload = {
|
| 175 |
+
"text": text,
|
| 176 |
+
"options": {"include_timing": include_timing}
|
| 177 |
+
}
|
| 178 |
+
|
| 179 |
+
response = requests.post(
|
| 180 |
+
f"{endpoint['base_url']}/api/correct",
|
| 181 |
+
json=payload,
|
| 182 |
+
timeout=endpoint['timeout']
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
if response.status_code == 200:
|
| 186 |
+
return response.json()
|
| 187 |
+
else:
|
| 188 |
+
return {
|
| 189 |
+
"success": False,
|
| 190 |
+
"corrected_text": text,
|
| 191 |
+
"error": f"API error: {response.status_code}"
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
+
except requests.exceptions.Timeout:
|
| 195 |
+
logger.warning(f"Timeout on {endpoint['name']} API")
|
| 196 |
+
return {
|
| 197 |
+
"success": False,
|
| 198 |
+
"corrected_text": text,
|
| 199 |
+
"error": "Request timeout"
|
| 200 |
+
}
|
| 201 |
+
|
| 202 |
+
except Exception as e:
|
| 203 |
+
logger.error(f"Error calling API: {e}")
|
| 204 |
+
return {
|
| 205 |
+
"success": False,
|
| 206 |
+
"corrected_text": text,
|
| 207 |
+
"error": str(e)
|
| 208 |
+
}
|
| 209 |
+
|
| 210 |
+
def correct_batch(self, texts: List[str], include_timing: bool = False) -> Dict[str, Any]:
|
| 211 |
+
"""
|
| 212 |
+
Correct multiple Czech texts in batch
|
| 213 |
+
|
| 214 |
+
Args:
|
| 215 |
+
texts: List of texts to correct (max 10)
|
| 216 |
+
include_timing: Whether to include processing time
|
| 217 |
+
|
| 218 |
+
Returns:
|
| 219 |
+
Dict with 'success', 'corrected_texts', and optionally 'processing_time_ms'
|
| 220 |
+
"""
|
| 221 |
+
if not texts:
|
| 222 |
+
return {
|
| 223 |
+
"success": True,
|
| 224 |
+
"corrected_texts": [],
|
| 225 |
+
"error": None
|
| 226 |
+
}
|
| 227 |
+
|
| 228 |
+
if len(texts) > 10:
|
| 229 |
+
return {
|
| 230 |
+
"success": False,
|
| 231 |
+
"corrected_texts": texts,
|
| 232 |
+
"error": "Batch size exceeds limit (10)"
|
| 233 |
+
}
|
| 234 |
+
|
| 235 |
+
endpoint = self._get_working_endpoint()
|
| 236 |
+
if not endpoint:
|
| 237 |
+
return {
|
| 238 |
+
"success": False,
|
| 239 |
+
"corrected_texts": texts,
|
| 240 |
+
"error": "Could not start or connect to local API server"
|
| 241 |
+
}
|
| 242 |
+
|
| 243 |
+
try:
|
| 244 |
+
payload = {
|
| 245 |
+
"texts": texts,
|
| 246 |
+
"options": {"include_timing": include_timing}
|
| 247 |
+
}
|
| 248 |
+
|
| 249 |
+
response = requests.post(
|
| 250 |
+
f"{endpoint['base_url']}/api/correct/batch",
|
| 251 |
+
json=payload,
|
| 252 |
+
timeout=endpoint['timeout'] * 2 # Longer timeout for batch
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
if response.status_code == 200:
|
| 256 |
+
return response.json()
|
| 257 |
+
else:
|
| 258 |
+
# Fallback to individual corrections
|
| 259 |
+
logger.warning(f"Batch API failed, falling back to individual corrections")
|
| 260 |
+
corrected_texts = []
|
| 261 |
+
for text in texts:
|
| 262 |
+
result = self.correct_text(text, include_timing=False)
|
| 263 |
+
corrected_texts.append(result.get('corrected_text', text))
|
| 264 |
+
|
| 265 |
+
return {
|
| 266 |
+
"success": True,
|
| 267 |
+
"corrected_texts": corrected_texts,
|
| 268 |
+
"error": None
|
| 269 |
+
}
|
| 270 |
+
|
| 271 |
+
except Exception as e:
|
| 272 |
+
logger.error(f"Error calling batch API: {e}")
|
| 273 |
+
# Fallback to individual corrections
|
| 274 |
+
corrected_texts = []
|
| 275 |
+
for text in texts:
|
| 276 |
+
result = self.correct_text(text, include_timing=False)
|
| 277 |
+
corrected_texts.append(result.get('corrected_text', text))
|
| 278 |
+
|
| 279 |
+
return {
|
| 280 |
+
"success": True,
|
| 281 |
+
"corrected_texts": corrected_texts,
|
| 282 |
+
"error": None
|
| 283 |
+
}
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
# Convenience functions for backward compatibility
|
| 287 |
+
_default_client = None
|
| 288 |
+
|
| 289 |
+
def get_client(prefer_local: bool = True) -> CzechCorrectionClient:
|
| 290 |
+
"""Get or create the default client (always uses local now)"""
|
| 291 |
+
global _default_client
|
| 292 |
+
if _default_client is None:
|
| 293 |
+
_default_client = CzechCorrectionClient(prefer_local=True)
|
| 294 |
+
return _default_client
|
| 295 |
+
|
| 296 |
+
def correct_text(text: str, prefer_local: bool = True) -> str:
|
| 297 |
+
"""Simple function for text correction (always uses local now)"""
|
| 298 |
+
client = get_client(prefer_local=True)
|
| 299 |
+
result = client.correct_text(text)
|
| 300 |
+
if result['success']:
|
| 301 |
+
return result['corrected_text']
|
| 302 |
+
return text
|
| 303 |
+
|
| 304 |
+
def correct_batch(texts: List[str], prefer_local: bool = True) -> List[str]:
|
| 305 |
+
"""Simple function for batch correction (always uses local now)"""
|
| 306 |
+
client = get_client(prefer_local=True)
|
| 307 |
+
result = client.correct_batch(texts)
|
| 308 |
+
if result['success']:
|
| 309 |
+
return result.get('corrected_texts', texts)
|
| 310 |
+
return texts
|