|
|
from pathlib import Path |
|
|
from typing import List, Tuple, Dict |
|
|
import sys |
|
|
import os |
|
|
|
|
|
from numpy import ndarray |
|
|
from pydantic import BaseModel |
|
|
sys.path.append(os.path.dirname(os.path.abspath(__file__))) |
|
|
|
|
|
from ultralytics import YOLO |
|
|
from team_cluster import TeamClassifier |
|
|
from utils import ( |
|
|
BoundingBox, |
|
|
Constants, |
|
|
) |
|
|
|
|
|
import time |
|
|
import torch |
|
|
import gc |
|
|
from pitch import process_batch_input, get_cls_net |
|
|
import yaml |
|
|
|
|
|
|
|
|
class BoundingBox(BaseModel): |
|
|
x1: int |
|
|
y1: int |
|
|
x2: int |
|
|
y2: int |
|
|
cls_id: int |
|
|
conf: float |
|
|
|
|
|
|
|
|
class TVFrameResult(BaseModel): |
|
|
frame_id: int |
|
|
boxes: List[BoundingBox] |
|
|
keypoints: List[Tuple[int, int]] |
|
|
|
|
|
|
|
|
class Miner: |
|
|
SMALL_CONTAINED_IOA = Constants.SMALL_CONTAINED_IOA |
|
|
SMALL_RATIO_MAX = Constants.SMALL_RATIO_MAX |
|
|
SINGLE_PLAYER_HUE_PIVOT = Constants.SINGLE_PLAYER_HUE_PIVOT |
|
|
CORNER_INDICES = Constants.CORNER_INDICES |
|
|
KEYPOINTS_CONFIDENCE = Constants.KEYPOINTS_CONFIDENCE |
|
|
CORNER_CONFIDENCE = Constants.CORNER_CONFIDENCE |
|
|
GOALKEEPER_POSITION_MARGIN = Constants.GOALKEEPER_POSITION_MARGIN |
|
|
MIN_SAMPLES_FOR_FIT = 16 |
|
|
MAX_SAMPLES_FOR_FIT = 600 |
|
|
|
|
|
def __init__(self, path_hf_repo: Path) -> None: |
|
|
try: |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
model_path = path_hf_repo / "football_object_detection.onnx" |
|
|
self.bbox_model = YOLO(model_path) |
|
|
|
|
|
print("BBox Model Loaded") |
|
|
|
|
|
team_model_path = path_hf_repo / "osnet_model.pth.tar-100" |
|
|
self.team_classifier = TeamClassifier( |
|
|
device=device, |
|
|
batch_size=32, |
|
|
model_name=str(team_model_path) |
|
|
) |
|
|
print("Team Classifier Loaded") |
|
|
|
|
|
|
|
|
self.team_classifier_fitted = False |
|
|
self.player_crops_for_fit = [] |
|
|
|
|
|
model_kp_path = path_hf_repo / 'keypoint' |
|
|
config_kp_path = path_hf_repo / 'hrnetv2_w48.yaml' |
|
|
cfg_kp = yaml.safe_load(open(config_kp_path, 'r')) |
|
|
|
|
|
loaded_state_kp = torch.load(model_kp_path, map_location=device) |
|
|
model = get_cls_net(cfg_kp) |
|
|
model.load_state_dict(loaded_state_kp) |
|
|
model.to(device) |
|
|
model.eval() |
|
|
|
|
|
self.keypoints_model = model |
|
|
self.kp_threshold = 0.1 |
|
|
self.pitch_batch_size = 4 |
|
|
self.health = "healthy" |
|
|
print("✅ Keypoints Model Loaded") |
|
|
except Exception as e: |
|
|
self.health = "❌ Miner initialization failed: " + str(e) |
|
|
print(self.health) |
|
|
|
|
|
def __repr__(self) -> str: |
|
|
if self.health == 'healthy': |
|
|
return ( |
|
|
f"health: {self.health}\n" |
|
|
f"BBox Model: {type(self.bbox_model).__name__}\n" |
|
|
f"Keypoints Model: {type(self.keypoints_model).__name__}" |
|
|
) |
|
|
else: |
|
|
return self.health |
|
|
|
|
|
def _calculate_iou(self, box1: Tuple[float, float, float, float], |
|
|
box2: Tuple[float, float, float, float]) -> float: |
|
|
""" |
|
|
Calculate Intersection over Union (IoU) between two bounding boxes. |
|
|
Args: |
|
|
box1: (x1, y1, x2, y2) |
|
|
box2: (x1, y1, x2, y2) |
|
|
Returns: |
|
|
IoU score (0-1) |
|
|
""" |
|
|
x1_1, y1_1, x2_1, y2_1 = box1 |
|
|
x1_2, y1_2, x2_2, y2_2 = box2 |
|
|
|
|
|
|
|
|
x_left = max(x1_1, x1_2) |
|
|
y_top = max(y1_1, y1_2) |
|
|
x_right = min(x2_1, x2_2) |
|
|
y_bottom = min(y2_1, y2_2) |
|
|
|
|
|
if x_right < x_left or y_bottom < y_top: |
|
|
return 0.0 |
|
|
|
|
|
intersection_area = (x_right - x_left) * (y_bottom - y_top) |
|
|
|
|
|
|
|
|
box1_area = (x2_1 - x1_1) * (y2_1 - y1_1) |
|
|
box2_area = (x2_2 - x1_2) * (y2_2 - y1_2) |
|
|
union_area = box1_area + box2_area - intersection_area |
|
|
|
|
|
if union_area == 0: |
|
|
return 0.0 |
|
|
|
|
|
return intersection_area / union_area |
|
|
|
|
|
def _detect_objects_batch(self, decoded_images: List[ndarray]) -> Dict[int, List[BoundingBox]]: |
|
|
batch_size = 16 |
|
|
detection_results = [] |
|
|
n_frames = len(decoded_images) |
|
|
for frame_number in range(0, n_frames, batch_size): |
|
|
batch_images = decoded_images[frame_number: frame_number + batch_size] |
|
|
detections = self.bbox_model(batch_images, verbose=False, save=False) |
|
|
detection_results.extend(detections) |
|
|
|
|
|
return detection_results |
|
|
|
|
|
def _team_classify(self, detection_results, decoded_images, offset): |
|
|
self.team_classifier_fitted = False |
|
|
start = time.time() |
|
|
|
|
|
fit_sample_size = 600 |
|
|
player_crops_for_fit = [] |
|
|
|
|
|
for frame_id in range(len(detection_results)): |
|
|
detection_box = detection_results[frame_id].boxes.data |
|
|
if len(detection_box) < 4: |
|
|
continue |
|
|
|
|
|
if len(player_crops_for_fit) < fit_sample_size: |
|
|
frame_image = decoded_images[frame_id] |
|
|
for box in detection_box: |
|
|
x1, y1, x2, y2, conf, cls_id = box.tolist() |
|
|
if conf < 0.5: |
|
|
continue |
|
|
mapped_cls_id = str(int(cls_id)) |
|
|
|
|
|
if mapped_cls_id == '2': |
|
|
crop = frame_image[int(y1):int(y2), int(x1):int(x2)] |
|
|
if crop.size > 0: |
|
|
player_crops_for_fit.append(crop) |
|
|
|
|
|
|
|
|
if self.team_classifier and not self.team_classifier_fitted and len(player_crops_for_fit) >= fit_sample_size: |
|
|
print(f"Fitting TeamClassifier with {len(player_crops_for_fit)} player crops") |
|
|
self.team_classifier.fit(player_crops_for_fit) |
|
|
self.team_classifier_fitted = True |
|
|
break |
|
|
if not self.team_classifier_fitted and len(player_crops_for_fit) >= 16: |
|
|
print(f"Fallback: Fitting TeamClassifier with {len(player_crops_for_fit)} player crops") |
|
|
self.team_classifier.fit(player_crops_for_fit) |
|
|
self.team_classifier_fitted = True |
|
|
end = time.time() |
|
|
print(f"Fitting Kmeans time: {end - start}") |
|
|
|
|
|
|
|
|
start = time.time() |
|
|
|
|
|
|
|
|
prediction_interval = 1 |
|
|
iou_threshold = 0.3 |
|
|
|
|
|
print(f"Team classification - prediction_interval: {prediction_interval}, iou_threshold: {iou_threshold}") |
|
|
|
|
|
|
|
|
predicted_frame_data = {} |
|
|
|
|
|
|
|
|
frames_to_predict = [] |
|
|
for frame_id in range(len(detection_results)): |
|
|
if frame_id % prediction_interval == 0: |
|
|
frames_to_predict.append(frame_id) |
|
|
|
|
|
print(f"Predicting teams for {len(frames_to_predict)}/{len(detection_results)} frames " |
|
|
f"(saving {100 - (len(frames_to_predict) * 100 // len(detection_results))}% compute)") |
|
|
|
|
|
for frame_id in frames_to_predict: |
|
|
detection_box = detection_results[frame_id].boxes.data |
|
|
frame_image = decoded_images[frame_id] |
|
|
|
|
|
|
|
|
frame_player_crops = [] |
|
|
frame_player_indices = [] |
|
|
frame_player_boxes = [] |
|
|
|
|
|
for idx, box in enumerate(detection_box): |
|
|
x1, y1, x2, y2, conf, cls_id = box.tolist() |
|
|
if cls_id == 2 and conf < 0.6: |
|
|
continue |
|
|
mapped_cls_id = str(int(cls_id)) |
|
|
|
|
|
|
|
|
if self.team_classifier and self.team_classifier_fitted and mapped_cls_id == '2': |
|
|
crop = frame_image[int(y1):int(y2), int(x1):int(x2)] |
|
|
if crop.size > 0: |
|
|
frame_player_crops.append(crop) |
|
|
frame_player_indices.append(idx) |
|
|
frame_player_boxes.append((x1, y1, x2, y2)) |
|
|
|
|
|
|
|
|
if len(frame_player_crops) > 0: |
|
|
team_ids = self.team_classifier.predict(frame_player_crops) |
|
|
predicted_frame_data[frame_id] = {} |
|
|
for idx, bbox, team_id in zip(frame_player_indices, frame_player_boxes, team_ids): |
|
|
|
|
|
team_cls_id = str(6 + int(team_id)) |
|
|
predicted_frame_data[frame_id][idx] = (bbox, team_cls_id) |
|
|
|
|
|
|
|
|
fallback_count = 0 |
|
|
interpolated_count = 0 |
|
|
bboxes: dict[int, list[BoundingBox]] = {} |
|
|
for frame_id in range(len(detection_results)): |
|
|
detection_box = detection_results[frame_id].boxes.data |
|
|
frame_image = decoded_images[frame_id] |
|
|
boxes = [] |
|
|
|
|
|
team_predictions = {} |
|
|
|
|
|
if frame_id % prediction_interval == 0: |
|
|
|
|
|
if frame_id in predicted_frame_data: |
|
|
for idx, (bbox, team_cls_id) in predicted_frame_data[frame_id].items(): |
|
|
team_predictions[idx] = team_cls_id |
|
|
else: |
|
|
|
|
|
|
|
|
prev_predicted_frame = (frame_id // prediction_interval) * prediction_interval |
|
|
next_predicted_frame = prev_predicted_frame + prediction_interval |
|
|
|
|
|
|
|
|
for idx, box in enumerate(detection_box): |
|
|
x1, y1, x2, y2, conf, cls_id = box.tolist() |
|
|
if cls_id == 2 and conf < 0.6: |
|
|
continue |
|
|
mapped_cls_id = str(int(cls_id)) |
|
|
|
|
|
if self.team_classifier and self.team_classifier_fitted and mapped_cls_id == '2': |
|
|
target_box = (x1, y1, x2, y2) |
|
|
|
|
|
|
|
|
best_team_id = None |
|
|
best_iou = 0.0 |
|
|
|
|
|
if prev_predicted_frame in predicted_frame_data: |
|
|
team_id, iou = self._find_best_match( |
|
|
target_box, |
|
|
predicted_frame_data[prev_predicted_frame], |
|
|
iou_threshold |
|
|
) |
|
|
if team_id is not None: |
|
|
best_team_id = team_id |
|
|
best_iou = iou |
|
|
|
|
|
|
|
|
if best_team_id is None and next_predicted_frame < len(detection_results): |
|
|
if next_predicted_frame in predicted_frame_data: |
|
|
team_id, iou = self._find_best_match( |
|
|
target_box, |
|
|
predicted_frame_data[next_predicted_frame], |
|
|
iou_threshold |
|
|
) |
|
|
if team_id is not None and iou > best_iou: |
|
|
best_team_id = team_id |
|
|
best_iou = iou |
|
|
|
|
|
|
|
|
if best_team_id is not None: |
|
|
interpolated_count += 1 |
|
|
else: |
|
|
|
|
|
crop = frame_image[int(y1):int(y2), int(x1):int(x2)] |
|
|
if crop.size > 0: |
|
|
team_id = self.team_classifier.predict([crop])[0] |
|
|
best_team_id = str(6 + int(team_id)) |
|
|
fallback_count += 1 |
|
|
|
|
|
if best_team_id is not None: |
|
|
team_predictions[idx] = best_team_id |
|
|
|
|
|
|
|
|
for idx, box in enumerate(detection_box): |
|
|
x1, y1, x2, y2, conf, cls_id = box.tolist() |
|
|
if cls_id == 2 and conf < 0.6: |
|
|
continue |
|
|
|
|
|
|
|
|
overlap_staff = False |
|
|
for idy, boxy in enumerate(detection_box): |
|
|
s_x1, s_y1, s_x2, s_y2, s_conf, s_cls_id = boxy.tolist() |
|
|
if cls_id == 2 and s_cls_id == 4: |
|
|
staff_iou = self._calculate_iou(box[:4], boxy[:4]) |
|
|
if staff_iou >= 0.8: |
|
|
overlap_staff = True |
|
|
break |
|
|
if overlap_staff: |
|
|
continue |
|
|
|
|
|
mapped_cls_id = str(int(cls_id)) |
|
|
|
|
|
|
|
|
if idx in team_predictions: |
|
|
mapped_cls_id = team_predictions[idx] |
|
|
if mapped_cls_id != '4': |
|
|
if int(mapped_cls_id) == 3 and conf < 0.5: |
|
|
continue |
|
|
boxes.append( |
|
|
BoundingBox( |
|
|
x1=int(x1), |
|
|
y1=int(y1), |
|
|
x2=int(x2), |
|
|
y2=int(y2), |
|
|
cls_id=int(mapped_cls_id), |
|
|
conf=float(conf), |
|
|
) |
|
|
) |
|
|
|
|
|
footballs = [bb for bb in boxes if int(bb.cls_id) == 0] |
|
|
if len(footballs) > 1: |
|
|
best_ball = max(footballs, key=lambda b: b.conf) |
|
|
boxes = [bb for bb in boxes if int(bb.cls_id) != 0] |
|
|
boxes.append(best_ball) |
|
|
|
|
|
bboxes[offset + frame_id] = boxes |
|
|
return bboxes |
|
|
|
|
|
|
|
|
def predict_batch(self, batch_images: List[ndarray], offset: int, n_keypoints: int) -> List[TVFrameResult]: |
|
|
start = time.time() |
|
|
detection_results = self._detect_objects_batch(batch_images) |
|
|
end = time.time() |
|
|
print(f"Detection time: {end - start}") |
|
|
start = time.time() |
|
|
bboxes = self._team_classify(detection_results, batch_images, offset) |
|
|
end = time.time() |
|
|
print(f"Team classify time: {end - start}") |
|
|
|
|
|
pitch_batch_size = min(self.pitch_batch_size, len(batch_images)) |
|
|
keypoints: Dict[int, List[Tuple[int, int]]] = {} |
|
|
|
|
|
start = time.time() |
|
|
while True: |
|
|
gc.collect() |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
torch.cuda.synchronize() |
|
|
device_str = "cuda" |
|
|
keypoints_result = process_batch_input( |
|
|
batch_images, |
|
|
self.keypoints_model, |
|
|
self.kp_threshold, |
|
|
device_str, |
|
|
batch_size=pitch_batch_size, |
|
|
) |
|
|
if keypoints_result is not None and len(keypoints_result) > 0: |
|
|
for frame_number_in_batch, kp_dict in enumerate(keypoints_result): |
|
|
if frame_number_in_batch >= len(batch_images): |
|
|
break |
|
|
frame_keypoints: List[Tuple[int, int]] = [] |
|
|
try: |
|
|
height, width = batch_images[frame_number_in_batch].shape[:2] |
|
|
if kp_dict is not None and isinstance(kp_dict, dict): |
|
|
for idx in range(32): |
|
|
x, y = 0, 0 |
|
|
kp_idx = idx + 1 |
|
|
if kp_idx in kp_dict: |
|
|
try: |
|
|
kp_data = kp_dict[kp_idx] |
|
|
if isinstance(kp_data, dict) and "x" in kp_data and "y" in kp_data: |
|
|
x = int(kp_data["x"] * width) |
|
|
y = int(kp_data["y"] * height) |
|
|
except (KeyError, TypeError, ValueError): |
|
|
pass |
|
|
frame_keypoints.append((x, y)) |
|
|
except (IndexError, ValueError, AttributeError): |
|
|
frame_keypoints = [(0, 0)] * 32 |
|
|
if len(frame_keypoints) < n_keypoints: |
|
|
frame_keypoints.extend([(0, 0)] * (n_keypoints - len(frame_keypoints))) |
|
|
else: |
|
|
frame_keypoints = frame_keypoints[:n_keypoints] |
|
|
keypoints[offset + frame_number_in_batch] = frame_keypoints |
|
|
break |
|
|
end = time.time() |
|
|
print(f"Keypoint time: {end - start}") |
|
|
|
|
|
|
|
|
results: List[TVFrameResult] = [] |
|
|
for frame_number in range(offset, offset + len(batch_images)): |
|
|
frame_boxes = bboxes.get(frame_number, []) |
|
|
frame_keypoints = keypoints.get(frame_number, [(0, 0) for _ in range(n_keypoints)]) |
|
|
result = TVFrameResult( |
|
|
frame_id=frame_number, |
|
|
boxes=frame_boxes, |
|
|
keypoints=frame_keypoints, |
|
|
) |
|
|
results.append(result) |
|
|
|
|
|
gc.collect() |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
torch.cuda.synchronize() |
|
|
|
|
|
return results |