Spaces:
Running
Running
| import os | |
| import random | |
| import time | |
| import datetime | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| import imageio | |
| from pathlib import Path | |
| from tqdm import tqdm | |
| import gradio as gr | |
| # Import your custom modules | |
| import utils.loss | |
| import utils.samp | |
| import utils.data | |
| import utils.improc | |
| import utils.misc | |
| import utils.saveload | |
| from nets.blocks import InputPadder | |
| from nets.net34 import Net | |
| import imageio | |
| from demo_dense_visualize import Tracker | |
| import spaces | |
| # Set torch matmul precision (as in your original code) | |
| torch.set_float32_matmul_precision('medium') | |
| # -------------------- Utility Functions -------------------- # | |
| def count_parameters(model): | |
| total_params = 0 | |
| for name, parameter in model.named_parameters(): | |
| if not parameter.requires_grad: | |
| continue | |
| total_params += parameter.numel() | |
| print('Total params: %.2f M' % (total_params/1e6)) | |
| return total_params | |
| def seed_everything(seed: int): | |
| random.seed(seed) | |
| os.environ["PYTHONHASHSEED"] = str(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| # torch.cuda.manual_seed(seed) | |
| torch.backends.cudnn.deterministic = True | |
| torch.backends.cudnn.benchmark = False | |
| seed_everything(42) | |
| torch.set_grad_enabled(False) | |
| # -------------------- Model Loading -------------------- # | |
| url = "https://huggingface.co/aharley/alltracker/resolve/main/alltracker.pth" | |
| state_dict = torch.hub.load_state_dict_from_url(url, map_location='cpu') | |
| model = Net(16) | |
| count_parameters(model) | |
| model.load_state_dict(state_dict['model'], strict=True) | |
| print('loaded ckpt') | |
| device = 'cpu:0' | |
| model.to(device) | |
| for n, p in model.named_parameters(): | |
| p.requires_grad = False | |
| model.eval() | |
| tracker = Tracker( | |
| model=model, | |
| mean=torch.tensor([0.485, 0.456, 0.406]).to(device).reshape(1, 3, 1, 1), | |
| std=torch.tensor([0.229, 0.224, 0.225]).to(device).reshape(1, 3, 1, 1), | |
| S=16, | |
| stride=8, | |
| inference_iters=4, | |
| target_res=1024, | |
| device=device, | |
| ) | |
| # -------------------- Step 1: Extract the First Frame -------------------- # | |
| def extract_first_frame(video_path, _): | |
| """ | |
| Opens the video, extracts the first frame, resizes it (largest dimension 1024), | |
| and returns: | |
| - the frame for display (to be annotated), | |
| - the video file path (to store in state), | |
| - a copy of the original first frame (to be used when adding points) | |
| """ | |
| cap = cv2.VideoCapture(video_path) | |
| if not cap.isOpened(): | |
| return None, None, None | |
| ret, frame = cap.read() | |
| cap.release() | |
| if not ret: | |
| return None, video_path, None | |
| # Convert BGR to RGB | |
| frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| scale = min(tracker.target_res / frame_rgb.shape[0], tracker.target_res / frame_rgb.shape[1]) | |
| frame_resized = cv2.resize(frame_rgb, None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR) | |
| # Return the displayed frame, the video file path, and a copy of the original frame for point drawing. | |
| return frame_resized, video_path, frame_resized.copy(), [] | |
| # -------------------- Callback to Add a Clicked Point -------------------- # | |
| def add_point(orig_frame, points, evt: gr.SelectData): | |
| """ | |
| Called when the user clicks on the displayed first frame. | |
| - orig_frame: The original first frame image (numpy array). | |
| - points: The current list of point coordinates. | |
| - evt: Event data from the image click (expects evt.index as (x, y)). | |
| Returns the updated image (with circles drawn at all points) | |
| and the updated list of points. | |
| """ | |
| if points is None: | |
| points = [] | |
| # evt.index contains the (x, y) coordinates of the click. | |
| x, y = evt.index | |
| new_points = points.copy() | |
| new_points.append([x, y]) | |
| # Draw circles on a copy of the original image. | |
| updated_frame = orig_frame.copy() | |
| for (px, py) in new_points: | |
| cv2.circle(updated_frame, (int(round(px)), int(round(py))), radius=5, color=(0,255,0), thickness=-1) | |
| # print(updated_frame.shape) | |
| return updated_frame, new_points | |
| # -------------------- Step 2: Process Video & Track Points -------------------- # | |
| def process_video_with_points(video_path, click_points): | |
| """ | |
| Runs the dense flow prediction over the entire video, tracking the user-selected points. | |
| Args: | |
| video_path: Path to the uploaded video. | |
| click_points: List of [x, y] coordinates selected on the first frame. | |
| (Coordinates are in the same (resized) space as the displayed first frame.) | |
| Returns: | |
| A path to the output video with tracked points overlaid. | |
| """ | |
| if len(click_points) == 0: | |
| print("No points selected for tracking.") | |
| return "Error: No points selected for tracking." | |
| # Open the video. | |
| cap = cv2.VideoCapture(video_path) | |
| if not cap.isOpened(): | |
| return "Error: Could not open video." | |
| fps = cap.get(cv2.CAP_PROP_FPS) | |
| # List to store frames with overlaid points. | |
| output_frames = [] | |
| # Initialize the points with those selected on the first frame. | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| pbar = tqdm(total=total_frames, desc="Processing video") | |
| tracker.reset() | |
| frame_disps = [] | |
| try: | |
| while True: | |
| if 'cuda' in device: | |
| torch.cuda.empty_cache() | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| # Convert frame from BGR to RGB and resize as in your original code. | |
| frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| scale = min(tracker.target_res / frame_rgb.shape[0], tracker.target_res / frame_rgb.shape[1]) | |
| frame_disp = cv2.resize(frame_rgb, None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR) | |
| frame_disps.append(frame_disp) | |
| flows = tracker.track(frame_rgb) | |
| if flows is not None: | |
| flows_np = flows[0].cpu().numpy() | |
| for i, flow_np in enumerate(flows_np): | |
| # --- Update tracked points using the flow --- | |
| current_points = [] | |
| for (x, y) in click_points: | |
| xi = int(round(x)) | |
| yi = int(round(y)) | |
| # print('xi, yi', xi, yi) | |
| if 0 <= yi < flow_np.shape[1] and 0 <= xi < flow_np.shape[2]: | |
| dx = flow_np[0, yi, xi] | |
| dy = flow_np[1, yi, xi] | |
| # print('dx, dy', dx, dy) | |
| else: | |
| dx, dy = 0.0, 0.0 | |
| current_points.append([x + dx, y + dy]) | |
| # Draw the updated points on the frame. | |
| for (x, y) in current_points: | |
| cv2.circle(frame_disps[i], (int(round(x)), int(round(y))), radius=5, color=(0,255,0), thickness=-1) | |
| output_frames.append(frame_disps[i]) | |
| frame_disps = [] | |
| pbar.update(1) | |
| except RuntimeError as e: | |
| # Check if the error message indicates an OOM error. | |
| if "out of memory" in str(e).lower(): | |
| if 'cuda' in device: | |
| torch.cuda.empty_cache() | |
| pbar.close() | |
| cap.release() | |
| print("Error: Out of Memory during video processing.") | |
| return "Error: Out of Memory during video processing. Please try a smaller video or lower resolution." | |
| else: | |
| # Re-raise if it's another type of error. | |
| raise e | |
| pbar.close() | |
| cap.release() | |
| # -------------------- Save the Output Video -------------------- # | |
| output_path = "tracked_output.mp4" | |
| print(len(output_frames), output_frames[0].shape) | |
| imageio.mimwrite(output_path, output_frames, fps=fps) | |
| return output_path | |
| # -------------------- Wrappers to Update Tracker Based on UI Settings -------------------- # | |
| def extract_with_config(video_path, points, resolution, window_index): | |
| """ | |
| Update the tracker configuration using the slider values, then extract the first frame. | |
| - resolution: Target resolution from slider (e.g., 512, 768, 1024). | |
| - window_index: An index (0–3) to be mapped to sliding window lengths {0:2, 1:4, 2:8, 3:16}. | |
| """ | |
| tracker.target_res = resolution | |
| mapping = {0: 2, 1: 4, 2: 8, 3: 16} | |
| tracker.S = mapping.get(int(window_index), 16) | |
| return extract_first_frame(video_path, points) | |
| def process_with_config(video_path, click_points, resolution, window_index): | |
| """ | |
| Update the tracker configuration using the slider values, then process the video. | |
| """ | |
| tracker.target_res = resolution | |
| mapping = {0: 2, 1: 4, 2: 8, 3: 16} | |
| tracker.S = mapping.get(int(window_index), 16) | |
| return process_video_with_points(video_path, click_points) | |
| if __name__ == '__main__': | |
| # -------------------- Gradio Interface -------------------- # | |
| # The interface is built in two steps: | |
| # 1. Upload a video and extract the first frame. | |
| # 2. Annotate the first frame with multiple points (using gr.Points), | |
| # then run tracking on the video. | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## Dense Flow Tracking with Clickable Points") | |
| with gr.Row(): | |
| with gr.Column(): | |
| video_input = gr.Video(label="Upload Video", value="172620-847860540_small.mp4") | |
| extract_btn = gr.Button("Extract First Frame") | |
| # Add sliders for resolution and sliding window length. | |
| resolution_slider = gr.Slider(minimum=512, maximum=1024, step=256, value=1024, label="Target Resolution") | |
| # The slider below outputs an index 0-3; we'll map it to {0:2, 1:4, 2:8, 3:16} | |
| window_slider = gr.Slider(minimum=0, maximum=3, step=1, value=3, label="Sliding Window Length (Index: 0->2, 1->4, 2->8, 3->16)") | |
| with gr.Column(): | |
| # This image will display the first frame and be interactive. | |
| first_frame_display = gr.Image(label="First Frame (Click to add points)", interactive=True) | |
| clear_pts_btn = gr.Button("Clear Points") | |
| # Hidden states: video file path, original first frame, and accumulated click points. | |
| video_state = gr.State(None) | |
| orig_frame_state = gr.State(None) | |
| points_state = gr.State([]) | |
| track_btn = gr.Button("Track Points") | |
| output_video = gr.Video(label="Tracked Video") | |
| # When "Extract First Frame" is clicked, extract and display the first frame. | |
| extract_btn.click( | |
| fn=extract_with_config, | |
| inputs=[video_input, points_state, resolution_slider, window_slider], | |
| outputs=[first_frame_display, video_state, orig_frame_state, points_state] | |
| ) | |
| clear_pts_btn.click( | |
| fn=lambda _, __: (orig_frame_state, []), | |
| inputs=[orig_frame_state, points_state], | |
| outputs=[first_frame_display, points_state] | |
| ) | |
| # When the user clicks on the image, add a point. | |
| first_frame_display.select( | |
| fn=add_point, | |
| inputs=[orig_frame_state, points_state], | |
| outputs=[first_frame_display, points_state] | |
| ) | |
| # When "Track Points" is clicked, process the video using the accumulated points. | |
| track_btn.click( | |
| fn=process_with_config, | |
| inputs=[video_state, points_state, resolution_slider, window_slider], | |
| outputs=output_video | |
| ) | |
| demo.launch() | |