Spaces:
Sleeping
Sleeping
| from paddleocr import PaddleOCR | |
| from gliner import GLiNER | |
| from PIL import Image | |
| import gradio as gr | |
| import numpy as np | |
| import cv2 | |
| import logging | |
| import os | |
| import tempfile | |
| import pandas as pd | |
| import re | |
| import traceback | |
| import zxingcpp # QR decoding | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Environment setup | |
| os.environ['GLINER_HOME'] = './gliner_models' | |
| # Load GLiNER model | |
| try: | |
| logger.info("Loading GLiNER model...") | |
| gliner_model = GLiNER.from_pretrained("urchade/gliner_large-v2.1") | |
| except Exception: | |
| logger.exception("Failed to load GLiNER model") | |
| raise | |
| # Regex patterns | |
| EMAIL_REGEX = re.compile(r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}\b") | |
| WEBSITE_REGEX = re.compile(r"(?:https?://)?(?:www\.)?([A-Za-z0-9-]+\.[A-Za-z]{2,})") | |
| # UAE phone country code | |
| UAE_CODE = '+971' | |
| # Utility functions | |
| def extract_emails(text: str) -> list[str]: | |
| return [e.lower() for e in EMAIL_REGEX.findall(text)] | |
| def extract_websites(text: str) -> list[str]: | |
| return [m.lower() for m in WEBSITE_REGEX.findall(text)] | |
| def normalize_website(url: str) -> str | None: | |
| u = url.lower().replace('www.', '').split('/')[0] | |
| return f"www.{u}" if re.match(r"^[a-z0-9-]+\.[a-z]{2,}$", u) else None | |
| # Phone cleaning: treat all local '0XXXXXXXXX' as UAE | |
| def clean_phone_number(phone: str) -> str | None: | |
| cleaned = re.sub(r"\D", "", phone) | |
| # Local UAE numbers (10 digits starting with 0) | |
| if len(cleaned) == 10 and cleaned.startswith('0'): | |
| return UAE_CODE + cleaned[1:] | |
| # International UAE numbers without plus (12 digits starting '971') | |
| if len(cleaned) == 12 and cleaned.startswith('971'): | |
| return '+' + cleaned | |
| # Already plus-prefixed UAE number | |
| if phone.strip().startswith('+971') and len(cleaned) == 12: | |
| return phone.strip() | |
| return None | |
| # Extract phone numbers from text | |
| def process_phone_numbers(text: str) -> list[str]: | |
| found = [] | |
| # Match '05' followed by 8 digits or plus variant | |
| for match in re.finditer(r'(?:05\d{8}|\+?\d{8,12})', text): | |
| raw = match.group().strip() | |
| if (c := clean_phone_number(raw)): | |
| found.append(c) | |
| return list(set(found)) | |
| # Address extraction | |
| def extract_address(ocr_texts: list[str]) -> str | None: | |
| keywords = ["block","street","ave","area","industrial","road"] | |
| parts = [t for t in ocr_texts if any(kw in t.lower() for kw in keywords)] | |
| return " ".join(parts) if parts else None | |
| # QR scanning | |
| def scan_qr_code(image: Image.Image) -> str | None: | |
| try: | |
| with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp: | |
| image.save(tmp, format="PNG") | |
| path = tmp.name | |
| img_cv = cv2.imread(path) | |
| # Direct decoding | |
| try: | |
| res = zxingcpp.read_barcodes(img_cv) | |
| if res and res[0].text: | |
| return res[0].text.strip() | |
| except: | |
| logger.warning("Direct QR decode failed") | |
| # Fallback recolor | |
| default_color = (0, 0, 0) | |
| tol = 50 | |
| pix = list(image.convert('RGB').getdata()) | |
| new_pix = [default_color if all(abs(p[i]-default_color[i])<=tol for i in range(3)) else (255,255,255) for p in pix] | |
| img_conv = Image.new('RGB', image.size) | |
| img_conv.putdata(new_pix) | |
| cv2.imwrite(path + '_conv.png', cv2.cvtColor(np.array(img_conv), cv2.COLOR_RGB2BGR)) | |
| res = zxingcpp.read_barcodes(cv2.imread(path + '_conv.png')) | |
| if res and res[0].text: | |
| return res[0].text.strip() | |
| except Exception: | |
| logger.exception("QR scan error") | |
| return None | |
| # Deduplication | |
| def deduplicate_data(results: dict[str, list[str]]) -> None: | |
| def clean_list(items, normalizer=lambda x: x): | |
| seen = set(); out = [] | |
| for raw in items: | |
| for part in re.split(r'[;,]\s*', raw): | |
| p = part.strip() | |
| if not p: continue | |
| norm = normalizer(p) | |
| if norm and norm not in seen: | |
| seen.add(norm); out.append(norm) | |
| return out | |
| results['Email Address'] = clean_list(results.get('Email Address', []), lambda e: e.lower()) | |
| results['Website'] = clean_list(results.get('Website', []), normalize_website) | |
| results['Phone Number'] = clean_list(results.get('Phone Number', []), clean_phone_number) | |
| for key in ['Person Name','Company Name','Job Title','Address','QR Code']: | |
| seen = set(); out = [] | |
| for v in results.get(key, []): | |
| vv = v.strip() | |
| if vv and vv not in seen: | |
| seen.add(vv); out.append(vv) | |
| results[key] = out | |
| # Inference pipeline | |
| def inference(img: Image.Image, confidence: float): | |
| try: | |
| ocr = PaddleOCR(use_angle_cls=True, lang='en', use_gpu=False) | |
| arr = np.array(img) | |
| raw = ocr.ocr(arr, cls=True)[0] | |
| ocr_texts = [ln[1][0] for ln in raw] | |
| full_text = ' '.join(ocr_texts) | |
| labels = ['person name','company name','job title','phone number','email address','address','website'] | |
| entities = gliner_model.predict_entities(full_text, labels, threshold=confidence, flat_ner=True) | |
| results = {k: [] for k in ['Person Name','Company Name','Job Title','Phone Number','Email Address','Address','Website','QR Code']} | |
| # Process NER entities | |
| for ent in entities: | |
| txt, lbl = ent['text'].strip(), ent['label'].lower() | |
| if lbl == 'person name': results['Person Name'].append(txt) | |
| elif lbl == 'company name': results['Company Name'].append(txt) | |
| elif lbl == 'job title': results['Job Title'].append(txt.title()) | |
| elif lbl == 'phone number': | |
| if (c := clean_phone_number(txt)): results['Phone Number'].append(c) | |
| elif lbl == 'email address' and EMAIL_REGEX.fullmatch(txt): | |
| results['Email Address'].append(txt.lower()) | |
| elif lbl == 'website' and WEBSITE_REGEX.search(txt): | |
| if (n := normalize_website(txt)): results['Website'].append(n) | |
| elif lbl == 'address': results['Address'].append(txt) | |
| # Regex fallbacks | |
| results['Email Address'] += extract_emails(full_text) | |
| results['Website'] += extract_websites(full_text) | |
| results['Phone Number'] += process_phone_numbers(full_text) | |
| # QR code | |
| if qr := scan_qr_code(img): | |
| results['QR Code'].append(qr) | |
| # Address fallback | |
| if not results['Address'] and (addr := extract_address(ocr_texts)): | |
| results['Address'].append(addr) | |
| # Deduplicate all fields | |
| deduplicate_data(results) | |
| # Company fallback | |
| if not results['Company Name'] and (dom := (results['Email Address'] or results['Website'])): | |
| domain = dom[0].split('@')[-1].split('.')[0] | |
| results['Company Name'].append(domain.title()) | |
| # Name fallback | |
| if not results['Person Name']: | |
| for t in ocr_texts: | |
| if re.match(r'^(?:[A-Z][a-z]+\s?){2,}$', t): | |
| results['Person Name'].append(t) | |
| break | |
| # Prepare CSV | |
| csv_map = {k: '; '.join(v) for k, v in results.items()} | |
| with tempfile.NamedTemporaryFile(suffix='.csv', delete=False, mode='w') as f: | |
| pd.DataFrame([csv_map]).to_csv(f, index=False) | |
| csv_path = f.name | |
| return full_text, results, csv_path, '' | |
| except Exception: | |
| err = traceback.format_exc() | |
| logger.error(f"Processing failed: {err}") | |
| empty = {k: [] for k in ['Person Name','Company Name','Job Title','Phone Number','Email Address','Address','Website','QR Code']} | |
| return '', empty, None, f"Error:\n{err}" | |
| # Gradio Interface | |
| if __name__ == '__main__': | |
| demo = gr.Interface( | |
| inference, | |
| [gr.Image(type='pil', label='Upload Business Card'), | |
| gr.Slider(0.1, 1, 0.4, step=0.1, label='Confidence Threshold')], | |
| [gr.Textbox(label="OCR Result"), | |
| gr.JSON(label="Structured Data"), | |
| gr.File(label="Download CSV"), | |
| gr.Textbox(label="Error Log")], | |
| title='Enhanced Business Card Parser', | |
| description='Entity extraction with AI and regex validation (UAE-focused phone support)', | |
| css=".gr-interface {max-width: 800px !important;}" | |
| ) | |
| demo.launch() | |