Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Fixed SAM 2.1 Interface - Handles negative stride issues properly | |
| """ | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| import matplotlib.pyplot as plt | |
| import gradio as gr | |
| from transformers import Sam2Model, Sam2Processor | |
| import warnings | |
| import io | |
| import base64 | |
| import os | |
| from datetime import datetime | |
| warnings.filterwarnings("ignore") | |
| # Global model instance to avoid reloading | |
| MODEL = None | |
| PROCESSOR = None | |
| DEVICE = None | |
| # Global state for saving | |
| CURRENT_MASK = None | |
| CURRENT_IMAGE_NAME = None | |
| CURRENT_POINTS = None | |
| def initialize_sam(model_size="small"): | |
| """Initialize SAM model once""" | |
| global MODEL, PROCESSOR, DEVICE | |
| if MODEL is None: | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Initializing SAM 2.1 {model_size} on {DEVICE}...") | |
| model_name = f"facebook/sam2-hiera-{model_size}" | |
| MODEL = Sam2Model.from_pretrained(model_name).to(DEVICE) | |
| PROCESSOR = Sam2Processor.from_pretrained(model_name) | |
| print("β Model loaded successfully!") | |
| return MODEL, PROCESSOR, DEVICE | |
| def fix_image_array(image): | |
| """Fix image input for SAM processing - handles filepath, numpy array, or PIL Image""" | |
| if isinstance(image, str): | |
| # Handle filepath input from Gradio | |
| return Image.open(image).convert("RGB") | |
| elif isinstance(image, np.ndarray): | |
| # Make sure array is contiguous | |
| if not image.flags['C_CONTIGUOUS']: | |
| image = np.ascontiguousarray(image) | |
| # Ensure uint8 dtype | |
| if image.dtype != np.uint8: | |
| if image.max() <= 1.0: | |
| image = (image * 255).astype(np.uint8) | |
| else: | |
| image = image.astype(np.uint8) | |
| # Convert to PIL Image to avoid any stride issues | |
| return Image.fromarray(image).convert("RGB") | |
| elif isinstance(image, Image.Image): | |
| return image.convert("RGB") | |
| else: | |
| raise ValueError(f"Unsupported image type: {type(image)}") | |
| def apply_mask_post_processing(mask, stability_threshold=0.95): | |
| """Apply post-processing to refine mask size and quality""" | |
| import cv2 | |
| # Convert to binary mask | |
| binary_mask = (mask > 0).astype(np.uint8) | |
| # Apply morphological operations to clean up the mask | |
| kernel_size = max(3, int(mask.shape[0] * 0.01)) # Adaptive kernel size | |
| kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size)) | |
| # Close small holes | |
| binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_CLOSE, kernel) | |
| # Remove small noise | |
| binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_OPEN, kernel) | |
| return binary_mask.astype(np.float32) | |
| def apply_erosion_dilation(mask, erosion_dilation_value): | |
| """Apply erosion or dilation to adjust mask size""" | |
| import cv2 | |
| binary_mask = (mask > 0).astype(np.uint8) | |
| if erosion_dilation_value == 0: | |
| return mask | |
| kernel_size = abs(erosion_dilation_value) | |
| kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size)) | |
| if erosion_dilation_value > 0: | |
| # Dilate (make larger) | |
| binary_mask = cv2.dilate(binary_mask, kernel, iterations=1) | |
| else: | |
| # Erode (make smaller) | |
| binary_mask = cv2.erode(binary_mask, kernel, iterations=1) | |
| return binary_mask.astype(np.float32) | |
| def save_binary_mask(mask, image_name, points, mask_threshold, erosion_dilation, save_low_res=False, custom_folder_name=None): | |
| """Save binary mask to organized folder structure""" | |
| global CURRENT_MASK, CURRENT_IMAGE_NAME, CURRENT_POINTS | |
| try: | |
| # Store current state for saving | |
| CURRENT_MASK = mask | |
| CURRENT_IMAGE_NAME = image_name | |
| CURRENT_POINTS = points | |
| # Extract image name without extension and sanitize | |
| if image_name: | |
| base_name = os.path.splitext(os.path.basename(image_name))[0] | |
| # Remove any path separators and special characters | |
| base_name = base_name.replace('/', '_').replace('\\', '_').replace(':', '_').replace(' ', '_') | |
| else: | |
| base_name = f"image_{datetime.now().strftime('%Y%m%d_%H%M%S')}" | |
| # Choose folder tag: user-provided name if available, else 'default' | |
| folder_tag = None | |
| if custom_folder_name and str(custom_folder_name).strip(): | |
| folder_tag = str(custom_folder_name).strip().replace(' ', '_') | |
| else: | |
| folder_tag = "default" | |
| # Create folder structure: masks/<image_base>/<folder_tag>/ | |
| folder_name = f"masks/{base_name}/{folder_tag}" | |
| os.makedirs(folder_name, exist_ok=True) | |
| # Create binary mask (0 and 255 values) | |
| binary_mask = (mask > 0).astype(np.uint8) * 255 | |
| # Calculate low resolution dimensions if requested | |
| original_height, original_width = binary_mask.shape | |
| if save_low_res: | |
| # Calculate sqrt-based resolution | |
| sqrt_factor = int(np.sqrt(max(original_width, original_height))) | |
| low_res_width = sqrt_factor | |
| low_res_height = sqrt_factor | |
| print(f"Original mask size: {original_width}x{original_height}") | |
| print(f"Low-res mask size: {low_res_width}x{low_res_height}") | |
| # Save binary mask | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| # Sanitize filename - replace problematic characters | |
| threshold_str = f"{mask_threshold:.2f}".replace('.', 'p') # 0.30 -> 0p30 | |
| adj_str = f"{erosion_dilation:+d}".replace('+', 'plus').replace('-', 'minus') # +2 -> plus2, -2 -> minus2 | |
| saved_paths = [] | |
| # Save full resolution mask as JPEG with a simple filename | |
| mask_filename = "image.jpg" | |
| mask_path = os.path.join(folder_name, mask_filename) | |
| mask_image = Image.fromarray(binary_mask, mode='L') | |
| mask_image.save(mask_path, format="JPEG", quality=95, optimize=True) | |
| saved_paths.append(mask_path) | |
| # Save tensor mask (.pt) as float tensor (0.0/1.0) | |
| tensor_filename = "image.pt" | |
| tensor_path = os.path.join(folder_name, tensor_filename) | |
| torch.save(torch.from_numpy((mask > 0).astype(np.float32)), tensor_path) | |
| saved_paths.append(tensor_path) | |
| # Save low resolution mask if requested | |
| if save_low_res: | |
| # Resize mask to low resolution | |
| low_res_mask = mask_image.resize((low_res_width, low_res_height), Image.Resampling.NEAREST) | |
| low_res_filename = f"mask_lowres_{sqrt_factor}x{sqrt_factor}_t{threshold_str}_adj{adj_str}_{timestamp}.png" | |
| low_res_path = os.path.join(folder_name, low_res_filename) | |
| low_res_mask.save(low_res_path) | |
| saved_paths.append(low_res_path) | |
| # Also save metadata | |
| metadata = { | |
| "timestamp": timestamp, | |
| "points": points, | |
| "mask_threshold": mask_threshold, | |
| "erosion_dilation": erosion_dilation, | |
| "image_name": image_name, | |
| "original_resolution": f"{original_width}x{original_height}", | |
| "saved_paths": saved_paths, | |
| "low_resolution_saved": save_low_res | |
| } | |
| if save_low_res: | |
| metadata["low_resolution"] = f"{low_res_width}x{low_res_height}" | |
| metadata["sqrt_factor"] = sqrt_factor | |
| import json | |
| metadata_path = os.path.join(folder_name, f"metadata_{timestamp}.json") | |
| with open(metadata_path, 'w') as f: | |
| json.dump(metadata, f, indent=2) | |
| # Return appropriate message | |
| if save_low_res: | |
| return f"β Masks saved:\nπ Full: {os.path.basename(mask_path)}\nπ Low-res: {os.path.basename(low_res_path)}" | |
| else: | |
| return f"β Mask saved to: {os.path.basename(mask_path)}" | |
| except Exception as e: | |
| return f"β Save failed: {str(e)}" | |
| def process_sam_segmentation(image, points_data, bbox_data, mode, image_name=None, top_k=3, mask_threshold=0.0, stability_score_threshold=0.95, erosion_dilation=0): | |
| """Main processing function with mask size controls - supports points and bounding boxes""" | |
| global CURRENT_MASK, CURRENT_IMAGE_NAME, CURRENT_POINTS | |
| if image is None: | |
| return None, None, "Please upload an image first." | |
| # Check input based on mode | |
| if mode == "Points": | |
| if not points_data or len(points_data) == 0: | |
| return None, None, "Please click on the image to select points." | |
| elif mode == "Bounding Box": | |
| if bbox_data is None: | |
| return None, None, "Please click two corners to define a bounding box." | |
| try: | |
| # Initialize model | |
| model, processor, device = initialize_sam() | |
| # Fix image | |
| pil_image = fix_image_array(image) | |
| # Prepare SAM inputs based on mode | |
| input_points = None | |
| input_labels = None | |
| input_boxes = None | |
| points = None | |
| if mode == "Points": | |
| # Extract points with positive/negative labels | |
| points = [] | |
| labels = [] | |
| for point_info in points_data: | |
| if isinstance(point_info, dict): | |
| points.append([point_info.get("x", 0), point_info.get("y", 0)]) | |
| labels.append(1 if point_info.get("positive", True) else 0) # 1 = positive, 0 = negative | |
| elif isinstance(point_info, (list, tuple)) and len(point_info) >= 2: | |
| points.append([point_info[0], point_info[1]]) | |
| labels.append(1) # Default to positive for old format | |
| if not points: | |
| return None, "No valid points found." | |
| print(f"Processing {len(points)} points: {points} with labels: {labels}") | |
| input_points = [[points]] | |
| input_labels = [[labels]] | |
| elif mode == "Bounding Box": | |
| # Use bounding box | |
| bbox = bbox_data # [x1, y1, x2, y2] | |
| print(f"Processing bounding box: {bbox}") | |
| input_boxes = [[bbox]] | |
| # For visualization, store the bbox corners as points | |
| points = [[bbox[0], bbox[1]], [bbox[2], bbox[3]]] | |
| # Process with SAM | |
| processor_inputs = { | |
| "images": pil_image, | |
| "return_tensors": "pt" | |
| } | |
| # Add points or boxes based on mode | |
| if mode == "Points": | |
| processor_inputs["input_points"] = input_points | |
| processor_inputs["input_labels"] = input_labels | |
| elif mode == "Bounding Box": | |
| processor_inputs["input_boxes"] = input_boxes | |
| inputs = processor(**processor_inputs).to(device) | |
| # Generate masks with multiple outputs for better control | |
| with torch.no_grad(): | |
| outputs = model(**inputs, multimask_output=True) | |
| # Get masks and scores | |
| masks = processor.post_process_masks( | |
| outputs.pred_masks.cpu(), | |
| inputs["original_sizes"] | |
| )[0] | |
| scores = outputs.iou_scores.cpu().numpy().flatten() | |
| # Get top-k masks | |
| top_indices = np.argsort(scores)[::-1][:top_k] | |
| # Apply mask threshold to control size | |
| best_mask = masks[0, top_indices[0]].numpy() | |
| best_score = scores[top_indices[0]] | |
| # Apply threshold to control mask size | |
| if mask_threshold > 0: | |
| best_mask = (best_mask > mask_threshold).astype(np.float32) | |
| # Additional mask processing for size control | |
| best_mask = apply_mask_post_processing(best_mask, stability_score_threshold) | |
| # Apply erosion/dilation for fine size control | |
| if erosion_dilation != 0: | |
| best_mask = apply_erosion_dilation(best_mask, erosion_dilation) | |
| # Store current state for saving | |
| CURRENT_MASK = best_mask | |
| CURRENT_IMAGE_NAME = image_name | |
| CURRENT_POINTS = points | |
| # Create dual visualizations | |
| original_with_input = create_original_with_input_visualization(pil_image, points, bbox_data, mode) | |
| mask_result = create_mask_visualization(pil_image, best_mask, best_score, mask_threshold) | |
| status = f"β Generated mask with score: {float(best_score):.3f}\nπ Ready to save!" | |
| return original_with_input, mask_result, status | |
| except Exception as e: | |
| print(f"Error in processing: {e}") | |
| return None, None, f"Error: {str(e)}" | |
| def create_original_with_input_visualization(pil_image, points, bbox, mode, negative_points=None): | |
| """Create visualization of original image with input points/bbox overlay""" | |
| # Convert PIL to numpy for matplotlib | |
| img_array = np.array(pil_image) | |
| fig, ax = plt.subplots(1, 1, figsize=(8, 6)) | |
| # Show original image only | |
| ax.imshow(img_array) | |
| # Show input visualization based on mode | |
| if mode == "Points": | |
| total_points = 0 | |
| # Show positive points (green) | |
| if points: | |
| for point in points: | |
| ax.plot(point[0], point[1], 'go', markersize=12, markeredgewidth=3, markerfacecolor='lime') | |
| total_points += len(points) | |
| # Show negative points (red) | |
| if negative_points: | |
| for point in negative_points: | |
| ax.plot(point[0], point[1], 'ro', markersize=12, markeredgewidth=3, markerfacecolor='red') | |
| total_points += len(negative_points) | |
| pos_count = len(points) if points else 0 | |
| neg_count = len(negative_points) if negative_points else 0 | |
| title_suffix = f"Points: {pos_count}+ {neg_count}-" if neg_count > 0 else f"Points: {pos_count}" | |
| elif mode == "Bounding Box" and bbox: | |
| # Show bounding box | |
| x1, y1, x2, y2 = bbox | |
| width = x2 - x1 | |
| height = y2 - y1 | |
| # Draw bounding box rectangle | |
| from matplotlib.patches import Rectangle | |
| rect = Rectangle((x1, y1), width, height, linewidth=3, edgecolor='lime', facecolor='none') | |
| ax.add_patch(rect) | |
| # Show corner points | |
| ax.plot([x1, x2], [y1, y2], 'go', markersize=8, markeredgewidth=2, markerfacecolor='lime') | |
| title_suffix = f"BBox: {int(width)}Γ{int(height)}" | |
| else: | |
| title_suffix = "No input" | |
| ax.set_title(f"Input Selection ({title_suffix})", fontsize=14) | |
| ax.axis('off') | |
| # Convert to numpy array | |
| fig.canvas.draw() | |
| buf = fig.canvas.buffer_rgba() | |
| result_array = np.asarray(buf) | |
| # Convert RGBA to RGB | |
| result_array = result_array[:, :, :3] | |
| plt.close(fig) | |
| return result_array | |
| def create_mask_visualization(pil_image, mask, score, mask_threshold=0.0): | |
| """Create clean mask visualization without input overlays""" | |
| # Convert PIL to numpy for matplotlib | |
| img_array = np.array(pil_image) | |
| fig, ax = plt.subplots(1, 1, figsize=(8, 6)) | |
| # Show original image | |
| ax.imshow(img_array) | |
| # Overlay mask in red | |
| mask_overlay = np.zeros((*mask.shape, 4)) | |
| mask_overlay[mask > 0] = [1, 0, 0, 0.6] # Red with transparency | |
| ax.imshow(mask_overlay) | |
| ax.set_title(f"Generated Mask (Score: {float(score):.3f}, Threshold: {mask_threshold:.2f})", fontsize=14) | |
| ax.axis('off') | |
| # Convert to numpy array | |
| fig.canvas.draw() | |
| buf = fig.canvas.buffer_rgba() | |
| result_array = np.asarray(buf) | |
| # Convert RGBA to RGB | |
| result_array = result_array[:, :, :3] | |
| plt.close(fig) | |
| return result_array | |
| def create_interface(): | |
| """Create a simplified single-image annotator interface.""" | |
| with gr.Blocks(title="SAM 2.1 - Simple Annotator", theme=gr.themes.Soft(), css=""" | |
| .negative-mode-checkbox label { | |
| color: #d00000 !important; | |
| font-weight: 800 !important; | |
| font-size: 16px !important; | |
| } | |
| """) as interface: | |
| gr.HTML(""" | |
| <div style="text-align: center;"> | |
| <h1>π― SAM 2.1 Simple Annotator</h1> | |
| <p>Upload one image, click to add positive/negative points, generate mask, and save.</p> | |
| </div> | |
| """) | |
| # Image input (single image) - directly annotate; this serves as uploader too | |
| # Users can upload by clicking the annotatable image component below. | |
| image_input = gr.Image( | |
| label=None, | |
| type="filepath", | |
| height=0, | |
| visible=False | |
| ) | |
| # Main layout: Selected Points on the left, annotatable image in the center, preview on the right | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| points_display = gr.JSON(label="π Selected Points", value=[], visible=True) | |
| with gr.Column(scale=3): | |
| # Negative mode toggle with clear red styling | |
| negative_point_mode = gr.Checkbox( | |
| label="β NEGATIVE POINT MODE", | |
| value=False, | |
| info="π΄ Enable to add negative points (shown in red)", | |
| interactive=True, | |
| elem_classes="negative-mode-checkbox" | |
| ) | |
| original_with_input = gr.Image( | |
| label="π Click to Annotate (toggle negative mode to exclude)", | |
| height=640, | |
| interactive=True | |
| ) | |
| with gr.Column(scale=1): | |
| points_overlay = gr.Image(label="π Points Preview (green=positive, red=negative)", height=720, interactive=False) | |
| # Action buttons | |
| with gr.Row(): | |
| generate_btn = gr.Button("π― Generate Mask", variant="primary", size="lg") | |
| clear_btn = gr.Button("ποΈ Clear Points", variant="secondary", size="lg") | |
| # Mask result under buttons | |
| with gr.Row(): | |
| mask_result = gr.Image(label="π Generated Mask", height=512) | |
| # Save controls under mask | |
| with gr.Row(): | |
| mask_name_input = gr.Textbox(label="Folder name (optional)", placeholder="e.g., michael_phelps_bottom_left") | |
| save_btn = gr.Button("πΎ Save Mask", variant="stop", size="lg") | |
| # Status | |
| with gr.Row(): | |
| status_text = gr.Textbox(label="π Status", interactive=False, lines=3) | |
| # State to store points only | |
| points_state = gr.State([]) | |
| # Event handlers | |
| def on_image_click(image, current_points, negative_mode, evt: gr.SelectData): | |
| """Handle clicks on the image for point annotations only.""" | |
| if evt.index is not None and image is not None: | |
| x, y = evt.index | |
| try: | |
| pil_image = fix_image_array(image) | |
| is_negative = negative_mode | |
| new_point = {"x": int(x), "y": int(y), "positive": not is_negative} | |
| updated_points = current_points + [new_point] | |
| positive_points = [[p["x"], p["y"]] for p in updated_points if p.get("positive", True)] | |
| negative_points = [[p["x"], p["y"]] for p in updated_points if not p.get("positive", True)] | |
| updated_visualization = create_original_with_input_visualization( | |
| pil_image, positive_points, None, "Points", negative_points | |
| ) | |
| point_type = "positive" if not is_negative else "negative" | |
| pos_count = len(positive_points) | |
| neg_count = len(negative_points) | |
| return updated_points, updated_points, updated_visualization, ( | |
| f"Added {point_type} point at ({x}, {y}). Total: {pos_count} positive, {neg_count} negative points." | |
| ) | |
| except Exception as e: | |
| print(f"Error in visualization: {e}") | |
| return current_points, current_points, None, f"Error updating visualization: {str(e)}" | |
| return current_points, current_points, None, "Click on the image to add points." | |
| def on_image_upload(image): | |
| """Handle image upload and show it for annotation.""" | |
| if image is not None: | |
| try: | |
| pil_image = fix_image_array(image) | |
| img_array = np.array(pil_image) | |
| # Populate both the annotation image (left) and the points preview (right) | |
| return img_array, img_array, [], [], "Image uploaded. Click on the left image to add points (enable negative mode for exclusion)." | |
| except Exception as e: | |
| return None, None, [], [], f"Error loading image: {str(e)}" | |
| return None, None, [], [], "No image uploaded." | |
| def clear_all_points(image): | |
| """Clear points and keep the image visible for annotation.""" | |
| try: | |
| if image is not None: | |
| pil_image = fix_image_array(image) | |
| img_array = np.array(pil_image) | |
| return [], [], img_array, img_array, None, "All points cleared. You can continue annotating." | |
| except Exception: | |
| pass | |
| return [], [], None, None, None, "All points cleared." | |
| def generate_segmentation(image, points): | |
| """Generate a single segmentation mask using points only.""" | |
| # Determine image name | |
| if isinstance(image, str): | |
| image_name = os.path.basename(image) | |
| else: | |
| # Prefer an explicit friendly default if metadata lacks a good name | |
| image_name = None | |
| if hasattr(image, 'orig_name'): | |
| image_name = image.orig_name | |
| elif isinstance(image, dict) and 'orig_name' in image: | |
| image_name = image['orig_name'] | |
| elif hasattr(image, 'name'): | |
| image_name = image.name | |
| if not image_name or 'tmp' in str(image_name).lower() or 'uploaded_image' in str(image_name).lower(): | |
| image_name = "michael_phelps_bottom_left.jpg" | |
| # Run segmentation (points mode) | |
| _, mask_img, status = process_sam_segmentation( | |
| image, points, None, "Points", image_name, 1, 0.0, 0.95, 0 | |
| ) | |
| if mask_img is not None: | |
| status += f"\nπ Image: {os.path.basename(image_name)}" | |
| return mask_img, status | |
| def save_current_mask(custom_folder_name): | |
| """Save the currently generated mask.""" | |
| global CURRENT_MASK, CURRENT_IMAGE_NAME, CURRENT_POINTS | |
| if CURRENT_MASK is None: | |
| return "β No mask to save. Generate a mask first." | |
| if CURRENT_POINTS is None: | |
| return "β No points available. Generate a mask first." | |
| return save_binary_mask(CURRENT_MASK, CURRENT_IMAGE_NAME, CURRENT_POINTS, 0.0, 0, False, custom_folder_name=(custom_folder_name or None)) | |
| # Wire events | |
| # Let the annotatable image also handle image uploads (drag & drop / click upload) | |
| original_with_input.upload( | |
| on_image_upload, | |
| inputs=[original_with_input], | |
| outputs=[original_with_input, points_overlay, points_state, points_display, status_text] | |
| ) | |
| original_with_input.select( | |
| on_image_click, | |
| inputs=[original_with_input, points_state, negative_point_mode], | |
| outputs=[points_state, points_display, points_overlay, status_text] | |
| ) | |
| generate_btn.click( | |
| generate_segmentation, | |
| inputs=[original_with_input, points_state], | |
| outputs=[mask_result, status_text] | |
| ) | |
| clear_btn.click( | |
| clear_all_points, | |
| inputs=[original_with_input], | |
| outputs=[points_state, points_display, points_overlay, original_with_input, mask_result, status_text] | |
| ) | |
| save_btn.click( | |
| save_current_mask, | |
| inputs=[mask_name_input], | |
| outputs=[status_text] | |
| ) | |
| return interface | |
| def main(): | |
| """Main function""" | |
| print("π Starting Fixed SAM 2.1 Interface...") | |
| interface = create_interface() | |
| print("π Launching web interface...") | |
| print("π Click on objects in images to segment them!") | |
| interface.launch( | |
| server_port=int(os.environ.get("GRADIO_SERVER_PORT", 7860)), | |
| share=False, | |
| inbrowser=False, # Don't auto-open browser in server environment | |
| show_error=True | |
| ) | |
| if __name__ == "__main__": | |
| main() |