|
|
import time |
|
|
import json |
|
|
import csv |
|
|
import zipfile |
|
|
from io import BytesIO |
|
|
from typing import List, Dict, Optional, Callable |
|
|
from PIL import Image |
|
|
import traceback |
|
|
|
|
|
class BatchProcessingManager: |
|
|
""" |
|
|
Manages batch processing of multiple images with progress tracking, |
|
|
error handling, and result export functionality. |
|
|
|
|
|
Follows the Facade pattern by delegating actual image processing |
|
|
to the PixcribePipeline instance. |
|
|
""" |
|
|
|
|
|
def __init__(self, pipeline=None): |
|
|
""" |
|
|
Initialize the Batch Processing Manager. |
|
|
|
|
|
Args: |
|
|
pipeline: Reference to PixcribePipeline instance for processing images |
|
|
""" |
|
|
self.pipeline = pipeline |
|
|
self.results = {} |
|
|
self.timing_data = [] |
|
|
|
|
|
def process_batch( |
|
|
self, |
|
|
images: List[Image.Image], |
|
|
platform: str = 'instagram', |
|
|
yolo_variant: str = 'l', |
|
|
language: str = 'zh', |
|
|
progress_callback: Optional[Callable] = None |
|
|
) -> Dict: |
|
|
""" |
|
|
Process a batch of images with progress tracking. |
|
|
|
|
|
Args: |
|
|
images: List of PIL Image objects to process (max 10) |
|
|
platform: Target social media platform |
|
|
yolo_variant: YOLO model variant ('m', 'l', 'x') |
|
|
language: Caption language ('zh', 'en') |
|
|
progress_callback: Optional callback function for progress updates |
|
|
|
|
|
Returns: |
|
|
Dictionary containing batch processing summary and results |
|
|
|
|
|
Raises: |
|
|
ValueError: If images list is empty or exceeds 10 images |
|
|
""" |
|
|
|
|
|
if not images: |
|
|
raise ValueError("Images list cannot be empty") |
|
|
|
|
|
if len(images) > 10: |
|
|
raise ValueError("Maximum 10 images allowed per batch") |
|
|
|
|
|
|
|
|
self.results = {} |
|
|
self.timing_data = [] |
|
|
total_images = len(images) |
|
|
|
|
|
|
|
|
batch_start_time = time.time() |
|
|
|
|
|
print(f"\n{'='*60}") |
|
|
print(f"Starting batch processing: {total_images} images") |
|
|
print(f"Platform: {platform} | Variant: {yolo_variant} | Language: {language}") |
|
|
print(f"{'='*60}\n") |
|
|
|
|
|
|
|
|
for idx, image in enumerate(images): |
|
|
image_start_time = time.time() |
|
|
image_index = idx + 1 |
|
|
|
|
|
try: |
|
|
print(f"[{image_index}/{total_images}] Processing image {image_index}...") |
|
|
|
|
|
|
|
|
result = self.pipeline.process_image( |
|
|
image=image, |
|
|
platform=platform, |
|
|
yolo_variant=yolo_variant, |
|
|
language=language |
|
|
) |
|
|
|
|
|
|
|
|
self.results[image_index] = { |
|
|
'status': 'success', |
|
|
'result': result, |
|
|
'image_index': image_index, |
|
|
'error': None |
|
|
} |
|
|
|
|
|
print(f"β Image {image_index} processed successfully") |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
error_trace = traceback.format_exc() |
|
|
self.results[image_index] = { |
|
|
'status': 'failed', |
|
|
'result': None, |
|
|
'image_index': image_index, |
|
|
'error': { |
|
|
'type': type(e).__name__, |
|
|
'message': str(e), |
|
|
'traceback': error_trace |
|
|
} |
|
|
} |
|
|
|
|
|
print(f"β Image {image_index} failed: {str(e)}") |
|
|
|
|
|
|
|
|
image_elapsed = time.time() - image_start_time |
|
|
self.timing_data.append(image_elapsed) |
|
|
|
|
|
|
|
|
completed = image_index |
|
|
percent = (completed / total_images) * 100 |
|
|
|
|
|
|
|
|
avg_time = sum(self.timing_data) / len(self.timing_data) |
|
|
remaining_images = total_images - completed |
|
|
estimated_remaining = avg_time * remaining_images |
|
|
|
|
|
|
|
|
if progress_callback: |
|
|
progress_info = { |
|
|
'current': completed, |
|
|
'total': total_images, |
|
|
'percent': percent, |
|
|
'estimated_remaining': estimated_remaining, |
|
|
'latest_result': self.results[image_index], |
|
|
'image_index': image_index |
|
|
} |
|
|
progress_callback(progress_info) |
|
|
|
|
|
|
|
|
batch_elapsed = time.time() - batch_start_time |
|
|
total_processed = len(self.results) |
|
|
total_failed = sum(1 for r in self.results.values() if r['status'] == 'failed') |
|
|
total_success = total_processed - total_failed |
|
|
|
|
|
print(f"\n{'='*60}") |
|
|
print(f"Batch processing completed!") |
|
|
print(f"Total: {total_processed} | Success: {total_success} | Failed: {total_failed}") |
|
|
print(f"Total time: {batch_elapsed:.2f}s | Avg per image: {batch_elapsed/total_processed:.2f}s") |
|
|
print(f"{'='*60}\n") |
|
|
|
|
|
|
|
|
return { |
|
|
'results': self.results, |
|
|
'total_processed': total_processed, |
|
|
'total_success': total_success, |
|
|
'total_failed': total_failed, |
|
|
'total_time': batch_elapsed, |
|
|
'average_time_per_image': batch_elapsed / total_processed if total_processed > 0 else 0 |
|
|
} |
|
|
|
|
|
def get_result(self, image_index: int) -> Optional[Dict]: |
|
|
""" |
|
|
Get processing result for a specific image. |
|
|
|
|
|
Args: |
|
|
image_index: Index of the image (1-based) |
|
|
|
|
|
Returns: |
|
|
Result dictionary or None if index doesn't exist |
|
|
""" |
|
|
return self.results.get(image_index) |
|
|
|
|
|
def get_all_results(self) -> Dict: |
|
|
""" |
|
|
Get all processing results. |
|
|
|
|
|
Returns: |
|
|
Complete results dictionary |
|
|
""" |
|
|
return self.results |
|
|
|
|
|
def clear_results(self): |
|
|
"""Clear all stored results to free memory.""" |
|
|
self.results = {} |
|
|
self.timing_data = [] |
|
|
print("β Batch results cleared") |
|
|
|
|
|
def export_to_json(self, results: Dict, output_path: str) -> str: |
|
|
""" |
|
|
Export batch results to JSON format. |
|
|
|
|
|
Args: |
|
|
results: Results dictionary from process_batch |
|
|
output_path: Path to save JSON file |
|
|
|
|
|
Returns: |
|
|
Path to the saved JSON file |
|
|
""" |
|
|
|
|
|
export_data = { |
|
|
'batch_summary': { |
|
|
'total_processed': results.get('total_processed', 0), |
|
|
'total_success': results.get('total_success', 0), |
|
|
'total_failed': results.get('total_failed', 0), |
|
|
'total_time': results.get('total_time', 0), |
|
|
'average_time_per_image': results.get('average_time_per_image', 0) |
|
|
}, |
|
|
'images': [] |
|
|
} |
|
|
|
|
|
|
|
|
for img_idx, img_result in results.get('results', {}).items(): |
|
|
if img_result['status'] == 'success': |
|
|
result_data = img_result['result'] |
|
|
image_export = { |
|
|
'image_index': img_idx, |
|
|
'status': 'success', |
|
|
'captions': result_data.get('captions', []), |
|
|
'detected_objects': [ |
|
|
det['class_name'] for det in result_data.get('detections', []) |
|
|
], |
|
|
'detected_brands': [ |
|
|
brand[0] if isinstance(brand, tuple) else brand |
|
|
for brand in result_data.get('brands', []) |
|
|
], |
|
|
'scene_info': result_data.get('scene', {}), |
|
|
'lighting': result_data.get('lighting', {}) |
|
|
} |
|
|
else: |
|
|
image_export = { |
|
|
'image_index': img_idx, |
|
|
'status': 'failed', |
|
|
'error': img_result.get('error', {}) |
|
|
} |
|
|
|
|
|
export_data['images'].append(image_export) |
|
|
|
|
|
|
|
|
with open(output_path, 'w', encoding='utf-8') as f: |
|
|
json.dump(export_data, f, ensure_ascii=False, indent=2) |
|
|
|
|
|
print(f"β Batch results exported to JSON: {output_path}") |
|
|
return output_path |
|
|
|
|
|
def export_to_csv(self, results: Dict, output_path: str) -> str: |
|
|
""" |
|
|
Export batch results to CSV format. |
|
|
|
|
|
Args: |
|
|
results: Results dictionary from process_batch |
|
|
output_path: Path to save CSV file |
|
|
|
|
|
Returns: |
|
|
Path to the saved CSV file |
|
|
""" |
|
|
|
|
|
headers = [ |
|
|
'image_index', |
|
|
'status', |
|
|
'caption_professional', |
|
|
'caption_creative', |
|
|
'caption_authentic', |
|
|
'detected_objects', |
|
|
'detected_brands', |
|
|
'hashtags' |
|
|
] |
|
|
|
|
|
|
|
|
rows = [] |
|
|
for img_idx, img_result in results.get('results', {}).items(): |
|
|
if img_result['status'] == 'success': |
|
|
result_data = img_result['result'] |
|
|
captions = result_data.get('captions', []) |
|
|
|
|
|
|
|
|
caption_professional = '' |
|
|
caption_creative = '' |
|
|
caption_authentic = '' |
|
|
all_hashtags = [] |
|
|
|
|
|
for cap in captions: |
|
|
tone = cap.get('tone', '').lower() |
|
|
caption_text = cap.get('caption', '') |
|
|
hashtags = cap.get('hashtags', []) |
|
|
|
|
|
if 'professional' in tone: |
|
|
caption_professional = caption_text |
|
|
elif 'creative' in tone: |
|
|
caption_creative = caption_text |
|
|
elif 'authentic' in tone or 'casual' in tone: |
|
|
caption_authentic = caption_text |
|
|
|
|
|
all_hashtags.extend(hashtags) |
|
|
|
|
|
|
|
|
all_hashtags = list(set(all_hashtags)) |
|
|
|
|
|
row = { |
|
|
'image_index': img_idx, |
|
|
'status': 'success', |
|
|
'caption_professional': caption_professional, |
|
|
'caption_creative': caption_creative, |
|
|
'caption_authentic': caption_authentic, |
|
|
'detected_objects': ', '.join([ |
|
|
det['class_name'] for det in result_data.get('detections', []) |
|
|
]), |
|
|
'detected_brands': ', '.join([ |
|
|
brand[0] if isinstance(brand, tuple) else brand |
|
|
for brand in result_data.get('brands', []) |
|
|
]), |
|
|
'hashtags': ' '.join([f'#{tag}' for tag in all_hashtags]) |
|
|
} |
|
|
else: |
|
|
row = { |
|
|
'image_index': img_idx, |
|
|
'status': 'failed', |
|
|
'caption_professional': '', |
|
|
'caption_creative': '', |
|
|
'caption_authentic': '', |
|
|
'detected_objects': '', |
|
|
'detected_brands': '', |
|
|
'hashtags': '' |
|
|
} |
|
|
|
|
|
rows.append(row) |
|
|
|
|
|
|
|
|
with open(output_path, 'w', newline='', encoding='utf-8') as f: |
|
|
writer = csv.DictWriter(f, fieldnames=headers) |
|
|
writer.writeheader() |
|
|
writer.writerows(rows) |
|
|
|
|
|
print(f"β Batch results exported to CSV: {output_path}") |
|
|
return output_path |
|
|
|
|
|
def export_to_zip(self, results: Dict, images: List[Image.Image], output_path: str) -> str: |
|
|
""" |
|
|
Export batch results to ZIP archive with images and text files. |
|
|
|
|
|
Args: |
|
|
results: Results dictionary from process_batch |
|
|
images: List of original PIL Image objects |
|
|
output_path: Path to save ZIP file |
|
|
|
|
|
Returns: |
|
|
Path to the saved ZIP file |
|
|
""" |
|
|
with zipfile.ZipFile(output_path, 'w', zipfile.ZIP_DEFLATED) as zipf: |
|
|
for img_idx, img_result in results.get('results', {}).items(): |
|
|
if img_result['status'] == 'success': |
|
|
|
|
|
image_filename = f"image_{img_idx:03d}.jpg" |
|
|
|
|
|
|
|
|
img_buffer = BytesIO() |
|
|
images[img_idx - 1].save(img_buffer, format='JPEG', quality=95) |
|
|
img_buffer.seek(0) |
|
|
|
|
|
zipf.writestr(image_filename, img_buffer.read()) |
|
|
|
|
|
|
|
|
text_filename = f"image_{img_idx:03d}.txt" |
|
|
text_content = self._format_result_as_text(img_result['result']) |
|
|
zipf.writestr(text_filename, text_content) |
|
|
|
|
|
print(f"β Added to ZIP: {image_filename} and {text_filename}") |
|
|
|
|
|
print(f"β Batch results exported to ZIP: {output_path}") |
|
|
return output_path |
|
|
|
|
|
def _format_result_as_text(self, result: Dict) -> str: |
|
|
""" |
|
|
Format a single image result as plain text for ZIP export. |
|
|
|
|
|
Args: |
|
|
result: Single image processing result dictionary |
|
|
|
|
|
Returns: |
|
|
Formatted text string |
|
|
""" |
|
|
lines = [] |
|
|
lines.append("=" * 60) |
|
|
lines.append("PIXCRIBE - AI GENERATED SOCIAL MEDIA CONTENT") |
|
|
lines.append("=" * 60) |
|
|
lines.append("") |
|
|
|
|
|
|
|
|
captions = result.get('captions', []) |
|
|
for i, cap in enumerate(captions, 1): |
|
|
tone = cap.get('tone', 'Unknown').upper() |
|
|
caption_text = cap.get('caption', '') |
|
|
hashtags = cap.get('hashtags', []) |
|
|
|
|
|
lines.append(f"CAPTION {i} - {tone} STYLE") |
|
|
lines.append("-" * 60) |
|
|
lines.append(caption_text) |
|
|
lines.append("") |
|
|
lines.append("Hashtags:") |
|
|
lines.append(' '.join([f'#{tag}' for tag in hashtags])) |
|
|
lines.append("") |
|
|
lines.append("") |
|
|
|
|
|
|
|
|
detections = result.get('detections', []) |
|
|
if detections: |
|
|
lines.append("DETECTED OBJECTS") |
|
|
lines.append("-" * 60) |
|
|
object_names = [det['class_name'] for det in detections] |
|
|
lines.append(', '.join(object_names)) |
|
|
lines.append("") |
|
|
|
|
|
|
|
|
brands = result.get('brands', []) |
|
|
if brands: |
|
|
lines.append("DETECTED BRANDS") |
|
|
lines.append("-" * 60) |
|
|
brand_names = [ |
|
|
brand[0] if isinstance(brand, tuple) else brand |
|
|
for brand in brands |
|
|
] |
|
|
lines.append(', '.join(brand_names)) |
|
|
lines.append("") |
|
|
|
|
|
|
|
|
scene_info = result.get('scene', {}) |
|
|
if scene_info: |
|
|
lines.append("SCENE ANALYSIS") |
|
|
lines.append("-" * 60) |
|
|
|
|
|
if 'lighting' in scene_info: |
|
|
lighting = scene_info['lighting'].get('top', 'Unknown') |
|
|
lines.append(f"Lighting: {lighting}") |
|
|
|
|
|
if 'mood' in scene_info: |
|
|
mood = scene_info['mood'].get('top', 'Unknown') |
|
|
lines.append(f"Mood: {mood}") |
|
|
|
|
|
lines.append("") |
|
|
|
|
|
lines.append("=" * 60) |
|
|
lines.append("Generated by Pixcribe V5 - AI Social Media Caption Generator") |
|
|
lines.append("=" * 60) |
|
|
|
|
|
return '\n'.join(lines) |
|
|
|
|
|
|
|
|
print("β BatchProcessingManager defined") |
|
|
|