|
|
import os |
|
|
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' |
|
|
os.environ['GRADIO_DEFAULT_LANG'] = 'en' |
|
|
|
|
|
import gradio as gr |
|
|
import torch |
|
|
import cv2 |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
from typing import Tuple, List |
|
|
from rfdetr.detr import RFDETRMedium |
|
|
|
|
|
|
|
|
CLASSES = ['button', 'field', 'heading', 'iframe', 'image', 'label', 'link', 'text'] |
|
|
|
|
|
|
|
|
BOX_COLOR = (0, 255, 0) |
|
|
|
|
|
|
|
|
model = None |
|
|
|
|
|
def load_model(model_path: str = "model.pth"): |
|
|
"""Load RF-DETR model""" |
|
|
global model |
|
|
if model is None: |
|
|
print("Loading RF-DETR model...") |
|
|
model = RFDETRMedium(pretrain_weights=model_path, resolution=1600) |
|
|
print("Model loaded successfully!") |
|
|
return model |
|
|
|
|
|
def draw_detections( |
|
|
image: np.ndarray, |
|
|
boxes: List[Tuple[int, int, int, int]], |
|
|
scores: List[float], |
|
|
classes: List[int], |
|
|
thickness: int = 3, |
|
|
font_scale: float = 0.6 |
|
|
) -> np.ndarray: |
|
|
"""Draw detection boxes and labels on image""" |
|
|
img_with_boxes = image.copy() |
|
|
|
|
|
for box, score, cls_id in zip(boxes, scores, classes): |
|
|
x1, y1, x2, y2 = map(int, box) |
|
|
|
|
|
|
|
|
cv2.rectangle(img_with_boxes, (x1, y1), (x2, y2), BOX_COLOR, thickness) |
|
|
|
|
|
|
|
|
label = f"{score:.2f}" |
|
|
|
|
|
|
|
|
(label_width, label_height), baseline = cv2.getTextSize( |
|
|
label, cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness=2 |
|
|
) |
|
|
|
|
|
|
|
|
label_y = max(y1 - 10, label_height + 10) |
|
|
cv2.rectangle( |
|
|
img_with_boxes, |
|
|
(x1, label_y - label_height - baseline - 5), |
|
|
(x1 + label_width + 5, label_y + baseline - 5), |
|
|
BOX_COLOR, |
|
|
-1 |
|
|
) |
|
|
|
|
|
|
|
|
cv2.putText( |
|
|
img_with_boxes, |
|
|
label, |
|
|
(x1 + 2, label_y - baseline - 5), |
|
|
cv2.FONT_HERSHEY_SIMPLEX, |
|
|
font_scale, |
|
|
(255, 255, 255), |
|
|
thickness=2 |
|
|
) |
|
|
|
|
|
return img_with_boxes |
|
|
|
|
|
@torch.inference_mode() |
|
|
def detect_ui_elements( |
|
|
image: Image.Image, |
|
|
confidence_threshold: float, |
|
|
line_thickness: int |
|
|
) -> Tuple[Image.Image, str]: |
|
|
""" |
|
|
Detect UI elements in the uploaded image |
|
|
|
|
|
Args: |
|
|
image: Input PIL Image |
|
|
confidence_threshold: Minimum confidence score for detections |
|
|
line_thickness: Thickness of bounding box lines |
|
|
|
|
|
Returns: |
|
|
Annotated image and detection summary text |
|
|
""" |
|
|
try: |
|
|
if image is None: |
|
|
return None, "Please upload an image first." |
|
|
|
|
|
|
|
|
model = load_model() |
|
|
|
|
|
|
|
|
img_array = np.array(image) |
|
|
|
|
|
|
|
|
img_bgr = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR) |
|
|
|
|
|
|
|
|
detections = model.predict(img_array, threshold=confidence_threshold) |
|
|
|
|
|
|
|
|
filtered_boxes = detections.xyxy |
|
|
filtered_scores = detections.confidence |
|
|
filtered_classes = detections.class_id |
|
|
|
|
|
|
|
|
annotated_img = draw_detections( |
|
|
img_bgr, |
|
|
filtered_boxes.tolist(), |
|
|
filtered_scores.tolist(), |
|
|
filtered_classes.tolist(), |
|
|
thickness=line_thickness |
|
|
) |
|
|
|
|
|
|
|
|
annotated_img_rgb = cv2.cvtColor(annotated_img, cv2.COLOR_BGR2RGB) |
|
|
annotated_pil = Image.fromarray(annotated_img_rgb) |
|
|
|
|
|
|
|
|
summary_text = f"**Total detections:** {len(filtered_boxes)}" |
|
|
|
|
|
return annotated_pil, summary_text |
|
|
|
|
|
except Exception as e: |
|
|
import traceback |
|
|
error_msg = f"**Error during detection:**\n\n```\n{str(e)}\n\n{traceback.format_exc()}\n```" |
|
|
print(error_msg) |
|
|
return None, error_msg |
|
|
|
|
|
|
|
|
with gr.Blocks(title="UI-DETR-1 UI Element Detector", theme=gr.themes.Soft()) as demo: |
|
|
|
|
|
gr.Markdown(""" |
|
|
# UI-DETR-1 UI Element Detector |
|
|
|
|
|
Upload a screenshot or UI mockup to automatically detect elements. |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
input_image = gr.Image( |
|
|
type="pil", |
|
|
label="Upload Screenshot", |
|
|
height=400, |
|
|
sources=["upload"] |
|
|
) |
|
|
|
|
|
with gr.Accordion("Detection Settings", open=True): |
|
|
confidence_slider = gr.Slider( |
|
|
minimum=0.1, |
|
|
maximum=0.9, |
|
|
value=0.35, |
|
|
step=0.05, |
|
|
label="Confidence Threshold", |
|
|
info="Higher values = fewer but more confident detections" |
|
|
) |
|
|
|
|
|
thickness_slider = gr.Slider( |
|
|
minimum=1, |
|
|
maximum=6, |
|
|
value=2, |
|
|
step=1, |
|
|
label="Box Line Thickness" |
|
|
) |
|
|
|
|
|
detect_button = gr.Button("Detect Elements", variant="primary", size="lg") |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
output_image = gr.Image( |
|
|
type="pil", |
|
|
label="Detected Elements", |
|
|
height=400 |
|
|
) |
|
|
|
|
|
summary_output = gr.Markdown(label="Detection Summary") |
|
|
|
|
|
|
|
|
|
|
|
detect_button.click( |
|
|
fn=detect_ui_elements, |
|
|
inputs=[input_image, confidence_slider, thickness_slider], |
|
|
outputs=[output_image, summary_output] |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.queue().launch(share=False) |