Spaces:
Running
on
Zero
Running
on
Zero
| import spaces | |
| import gradio as gr | |
| import torch | |
| import matplotlib.pyplot as plt | |
| from PIL import Image, ImageDraw, ImageFont | |
| import requests | |
| from io import BytesIO | |
| import numpy as np | |
| # load a simple face detector | |
| from retinaface import RetinaFace | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # load Gaze-LLE model | |
| model, transform = torch.hub.load("fkryan/gazelle", "gazelle_dinov2_vitl14_inout") | |
| model.eval() | |
| model.to(device) | |
| def visualize_heatmap(pil_image, heatmap, bbox=None, inout_score=None): | |
| if isinstance(heatmap, torch.Tensor): | |
| heatmap = heatmap.detach().cpu().numpy() | |
| heatmap = Image.fromarray((heatmap * 255).astype(np.uint8)).resize(pil_image.size, Image.Resampling.BILINEAR) | |
| heatmap = plt.cm.jet(np.array(heatmap) / 255.) | |
| heatmap = (heatmap[:, :, :3] * 255).astype(np.uint8) | |
| heatmap = Image.fromarray(heatmap).convert("RGBA") | |
| heatmap.putalpha(90) | |
| overlay_image = Image.alpha_composite(pil_image.convert("RGBA"), heatmap) | |
| if bbox is not None: | |
| width, height = pil_image.size | |
| xmin, ymin, xmax, ymax = bbox | |
| draw = ImageDraw.Draw(overlay_image) | |
| draw.rectangle([xmin * width, ymin * height, xmax * width, ymax * height], outline="lime", width=int(min(width, height) * 0.01)) | |
| if inout_score is not None: | |
| text = f"in-frame: {inout_score:.2f}" | |
| text_width = draw.textlength(text) | |
| text_height = int(height * 0.01) | |
| text_x = xmin * width | |
| text_y = ymax * height + text_height | |
| draw.text((text_x, text_y), text, fill="lime", font=ImageFont.load_default(size=int(min(width, height) * 0.05))) | |
| return overlay_image | |
| def visualize_all(pil_image, heatmaps, bboxes, inout_scores, inout_thresh=0.5): | |
| colors = ['lime', 'tomato', 'cyan', 'fuchsia', 'yellow'] | |
| overlay_image = pil_image.convert("RGBA") | |
| draw = ImageDraw.Draw(overlay_image) | |
| width, height = pil_image.size | |
| for i in range(len(bboxes)): | |
| bbox = bboxes[i] | |
| xmin, ymin, xmax, ymax = bbox | |
| color = colors[i % len(colors)] | |
| draw.rectangle([xmin * width, ymin * height, xmax * width, ymax * height], outline=color, width=int(min(width, height) * 0.01)) | |
| if inout_scores is not None: | |
| inout_score = inout_scores[i] | |
| text = f"in-frame: {inout_score:.2f}" | |
| text_width = draw.textlength(text) | |
| text_height = int(height * 0.01) | |
| text_x = xmin * width | |
| text_y = ymax * height + text_height | |
| draw.text((text_x, text_y), text, fill=color, font=ImageFont.load_default(size=int(min(width, height) * 0.05))) | |
| if inout_scores is not None and inout_score > inout_thresh: | |
| heatmap = heatmaps[i] | |
| heatmap_np = heatmap.detach().cpu().numpy() | |
| max_index = np.unravel_index(np.argmax(heatmap_np), heatmap_np.shape) | |
| gaze_target_x = max_index[1] / heatmap_np.shape[1] * width | |
| gaze_target_y = max_index[0] / heatmap_np.shape[0] * height | |
| bbox_center_x = ((xmin + xmax) / 2) * width | |
| bbox_center_y = ((ymin + ymax) / 2) * height | |
| draw.ellipse([(gaze_target_x-5, gaze_target_y-5), (gaze_target_x+5, gaze_target_y+5)], fill=color, width=int(0.005*min(width, height))) | |
| draw.line([(bbox_center_x, bbox_center_y), (gaze_target_x, gaze_target_y)], fill=color, width=int(0.005*min(width, height))) | |
| return overlay_image | |
| # ZeroGPU ready | |
| def main(image_input, progress=gr.Progress(track_tqdm=True)): | |
| """Estimate gaze direction for detected faces in an image using Gaze-LLE. | |
| This function processes an input image to detect faces, estimates gaze heatmaps | |
| for each face using a pre-trained Gaze-LLE model, and then visualizes the results | |
| including gaze direction and whether each person's gaze is within the frame. | |
| Args: | |
| image_input: A filepath to the input image. Should be a photo containing one or more human faces. | |
| progress: Optional Gradio progress tracker for UI feedback (used during inference). | |
| Returns: | |
| result_gazed (PIL.Image.Image): A single composite image with bounding boxes around faces, | |
| lines indicating predicted gaze direction, and indicators of whether gaze is "in-frame". | |
| heatmap_results (List[PIL.Image.Image]): A list of individual images, one per face, each showing | |
| the original image overlaid with a heatmap of the predicted gaze target. | |
| """ | |
| # load image | |
| image = Image.open(image_input) | |
| width, height = image.size | |
| # detect faces | |
| resp = RetinaFace.detect_faces(np.array(image)) | |
| print(resp) | |
| bboxes = [resp[key]["facial_area"] for key in resp.keys()] | |
| print(bboxes) | |
| # prepare gazelle input | |
| img_tensor = transform(image).unsqueeze(0).to(device) | |
| norm_bboxes = [[np.array(bbox) / np.array([width, height, width, height]) for bbox in bboxes]] | |
| input = { | |
| "images": img_tensor, # [num_images, 3, 448, 448] | |
| "bboxes": norm_bboxes # [[img1_bbox1, img1_bbox2...], [img2_bbox1, img2_bbox2]...] | |
| } | |
| with torch.no_grad(): | |
| output = model(input) | |
| img1_person1_heatmap = output['heatmap'][0][0] # [64, 64] heatmap | |
| print(img1_person1_heatmap.shape) | |
| if model.inout: | |
| img1_person1_inout = output['inout'][0][0] # gaze in frame score (if model supports inout prediction) | |
| print(img1_person1_inout.item()) | |
| # visualize predicted gaze heatmap for each person and gaze in/out of frame score | |
| heatmap_results = [] | |
| for i in range(len(bboxes)): | |
| overlay_img = visualize_heatmap(image, output['heatmap'][0][i], norm_bboxes[0][i], inout_score=output['inout'][0][i] if output['inout'] is not None else None) | |
| heatmap_results.append(overlay_img) | |
| # combined visualization with maximal gaze points for each person | |
| result_gazed = visualize_all(image, output['heatmap'][0], norm_bboxes[0], output['inout'][0] if output['inout'] is not None else None, inout_thresh=0.5) | |
| return result_gazed, heatmap_results | |
| css=""" | |
| div#col-container{ | |
| margin: 0 auto; | |
| max-width: 982px; | |
| } | |
| """ | |
| with gr.Blocks(css=css) as demo: | |
| with gr.Column(elem_id="col-container"): | |
| gr.Markdown("# Gaze-LLE: Gaze Target Estimation via Large-Scale Learned Encoders") | |
| gr.Markdown("A transformer approach for estimating gaze targets that leverages the power of pretrained visual foundation models. Gaze-LLE provides a streamlined gaze architecture that learns only a lightweight gaze decoder on top of a frozen, pretrained visual encoder (DINOv2). Gaze-LLE learns 1-2 orders of magnitude fewer parameters than prior works and doesn't require any extra input modalities like depth and pose!") | |
| gr.HTML(""" | |
| <div style="display:flex;column-gap:4px;"> | |
| <a href="https://github.com/fkryan/gazelle"> | |
| <img src='https://img.shields.io/badge/GitHub-Repo-blue'> | |
| </a> | |
| <a href="https://arxiv.org/abs/2412.09586"> | |
| <img src='https://img.shields.io/badge/ArXiv-Paper-red'> | |
| </a> | |
| <a href="https://huggingface.co/spaces/fffiloni/Gaze-LLE?duplicate=true"> | |
| <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-sm.svg" alt="Duplicate this Space"> | |
| </a> | |
| <a href="https://huggingface.co/fffiloni"> | |
| <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/follow-me-on-HF-sm-dark.svg" alt="Follow me on HF"> | |
| </a> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image(label="Image Input", type="filepath") | |
| submit_button = gr.Button("Submit") | |
| gr.Examples( | |
| examples = ["examples/the_office.png", "examples/succession.png"], | |
| inputs = [input_image] | |
| ) | |
| with gr.Column(): | |
| result = gr.Image(label="Result") | |
| heatmaps = gr.Gallery(label="Heatmap", columns=3) | |
| submit_button.click( | |
| fn = main, | |
| inputs = [input_image], | |
| outputs = [result, heatmaps] | |
| ) | |
| demo.queue().launch(ssr_mode=False, show_error=True, mcp_server=True) |