asdfasdfdsafdsa commited on
Commit
508f678
·
verified ·
1 Parent(s): 1601325

Upload 2 files

Browse files
Files changed (2) hide show
  1. api.py +154 -14
  2. api_client.py +310 -0
api.py CHANGED
@@ -12,11 +12,18 @@ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForToken
12
  import time
13
  import re
14
  import logging
 
15
  from contextlib import asynccontextmanager
16
 
 
 
 
 
 
17
  # Configure logging
18
  logging.basicConfig(level=logging.INFO)
19
  logger = logging.getLogger(__name__)
 
20
 
21
  # Global variables for models
22
  gec_model = None
@@ -32,7 +39,7 @@ GEC_CONFIG = {
32
  "length_penalty": 1.0,
33
  "no_repeat_ngram_size": 0,
34
  "early_stopping": True,
35
- "max_new_tokens": 1500
36
  }
37
 
38
  @asynccontextmanager
@@ -91,7 +98,7 @@ app.add_middleware(
91
 
92
  # Request/Response models
93
  class CorrectionRequest(BaseModel):
94
- text: str = Field(..., max_length=5000, description="Czech text to correct")
95
  options: Optional[Dict] = Field(default={}, description="Optional parameters")
96
 
97
  class CorrectionResponse(BaseModel):
@@ -132,7 +139,7 @@ def apply_gec_correction(text: str) -> str:
132
  inputs = gec_tokenizer(
133
  text,
134
  return_tensors="pt",
135
- max_length=1024,
136
  truncation=True
137
  )
138
  inputs = {k: v.to(device) for k, v in inputs.items()}
@@ -148,6 +155,52 @@ def apply_gec_correction(text: str) -> str:
148
  corrected = gec_tokenizer.decode(outputs[0], skip_special_tokens=True)
149
  return corrected
150
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  def apply_punctuation(text: str) -> str:
152
  """Apply punctuation and capitalization to text"""
153
  if not text.strip():
@@ -213,6 +266,79 @@ def apply_punctuation(text: str) -> str:
213
 
214
  return capitalized
215
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
  def process_text(text: str) -> str:
217
  """Full pipeline: GEC + punctuation"""
218
  # Step 1: Grammar correction
@@ -235,8 +361,8 @@ async def correct_text(request: CorrectionRequest):
235
  if not request.text.strip():
236
  raise HTTPException(status_code=400, detail="Text cannot be empty")
237
 
238
- if len(request.text) > 5000:
239
- raise HTTPException(status_code=400, detail="Text too long (max 5000 characters)")
240
 
241
  # Process text
242
  corrected = process_text(request.text)
@@ -266,7 +392,7 @@ async def correct_text(request: CorrectionRequest):
266
  @app.post("/api/correct/batch", response_model=BatchCorrectionResponse)
267
  async def correct_batch(request: BatchCorrectionRequest):
268
  """
269
- Correct multiple Czech texts
270
  """
271
  try:
272
  start_time = time.time()
@@ -275,14 +401,28 @@ async def correct_batch(request: BatchCorrectionRequest):
275
  if not request.texts:
276
  raise HTTPException(status_code=400, detail="No texts provided")
277
 
278
- # Process each text
279
- corrected_texts = []
280
  for text in request.texts:
281
- if len(text) > 5000:
282
- corrected_texts.append(f"[Error: Text too long]")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
  else:
284
- corrected = process_text(text)
285
- corrected_texts.append(corrected)
286
 
287
  # Calculate processing time
288
  processing_time = (time.time() - start_time) * 1000
@@ -374,7 +514,7 @@ async def get_info():
374
  "Batch processing",
375
  "Czech language focus"
376
  ],
377
- max_input_length=5000
378
  )
379
 
380
  @app.get("/")
@@ -390,5 +530,5 @@ async def root():
390
  if __name__ == "__main__":
391
  import uvicorn
392
  import os
393
- port = int(os.environ.get("PORT", 7860))
394
  uvicorn.run(app, host="0.0.0.0", port=port)
 
12
  import time
13
  import re
14
  import logging
15
+ import os
16
  from contextlib import asynccontextmanager
17
 
18
+ # Configure CPU threads for model inference (default 12 threads for better performance)
19
+ num_threads = int(os.environ.get("OMP_NUM_THREADS", 12))
20
+ torch.set_num_threads(num_threads)
21
+ torch.set_num_interop_threads(num_threads)
22
+
23
  # Configure logging
24
  logging.basicConfig(level=logging.INFO)
25
  logger = logging.getLogger(__name__)
26
+ logger.info(f"PyTorch configured to use {num_threads} CPU threads")
27
 
28
  # Global variables for models
29
  gec_model = None
 
39
  "length_penalty": 1.0,
40
  "no_repeat_ngram_size": 0,
41
  "early_stopping": True,
42
+ "max_new_tokens": 100000
43
  }
44
 
45
  @asynccontextmanager
 
98
 
99
  # Request/Response models
100
  class CorrectionRequest(BaseModel):
101
+ text: str = Field(..., max_length=100000, description="Czech text to correct")
102
  options: Optional[Dict] = Field(default={}, description="Optional parameters")
103
 
104
  class CorrectionResponse(BaseModel):
 
139
  inputs = gec_tokenizer(
140
  text,
141
  return_tensors="pt",
142
+ max_length=100000,
143
  truncation=True
144
  )
145
  inputs = {k: v.to(device) for k, v in inputs.items()}
 
155
  corrected = gec_tokenizer.decode(outputs[0], skip_special_tokens=True)
156
  return corrected
157
 
158
+ def apply_gec_correction_batch(texts: List[str]) -> List[str]:
159
+ """Apply grammar error correction to multiple texts (batched for GPU efficiency)"""
160
+ if not texts:
161
+ return []
162
+
163
+ # Filter empty texts and track indices
164
+ non_empty_texts = []
165
+ non_empty_indices = []
166
+ results = [""] * len(texts)
167
+
168
+ for i, text in enumerate(texts):
169
+ if text.strip():
170
+ non_empty_texts.append(text)
171
+ non_empty_indices.append(i)
172
+ else:
173
+ results[i] = text
174
+
175
+ if not non_empty_texts:
176
+ return results
177
+
178
+ # Tokenize all texts at once
179
+ inputs = gec_tokenizer(
180
+ non_empty_texts,
181
+ return_tensors="pt",
182
+ max_length=100000,
183
+ truncation=True,
184
+ padding=True
185
+ )
186
+ inputs = {k: v.to(device) for k, v in inputs.items()}
187
+
188
+ # Generate corrections in batch
189
+ with torch.no_grad():
190
+ outputs = gec_model.generate(
191
+ **inputs,
192
+ **GEC_CONFIG
193
+ )
194
+
195
+ # Decode all outputs
196
+ corrected_texts = gec_tokenizer.batch_decode(outputs, skip_special_tokens=True)
197
+
198
+ # Map back to original positions
199
+ for i, corrected in zip(non_empty_indices, corrected_texts):
200
+ results[i] = corrected
201
+
202
+ return results
203
+
204
  def apply_punctuation(text: str) -> str:
205
  """Apply punctuation and capitalization to text"""
206
  if not text.strip():
 
266
 
267
  return capitalized
268
 
269
+ def apply_punctuation_batch(texts: List[str]) -> List[str]:
270
+ """Apply punctuation and capitalization to multiple texts (batched for GPU efficiency)"""
271
+ if not texts:
272
+ return []
273
+
274
+ results = []
275
+ for text in texts:
276
+ if not text.strip():
277
+ results.append(text)
278
+ continue
279
+
280
+ # Process with pipeline (pipeline handles batching internally)
281
+ clean_text = text.lower()
282
+ pipeline_results = punct_pipeline(clean_text)
283
+
284
+ # Build punctuation map
285
+ punct_map = {}
286
+ current_word = ""
287
+ current_punct = ""
288
+
289
+ for i, result in enumerate(pipeline_results):
290
+ word = result['word'].replace('▁', '').strip()
291
+
292
+ # Map entity labels to punctuation
293
+ entity = result['entity']
294
+ punct_marks = {
295
+ 'LABEL_0': '',
296
+ 'LABEL_1': '.',
297
+ 'LABEL_2': ',',
298
+ 'LABEL_3': '?',
299
+ 'LABEL_4': '-',
300
+ 'LABEL_5': ':'
301
+ }
302
+ punct = punct_marks.get(entity, '')
303
+
304
+ # Handle subword tokens
305
+ if not result['word'].startswith('▁') and i > 0:
306
+ current_word += word
307
+ else:
308
+ if current_word:
309
+ punct_map[current_word] = current_punct
310
+ current_word = word
311
+ current_punct = punct
312
+
313
+ # Add last word
314
+ if current_word:
315
+ punct_map[current_word] = current_punct
316
+
317
+ # Reconstruct with punctuation
318
+ words = clean_text.split()
319
+ punctuated = []
320
+
321
+ for word in words:
322
+ if word in punct_map and punct_map[word]:
323
+ punctuated.append(word + punct_map[word])
324
+ else:
325
+ punctuated.append(word)
326
+
327
+ # Join and capitalize sentences
328
+ result_text = ' '.join(punctuated)
329
+
330
+ # Capitalize first letter and after sentence endings
331
+ sentences = re.split(r'(?<=[.?!])\s+', result_text)
332
+ capitalized = ' '.join(s[0].upper() + s[1:] if s else s for s in sentences)
333
+
334
+ # Clean spacing around punctuation
335
+ for p in [',', '.', '?', ':', '!', ';']:
336
+ capitalized = capitalized.replace(f' {p}', p)
337
+
338
+ results.append(capitalized)
339
+
340
+ return results
341
+
342
  def process_text(text: str) -> str:
343
  """Full pipeline: GEC + punctuation"""
344
  # Step 1: Grammar correction
 
361
  if not request.text.strip():
362
  raise HTTPException(status_code=400, detail="Text cannot be empty")
363
 
364
+ if len(request.text) > 100000:
365
+ raise HTTPException(status_code=400, detail="Text too long (max 100000 characters)")
366
 
367
  # Process text
368
  corrected = process_text(request.text)
 
392
  @app.post("/api/correct/batch", response_model=BatchCorrectionResponse)
393
  async def correct_batch(request: BatchCorrectionRequest):
394
  """
395
+ Correct multiple Czech texts (batched for GPU efficiency)
396
  """
397
  try:
398
  start_time = time.time()
 
401
  if not request.texts:
402
  raise HTTPException(status_code=400, detail="No texts provided")
403
 
404
+ # Validate text lengths
405
+ validated_texts = []
406
  for text in request.texts:
407
+ if len(text) > 100000:
408
+ validated_texts.append("") # Will be handled as error later
409
+ else:
410
+ validated_texts.append(text)
411
+
412
+ # Process all texts in batch (GPU efficient!)
413
+ # Step 1: Grammar correction (batched)
414
+ gec_corrected_texts = apply_gec_correction_batch(validated_texts)
415
+
416
+ # Step 2: Punctuation and capitalization (batched)
417
+ final_texts = apply_punctuation_batch(gec_corrected_texts)
418
+
419
+ # Mark texts that were too long
420
+ corrected_texts = []
421
+ for i, text in enumerate(request.texts):
422
+ if len(text) > 100000:
423
+ corrected_texts.append("[Error: Text too long]")
424
  else:
425
+ corrected_texts.append(final_texts[i])
 
426
 
427
  # Calculate processing time
428
  processing_time = (time.time() - start_time) * 1000
 
514
  "Batch processing",
515
  "Czech language focus"
516
  ],
517
+ max_input_length=100000
518
  )
519
 
520
  @app.get("/")
 
530
  if __name__ == "__main__":
531
  import uvicorn
532
  import os
533
+ port = int(os.environ.get("PORT", 8042))
534
  uvicorn.run(app, host="0.0.0.0", port=port)
api_client.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Client for Czech text correction API with local server auto-start
3
+ """
4
+
5
+ import requests
6
+ import time
7
+ import subprocess
8
+ import os
9
+ import sys
10
+ from pathlib import Path
11
+ from typing import Optional, Dict, List, Any
12
+ import logging
13
+
14
+ # Configure logging
15
+ logging.basicConfig(level=logging.INFO)
16
+ logger = logging.getLogger(__name__)
17
+
18
+ class CzechCorrectionClient:
19
+ """Client for Czech text correction with automatic local server startup"""
20
+
21
+ # Local API endpoint only
22
+ LOCAL_ENDPOINT = {
23
+ "name": "Local",
24
+ "base_url": "http://localhost:8042",
25
+ "timeout": 3600 # 1 hour for local (grammar correction can be slow)
26
+ }
27
+
28
+ def __init__(self, prefer_local: bool = True):
29
+ """
30
+ Initialize the client
31
+
32
+ Args:
33
+ prefer_local: Deprecated, always uses local API now
34
+ """
35
+ self.endpoint = self.LOCAL_ENDPOINT
36
+ self._working_endpoint = None
37
+ self._last_health_check = 0
38
+ self.health_check_interval = 3600 # Cache endpoint for 1 hour
39
+ self._server_process = None
40
+
41
+ def _check_endpoint_health(self, endpoint: Dict) -> bool:
42
+ """Check if an endpoint is healthy"""
43
+ try:
44
+ response = requests.get(
45
+ f"{endpoint['base_url']}/api/health",
46
+ timeout=10 # Increased timeout for health check
47
+ )
48
+ if response.status_code == 200:
49
+ data = response.json()
50
+ return data.get('status') == 'healthy'
51
+ except Exception as e:
52
+ logger.debug(f"Health check failed for {endpoint['name']}: {e}")
53
+ return False
54
+
55
+ def _is_port_in_use(self, port: int) -> bool:
56
+ """Check if a port is already in use"""
57
+ import socket
58
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
59
+ try:
60
+ s.bind(('localhost', port))
61
+ return False
62
+ except OSError:
63
+ return True
64
+
65
+ def _start_local_server(self) -> bool:
66
+ """Start the local API server if not already running"""
67
+ try:
68
+ # Check if port 8042 is already in use
69
+ if self._is_port_in_use(8042):
70
+ logger.warning("Port 8042 is already in use - server may already be running")
71
+ # Wait a bit and check health again
72
+ time.sleep(2)
73
+ if self._check_endpoint_health(self.endpoint):
74
+ logger.info("✅ Server is already running on port 8042")
75
+ return True
76
+ else:
77
+ logger.error("Port 8042 is in use but server is not responding to health checks")
78
+ return False
79
+
80
+ # Find the api_service directory
81
+ current_file = Path(__file__).resolve()
82
+ api_service_dir = current_file.parent
83
+ api_script = api_service_dir / "api.py"
84
+
85
+ if not api_script.exists():
86
+ logger.error(f"API script not found at {api_script}")
87
+ return False
88
+
89
+ logger.info("Starting local API server...")
90
+ logger.info("This may take 1-2 minutes to load models...")
91
+
92
+ # Start the server in the background
93
+ env = os.environ.copy()
94
+ env['PORT'] = '8042' # Set port to 8042
95
+
96
+ self._server_process = subprocess.Popen(
97
+ [sys.executable, str(api_script)],
98
+ cwd=str(api_service_dir),
99
+ env=env,
100
+ stdout=subprocess.PIPE,
101
+ stderr=subprocess.PIPE,
102
+ start_new_session=True
103
+ )
104
+
105
+ # Wait for server to be ready (up to 2 minutes)
106
+ max_wait = 120
107
+ start_time = time.time()
108
+
109
+ while time.time() - start_time < max_wait:
110
+ if self._check_endpoint_health(self.endpoint):
111
+ logger.info("✅ Local API server started successfully")
112
+ return True
113
+ time.sleep(2)
114
+
115
+ logger.error("Server failed to start within timeout")
116
+ return False
117
+
118
+ except Exception as e:
119
+ logger.error(f"Failed to start local server: {e}")
120
+ return False
121
+
122
+ def _get_working_endpoint(self) -> Optional[Dict]:
123
+ """Get working endpoint, starting server if needed"""
124
+ current_time = time.time()
125
+
126
+ # Use cached endpoint if still valid
127
+ if self._working_endpoint and (current_time - self._last_health_check < self.health_check_interval):
128
+ return self._working_endpoint
129
+
130
+ # Check if local server is running
131
+ if self._check_endpoint_health(self.endpoint):
132
+ logger.info(f"Using {self.endpoint['name']} API endpoint")
133
+ self._working_endpoint = self.endpoint
134
+ self._last_health_check = current_time
135
+ return self.endpoint
136
+
137
+ # Try to start the server
138
+ logger.info("Local API server not running, attempting to start...")
139
+ if self._start_local_server():
140
+ self._working_endpoint = self.endpoint
141
+ self._last_health_check = current_time
142
+ return self.endpoint
143
+
144
+ logger.error("Could not start or connect to local API server")
145
+ return None
146
+
147
+ def correct_text(self, text: str, include_timing: bool = False) -> Dict[str, Any]:
148
+ """
149
+ Correct Czech text (grammar and punctuation)
150
+
151
+ Args:
152
+ text: Text to correct
153
+ include_timing: Whether to include processing time in response
154
+
155
+ Returns:
156
+ Dict with 'success', 'corrected_text', and optionally 'processing_time_ms'
157
+ """
158
+ if not text or not text.strip():
159
+ return {
160
+ "success": True,
161
+ "corrected_text": text,
162
+ "error": None
163
+ }
164
+
165
+ endpoint = self._get_working_endpoint()
166
+ if not endpoint:
167
+ return {
168
+ "success": False,
169
+ "corrected_text": text,
170
+ "error": "Could not start or connect to local API server"
171
+ }
172
+
173
+ try:
174
+ payload = {
175
+ "text": text,
176
+ "options": {"include_timing": include_timing}
177
+ }
178
+
179
+ response = requests.post(
180
+ f"{endpoint['base_url']}/api/correct",
181
+ json=payload,
182
+ timeout=endpoint['timeout']
183
+ )
184
+
185
+ if response.status_code == 200:
186
+ return response.json()
187
+ else:
188
+ return {
189
+ "success": False,
190
+ "corrected_text": text,
191
+ "error": f"API error: {response.status_code}"
192
+ }
193
+
194
+ except requests.exceptions.Timeout:
195
+ logger.warning(f"Timeout on {endpoint['name']} API")
196
+ return {
197
+ "success": False,
198
+ "corrected_text": text,
199
+ "error": "Request timeout"
200
+ }
201
+
202
+ except Exception as e:
203
+ logger.error(f"Error calling API: {e}")
204
+ return {
205
+ "success": False,
206
+ "corrected_text": text,
207
+ "error": str(e)
208
+ }
209
+
210
+ def correct_batch(self, texts: List[str], include_timing: bool = False) -> Dict[str, Any]:
211
+ """
212
+ Correct multiple Czech texts in batch
213
+
214
+ Args:
215
+ texts: List of texts to correct (max 10)
216
+ include_timing: Whether to include processing time
217
+
218
+ Returns:
219
+ Dict with 'success', 'corrected_texts', and optionally 'processing_time_ms'
220
+ """
221
+ if not texts:
222
+ return {
223
+ "success": True,
224
+ "corrected_texts": [],
225
+ "error": None
226
+ }
227
+
228
+ if len(texts) > 10:
229
+ return {
230
+ "success": False,
231
+ "corrected_texts": texts,
232
+ "error": "Batch size exceeds limit (10)"
233
+ }
234
+
235
+ endpoint = self._get_working_endpoint()
236
+ if not endpoint:
237
+ return {
238
+ "success": False,
239
+ "corrected_texts": texts,
240
+ "error": "Could not start or connect to local API server"
241
+ }
242
+
243
+ try:
244
+ payload = {
245
+ "texts": texts,
246
+ "options": {"include_timing": include_timing}
247
+ }
248
+
249
+ response = requests.post(
250
+ f"{endpoint['base_url']}/api/correct/batch",
251
+ json=payload,
252
+ timeout=endpoint['timeout'] * 2 # Longer timeout for batch
253
+ )
254
+
255
+ if response.status_code == 200:
256
+ return response.json()
257
+ else:
258
+ # Fallback to individual corrections
259
+ logger.warning(f"Batch API failed, falling back to individual corrections")
260
+ corrected_texts = []
261
+ for text in texts:
262
+ result = self.correct_text(text, include_timing=False)
263
+ corrected_texts.append(result.get('corrected_text', text))
264
+
265
+ return {
266
+ "success": True,
267
+ "corrected_texts": corrected_texts,
268
+ "error": None
269
+ }
270
+
271
+ except Exception as e:
272
+ logger.error(f"Error calling batch API: {e}")
273
+ # Fallback to individual corrections
274
+ corrected_texts = []
275
+ for text in texts:
276
+ result = self.correct_text(text, include_timing=False)
277
+ corrected_texts.append(result.get('corrected_text', text))
278
+
279
+ return {
280
+ "success": True,
281
+ "corrected_texts": corrected_texts,
282
+ "error": None
283
+ }
284
+
285
+
286
+ # Convenience functions for backward compatibility
287
+ _default_client = None
288
+
289
+ def get_client(prefer_local: bool = True) -> CzechCorrectionClient:
290
+ """Get or create the default client (always uses local now)"""
291
+ global _default_client
292
+ if _default_client is None:
293
+ _default_client = CzechCorrectionClient(prefer_local=True)
294
+ return _default_client
295
+
296
+ def correct_text(text: str, prefer_local: bool = True) -> str:
297
+ """Simple function for text correction (always uses local now)"""
298
+ client = get_client(prefer_local=True)
299
+ result = client.correct_text(text)
300
+ if result['success']:
301
+ return result['corrected_text']
302
+ return text
303
+
304
+ def correct_batch(texts: List[str], prefer_local: bool = True) -> List[str]:
305
+ """Simple function for batch correction (always uses local now)"""
306
+ client = get_client(prefer_local=True)
307
+ result = client.correct_batch(texts)
308
+ if result['success']:
309
+ return result.get('corrected_texts', texts)
310
+ return texts