from pathlib import Path from typing import List, Tuple, Dict import sys import os from numpy import ndarray from pydantic import BaseModel sys.path.append(os.path.dirname(os.path.abspath(__file__))) from ultralytics import YOLO from team_cluster import TeamClassifier from utils import ( BoundingBox, Constants, ) import time import torch import gc from pitch import process_batch_input, get_cls_net import yaml class BoundingBox(BaseModel): x1: int y1: int x2: int y2: int cls_id: int conf: float class TVFrameResult(BaseModel): frame_id: int boxes: List[BoundingBox] keypoints: List[Tuple[int, int]] class Miner: SMALL_CONTAINED_IOA = Constants.SMALL_CONTAINED_IOA SMALL_RATIO_MAX = Constants.SMALL_RATIO_MAX SINGLE_PLAYER_HUE_PIVOT = Constants.SINGLE_PLAYER_HUE_PIVOT CORNER_INDICES = Constants.CORNER_INDICES KEYPOINTS_CONFIDENCE = Constants.KEYPOINTS_CONFIDENCE CORNER_CONFIDENCE = Constants.CORNER_CONFIDENCE GOALKEEPER_POSITION_MARGIN = Constants.GOALKEEPER_POSITION_MARGIN MIN_SAMPLES_FOR_FIT = 16 # Minimum player crops needed before fitting TeamClassifier MAX_SAMPLES_FOR_FIT = 600 # Maximum samples to avoid overfitting def __init__(self, path_hf_repo: Path) -> None: try: device = "cuda" if torch.cuda.is_available() else "cpu" model_path = path_hf_repo / "football_object_detection.onnx" self.bbox_model = YOLO(model_path) print("BBox Model Loaded") team_model_path = path_hf_repo / "osnet_model.pth.tar-100" self.team_classifier = TeamClassifier( device=device, batch_size=32, model_name=str(team_model_path) ) print("Team Classifier Loaded") # Team classification state self.team_classifier_fitted = False self.player_crops_for_fit = [] model_kp_path = path_hf_repo / 'keypoint' config_kp_path = path_hf_repo / 'hrnetv2_w48.yaml' cfg_kp = yaml.safe_load(open(config_kp_path, 'r')) loaded_state_kp = torch.load(model_kp_path, map_location=device) model = get_cls_net(cfg_kp) model.load_state_dict(loaded_state_kp) model.to(device) model.eval() self.keypoints_model = model self.kp_threshold = 0.1 self.pitch_batch_size = 4 self.health = "healthy" print("✅ Keypoints Model Loaded") except Exception as e: self.health = "❌ Miner initialization failed: " + str(e) print(self.health) def __repr__(self) -> str: if self.health == 'healthy': return ( f"health: {self.health}\n" f"BBox Model: {type(self.bbox_model).__name__}\n" f"Keypoints Model: {type(self.keypoints_model).__name__}" ) else: return self.health def _calculate_iou(self, box1: Tuple[float, float, float, float], box2: Tuple[float, float, float, float]) -> float: """ Calculate Intersection over Union (IoU) between two bounding boxes. Args: box1: (x1, y1, x2, y2) box2: (x1, y1, x2, y2) Returns: IoU score (0-1) """ x1_1, y1_1, x2_1, y2_1 = box1 x1_2, y1_2, x2_2, y2_2 = box2 # Calculate intersection area x_left = max(x1_1, x1_2) y_top = max(y1_1, y1_2) x_right = min(x2_1, x2_2) y_bottom = min(y2_1, y2_2) if x_right < x_left or y_bottom < y_top: return 0.0 intersection_area = (x_right - x_left) * (y_bottom - y_top) # Calculate union area box1_area = (x2_1 - x1_1) * (y2_1 - y1_1) box2_area = (x2_2 - x1_2) * (y2_2 - y1_2) union_area = box1_area + box2_area - intersection_area if union_area == 0: return 0.0 return intersection_area / union_area def _detect_objects_batch(self, decoded_images: List[ndarray]) -> Dict[int, List[BoundingBox]]: batch_size = 16 detection_results = [] n_frames = len(decoded_images) for frame_number in range(0, n_frames, batch_size): batch_images = decoded_images[frame_number: frame_number + batch_size] detections = self.bbox_model(batch_images, verbose=False, save=False) detection_results.extend(detections) return detection_results def _team_classify(self, detection_results, decoded_images, offset): self.team_classifier_fitted = False start = time.time() # Collect player crops from first batch for fitting fit_sample_size = 600 player_crops_for_fit = [] for frame_id in range(len(detection_results)): detection_box = detection_results[frame_id].boxes.data if len(detection_box) < 4: continue # Collect player boxes for team classification fitting (first batch only) if len(player_crops_for_fit) < fit_sample_size: frame_image = decoded_images[frame_id] for box in detection_box: x1, y1, x2, y2, conf, cls_id = box.tolist() if conf < 0.5: continue mapped_cls_id = str(int(cls_id)) # Only collect player crops (cls_id = 2) if mapped_cls_id == '2': crop = frame_image[int(y1):int(y2), int(x1):int(x2)] if crop.size > 0: player_crops_for_fit.append(crop) # Fit team classifier after collecting samples if self.team_classifier and not self.team_classifier_fitted and len(player_crops_for_fit) >= fit_sample_size: print(f"Fitting TeamClassifier with {len(player_crops_for_fit)} player crops") self.team_classifier.fit(player_crops_for_fit) self.team_classifier_fitted = True break if not self.team_classifier_fitted and len(player_crops_for_fit) >= 16: print(f"Fallback: Fitting TeamClassifier with {len(player_crops_for_fit)} player crops") self.team_classifier.fit(player_crops_for_fit) self.team_classifier_fitted = True end = time.time() print(f"Fitting Kmeans time: {end - start}") # Second pass: predict teams with configurable frame skipping optimization start = time.time() # Get configuration for frame skipping prediction_interval = 1 # Default: predict every 2 frames iou_threshold = 0.3 print(f"Team classification - prediction_interval: {prediction_interval}, iou_threshold: {iou_threshold}") # Storage for predicted frame results: {frame_id: {box_idx: (bbox, team_id)}} predicted_frame_data = {} # Step 1: Predict for frames at prediction_interval only frames_to_predict = [] for frame_id in range(len(detection_results)): if frame_id % prediction_interval == 0: frames_to_predict.append(frame_id) print(f"Predicting teams for {len(frames_to_predict)}/{len(detection_results)} frames " f"(saving {100 - (len(frames_to_predict) * 100 // len(detection_results))}% compute)") for frame_id in frames_to_predict: detection_box = detection_results[frame_id].boxes.data frame_image = decoded_images[frame_id] # Collect player crops for this frame frame_player_crops = [] frame_player_indices = [] frame_player_boxes = [] for idx, box in enumerate(detection_box): x1, y1, x2, y2, conf, cls_id = box.tolist() if cls_id == 2 and conf < 0.6: continue mapped_cls_id = str(int(cls_id)) # Collect player crops for prediction if self.team_classifier and self.team_classifier_fitted and mapped_cls_id == '2': crop = frame_image[int(y1):int(y2), int(x1):int(x2)] if crop.size > 0: frame_player_crops.append(crop) frame_player_indices.append(idx) frame_player_boxes.append((x1, y1, x2, y2)) # Predict teams for all players in this frame if len(frame_player_crops) > 0: team_ids = self.team_classifier.predict(frame_player_crops) predicted_frame_data[frame_id] = {} for idx, bbox, team_id in zip(frame_player_indices, frame_player_boxes, team_ids): # Map team_id (0,1) to cls_id (6,7) team_cls_id = str(6 + int(team_id)) predicted_frame_data[frame_id][idx] = (bbox, team_cls_id) # Step 2: Process all frames (interpolate skipped frames) fallback_count = 0 interpolated_count = 0 bboxes: dict[int, list[BoundingBox]] = {} for frame_id in range(len(detection_results)): detection_box = detection_results[frame_id].boxes.data frame_image = decoded_images[frame_id] boxes = [] team_predictions = {} if frame_id % prediction_interval == 0: # Predicted frame: use pre-computed predictions if frame_id in predicted_frame_data: for idx, (bbox, team_cls_id) in predicted_frame_data[frame_id].items(): team_predictions[idx] = team_cls_id else: # Skipped frame: interpolate from neighboring predicted frames # Find nearest predicted frames prev_predicted_frame = (frame_id // prediction_interval) * prediction_interval next_predicted_frame = prev_predicted_frame + prediction_interval # Collect current frame player boxes for idx, box in enumerate(detection_box): x1, y1, x2, y2, conf, cls_id = box.tolist() if cls_id == 2 and conf < 0.6: continue mapped_cls_id = str(int(cls_id)) if self.team_classifier and self.team_classifier_fitted and mapped_cls_id == '2': target_box = (x1, y1, x2, y2) # Try to match with previous predicted frame best_team_id = None best_iou = 0.0 if prev_predicted_frame in predicted_frame_data: team_id, iou = self._find_best_match( target_box, predicted_frame_data[prev_predicted_frame], iou_threshold ) if team_id is not None: best_team_id = team_id best_iou = iou # Try to match with next predicted frame if available and no good match yet if best_team_id is None and next_predicted_frame < len(detection_results): if next_predicted_frame in predicted_frame_data: team_id, iou = self._find_best_match( target_box, predicted_frame_data[next_predicted_frame], iou_threshold ) if team_id is not None and iou > best_iou: best_team_id = team_id best_iou = iou # Track interpolation success if best_team_id is not None: interpolated_count += 1 else: # Fallback: if no match found, predict individually crop = frame_image[int(y1):int(y2), int(x1):int(x2)] if crop.size > 0: team_id = self.team_classifier.predict([crop])[0] best_team_id = str(6 + int(team_id)) fallback_count += 1 if best_team_id is not None: team_predictions[idx] = best_team_id # Parse boxes with team classification for idx, box in enumerate(detection_box): x1, y1, x2, y2, conf, cls_id = box.tolist() if cls_id == 2 and conf < 0.6: continue # Check overlap with staff box overlap_staff = False for idy, boxy in enumerate(detection_box): s_x1, s_y1, s_x2, s_y2, s_conf, s_cls_id = boxy.tolist() if cls_id == 2 and s_cls_id == 4: staff_iou = self._calculate_iou(box[:4], boxy[:4]) if staff_iou >= 0.8: overlap_staff = True break if overlap_staff: continue mapped_cls_id = str(int(cls_id)) # Override cls_id for players with team prediction if idx in team_predictions: mapped_cls_id = team_predictions[idx] if mapped_cls_id != '4': if int(mapped_cls_id) == 3 and conf < 0.5: continue boxes.append( BoundingBox( x1=int(x1), y1=int(y1), x2=int(x2), y2=int(y2), cls_id=int(mapped_cls_id), conf=float(conf), ) ) # Handle footballs - keep only the best one footballs = [bb for bb in boxes if int(bb.cls_id) == 0] if len(footballs) > 1: best_ball = max(footballs, key=lambda b: b.conf) boxes = [bb for bb in boxes if int(bb.cls_id) != 0] boxes.append(best_ball) bboxes[offset + frame_id] = boxes return bboxes def predict_batch(self, batch_images: List[ndarray], offset: int, n_keypoints: int) -> List[TVFrameResult]: start = time.time() detection_results = self._detect_objects_batch(batch_images) end = time.time() print(f"Detection time: {end - start}") start = time.time() bboxes = self._team_classify(detection_results, batch_images, offset) end = time.time() print(f"Team classify time: {end - start}") pitch_batch_size = min(self.pitch_batch_size, len(batch_images)) keypoints: Dict[int, List[Tuple[int, int]]] = {} start = time.time() while True: gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.synchronize() device_str = "cuda" keypoints_result = process_batch_input( batch_images, self.keypoints_model, self.kp_threshold, device_str, batch_size=pitch_batch_size, ) if keypoints_result is not None and len(keypoints_result) > 0: for frame_number_in_batch, kp_dict in enumerate(keypoints_result): if frame_number_in_batch >= len(batch_images): break frame_keypoints: List[Tuple[int, int]] = [] try: height, width = batch_images[frame_number_in_batch].shape[:2] if kp_dict is not None and isinstance(kp_dict, dict): for idx in range(32): x, y = 0, 0 kp_idx = idx + 1 if kp_idx in kp_dict: try: kp_data = kp_dict[kp_idx] if isinstance(kp_data, dict) and "x" in kp_data and "y" in kp_data: x = int(kp_data["x"] * width) y = int(kp_data["y"] * height) except (KeyError, TypeError, ValueError): pass frame_keypoints.append((x, y)) except (IndexError, ValueError, AttributeError): frame_keypoints = [(0, 0)] * 32 if len(frame_keypoints) < n_keypoints: frame_keypoints.extend([(0, 0)] * (n_keypoints - len(frame_keypoints))) else: frame_keypoints = frame_keypoints[:n_keypoints] keypoints[offset + frame_number_in_batch] = frame_keypoints break end = time.time() print(f"Keypoint time: {end - start}") results: List[TVFrameResult] = [] for frame_number in range(offset, offset + len(batch_images)): frame_boxes = bboxes.get(frame_number, []) frame_keypoints = keypoints.get(frame_number, [(0, 0) for _ in range(n_keypoints)]) result = TVFrameResult( frame_id=frame_number, boxes=frame_boxes, keypoints=frame_keypoints, ) results.append(result) gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.synchronize() return results