Pixcribe / saliency_detection_manager.py
DawnC's picture
Upload 22 files
6a3bd1f verified
raw
history blame
3.52 kB
import torch
import numpy as np
from PIL import Image
import cv2
from typing import List, Dict
import torchvision.transforms as transforms
class SaliencyDetectionManager:
"""Visual saliency detection using U2-Net"""
def __init__(self):
print("Loading U2-Net model...")
try:
from torchvision.models.segmentation import deeplabv3_resnet50
self.model = deeplabv3_resnet50(pretrained=True)
self.model.eval()
if torch.cuda.is_available():
self.model = self.model.cuda()
except Exception as e:
print(f"Warning: Cannot load deep learning model, using fallback: {e}")
self.model = None
self.threshold = 0.5
self.min_area = 1600
self.min_saliency = 0.6
self.transform = transforms.Compose([
transforms.Resize((320, 320)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
print("✓ SaliencyDetectionManager initialized")
def detect_salient_regions(self, image: Image.Image) -> List[Dict]:
"""Detect salient regions"""
img_array = np.array(image)
gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
_, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
regions = []
height, width = img_array.shape[:2]
for contour in contours:
area = cv2.contourArea(contour)
if area < self.min_area:
continue
x, y, w, h = cv2.boundingRect(contour)
bbox = [float(x), float(y), float(x + w), float(y + h)]
region_img = image.crop(bbox)
regions.append({
'bbox': bbox,
'area': area,
'saliency_score': min(area / (width * height), 1.0),
'image': region_img
})
regions = sorted(regions, key=lambda x: x['saliency_score'], reverse=True)
return regions[:10]
def extract_unknown_regions(self, salient_regions: List[Dict], yolo_detections: List[Dict]) -> List[Dict]:
"""Extract salient regions not detected by YOLO"""
unknown_regions = []
for region in salient_regions:
max_iou = 0.0
for det in yolo_detections:
iou = self._calculate_iou(region['bbox'], det['bbox'])
max_iou = max(max_iou, iou)
if max_iou < 0.3:
unknown_regions.append(region)
return unknown_regions
def _calculate_iou(self, box1: List[float], box2: List[float]) -> float:
"""Calculate IoU (Intersection over Union)"""
x1_min, y1_min, x1_max, y1_max = box1
x2_min, y2_min, x2_max, y2_max = box2
inter_xmin = max(x1_min, x2_min)
inter_ymin = max(y1_min, y2_min)
inter_xmax = min(x1_max, x2_max)
inter_ymax = min(y1_max, y2_max)
if inter_xmax < inter_xmin or inter_ymax < inter_ymin:
return 0.0
inter_area = (inter_xmax - inter_xmin) * (inter_ymax - inter_ymin)
box1_area = (x1_max - x1_min) * (y1_max - y1_min)
box2_area = (x2_max - x2_min) * (y2_max - y2_min)
union_area = box1_area + box2_area - inter_area
return inter_area / union_area if union_area > 0 else 0.0
print("✓ SaliencyDetectionManager defined")