Spaces:
Sleeping
Sleeping
| import cv2 | |
| import gradio as gr | |
| import numpy as np | |
| import supervision as sv | |
| import torch | |
| from PIL import Image | |
| from transformers import ( | |
| RTDetrForObjectDetection, | |
| RTDetrImageProcessor, | |
| VitPoseForPoseEstimation, | |
| VitPoseImageProcessor, | |
| pipeline, | |
| ) | |
| KEYPOINT_LABEL_MAP = { | |
| 0: "Nose", | |
| 1: "L_Eye", | |
| 2: "R_Eye", | |
| 3: "L_Ear", | |
| 4: "R_Ear", | |
| 5: "L_Shoulder", | |
| 6: "R_Shoulder", | |
| 7: "L_Elbow", | |
| 8: "R_Elbow", | |
| 9: "L_Wrist", | |
| 10: "R_Wrist", | |
| 11: "L_Hip", | |
| 12: "R_Hip", | |
| 13: "L_Knee", | |
| 14: "R_Knee", | |
| 15: "L_Ankle", | |
| 16: "R_Ankle", | |
| } | |
| class InteractionDetector: | |
| def __init__(self): | |
| self.person_detector = None | |
| self.person_processor = None | |
| self.pose_model = None | |
| self.pose_processor = None | |
| self.depth_model = None | |
| self.segmentation_model = None | |
| self.interaction_threshold = 2 | |
| self.load_models() | |
| def load_models(self): | |
| """Load all required models""" | |
| # Person detection model | |
| self.person_processor = RTDetrImageProcessor.from_pretrained("PekingU/rtdetr_r50vd_coco_o365") | |
| self.person_detector = RTDetrForObjectDetection.from_pretrained("PekingU/rtdetr_r50vd_coco_o365") | |
| # Pose estimation model | |
| self.pose_processor = VitPoseImageProcessor.from_pretrained("nielsr/vitpose-base-simple") | |
| self.pose_model = VitPoseForPoseEstimation.from_pretrained("nielsr/vitpose-base-simple") | |
| # Depth estimation model | |
| self.depth_model = pipeline("depth-estimation", model="depth-anything/Depth-Anything-V2-Small-hf") | |
| # Semantic segmentation model | |
| self.segmentation_model = pipeline("image-segmentation", model="facebook/maskformer-swin-base-ade") | |
| self.segmentation_id2label = self.segmentation_model.model.config.id2label | |
| self.segmentation_label2id = {v: k for k, v in self.segmentation_model.model.config.id2label.items()} | |
| def get_nearest_pixel_class(self, joint, depth_map, segmentation_map): | |
| """ | |
| Find the nearest pixel of a specific class to a given joint coordinate | |
| Args: | |
| joint: (x, y) coordinates of the joint | |
| depth_map: Depth map | |
| segmentation_map: Semantic segmentation results | |
| Returns: | |
| tuple: class_name of nearest pixel, distance to that pixel | |
| """ | |
| PERSON_ID = 12 | |
| grid_x, grid_y = np.meshgrid(np.arange(depth_map.shape[0]), np.arange(depth_map.shape[1])) | |
| dist_x = np.abs(grid_x.T - joint[1]) | |
| dist_y = np.abs(grid_y.T - joint[0]) | |
| dist_coord = dist_x + dist_y | |
| depth_dist = np.abs(depth_map - depth_map[joint[1], joint[0]]) | |
| depth_dist[(segmentation_map == PERSON_ID) | (dist_coord > 50)] = 255 | |
| min_dist = np.unravel_index(np.argmin(depth_dist), depth_dist.shape) | |
| return segmentation_map[min_dist], depth_dist[min_dist] | |
| def detect_persons(self, image: Image.Image): | |
| """Detect persons in the image""" | |
| inputs = self.person_processor(images=image, return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = self.person_detector(**inputs) | |
| results = self.person_processor.post_process_object_detection( | |
| outputs, | |
| target_sizes=torch.tensor([(image.height, image.width)]), | |
| threshold=0.3 | |
| ) | |
| boxes = results[0]["boxes"][results[0]["labels"] == 0] | |
| scores = results[0]["scores"][results[0]["labels"] == 0] | |
| return boxes.cpu().numpy(), scores.cpu().numpy() | |
| def detect_keypoints(self, image: Image.Image): | |
| """Detect keypoints in the image""" | |
| boxes, scores = self.detect_persons(image) | |
| pixel_values = self.pose_processor(image, boxes=[boxes], return_tensors="pt").pixel_values | |
| with torch.no_grad(): | |
| outputs = self.pose_model(pixel_values) | |
| pose_results = self.pose_processor.post_process_pose_estimation(outputs, boxes=[boxes])[0] | |
| return pose_results, boxes, scores | |
| def estimate_depth(self, image: Image.Image): | |
| """Estimate depth for the image""" | |
| with torch.no_grad(): | |
| depth_map = np.array(self.depth_model(image)['depth']) | |
| return depth_map | |
| def segment_image(self, image: Image.Image): | |
| """Perform semantic segmentation on the image""" | |
| with torch.no_grad(): | |
| segmentation_map = self.segmentation_model(image) | |
| result = np.zeros(np.array(image).shape[:2], dtype=np.uint8) | |
| print("Found", [l['label'] for l in segmentation_map]) | |
| for cls_item in sorted(segmentation_map, key=lambda l: np.sum(l['mask']), reverse=True): | |
| result[np.array(cls_item['mask']) > 0] = self.segmentation_label2id[cls_item['label']] | |
| return result | |
| def detect_wall_interaction(self, image: Image.Image): | |
| """Detect if hands are touching walls""" | |
| # Get all necessary information | |
| pose_results, boxes, scores = self.detect_keypoints(image) | |
| depth_map = self.estimate_depth(image) | |
| segmentation_map = self.segment_image(image) | |
| interactions = [] | |
| for person_idx, pose_result in enumerate(pose_results): | |
| # Get hand keypoints | |
| right_hand = pose_result["keypoints"][10].numpy().astype(int) | |
| left_hand = pose_result["keypoints"][9].numpy().astype(int) | |
| # Find nearest anything pixels | |
| right_cls, r_distance = self.get_nearest_pixel_class(right_hand[:2], depth_map, segmentation_map) | |
| left_cls, l_distance = self.get_nearest_pixel_class(left_hand[:2], depth_map, segmentation_map) | |
| # Check for interactions | |
| right_touching = r_distance < self.interaction_threshold | |
| left_touching = l_distance < self.interaction_threshold | |
| interactions.append({ | |
| "person_id": person_idx, | |
| "right_hand_touching_object": self.segmentation_id2label[right_cls], | |
| "left_hand_touching_object": self.segmentation_id2label[left_cls], | |
| "right_hand_touching": right_touching, | |
| "left_hand_touching": left_touching, | |
| "right_hand_distance": r_distance, | |
| "left_hand_distance": l_distance | |
| }) | |
| return interactions, pose_results, segmentation_map, depth_map | |
| def visualize_results(self, image: Image.Image, interactions, pose_results): | |
| """Visualize detection results""" | |
| # Create base visualization from original image | |
| vis_image = np.array(image).copy() | |
| # Add pose keypoints | |
| edge_annotator = sv.EdgeAnnotator(color=sv.Color.GREEN, thickness=2) | |
| key_points = sv.KeyPoints( | |
| xy=torch.cat([pose_result['keypoints'].unsqueeze(0) for pose_result in pose_results]).cpu().numpy() | |
| ) | |
| vis_image = edge_annotator.annotate(scene=vis_image, key_points=key_points) | |
| # Add interaction indicators | |
| for interaction in interactions: | |
| person_id = interaction["person_id"] | |
| pose_result = pose_results[person_id] | |
| # Draw indicators for touching hands | |
| if interaction["right_hand_touching"]: | |
| cv2.circle(vis_image, | |
| tuple(map(int, pose_result["keypoints"][10][:2])), | |
| 10, (0, 0, 255), -1) | |
| if interaction["left_hand_touching"]: | |
| cv2.circle(vis_image, | |
| tuple(map(int, pose_result["keypoints"][9][:2])), | |
| 10, (0, 0, 255), -1) | |
| return Image.fromarray(vis_image) | |
| def process_image(self, input_image): | |
| """Process image and return visualization with interaction detection""" | |
| if input_image is None: | |
| return None, "" | |
| # Convert to PIL Image if necessary | |
| if isinstance(input_image, np.ndarray): | |
| image = Image.fromarray(input_image) | |
| else: | |
| image = input_image | |
| image = image.resize((1280, 720)) | |
| # Detect interactions | |
| interactions, pose_results, segmentation_map, depth_map = self.detect_wall_interaction(image) | |
| # Visualize results | |
| result_image = self.visualize_results(image, interactions, pose_results) | |
| # Create interaction information text | |
| info_text = [] | |
| for interaction in interactions: | |
| info_text.append(f"\nPerson {interaction['person_id'] + 1}:") | |
| if interaction["right_hand_touching"]: | |
| info_text.append(f"Right hand is touching {interaction['right_hand_touching_object']}") | |
| if interaction["left_hand_touching"]: | |
| info_text.append(f"Left hand is touching {interaction['left_hand_touching_object']}") | |
| info_text.append(f"Right hand distance to wall: {interaction['right_hand_distance']:.2f}") | |
| info_text.append(f"Left hand distance to wall: {interaction['left_hand_distance']:.2f}") | |
| # Add color to segmentation | |
| mask = np.zeros((*segmentation_map.shape, 3), dtype=np.uint8) | |
| colors = np.random.randint(0, 255, size=(100, 3)) | |
| for cl_id in np.unique(segmentation_map): | |
| mask_array = np.array(segmentation_map == cl_id) | |
| color = colors[cl_id % len(colors)] | |
| mask[mask_array] = color | |
| return result_image, mask, depth_map, "\n".join(info_text) | |
| def create_gradio_interface(): | |
| """Create Gradio interface""" | |
| detector = InteractionDetector() | |
| with gr.Blocks() as interface: | |
| gr.Markdown("# Object Interaction Detection") | |
| gr.Markdown("Upload an image to detect when people are touching objects.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image(label="Input Image") | |
| process_button = gr.Button("Detect Interactions") | |
| with gr.Column(): | |
| output_image = gr.Image(label="Detection Results") | |
| interaction_info = gr.Textbox( | |
| label="Interaction Information", | |
| lines=10, | |
| placeholder="Interaction details will appear here..." | |
| ) | |
| segmentation_im = gr.Image(label="Segmentaiton Results") | |
| depth_im = gr.Image(label="Depth Results") | |
| process_button.click( | |
| fn=detector.process_image, | |
| inputs=input_image, | |
| outputs=[output_image, segmentation_im, depth_im, interaction_info] | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| "images/1-8ea4418f.jpg", | |
| "images/276757975.jpg" | |
| ], | |
| inputs=input_image | |
| ) | |
| return interface | |
| interface = create_gradio_interface() | |
| if __name__ == "__main__": | |
| interface.launch(debug=True) | |