Pixcribe / pixcribe_pipeline.py
DawnC's picture
Upload 5 files
f3a4ad9 verified
import sys
import time
import traceback
from PIL import Image
from typing import Dict, List, Callable, Optional
from image_processor_manager import ImageProcessorManager
from yolo_detection_manager import YOLODetectionManager
from saliency_detection_manager import SaliencyDetectionManager
from openclip_semantic_manager import OpenCLIPSemanticManager
from lighting_analysis_manager import LightingAnalysisManager
from ocr_engine_manager import OCREngineManager
from prompt_library_manager import PromptLibraryManager
from brand_recognition_manager import BrandRecognitionManager
from brand_visualization_manager import BrandVisualizationManager
from brand_verification_manager import BrandVerificationManager
from scene_compatibility_manager import SceneCompatibilityManager
from caption_generation_manager import CaptionGenerationManager
from detection_fusion_manager import DetectionFusionManager
from output_processing_manager import OutputProcessingManager
from batch_processing_manager import BatchProcessingManager
class PixcribePipeline:
"""Main Facade coordinating all components (V2 with multi-language support)"""
def __init__(self, yolo_variant='l', vlm_model_name='Qwen/Qwen2.5-VL-7B-Instruct'):
"""
Args:
yolo_variant: 'm', 'l' (default), or 'x'
vlm_model_name: Vision-Language Model name (default: Qwen2.5-VL-7B-Instruct)
Can be changed to 'Qwen/Qwen3-VL-8B-Instruct' for latest model
"""
print("="*60)
print("Initializing Pixcribe Pipeline V2...")
print("="*60)
start_time = time.time()
# Initialize all managers
self.image_processor = ImageProcessorManager()
self.yolo_detector = YOLODetectionManager(variant=yolo_variant)
self.saliency_detector = SaliencyDetectionManager()
self.clip_semantic = OpenCLIPSemanticManager()
self.lighting_analyzer = LightingAnalysisManager()
self.ocr_engine = OCREngineManager()
# NEW: Initialize PromptLibrary (centralized prompt management)
self.prompt_library = PromptLibraryManager()
# Initialize BrandRecognitionManager with PromptLibrary
self.brand_recognizer = BrandRecognitionManager(
self.clip_semantic, self.ocr_engine, self.prompt_library
)
# NEW: Brand visualization manager
self.brand_visualizer = BrandVisualizationManager()
self.caption_generator = CaptionGenerationManager(model_name=vlm_model_name)
# NEW: Brand verification with VLM
self.brand_verifier = BrandVerificationManager(self.caption_generator)
# NEW: Scene compatibility checker
self.scene_compatibility = SceneCompatibilityManager(self.prompt_library)
self.fusion_manager = DetectionFusionManager(self.clip_semantic)
# Initialize OutputProcessingManager with PromptLibrary for smart hashtag generation
self.output_processor = OutputProcessingManager(self.prompt_library)
# Initialize BatchProcessingManager with pipeline reference
self.batch_processor = BatchProcessingManager(pipeline=self)
elapsed = time.time() - start_time
print("="*60)
print(f"✓ Pipeline V5 initialized successfully with batch processing (Time: {elapsed:.2f}s)")
print("="*60)
def process_image(self, image, platform='instagram', yolo_variant='l', language='zh') -> Dict:
"""End-to-end image processing pipeline
Args:
image: PIL Image or path
platform: 'instagram', 'tiktok', or 'xiaohongshu'
yolo_variant: 'm', 'l' (default), or 'x'
language: 'zh' (Traditional Chinese), 'en' (English), 'zh-en' (Bilingual)
Returns:
Processing results dictionary with brand visualizations
"""
print(f"\nProcessing image (Platform: {platform}, Language: {language})...")
start_time = time.time()
try:
# Step 1: Preprocessing
print("[1/9] Preprocessing image...")
processed_img = self.image_processor.load_image(image)
yolo_input = self.image_processor.preprocess_for_yolo(processed_img)
# Step 2: Parallel detection
print("[2/9] YOLO object detection...")
yolo_results = self.yolo_detector.detect(yolo_input)
print(f" Detected {len(yolo_results)} objects")
print("[3/9] Saliency detection...")
salient_regions = self.saliency_detector.detect_salient_regions(processed_img)
print(f" Found {len(salient_regions)} salient regions")
# Step 3: Identify unknown objects
print("[4/9] Identifying unknown objects...")
unknown_regions = self.saliency_detector.extract_unknown_regions(
salient_regions, yolo_results
)
print(f" Found {len(unknown_regions)} unknown regions")
# Step 4: Brand recognition (with bounding boxes)
print("[5/9] Brand recognition...")
brands = []
brand_detections = [] # For visualization
# Method 1: Check YOLO-detected brand-relevant objects
brand_relevant = self.yolo_detector.filter_brand_relevant_objects(yolo_results)
if brand_relevant:
print(f" Checking {len(brand_relevant)} YOLO brand-relevant objects...")
for det in brand_relevant[:5]: # Check top 5 brand-relevant objects
region = processed_img.crop(det['bbox'])
brand_result = self.brand_recognizer.recognize_brand(
region, processed_img, region_bbox=det['bbox']
)
if brand_result:
for brand_name, confidence, bbox in brand_result[:2]: # Top 2 brands per region
brands.append((brand_name, confidence))
# Prepare for visualization
brand_info = self.prompt_library.get_brand_prompts(brand_name)
category = brand_info.get('category', 'default') if brand_info else 'default'
brand_detections.append({
'name': brand_name,
'confidence': confidence,
'bbox': bbox,
'category': category
})
# Method 2: Full-image brand scan (商業級必要功能)
# 無論 YOLO 是否檢測到相關物體,都執行全圖品牌掃描
print(" Performing intelligent full-image brand scan...")
full_image_brands = self.brand_recognizer.scan_full_image_for_brands(
processed_img,
exclude_bboxes=[bd['bbox'] for bd in brand_detections if bd.get('bbox')],
saliency_regions=salient_regions # 傳遞顯著性區域以智能選擇掃描區域
)
# 合併全圖掃描結果
if full_image_brands:
print(f" Full-image scan found {len(full_image_brands)} additional brands")
for brand_name, confidence, bbox in full_image_brands:
# 避免重複檢測同一品牌
if not any(bd['name'] == brand_name for bd in brand_detections):
brands.append((brand_name, confidence))
brand_info = self.prompt_library.get_brand_prompts(brand_name)
category = brand_info.get('category', 'default') if brand_info else 'default'
brand_detections.append({
'name': brand_name,
'confidence': confidence,
'bbox': bbox,
'category': category
})
print(f" Identified {len(brands)} brand instances (before verification)")
# Step 4.5: CLIP scene understanding (moved earlier for compatibility check)
print("[5.5/11] Scene understanding (CLIP)...")
scene_analysis = self.clip_semantic.analyze_scene(processed_img)
print(f" Scene: {scene_analysis.get('urban', {}).get('top', 'unknown')}")
# Step 4.6: Scene compatibility check
if brands:
print("[5.6/11] Checking scene compatibility...")
brands_with_bbox = [(b[0], b[1], brand_detections[i]['bbox'])
for i, b in enumerate(brands)]
compatible_brands = self.scene_compatibility.batch_check_compatibility(
brands_with_bbox, scene_analysis
)
print(f" {len(compatible_brands)} brands passed compatibility check")
# Update brands and brand_detections
if compatible_brands:
brands = [(b[0], b[1]) for b in compatible_brands]
brand_detections = []
for brand_name, confidence, bbox in compatible_brands:
brand_info = self.prompt_library.get_brand_prompts(brand_name)
category = brand_info.get('category', 'default') if brand_info else 'default'
brand_detections.append({
'name': brand_name,
'confidence': confidence,
'bbox': bbox,
'category': category
})
else:
brands = []
brand_detections = []
# Step 4.7: VLM brand verification
if brand_detections:
print("[5.7/11] VLM brand verification...")
vlm_verification = self.brand_verifier.verify_brands(
processed_img, [(bd['name'], bd['confidence'], bd['bbox'])
for bd in brand_detections]
)
print(f" VLM verified {len(vlm_verification.get('verified_brands', []))} brands")
# Three-way voting: OpenCLIP + OCR + VLM
# Collect OCR matches for voting
ocr_brands = {}
for brand_name, conf in brands:
if brand_name not in ocr_brands:
ocr_brands[brand_name] = (0.5, conf) # Approximate text/ocr split
final_brands = self.brand_verifier.three_way_voting(
[(bd['name'], bd['confidence'], bd['bbox']) for bd in brand_detections],
ocr_brands,
vlm_verification
)
print(f" Final verified brands: {len(final_brands)}")
# Update brands and brand_detections with verified results
if final_brands:
brands = [(b[0], b[1]) for b in final_brands]
brand_detections = []
for brand_name, confidence, bbox in final_brands:
brand_info = self.prompt_library.get_brand_prompts(brand_name)
category = brand_info.get('category', 'default') if brand_info else 'default'
brand_detections.append({
'name': brand_name,
'confidence': confidence,
'bbox': bbox,
'category': category
})
else:
brands = []
brand_detections = []
# NEW: Visualize brand detections on image
if brand_detections:
visualized_image = self.brand_visualizer.draw_brand_detections(
processed_img.copy(), brand_detections
)
else:
visualized_image = processed_img
# Step 6: CV-based lighting analysis
print("[7/11] Analyzing lighting conditions...")
cv_lighting = self.lighting_analyzer.analyze_lighting(processed_img)
print(f" CV Lighting: {cv_lighting['lighting_type']} (confidence: {cv_lighting['confidence']:.2f})")
print(f" Details: brightness={cv_lighting['cv_features']['brightness']:.1f}, "
f"temp_ratio={cv_lighting['cv_features']['color_temp']:.2f}, "
f"contrast={cv_lighting['cv_features']['contrast']:.1f}")
# Step 7: Additional scene analysis details
print("[8/11] Additional scene analysis...")
print(f" CLIP Lighting: {scene_analysis.get('lighting', {}).get('top', 'unknown')}")
print(f" Mood: {scene_analysis.get('mood', {}).get('top', 'unknown')}")
# Step 8: Fusion with lighting analysis
print("[9/11] Fusing detection results...")
fused_results = self.fusion_manager.fuse_detections(
yolo_results, unknown_regions, scene_analysis, processed_img, cv_lighting
)
fused_results['brands'] = brands
fused_results['scene_analysis'] = scene_analysis
# Print fused lighting result
fused_lighting = fused_results['scene_analysis']['lighting']['top']
print(f" Fused Lighting: {fused_lighting}")
# Step 9: Caption generation with language support
print("[10/11] Generating captions...")
captions = self.caption_generator.generate_captions(
fused_results, processed_img, platform, language
)
# Step 10: Output processing with smart hashtags
print("[11/11] Output processing...")
validated_captions = []
for caption in captions:
# Only generate hashtags if VLM didn't generate any
# DO NOT override VLM hashtags - they follow language requirements
if not caption.get('hashtags') or len(caption.get('hashtags', [])) < 3:
print(f" [DEBUG] Caption has {len(caption.get('hashtags', []))} hashtags, generating smart hashtags...")
caption['hashtags'] = self.output_processor.generate_smart_hashtags(
fused_results['detections'],
scene_analysis,
brands,
platform,
language
)
else:
print(f" [DEBUG] Caption has {len(caption['hashtags'])} VLM-generated hashtags")
# 傳遞完整參數給 validate_output 以啟用標籤自動補充
is_valid, msg = self.output_processor.validate_output(
caption, platform,
detections=fused_results['detections'],
scene_info=scene_analysis,
brands=brands,
language=language
)
if is_valid:
validated_captions.append(caption)
else:
print(f" [DEBUG] Caption validation failed: {msg}")
elapsed = time.time() - start_time
print(f"\n✓ Processing complete (Total time: {elapsed:.2f}s)")
print(f" Generated {len(validated_captions)} caption variations")
return {
'captions': validated_captions,
'detections': fused_results['detections'],
'brands': brands,
'brand_detections': brand_detections, # NEW: For UI display
'visualized_image': visualized_image, # NEW: Image with brand boxes
'scene': scene_analysis,
'composition': fused_results.get('composition', {}),
'lighting': cv_lighting,
'processing_time': elapsed
}
except Exception as e:
print(f"\n✗ Processing error: {str(e)}")
traceback.print_exc()
# Re-raise exception so it can be caught and displayed
raise
def process_batch(
self,
images: List[Image.Image],
platform: str = 'instagram',
yolo_variant: str = 'l',
language: str = 'zh',
progress_callback: Optional[Callable] = None
) -> Dict:
"""
Process multiple images in batch with progress tracking.
This method provides a Facade interface to the BatchProcessingManager,
allowing batch processing through the main Pipeline API.
Args:
images: List of PIL Image objects to process (max 10)
platform: Target social media platform ('instagram', 'tiktok', 'xiaohongshu')
yolo_variant: YOLO model variant ('m', 'l', 'x')
language: Caption language ('zh' for Traditional Chinese, 'en' for English)
progress_callback: Optional callback function for progress updates
Returns:
Dictionary containing:
- results: Dict mapping image index to processing results
- total_processed: Total number of images processed
- total_success: Number of successfully processed images
- total_failed: Number of failed images
- total_time: Total processing time in seconds
- average_time_per_image: Average time per image in seconds
Raises:
ValueError: If images list is empty or exceeds 10 images
Example:
>>> images = [Image.open(f'image{i}.jpg') for i in range(1, 6)]
>>> results = pipeline.process_batch(images, platform='instagram')
>>> print(f"Processed {results['total_success']}/{results['total_processed']} images")
"""
return self.batch_processor.process_batch(
images=images,
platform=platform,
yolo_variant=yolo_variant,
language=language,
progress_callback=progress_callback
)
print("✓ PixcribePipeline V5 (with Batch Processing) defined")