Pixcribe / yolo_detection_manager.py
DawnC's picture
Upload 22 files
6a3bd1f verified
raw
history blame
2.01 kB
from ultralytics import YOLO
import numpy as np
from typing import List, Dict
from PIL import Image
class YOLODetectionManager:
"""Object detection using YOLOv11"""
def __init__(self, variant='m'):
print(f"Loading YOLOv11{variant} model...")
self.model = YOLO(f'yolo11{variant}.pt')
self.variant = variant
self.conf_threshold = 0.25
self.iou_threshold = 0.45
self.max_detections = 100
# Brand-relevant classes
self.brand_relevant_classes = [
'handbag', 'bottle', 'cell phone', 'laptop',
'backpack', 'tie', 'suitcase', 'cup', 'watch',
'shoe', 'sneaker', 'boot'
]
print(f"✓ YOLOv11{variant} loaded")
def detect(self, image: np.ndarray) -> List[Dict]:
"""Detect objects in image"""
results = self.model.predict(
image,
conf=self.conf_threshold,
iou=self.iou_threshold,
max_det=self.max_detections,
verbose=False
)
detections = []
for result in results:
boxes = result.boxes
for box in boxes:
class_id = int(box.cls[0])
class_name = result.names[class_id]
bbox = box.xyxy[0].cpu().numpy().tolist()
confidence = float(box.conf[0])
detection = {
'class_id': class_id,
'class_name': class_name,
'bbox': bbox,
'confidence': confidence,
'is_brand_relevant': class_name.lower() in self.brand_relevant_classes,
'source': 'yolo'
}
detections.append(detection)
return detections
def filter_brand_relevant_objects(self, detections: List[Dict]) -> List[Dict]:
"""Filter brand-relevant objects"""
return [det for det in detections if det['is_brand_relevant']]
print("✓ YOLODetectionManager defined")