ScoreVision11 / miner.py
gloriforge's picture
Upload folder using huggingface_hub
f429f04 verified
from pathlib import Path
from typing import List, Tuple, Dict
import sys
import os
from numpy import ndarray
from pydantic import BaseModel
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from ultralytics import YOLO
from team_cluster import TeamClassifier
from utils import (
BoundingBox,
Constants,
)
import time
import torch
import gc
from pitch import process_batch_input, get_cls_net
import yaml
class BoundingBox(BaseModel):
x1: int
y1: int
x2: int
y2: int
cls_id: int
conf: float
class TVFrameResult(BaseModel):
frame_id: int
boxes: List[BoundingBox]
keypoints: List[Tuple[int, int]]
class Miner:
SMALL_CONTAINED_IOA = Constants.SMALL_CONTAINED_IOA
SMALL_RATIO_MAX = Constants.SMALL_RATIO_MAX
SINGLE_PLAYER_HUE_PIVOT = Constants.SINGLE_PLAYER_HUE_PIVOT
CORNER_INDICES = Constants.CORNER_INDICES
KEYPOINTS_CONFIDENCE = Constants.KEYPOINTS_CONFIDENCE
CORNER_CONFIDENCE = Constants.CORNER_CONFIDENCE
GOALKEEPER_POSITION_MARGIN = Constants.GOALKEEPER_POSITION_MARGIN
MIN_SAMPLES_FOR_FIT = 16 # Minimum player crops needed before fitting TeamClassifier
MAX_SAMPLES_FOR_FIT = 600 # Maximum samples to avoid overfitting
def __init__(self, path_hf_repo: Path) -> None:
try:
device = "cuda" if torch.cuda.is_available() else "cpu"
model_path = path_hf_repo / "football_object_detection.onnx"
self.bbox_model = YOLO(model_path)
print("BBox Model Loaded")
team_model_path = path_hf_repo / "osnet_model.pth.tar-100"
self.team_classifier = TeamClassifier(
device=device,
batch_size=32,
model_name=str(team_model_path)
)
print("Team Classifier Loaded")
# Team classification state
self.team_classifier_fitted = False
self.player_crops_for_fit = []
model_kp_path = path_hf_repo / 'keypoint'
config_kp_path = path_hf_repo / 'hrnetv2_w48.yaml'
cfg_kp = yaml.safe_load(open(config_kp_path, 'r'))
loaded_state_kp = torch.load(model_kp_path, map_location=device)
model = get_cls_net(cfg_kp)
model.load_state_dict(loaded_state_kp)
model.to(device)
model.eval()
self.keypoints_model = model
self.kp_threshold = 0.1
self.pitch_batch_size = 4
self.health = "healthy"
print("✅ Keypoints Model Loaded")
except Exception as e:
self.health = "❌ Miner initialization failed: " + str(e)
print(self.health)
def __repr__(self) -> str:
if self.health == 'healthy':
return (
f"health: {self.health}\n"
f"BBox Model: {type(self.bbox_model).__name__}\n"
f"Keypoints Model: {type(self.keypoints_model).__name__}"
)
else:
return self.health
def _calculate_iou(self, box1: Tuple[float, float, float, float],
box2: Tuple[float, float, float, float]) -> float:
"""
Calculate Intersection over Union (IoU) between two bounding boxes.
Args:
box1: (x1, y1, x2, y2)
box2: (x1, y1, x2, y2)
Returns:
IoU score (0-1)
"""
x1_1, y1_1, x2_1, y2_1 = box1
x1_2, y1_2, x2_2, y2_2 = box2
# Calculate intersection area
x_left = max(x1_1, x1_2)
y_top = max(y1_1, y1_2)
x_right = min(x2_1, x2_2)
y_bottom = min(y2_1, y2_2)
if x_right < x_left or y_bottom < y_top:
return 0.0
intersection_area = (x_right - x_left) * (y_bottom - y_top)
# Calculate union area
box1_area = (x2_1 - x1_1) * (y2_1 - y1_1)
box2_area = (x2_2 - x1_2) * (y2_2 - y1_2)
union_area = box1_area + box2_area - intersection_area
if union_area == 0:
return 0.0
return intersection_area / union_area
def _detect_objects_batch(self, decoded_images: List[ndarray]) -> Dict[int, List[BoundingBox]]:
batch_size = 16
detection_results = []
n_frames = len(decoded_images)
for frame_number in range(0, n_frames, batch_size):
batch_images = decoded_images[frame_number: frame_number + batch_size]
detections = self.bbox_model(batch_images, verbose=False, save=False)
detection_results.extend(detections)
return detection_results
def _team_classify(self, detection_results, decoded_images, offset):
self.team_classifier_fitted = False
start = time.time()
# Collect player crops from first batch for fitting
fit_sample_size = 600
player_crops_for_fit = []
for frame_id in range(len(detection_results)):
detection_box = detection_results[frame_id].boxes.data
if len(detection_box) < 4:
continue
# Collect player boxes for team classification fitting (first batch only)
if len(player_crops_for_fit) < fit_sample_size:
frame_image = decoded_images[frame_id]
for box in detection_box:
x1, y1, x2, y2, conf, cls_id = box.tolist()
if conf < 0.5:
continue
mapped_cls_id = str(int(cls_id))
# Only collect player crops (cls_id = 2)
if mapped_cls_id == '2':
crop = frame_image[int(y1):int(y2), int(x1):int(x2)]
if crop.size > 0:
player_crops_for_fit.append(crop)
# Fit team classifier after collecting samples
if self.team_classifier and not self.team_classifier_fitted and len(player_crops_for_fit) >= fit_sample_size:
print(f"Fitting TeamClassifier with {len(player_crops_for_fit)} player crops")
self.team_classifier.fit(player_crops_for_fit)
self.team_classifier_fitted = True
break
if not self.team_classifier_fitted and len(player_crops_for_fit) >= 16:
print(f"Fallback: Fitting TeamClassifier with {len(player_crops_for_fit)} player crops")
self.team_classifier.fit(player_crops_for_fit)
self.team_classifier_fitted = True
end = time.time()
print(f"Fitting Kmeans time: {end - start}")
# Second pass: predict teams with configurable frame skipping optimization
start = time.time()
# Get configuration for frame skipping
prediction_interval = 1 # Default: predict every 2 frames
iou_threshold = 0.3
print(f"Team classification - prediction_interval: {prediction_interval}, iou_threshold: {iou_threshold}")
# Storage for predicted frame results: {frame_id: {box_idx: (bbox, team_id)}}
predicted_frame_data = {}
# Step 1: Predict for frames at prediction_interval only
frames_to_predict = []
for frame_id in range(len(detection_results)):
if frame_id % prediction_interval == 0:
frames_to_predict.append(frame_id)
print(f"Predicting teams for {len(frames_to_predict)}/{len(detection_results)} frames "
f"(saving {100 - (len(frames_to_predict) * 100 // len(detection_results))}% compute)")
for frame_id in frames_to_predict:
detection_box = detection_results[frame_id].boxes.data
frame_image = decoded_images[frame_id]
# Collect player crops for this frame
frame_player_crops = []
frame_player_indices = []
frame_player_boxes = []
for idx, box in enumerate(detection_box):
x1, y1, x2, y2, conf, cls_id = box.tolist()
if cls_id == 2 and conf < 0.6:
continue
mapped_cls_id = str(int(cls_id))
# Collect player crops for prediction
if self.team_classifier and self.team_classifier_fitted and mapped_cls_id == '2':
crop = frame_image[int(y1):int(y2), int(x1):int(x2)]
if crop.size > 0:
frame_player_crops.append(crop)
frame_player_indices.append(idx)
frame_player_boxes.append((x1, y1, x2, y2))
# Predict teams for all players in this frame
if len(frame_player_crops) > 0:
team_ids = self.team_classifier.predict(frame_player_crops)
predicted_frame_data[frame_id] = {}
for idx, bbox, team_id in zip(frame_player_indices, frame_player_boxes, team_ids):
# Map team_id (0,1) to cls_id (6,7)
team_cls_id = str(6 + int(team_id))
predicted_frame_data[frame_id][idx] = (bbox, team_cls_id)
# Step 2: Process all frames (interpolate skipped frames)
fallback_count = 0
interpolated_count = 0
bboxes: dict[int, list[BoundingBox]] = {}
for frame_id in range(len(detection_results)):
detection_box = detection_results[frame_id].boxes.data
frame_image = decoded_images[frame_id]
boxes = []
team_predictions = {}
if frame_id % prediction_interval == 0:
# Predicted frame: use pre-computed predictions
if frame_id in predicted_frame_data:
for idx, (bbox, team_cls_id) in predicted_frame_data[frame_id].items():
team_predictions[idx] = team_cls_id
else:
# Skipped frame: interpolate from neighboring predicted frames
# Find nearest predicted frames
prev_predicted_frame = (frame_id // prediction_interval) * prediction_interval
next_predicted_frame = prev_predicted_frame + prediction_interval
# Collect current frame player boxes
for idx, box in enumerate(detection_box):
x1, y1, x2, y2, conf, cls_id = box.tolist()
if cls_id == 2 and conf < 0.6:
continue
mapped_cls_id = str(int(cls_id))
if self.team_classifier and self.team_classifier_fitted and mapped_cls_id == '2':
target_box = (x1, y1, x2, y2)
# Try to match with previous predicted frame
best_team_id = None
best_iou = 0.0
if prev_predicted_frame in predicted_frame_data:
team_id, iou = self._find_best_match(
target_box,
predicted_frame_data[prev_predicted_frame],
iou_threshold
)
if team_id is not None:
best_team_id = team_id
best_iou = iou
# Try to match with next predicted frame if available and no good match yet
if best_team_id is None and next_predicted_frame < len(detection_results):
if next_predicted_frame in predicted_frame_data:
team_id, iou = self._find_best_match(
target_box,
predicted_frame_data[next_predicted_frame],
iou_threshold
)
if team_id is not None and iou > best_iou:
best_team_id = team_id
best_iou = iou
# Track interpolation success
if best_team_id is not None:
interpolated_count += 1
else:
# Fallback: if no match found, predict individually
crop = frame_image[int(y1):int(y2), int(x1):int(x2)]
if crop.size > 0:
team_id = self.team_classifier.predict([crop])[0]
best_team_id = str(6 + int(team_id))
fallback_count += 1
if best_team_id is not None:
team_predictions[idx] = best_team_id
# Parse boxes with team classification
for idx, box in enumerate(detection_box):
x1, y1, x2, y2, conf, cls_id = box.tolist()
if cls_id == 2 and conf < 0.6:
continue
# Check overlap with staff box
overlap_staff = False
for idy, boxy in enumerate(detection_box):
s_x1, s_y1, s_x2, s_y2, s_conf, s_cls_id = boxy.tolist()
if cls_id == 2 and s_cls_id == 4:
staff_iou = self._calculate_iou(box[:4], boxy[:4])
if staff_iou >= 0.8:
overlap_staff = True
break
if overlap_staff:
continue
mapped_cls_id = str(int(cls_id))
# Override cls_id for players with team prediction
if idx in team_predictions:
mapped_cls_id = team_predictions[idx]
if mapped_cls_id != '4':
if int(mapped_cls_id) == 3 and conf < 0.5:
continue
boxes.append(
BoundingBox(
x1=int(x1),
y1=int(y1),
x2=int(x2),
y2=int(y2),
cls_id=int(mapped_cls_id),
conf=float(conf),
)
)
# Handle footballs - keep only the best one
footballs = [bb for bb in boxes if int(bb.cls_id) == 0]
if len(footballs) > 1:
best_ball = max(footballs, key=lambda b: b.conf)
boxes = [bb for bb in boxes if int(bb.cls_id) != 0]
boxes.append(best_ball)
bboxes[offset + frame_id] = boxes
return bboxes
def predict_batch(self, batch_images: List[ndarray], offset: int, n_keypoints: int) -> List[TVFrameResult]:
start = time.time()
detection_results = self._detect_objects_batch(batch_images)
end = time.time()
print(f"Detection time: {end - start}")
start = time.time()
bboxes = self._team_classify(detection_results, batch_images, offset)
end = time.time()
print(f"Team classify time: {end - start}")
pitch_batch_size = min(self.pitch_batch_size, len(batch_images))
keypoints: Dict[int, List[Tuple[int, int]]] = {}
start = time.time()
while True:
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
device_str = "cuda"
keypoints_result = process_batch_input(
batch_images,
self.keypoints_model,
self.kp_threshold,
device_str,
batch_size=pitch_batch_size,
)
if keypoints_result is not None and len(keypoints_result) > 0:
for frame_number_in_batch, kp_dict in enumerate(keypoints_result):
if frame_number_in_batch >= len(batch_images):
break
frame_keypoints: List[Tuple[int, int]] = []
try:
height, width = batch_images[frame_number_in_batch].shape[:2]
if kp_dict is not None and isinstance(kp_dict, dict):
for idx in range(32):
x, y = 0, 0
kp_idx = idx + 1
if kp_idx in kp_dict:
try:
kp_data = kp_dict[kp_idx]
if isinstance(kp_data, dict) and "x" in kp_data and "y" in kp_data:
x = int(kp_data["x"] * width)
y = int(kp_data["y"] * height)
except (KeyError, TypeError, ValueError):
pass
frame_keypoints.append((x, y))
except (IndexError, ValueError, AttributeError):
frame_keypoints = [(0, 0)] * 32
if len(frame_keypoints) < n_keypoints:
frame_keypoints.extend([(0, 0)] * (n_keypoints - len(frame_keypoints)))
else:
frame_keypoints = frame_keypoints[:n_keypoints]
keypoints[offset + frame_number_in_batch] = frame_keypoints
break
end = time.time()
print(f"Keypoint time: {end - start}")
results: List[TVFrameResult] = []
for frame_number in range(offset, offset + len(batch_images)):
frame_boxes = bboxes.get(frame_number, [])
frame_keypoints = keypoints.get(frame_number, [(0, 0) for _ in range(n_keypoints)])
result = TVFrameResult(
frame_id=frame_number,
boxes=frame_boxes,
keypoints=frame_keypoints,
)
results.append(result)
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
return results