Pixcribe / ocr_engine_manager.py
DawnC's picture
Upload 22 files
6a3bd1f verified
import torch
import easyocr
import numpy as np
import cv2
from PIL import Image
from typing import List, Dict
import re
class OCREngineManager:
"""Text extraction using EasyOCR with brand-optimized preprocessing"""
def __init__(self):
print("Loading EasyOCR (English + Traditional Chinese)...")
# Try GPU first, fallback to CPU if GPU fails
try:
if torch.cuda.is_available():
print(" Attempting GPU initialization...")
self.reader = easyocr.Reader(['en', 'ch_tra'], gpu=True)
print(" βœ“ EasyOCR loaded with GPU")
else:
print(" CUDA not available, using CPU...")
self.reader = easyocr.Reader(['en', 'ch_tra'], gpu=False)
print(" βœ“ EasyOCR loaded with CPU")
except Exception as e:
print(f" ⚠️ GPU initialization failed: {e}")
print(" Falling back to CPU...")
self.reader = easyocr.Reader(['en', 'ch_tra'], gpu=False)
print(" βœ“ EasyOCR loaded with CPU (fallback)")
print("βœ“ EasyOCR loaded")
def extract_text(self, image: Image.Image, use_brand_preprocessing: bool = False) -> List[Dict]:
"""Extract text from image with optional brand-optimized preprocessing"""
if use_brand_preprocessing:
# Apply brand-optimized preprocessing
processed_image = self.preprocess_for_brand_ocr(image)
img_array = np.array(processed_image)
else:
img_array = np.array(image)
# Use more aggressive settings for brand detection
if use_brand_preprocessing:
results = self.reader.readtext(
img_array,
detail=1,
paragraph=False,
min_size=10, # Lower to catch small brand text
text_threshold=0.5, # Lower threshold for brand logos
link_threshold=0.3,
contrast_ths=0.1, # Lower to handle metallic/reflective text
adjust_contrast=0.8 # Enhance contrast for logos
)
else:
results = self.reader.readtext(
img_array,
detail=1,
paragraph=False,
min_size=20,
text_threshold=0.7,
link_threshold=0.4
)
structured_results = []
for bbox, text, confidence in results:
structured_results.append({
'bbox': bbox,
'text': self.clean_and_normalize(text),
'confidence': confidence,
'raw_text': text
})
return structured_results
def clean_and_normalize(self, text: str) -> str:
"""Clean and normalize text"""
# Keep Traditional Chinese characters
text = re.sub(r'[^\w\s\u4e00-\u9fff]', '', text)
text = ' '.join(text.split())
return text.upper()
def preprocess_for_brand_ocr(self, image_region: Image.Image) -> Image.Image:
"""
Preprocess image for brand OCR recognition
Optimizes for detecting brand logos and text on products (especially metallic logos)
Args:
image_region: PIL Image (typically a cropped region)
Returns:
Preprocessed PIL Image
"""
# Convert to numpy array
img_array = np.array(image_region)
# Convert to grayscale
if len(img_array.shape) == 3:
gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
else:
gray = img_array
# Apply CLAHE (Contrast Limited Adaptive Histogram Equalization)
# Increased clipLimit for metallic logos (2.0 β†’ 3.0)
clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8))
enhanced = clahe.apply(gray)
# Denoise (slightly reduced strength to preserve logo edges)
denoised = cv2.fastNlMeansDenoising(enhanced, None, h=8, templateWindowSize=7, searchWindowSize=21)
# Adaptive thresholding to handle varying lighting
# Adjusted blockSize for better logo detection (11 β†’ 15)
binary = cv2.adaptiveThreshold(
denoised, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
cv2.THRESH_BINARY, 15, 2
)
# Morphological operations to connect broken characters
# Slightly larger kernel for logo text (2x2 β†’ 3x3)
kernel = np.ones((3, 3), np.uint8)
morph = cv2.morphologyEx(binary, cv2.MORPH_CLOSE, kernel)
# Sharpen to enhance edges (increased center weight 9 β†’ 11)
kernel_sharp = np.array([[-1, -1, -1], [-1, 11, -1], [-1, -1, -1]])
sharpened = cv2.filter2D(morph, -1, kernel_sharp)
# Convert back to PIL Image
return Image.fromarray(sharpened)
print("βœ“ OCREngineManager (with brand OCR preprocessing) defined")