UI-DETR-1 / app.py
paulml's picture
Update app.py
6722d08 verified
import os
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
os.environ['GRADIO_DEFAULT_LANG'] = 'en'
import gradio as gr
import torch
import cv2
import numpy as np
from PIL import Image
from typing import Tuple, List
from rfdetr.detr import RFDETRMedium
# UI Element classes
CLASSES = ['button', 'field', 'heading', 'iframe', 'image', 'label', 'link', 'text']
# Single color for all boxes (BGR format for OpenCV)
BOX_COLOR = (0, 255, 0) # Green
# Global model variable
model = None
def load_model(model_path: str = "model.pth"):
"""Load RF-DETR model"""
global model
if model is None:
print("Loading RF-DETR model...")
model = RFDETRMedium(pretrain_weights=model_path, resolution=1600)
print("Model loaded successfully!")
return model
def draw_detections(
image: np.ndarray,
boxes: List[Tuple[int, int, int, int]],
scores: List[float],
classes: List[int],
thickness: int = 3,
font_scale: float = 0.6
) -> np.ndarray:
"""Draw detection boxes and labels on image"""
img_with_boxes = image.copy()
for box, score, cls_id in zip(boxes, scores, classes):
x1, y1, x2, y2 = map(int, box)
# Draw rectangle
cv2.rectangle(img_with_boxes, (x1, y1), (x2, y2), BOX_COLOR, thickness)
# Prepare label with confidence score only
label = f"{score:.2f}"
# Calculate label size and position
(label_width, label_height), baseline = cv2.getTextSize(
label, cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness=2
)
# Draw label background
label_y = max(y1 - 10, label_height + 10)
cv2.rectangle(
img_with_boxes,
(x1, label_y - label_height - baseline - 5),
(x1 + label_width + 5, label_y + baseline - 5),
BOX_COLOR,
-1
)
# Draw label text
cv2.putText(
img_with_boxes,
label,
(x1 + 2, label_y - baseline - 5),
cv2.FONT_HERSHEY_SIMPLEX,
font_scale,
(255, 255, 255),
thickness=2
)
return img_with_boxes
@torch.inference_mode()
def detect_ui_elements(
image: Image.Image,
confidence_threshold: float,
line_thickness: int
) -> Tuple[Image.Image, str]:
"""
Detect UI elements in the uploaded image
Args:
image: Input PIL Image
confidence_threshold: Minimum confidence score for detections
line_thickness: Thickness of bounding box lines
Returns:
Annotated image and detection summary text
"""
try:
if image is None:
return None, "Please upload an image first."
# Load model
model = load_model()
# Convert PIL to numpy array (RGB)
img_array = np.array(image)
# Convert RGB to BGR for OpenCV
img_bgr = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR)
# Run detection (returns supervision Detections object)
detections = model.predict(img_array, threshold=confidence_threshold)
# Extract detection data
filtered_boxes = detections.xyxy # Bounding boxes in xyxy format
filtered_scores = detections.confidence # Confidence scores
filtered_classes = detections.class_id # Class IDs
# Draw detections
annotated_img = draw_detections(
img_bgr,
filtered_boxes.tolist(),
filtered_scores.tolist(),
filtered_classes.tolist(),
thickness=line_thickness
)
# Convert back to RGB for display
annotated_img_rgb = cv2.cvtColor(annotated_img, cv2.COLOR_BGR2RGB)
annotated_pil = Image.fromarray(annotated_img_rgb)
# Create summary text
summary_text = f"**Total detections:** {len(filtered_boxes)}"
return annotated_pil, summary_text
except Exception as e:
import traceback
error_msg = f"**Error during detection:**\n\n```\n{str(e)}\n\n{traceback.format_exc()}\n```"
print(error_msg) # Also print to logs
return None, error_msg
# Gradio interface
with gr.Blocks(title="UI-DETR-1 UI Element Detector", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# UI-DETR-1 UI Element Detector
Upload a screenshot or UI mockup to automatically detect elements.
""")
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(
type="pil",
label="Upload Screenshot",
height=400,
sources=["upload"]
)
with gr.Accordion("Detection Settings", open=True):
confidence_slider = gr.Slider(
minimum=0.1,
maximum=0.9,
value=0.35,
step=0.05,
label="Confidence Threshold",
info="Higher values = fewer but more confident detections"
)
thickness_slider = gr.Slider(
minimum=1,
maximum=6,
value=2,
step=1,
label="Box Line Thickness"
)
detect_button = gr.Button("Detect Elements", variant="primary", size="lg")
with gr.Column(scale=1):
output_image = gr.Image(
type="pil",
label="Detected Elements",
height=400
)
summary_output = gr.Markdown(label="Detection Summary")
# Connect button
detect_button.click(
fn=detect_ui_elements,
inputs=[input_image, confidence_slider, thickness_slider],
outputs=[output_image, summary_output]
)
# Launch
if __name__ == "__main__":
demo.queue().launch(share=False)