#!/usr/bin/env python3
# -*- coding:utf-8 -*-
"""
YOLOX Traffic Analysis Demo
Combines ByteTrack's robust tracking with advanced traffic analysis features:
- 2D perspective mapping for accurate counting
- Multi-line crossing detection
- Trajectory analysis and origin prediction
- Real-time traffic statistics
"""

import argparse
import os
import time
import json
import numpy as np
import cv2
from loguru import logger
from collections import defaultdict, deque
import torch
import sys
import threading
from queue import Queue
import math
from datetime import datetime, timedelta
import glob

# Add parent directory to path to handle imports
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

# Define COCO_CLASSES as fallback
COCO_CLASSES = (
    "airplane", "apple", "backpack", "banana", "baseball bat", "baseball glove", "bear", "bed", "bench", "bicycle",
    "bird", "boat", "book", "bottle", "bowl", "broccoli", "bus", "cake", "car", "carrot",
    "cat", "cell phone", "chair", "clock", "couch", "cow", "cup", "dining table", "dog", "donut",
    "elephant", "fire hydrant", "fork", "frisbee", "giraffe", "hair drier", "handbag", "horse", "hot dog", "keyboard",
    "kite", "knife", "laptop", "microwave", "motorcycle", "mouse", "orange", "oven", "parking meter", "person",
    "pizza", "potted plant", "refrigerator", "remote", "sandwich", "scissors", "sheep", "sink", "skateboard", "skis",
    "snowboard", "spoon", "sports ball", "stop sign", "suitcase", "surfboard", "teddy bear", "tennis racket", "tie", "toaster",
    "toilet", "toothbrush", "traffic light", "train", "truck", "tv", "umbrella", "vase", "wine glass", "zebra"
)

try:
    from yolox.data.data_augment import ValTransform
    from yolox.exp import get_exp
    from yolox.utils import fuse_model, get_model_info, postprocess
    from yolox.tracker.byte_tracker import BYTETracker
    # Try to import COCO_CLASSES from the module if available
    try:
        from yolox.data.datasets import COCO_CLASSES as COCO_CLASSES_IMPORT
        COCO_CLASSES = COCO_CLASSES_IMPORT
    except ImportError:
        try:
            from yolox.data.datasets.coco_classes import COCO_CLASSES as COCO_CLASSES_IMPORT
            COCO_CLASSES = COCO_CLASSES_IMPORT
        except ImportError:
            # Use our fallback definition
            pass
except ImportError as e:
    print(f"Error importing YOLOX modules: {e}")
    print("Make sure you're running from the YOLOX directory or have YOLOX installed.")
    print("Current directory:", os.getcwd())
    sys.exit(1)


def scan_mobius_input():
    """Scan MobiusInput folder for video and JSON files"""
    mobius_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'MobiusInput')
    
    if not os.path.exists(mobius_dir):
        raise FileNotFoundError(f"MobiusInput directory not found: {mobius_dir}")
    
    # Find video files (.mp4)
    video_files = glob.glob(os.path.join(mobius_dir, "*.mp4"))
    if not video_files:
        raise FileNotFoundError("No .mp4 video files found in MobiusInput directory")
    
    # Find JSON files
    json_files = glob.glob(os.path.join(mobius_dir, "*.json"))
    if not json_files:
        raise FileNotFoundError("No .json configuration files found in MobiusInput directory")
    
    # Use the first video and JSON file found
    video_path = video_files[0]
    json_path = json_files[0]
    
    logger.info(f"Found video file: {video_path}")
    logger.info(f"Found config file: {json_path}")
    
    return video_path, json_path


def parse_mobius_config(json_path):
    """Parse the MobiusInput JSON configuration file"""
    with open(json_path, 'r') as f:
        config = json.load(f)
    
    # Extract required fields
    filename_instance = config.get('FileName_Instance', '')
    start_time = config.get('startTime', '')
    direction = config.get('direction', '')
    
    # Parse FileName_Instance to get File_Name and File_Instance
    if '_' in filename_instance:
        file_name = filename_instance.split('_')[0]
        file_instance_str = filename_instance.split('_')[1]
        # Convert to integer (remove leading zeros)
        file_instance = int(file_instance_str)
    else:
        file_name = filename_instance
        file_instance = 0
    
    # Parse startTime (MM/DD/YYYY_HRMM) and calculate TimeCalc
    time_calc = calculate_time_calc(start_time, file_instance)
    
    # Extract perspective corners if available
    perspective_corners = None
    if 'perspective_corners' in config:
        perspective_corners = [corner['point'] for corner in config['perspective_corners']]
    
    return {
        'file_name': file_name,
        'file_instance': file_instance,
        'time_calc': time_calc,
        'direction': direction,
        'perspective_corners': perspective_corners,
        'original_config': config
    }


def calculate_time_calc(start_time, file_instance):
    """Calculate TimeCalc by adding file_instance hours to start_time"""
    try:
        # Parse MM/DD/YYYY_HRMM format
        date_part, time_part = start_time.split('_')
        month, day, year = map(int, date_part.split('/'))
        hour = int(time_part[:2])
        minute = int(time_part[2:])
        
        # Create datetime object
        start_datetime = datetime(year, month, day, hour, minute)
        
        # Add file_instance hours
        new_datetime = start_datetime + timedelta(hours=file_instance)
        
        # Format back to MM/DD/YYYY_HRMM
        time_calc = new_datetime.strftime('%m/%d/%Y_%H%M')
        
        logger.info(f"Calculated TimeCalc: {start_time} + {file_instance} hours = {time_calc}")
        
        return time_calc
    except Exception as e:
        logger.error(f"Error calculating time_calc: {e}")
        return start_time




class PerspectiveTransform:
    """Handles 2D perspective transformation for top-down view"""
    
    def __init__(self, corners, output_size=(600, 450), zoom_out_factor=1.0):
        self.corners = corners
        self.output_size = output_size
        self.zoom_out_factor = zoom_out_factor
        self.transform_matrix = None
        self.inverse_matrix = None
        
        if len(corners) == 4:
            self._calculate_transform_matrix()
    
    def _calculate_transform_matrix(self):
        """Calculate perspective transform matrix from corners"""
        # Apply zoom out by expanding the source points outward from center
        if self.zoom_out_factor != 1.0:
            src_pts = self._expand_corners(self.corners, self.zoom_out_factor)
        else:
            src_pts = np.float32(self.corners)
        
        # Destination points (rectangular top-down view)
        dst_pts = np.float32([
            [0, 0],  # top-left
            [self.output_size[0], 0],  # top-right
            [self.output_size[0], self.output_size[1]],  # bottom-right
            [0, self.output_size[1]]  # bottom-left
        ])
        
        # Calculate transform matrices
        self.transform_matrix = cv2.getPerspectiveTransform(src_pts, dst_pts)
        self.inverse_matrix = cv2.getPerspectiveTransform(dst_pts, src_pts)
    
    def _expand_corners(self, corners, zoom_factor):
        """Expand corners outward from center to create zoom out effect"""
        corners_array = np.array(corners)
        
        # Calculate center point of the quadrilateral
        center = np.mean(corners_array, axis=0)
        
        # Expand each corner outward from center
        expanded_corners = []
        for corner in corners_array:
            # Vector from center to corner
            vector = corner - center
            # Scale the vector by zoom factor
            expanded_vector = vector * zoom_factor
            # New corner position
            expanded_corner = center + expanded_vector
            expanded_corners.append(expanded_corner)
        
        return np.float32(expanded_corners)
    
    def transform_point(self, point):
        """Transform a point from original to top-down view"""
        if self.transform_matrix is None:
            return point
        
        # Convert point to homogeneous coordinates
        pt = np.array([[[point[0], point[1]]]], dtype=np.float32)
        transformed = cv2.perspectiveTransform(pt, self.transform_matrix)
        return tuple(map(int, transformed[0][0]))
    
    def transform_frame(self, frame):
        """Transform entire frame to top-down view"""
        if self.transform_matrix is None:
            return frame
        
        return cv2.warpPerspective(frame, self.transform_matrix, self.output_size)


class TransformVisualizer:
    """Async visualizer for 2D transformed view with track dots and spawn/death analysis"""
    
    def __init__(self, perspective_transform, window_name="2D Transform View"):
        self.perspective_transform = perspective_transform
        self.window_name = window_name
        self.tracks_queue = Queue(maxsize=10)
        self.running = False
        self.thread = None
        self.current_tracks = []
        self.track_colors = {}
        self.transform_view = None
        self.window_positioned = False
        self.traffic_analyzer = None  # Reference to traffic analyzer for spawn/death data
        
    def start(self):
        """Start the async visualization thread"""
        self.running = True
        self.thread = threading.Thread(target=self._visualization_loop, daemon=True)
        self.thread.start()
        logger.info("Started 2D transform visualization thread")
    
    def stop(self):
        """Stop the visualization thread"""
        self.running = False
        if self.thread and self.thread.is_alive():
            self.thread.join(timeout=1.0)
        # Only destroy window if we're not in headless mode
        try:
            cv2.destroyWindow(self.window_name)
        except:
            pass  # Ignore errors in headless mode
    
    def update_tracks(self, tracks, original_frame=None):
        """Update tracks data (non-blocking)"""
        if not self.tracks_queue.full():
            try:
                self.tracks_queue.put_nowait((tracks, original_frame))
            except:
                pass  # Queue full, skip this update
    
    def _get_track_color(self, track_id):
        """Get consistent color for track ID"""
        if track_id not in self.track_colors:
            # Generate color based on track ID
            hue = (track_id * 137) % 360  # Golden angle for good distribution
            color = cv2.cvtColor(np.uint8([[[hue, 255, 255]]]), cv2.COLOR_HSV2BGR)[0][0]
            self.track_colors[track_id] = tuple(map(int, color))
        return self.track_colors[track_id]
    
    def _visualization_loop(self):
        """Main visualization loop running in separate thread"""
        while self.running:
            try:
                # Get latest tracks data (with timeout)
                if not self.tracks_queue.empty():
                    tracks, original_frame = self.tracks_queue.get(timeout=0.1)
                    self.current_tracks = tracks
                    
                    # Create transform view if we have original frame
                    if original_frame is not None and self.perspective_transform.transform_matrix is not None:
                        self.transform_view = self.perspective_transform.transform_frame(original_frame)
                
                # Create visualization
                self._create_visualization()
                
                # Small delay to prevent excessive CPU usage
                time.sleep(0.033)  # ~30 FPS
                
            except Exception as e:
                if self.running:  # Only log if not shutting down
                    logger.debug(f"Visualization loop error: {e}")
                time.sleep(0.1)
    
    def _create_visualization(self):
        """Create the 2D transform visualization with spawn/death analysis"""
        # Create base image with black background
        vis_img = np.zeros((self.perspective_transform.output_size[1], 
                           self.perspective_transform.output_size[0], 3), dtype=np.uint8)
        
        # Draw grid lines for reference
        self._draw_grid(vis_img)
        
        # Draw spawn-death lines for completed tracks
        if self.traffic_analyzer:
            self._draw_spawn_death_lines_2d(vis_img)
        
        # Draw spawn points for active tracks
        if self.traffic_analyzer:
            self._draw_active_spawn_points_2d(vis_img)
            
        # Removed stationary track indicators - they were cluttering the 2D visualization
        
        
        # Draw current tracks as dots
        for track in self.current_tracks:
            self._draw_track_dot(vis_img, track)
        
        # Add title and traffic info
        cv2.putText(vis_img, "2D Traffic Analysis", (10, 20),
                   cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)
        
        if self.traffic_analyzer:
            # Show 2D traffic counts (more accurate)
            cv2.putText(vis_img, f"2D IN: {self.traffic_analyzer.traffic_counts_2d['IN']}", 
                       (10, 40), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 255, 0), 1)
            cv2.putText(vis_img, f"OUT: {self.traffic_analyzer.traffic_counts_2d['OUT']}", 
                       (80, 40), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 0, 255), 1)
            cv2.putText(vis_img, f"Rej2D: {self.traffic_analyzer.rejected_tracks_2d}", 
                       (130, 40), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (128, 128, 128), 1)
            cv2.putText(vis_img, f"2D Cycles: {len(self.traffic_analyzer.track_lifecycles_2d)}", 
                       (10, 60), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 255), 1)
            
            # Show basic statistics
            cv2.putText(vis_img, f"Active: {len(self.traffic_analyzer.active_tracks)}", 
                       (10, 80), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 0), 1)
            cv2.putText(vis_img, f"IN: {self.traffic_analyzer.traffic_counts_2d['IN']}", 
                       (85, 80), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 255, 0), 1)
            cv2.putText(vis_img, f"OUT: {self.traffic_analyzer.traffic_counts_2d['OUT']}", 
                       (125, 80), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 0, 255), 1)
        else:
            cv2.putText(vis_img, "NO ANALYZER", (10, 40),
                       cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 0, 0), 1)
        
        # Show the visualization (skip in headless mode)
        try:
            # Check if display is available
            import os
            if os.environ.get('DISPLAY') or os.name == 'nt':  # Linux with display or Windows
                cv2.namedWindow(self.window_name, cv2.WINDOW_NORMAL)
                cv2.resizeWindow(self.window_name, 600, 450)
                
                # Position window on top-right (only once)
                if not self.window_positioned:
                    # Get screen resolution (approximate, since OpenCV doesn't provide this directly)
                    # Position at top-right corner
                    cv2.moveWindow(self.window_name, 1200, 50)  # Adjust coordinates as needed
                    self.window_positioned = True
                
                cv2.imshow(self.window_name, vis_img)
                cv2.waitKey(1)
        except:
            pass  # Window might be closed or in headless mode
    
    def _draw_grid(self, img):
        """Draw reference grid on transform view"""
        h, w = img.shape[:2]
        
        # Draw grid lines
        for i in range(0, w, 50):
            cv2.line(img, (i, 0), (i, h), (100, 100, 100), 1)
        for i in range(0, h, 50):
            cv2.line(img, (0, i), (w, i), (100, 100, 100), 1)
    
    def _draw_track_dot(self, img, track):
        """Draw a track as a colored dot in transformed space"""
        # Get track center point
        tlwh = track.tlwh
        center = (int(tlwh[0] + tlwh[2]/2), int(tlwh[1] + tlwh[3]/2))
        
        # Transform point to 2D view
        transformed_center = self.perspective_transform.transform_point(center)
        
        # Check if point is within bounds
        if (0 <= transformed_center[0] < self.perspective_transform.output_size[0] and 
            0 <= transformed_center[1] < self.perspective_transform.output_size[1]):
            
            track_id = track.track_id
            color = self._get_track_color(track_id)
            
            # Draw dot
            cv2.circle(img, transformed_center, 8, color, -1)
            cv2.circle(img, transformed_center, 10, (255, 255, 255), 2)
            
            # Draw track ID
            cv2.putText(img, str(track_id), 
                       (transformed_center[0] + 12, transformed_center[1] - 12),
                       cv2.FONT_HERSHEY_SIMPLEX, 0.4, color, 1)
    
    def _draw_spawn_death_lines_2d(self, img):
        """Draw spawn-death lines using native 2D space calculations"""
        if not self.traffic_analyzer or not hasattr(self.traffic_analyzer, 'track_lifecycles_2d'):
            cv2.putText(img, "NO 2D ANALYZER", (10, 100), cv2.FONT_HERSHEY_SIMPLEX, 0.3, (255, 0, 0), 1)
            return
        if not self.traffic_analyzer.track_lifecycles_2d:
            cv2.putText(img, "NO 2D LIFECYCLES", (10, 100), cv2.FONT_HERSHEY_SIMPLEX, 0.3, (255, 255, 0), 1)
            return
        
        num_lifecycles = len(self.traffic_analyzer.track_lifecycles_2d)
        cv2.putText(img, f"2D LINES: {num_lifecycles}", (10, 100), cv2.FONT_HERSHEY_SIMPLEX, 0.3, (0, 255, 0), 1)
        
        if num_lifecycles == 0:
            cv2.putText(img, "WAITING FOR COMPLETED TRACKS", (10, 120), cv2.FONT_HERSHEY_SIMPLEX, 0.3, (255, 255, 0), 1)
            return
        
        line_count = 0
        for track_id, lifecycle_2d in self.traffic_analyzer.track_lifecycles_2d.items():
            spawn_2d = lifecycle_2d["spawn"]  # Already in 2D space
            death_2d = lifecycle_2d["death"]  # Already in 2D space
            # Use pre-calculated extended coordinates if available
            spawn_extended = lifecycle_2d.get("spawn_extended", spawn_2d)
            death_extended = lifecycle_2d.get("death_extended", death_2d)
            direction = lifecycle_2d["direction"]
            death_frame = lifecycle_2d["death_frame"]
            
            # Calculate fade factor
            frames_since_death = self.traffic_analyzer.frame_count - death_frame
            if frames_since_death > self.traffic_analyzer.fade_duration:
                continue
            
            alpha = max(0.0, 1.0 - (frames_since_death / self.traffic_analyzer.fade_duration))
            
            # Check if points are within bounds
            w, h = self.perspective_transform.output_size
            
            # Debug output
            cv2.putText(img, f"T{track_id} S:{spawn_2d} D:{death_2d}", 
                       (10, 140 + line_count * 15), cv2.FONT_HERSHEY_SIMPLEX, 0.25, (255, 255, 255), 1)
            
            # Base color coding: Green for IN, Red for OUT
            if direction == "IN":
                base_color = (0, 255, 0)  # Green
            elif direction == "OUT":
                base_color = (0, 0, 255)  # Red
            else:
                base_color = (128, 128, 255)  # Light blue for unknown
            
            # Apply alpha to color
            color = tuple(int(c * alpha) for c in base_color)
            line_thickness = max(2, int(3 * alpha))  # Thicker lines for visibility
            
            # Draw the EXTENDED line (using pre-calculated extended coordinates)
            # Ensure coordinates are integer tuples
            spawn_ext_int = (int(spawn_extended[0]), int(spawn_extended[1]))
            death_ext_int = (int(death_extended[0]), int(death_extended[1]))
            cv2.line(img, spawn_ext_int, death_ext_int, color, line_thickness)
            line_count += 1
            
            # Skip drawing spawn points - only draw lines and death points
            # cv2.circle(img, spawn_2d, max(3, int(5 * alpha)), color, -1)
            # cv2.putText(img, "S", (spawn_2d[0] - 5, spawn_2d[1] + 3),
            #            cv2.FONT_HERSHEY_SIMPLEX, 0.3, color, 1)
            
            # Draw death point as X (using actual death point, not extended)
            death_int = (int(death_2d[0]), int(death_2d[1]))
            self._draw_small_x_2d(img, death_int, color, max(2, int(3 * alpha)))
            cv2.putText(img, "D", (death_int[0] + 5, death_int[1] + 3),
                       cv2.FONT_HERSHEY_SIMPLEX, 0.3, color, 1)
            
            # Add track ID if alpha is high enough
            if alpha > 0.5:
                cv2.putText(img, str(track_id), 
                           ((spawn_2d[0] + death_2d[0]) // 2, (spawn_2d[1] + death_2d[1]) // 2 - 5),
                           cv2.FONT_HERSHEY_SIMPLEX, 0.4, color, 1)
        
        # Show total lines drawn
        cv2.putText(img, f"Lines drawn: {line_count}", (10, 120), cv2.FONT_HERSHEY_SIMPLEX, 0.3, (255, 255, 255), 1)
    
    def _draw_active_spawn_points_2d(self, img):
            """Draw spawn points for active tracks in 2D space"""
            if not self.traffic_analyzer or not hasattr(self.traffic_analyzer, 'active_tracks'):
                return
            
            # Only draw spawn points for truly active tracks
            for track_id in self.traffic_analyzer.active_tracks:
                if track_id in self.traffic_analyzer.track_spawn_points_2d:
                    spawn_2d = self.traffic_analyzer.track_spawn_points_2d[track_id]
                    
                    # Check if point is within bounds
                    w, h = self.perspective_transform.output_size
                    if 0 <= spawn_2d[0] < w and 0 <= spawn_2d[1] < h:
                        # Draw cyan circle for active spawn
                        cv2.circle(img, tuple(map(int, spawn_2d)), 5, (0, 255, 255), 2)  # Cyan for visibility
                        cv2.putText(img, f"S{track_id}", 
                                   (int(spawn_2d[0]) + 8, int(spawn_2d[1]) - 8),
                                   cv2.FONT_HERSHEY_SIMPLEX, 0.3, (0, 255, 255), 1)

    
    # def _draw_stationary_tracks_2d(self, img):
    #     """Draw stationary tracks in 2D space - REMOVED to reduce clutter"""
    #     pass
    
    def _draw_small_x_2d(self, img, center, color, thickness):
        """Draw a small X mark for death points in 2D"""
        x, y = center
        size = 3
        cv2.line(img, (x - size, y - size), (x + size, y + size), color, thickness)
        cv2.line(img, (x - size, y + size), (x + size, y - size), color, thickness)


class TrafficAnalyzer:
    """Enhanced analyzer with 2D transform visualization"""
    
    # Define target vehicle classes
    TARGET_VEHICLE_CLASSES = {
        18: "car",      # Your model's class ID 18 = car
        44: "motorcycle", # Your model's class ID 44 = motorcycle
        16: "bus",      # Your model's class ID 16 = bus
        74: "truck"     # Your model's class ID 74 = truck
    }
    
    def __init__(self, args, fps=30):
            self.args = args
            self.track_histories = defaultdict(lambda: deque(maxlen=50))
            self.frame_count = 0
            self.start_time = time.time()
            self.perspective_transform = None
            self.transform_visualizer = None
            
            # 2D Transform space tracking (only 2D calculations)
            self.track_spawn_points_2d = {}  # track_id -> (x, y) in 2D space
            self.track_death_points_2d = {}  # track_id -> (x, y) in 2D space
            self.track_lifecycles_2d = {}    # track_id -> 2D space lifecycle data
            self.active_tracks = set()    # Currently active track IDs
            self.traffic_counts_2d = {"IN": 0, "OUT": 0}  # 2D space direction counters
            self.rejected_tracks_2d = 0      # 2D space rejected tracks
            
            # Track last seen frame tracking
            self.track_last_seen = {}  # track_id -> frame_number
            
            # Track stability and recovery system
            self.track_stability = {}  # track_id -> {"absent_frames": int, "last_position": (x,y), "grace_period": bool}
            self.max_absence_frames = getattr(args, 'max_absence_frames', 10)  # Allow 10 frames of absence
            self.grace_period_frames = getattr(args, 'grace_period_frames', 30)  # 30 frame grace period for recovery
            self.track_recovery_history = {}  # track_id -> list of recovery events
            
            # Edge-based counting (spawn=OUT edge, death=IN edge)
            self.edge_counts = {
                "top": {"IN": 0, "OUT": 0},     # IN=leaving to top, OUT=entering from top
                "bottom": {"IN": 0, "OUT": 0},  # IN=leaving to bottom, OUT=entering from bottom
                "left": {"IN": 0, "OUT": 0},    # IN=leaving to left, OUT=entering from left
                "right": {"IN": 0, "OUT": 0}    # IN=leaving to right, OUT=entering from right
            }
            
            # Detailed traffic counters by direction and vehicle type (32 categories total)
            self.detailed_traffic_counts = {
                "inTop_car": 0, "inTop_motorcycle": 0, "inTop_bus": 0, "inTop_truck": 0,
                "inLeft_car": 0, "inLeft_motorcycle": 0, "inLeft_bus": 0, "inLeft_truck": 0,
                "inRight_car": 0, "inRight_motorcycle": 0, "inRight_bus": 0, "inRight_truck": 0,
                "inBottom_car": 0, "inBottom_motorcycle": 0, "inBottom_bus": 0, "inBottom_truck": 0,
                "outTop_car": 0, "outTop_motorcycle": 0, "outTop_bus": 0, "outTop_truck": 0,
                "outLeft_car": 0, "outLeft_motorcycle": 0, "outLeft_bus": 0, "outLeft_truck": 0,
                "outRight_car": 0, "outRight_motorcycle": 0, "outRight_bus": 0, "outRight_truck": 0,
                "outBottom_car": 0, "outBottom_motorcycle": 0, "outBottom_bus": 0, "outBottom_truck": 0
            }
            
            # NEW: Turning movement counters (origin_to_destination)
            self.turning_movements = defaultdict(int)
            # Will track movements like: "North_to_East_car", "South_to_West_truck", etc.
            self.fade_duration = getattr(args, 'fade_duration', 180)  # Number of frames to fade over (6 seconds at 30fps)
            
            # Movement detection data structures
            self.track_positions = {}     # track_id -> deque of recent positions
            self.track_movement_status = {}  # track_id -> {"is_moving": bool, "movement_start_frame": int}
            self.movement_threshold = getattr(args, 'movement_threshold', 15)  # Minimum distance to consider movement
            self.stability_frames = getattr(args, 'stability_frames', 30)  # Frames to check for stability
            self.min_spawn_death_distance = getattr(args, 'min_spawn_death_distance', 100)  # Minimum distance between spawn and death
            
            # Frame dimensions for reference
            self.frame_dimensions = None
            
            # Vehicle class tracking and stability
            self.track_class_histories = defaultdict(lambda: deque(maxlen=30))  # Class history for each track
            self.track_stable_classes = {}  # Locked-in stable class for each track
            self.class_stability_threshold = 0.8  # 80% threshold for class stability
            
            # Load perspective transform if available
            self._load_perspective_transform()

    
    def _load_perspective_transform(self):
        """Load perspective transform from MobiusInput config or calibration file"""
        # First, try to load from MobiusInput config
        if hasattr(self.args, 'perspective_corners') and self.args.perspective_corners:
            try:
                corners = self.args.perspective_corners
                zoom_out = getattr(self.args, 'zoom_out', 1.0)
                self.perspective_transform = PerspectiveTransform(corners, zoom_out_factor=zoom_out)
                
                # Start transform visualizer
                self.transform_visualizer = TransformVisualizer(self.perspective_transform)
                self.transform_visualizer.traffic_analyzer = self  # Pass reference to self
                self.transform_visualizer.start()
                
                logger.info("Loaded perspective transform from MobiusInput config")
                return
            except Exception as e:
                logger.warning(f"Failed to load perspective transform from MobiusInput: {e}")
        
        # Fallback to original method - try to find perspective calibration file
        if hasattr(self.args, 'path') and self.args.path:
            video_name = os.path.splitext(os.path.basename(self.args.path))[0]
            perspective_file = f"{video_name}_perspective.json"
            
            if os.path.exists(perspective_file):
                try:
                    with open(perspective_file, 'r') as f:
                        data = json.load(f)
                    
                    if 'perspective_corners' in data:
                        corners = [corner['point'] for corner in data['perspective_corners']]
                        zoom_out = getattr(self.args, 'zoom_out', 1.0)
                        self.perspective_transform = PerspectiveTransform(corners, zoom_out_factor=zoom_out)
                        
                        # Start transform visualizer
                        self.transform_visualizer = TransformVisualizer(self.perspective_transform)
                        self.transform_visualizer.traffic_analyzer = self  # Pass reference to self
                        self.transform_visualizer.start()
                        
                        logger.info(f"Loaded perspective transform from {perspective_file}")
                        return
                except Exception as e:
                    logger.warning(f"Failed to load perspective transform: {e}")
            
            logger.info(f"No perspective transform file found ({perspective_file})")
    
    def is_target_vehicle(self, cls_id):
        """Check if class ID is one of our target vehicle types"""
        return cls_id in self.TARGET_VEHICLE_CLASSES
    
    def get_vehicle_class_name(self, cls_id):
        """Get vehicle class name from class ID"""
        return self.TARGET_VEHICLE_CLASSES.get(cls_id, "unknown")
    
    def update_track_class_stability(self, track_id, cls_id):
        """Update class history and determine stable class"""
        if not self.is_target_vehicle(cls_id):
            return None
            
        # Add to class history
        self.track_class_histories[track_id].append(cls_id)
        
        # Check if we have enough history for stability analysis
        if len(self.track_class_histories[track_id]) >= 10:  # Need at least 10 samples
            class_counts = {}
            for class_id in self.track_class_histories[track_id]:
                class_counts[class_id] = class_counts.get(class_id, 0) + 1
            
            total_count = len(self.track_class_histories[track_id])
            
            # Find most frequent class
            most_frequent_class = max(class_counts, key=class_counts.get)
            frequency_ratio = class_counts[most_frequent_class] / total_count
            
            # Lock in stable class if threshold met
            if frequency_ratio >= self.class_stability_threshold:
                self.track_stable_classes[track_id] = most_frequent_class
                return most_frequent_class
        
        return None
    
    def get_track_display_class(self, track_id):
        """Get the class to display for a track (stable if locked, otherwise current)"""
        if track_id in self.track_stable_classes:
            return self.track_stable_classes[track_id]
        elif track_id in self.track_class_histories and self.track_class_histories[track_id]:
            return self.track_class_histories[track_id][-1]  # Most recent
        return None
    
    def update(self, tracks, frame=None):
        """Update with new tracking data"""
        self.frame_count += 1
        
        # Store frame dimensions if not set and frame is provided
        if self.frame_dimensions is None and frame is not None:
            h, w = frame.shape[:2]
            self.frame_dimensions = (w, h)
        
        # Get current track IDs
        current_track_ids = set()
        tracks_outside_boundary = set()
        
        # Update track histories and handle spawn/death tracking
        for track in tracks:
            track_id = track.track_id
            
            # Get class information and filter for vehicles only
            cls_id = int(track.cls) if hasattr(track, 'cls') else None
            
            # If no cls attribute, assume class 0 (since that's what we're detecting)
            if cls_id is None:
                cls_id = 0
            
            # Skip non-vehicle tracks
            if cls_id is None or not self.is_target_vehicle(cls_id):
                continue
                
            # Update class stability tracking
            self.update_track_class_stability(track_id, cls_id)
            
            current_track_ids.add(track_id)
            
            # Skip if track has already been finalized
            if track_id in self.track_lifecycles_2d:
                continue
            
            # Get center from tlwh
            tlwh = track.tlwh
            center = (int(tlwh[0] + tlwh[2]/2), int(tlwh[1] + tlwh[3]/2))
            
            # Initialize position tracking for new tracks
            if track_id not in self.track_positions:
                self.track_positions[track_id] = deque(maxlen=self.stability_frames)
                self.track_movement_status[track_id] = {"is_moving": False, "movement_start_frame": None}
            
            # Update position history
            self.track_positions[track_id].append(center)
            
            # Check if track is moving
            is_moving = self.check_track_movement(track_id)
            
            # Update movement status and create spawn point for moving tracks
            if is_moving:
                if not self.track_movement_status[track_id]["is_moving"]:
                    # Track just started moving
                    self.track_movement_status[track_id]["is_moving"] = True
                    self.track_movement_status[track_id]["movement_start_frame"] = self.frame_count
                
                # Create spawn point for moving tracks that don't have one yet
                if track_id not in self.track_spawn_points_2d and self.perspective_transform:
                    spawn_2d = self.perspective_transform.transform_point(center)
                    
                    # Only create spawn point if not too close to boundary (with padding)
                    if not self.is_at_boundary_2d(spawn_2d, margin=20):  # Increased margin for spawn points
                        self.track_spawn_points_2d[track_id] = spawn_2d
                        self.active_tracks.add(track_id)
                    else:
                        pass
            
            elif not is_moving:
                self.track_movement_status[track_id]["is_moving"] = False
            
            # Update death point for all tracks (both moving and stationary for boundary detection)
            self.update_death_point(track_id, center)
            
            # Check if track has gone outside the 2D transform area
            if self.perspective_transform and track_id in self.track_spawn_points_2d:
                current_2d = self.perspective_transform.transform_point(center)
                if self.is_outside_boundary_2d(current_2d):
                    # Clamp death point back inside boundary for proper counting
                    clamped_death_2d = self.clamp_point_inside_boundary_2d(current_2d)
                    self.track_death_points_2d[track_id] = clamped_death_2d
                    self.finalize_track(track_id)
                    if track_id in self.active_tracks:
                        self.active_tracks.remove(track_id)
                    # Mark as outside boundary to remove from current tracks
                    tracks_outside_boundary.add(track_id)
                    continue  # Skip further processing for this track
            
            # Record last seen frame for this track
            self.track_last_seen[track_id] = self.frame_count
            
            # Store history for all tracks (for visualization)
            self.track_histories[track_id].append(center)
        
        # Remove tracks that went outside boundary from current tracks
        current_track_ids = current_track_ids - tracks_outside_boundary
        
        # Validate spawn points every 10 frames
        if self.frame_count % 10 == 0:
            self.validate_spawn_points(tracks)
        
        # Check for tracks not seen for 15 frames and finalize them
        self.check_unseen_tracks()
        
        # Handle tracks with stability and recovery system
        self.handle_unstable_tracks(current_track_ids)
        
        # Update transform visualizer if available and frame is provided
        if self.transform_visualizer and frame is not None:
            # Pass all the data the visualizer needs directly
            self.transform_visualizer.update_tracks(tracks, frame)
            # Force update the analyzer reference in case it's lost
            self.transform_visualizer.traffic_analyzer = self
    
    def validate_spawn_points(self, tracks):
        """Validate that moving tracks have spawn points, create if missing and within boundary"""
        if not self.perspective_transform:
            return
        
        for track in tracks:
            track_id = track.track_id
            
            # Skip if track has already been finalized
            if track_id in self.track_lifecycles_2d:
                continue
            
            # Skip if track is not moving
            if track_id not in self.track_movement_status or not self.track_movement_status[track_id]["is_moving"]:
                continue
            
            # Create spawn point for moving tracks that don't have one yet (failsafe)
            if track_id not in self.track_spawn_points_2d and track_id in self.active_tracks:
                tlwh = track.tlwh
                center = (int(tlwh[0] + tlwh[2]/2), int(tlwh[1] + tlwh[3]/2))
                spawn_2d = self.perspective_transform.transform_point(center)
                
                # Only create if not at boundary
                if not self.is_at_boundary_2d(spawn_2d, margin=20):
                    self.track_spawn_points_2d[track_id] = spawn_2d
    
    def check_unseen_tracks(self):
        """Check for tracks not seen for 15 frames and force finalize them"""
        unseen_threshold = 15
        tracks_to_finalize = []
        
        for track_id in list(self.active_tracks):
            if track_id in self.track_last_seen:
                frames_unseen = self.frame_count - self.track_last_seen[track_id]
                
                if frames_unseen >= unseen_threshold:
                    # Force create death point at last known location if missing
                    if (track_id not in self.track_death_points_2d and 
                        track_id in self.track_histories and 
                        len(self.track_histories[track_id]) > 0 and
                        self.perspective_transform):
                        
                        # Use last known position from history
                        last_position = self.track_histories[track_id][-1]
                        death_2d = self.perspective_transform.transform_point(last_position)
                        self.track_death_points_2d[track_id] = death_2d
                    
                    # Add to finalize list regardless of whether we had to create death point
                    if track_id in self.track_spawn_points_2d:
                        tracks_to_finalize.append(track_id)
        
        # Finalize tracks that have been unseen too long
        for track_id in tracks_to_finalize:
            self.finalize_track(track_id)
            if track_id in self.active_tracks:
                self.active_tracks.remove(track_id)

    def handle_unstable_tracks(self, current_track_ids):
        """Handle tracks with stability tracking and recovery mechanism"""
        lost_tracks = self.active_tracks - current_track_ids
        recovered_tracks = current_track_ids & set(self.track_stability.keys())
        
        # Handle tracks that just became absent
        for track_id in lost_tracks:
            if track_id not in self.track_stability:
                # Track just disappeared - initialize stability tracking
                last_position = None
                if (track_id in self.track_histories and 
                    len(self.track_histories[track_id]) > 0):
                    last_position = self.track_histories[track_id][-1]
                
                self.track_stability[track_id] = {
                    "absent_frames": 1,
                    "last_position": last_position,
                    "grace_period": True,
                    "first_absent_frame": self.frame_count
                }
            else:
                # Track continues to be absent - increment counter
                self.track_stability[track_id]["absent_frames"] += 1
        
        # Handle tracks that recovered
        for track_id in recovered_tracks:
            if track_id in self.track_stability:
                absent_frames = self.track_stability[track_id]["absent_frames"]
                
                # Record recovery event for analysis
                if track_id not in self.track_recovery_history:
                    self.track_recovery_history[track_id] = []
                self.track_recovery_history[track_id].append({
                    "frame": self.frame_count,
                    "absent_duration": absent_frames
                })
                
                # Remove from stability tracking
                del self.track_stability[track_id]
        
        # Check for tracks that have exceeded grace period
        tracks_to_finalize = []
        for track_id, stability_info in list(self.track_stability.items()):
            absent_frames = stability_info["absent_frames"]
            within_boundary = True
            
            # Check if track was within boundary when it disappeared
            if (stability_info["last_position"] and self.perspective_transform):
                last_2d = self.perspective_transform.transform_point(stability_info["last_position"])
                within_boundary = not self.is_at_boundary_2d(last_2d)
            
            # Apply different thresholds based on whether track was within boundary
            if within_boundary:
                # More lenient for tracks within boundary - they might come back
                threshold = self.grace_period_frames
            else:
                # Stricter for tracks at boundary - likely genuinely left
                threshold = self.max_absence_frames
            
            if absent_frames >= threshold:
                # Create death point if needed
                if (track_id not in self.track_death_points_2d and 
                    stability_info["last_position"] and
                    self.perspective_transform):
                    
                    death_2d = self.perspective_transform.transform_point(stability_info["last_position"])
                    self.track_death_points_2d[track_id] = death_2d
                    
                    if within_boundary:
                        logger.info(f"Track {track_id} lost within boundary after {absent_frames} frames - creating death point at 2D: {death_2d}")
                    else:
                        logger.info(f"Track {track_id} lost at boundary after {absent_frames} frames - creating death point at 2D: {death_2d}")
                
                tracks_to_finalize.append(track_id)
                del self.track_stability[track_id]
        
        # Finalize tracks that exceeded grace period
        for track_id in tracks_to_finalize:
            self.finalize_track(track_id)
            if track_id in self.active_tracks:
                self.active_tracks.remove(track_id)

    def check_track_movement(self, track_id):
        """Check if a track is moving based on position history"""
        if track_id not in self.track_positions:
            return False
        
        positions = self.track_positions[track_id]
        
        # Need at least a few positions to determine movement
        if len(positions) < 3:
            return False
        
        # Calculate total distance traveled in recent frames
        total_distance = 0
        for i in range(1, len(positions)):
            prev_pos = positions[i-1]
            curr_pos = positions[i]
            distance = np.sqrt((curr_pos[0] - prev_pos[0])**2 + (curr_pos[1] - prev_pos[1])**2)
            total_distance += distance
        
        # Check if total movement exceeds threshold
        return total_distance > self.movement_threshold
    
    def update_death_point(self, track_id, center):
        """Update the last known position of a track in 2D space and check boundary"""
        # Skip if track has already been finalized
        if track_id in self.track_lifecycles_2d:
            return
        
        # Only update death point for tracks that have spawn points (are being tracked)
        if track_id not in self.track_spawn_points_2d:
            return
            
        if self.perspective_transform:
            death_2d = self.perspective_transform.transform_point(center)
            self.track_death_points_2d[track_id] = death_2d
            
            # Check if track hit boundary and finalize immediately
            # But only if it has moved some distance from spawn point
            if self.is_at_boundary_2d(death_2d) and track_id in self.active_tracks:
                if track_id in self.track_spawn_points_2d:
                    spawn_2d = self.track_spawn_points_2d[track_id]
                    distance = np.sqrt((death_2d[0] - spawn_2d[0])**2 + (death_2d[1] - spawn_2d[1])**2)
                    
                    # Safeguard: If spawn and death are too close, don't count and delete death point
                    if distance < 20:  # Threshold for too close
                        if track_id in self.track_death_points_2d:
                            del self.track_death_points_2d[track_id]
                        return
                    
                    # Only finalize if moved sufficient distance
                    if distance > 10:  # Small threshold to avoid immediate finalization
                        self.finalize_track(track_id)
                        self.active_tracks.remove(track_id)
                else:
                    # Failsafe: Track hit boundary but no spawn point - try to create one
                    if self.create_failsafe_spawn_point(track_id, death_2d):
                        # Now that we have a spawn point, check distance and finalize if appropriate
                        spawn_2d = self.track_spawn_points_2d[track_id]
                        distance = np.sqrt((death_2d[0] - spawn_2d[0])**2 + (death_2d[1] - spawn_2d[1])**2)
                        
                        if distance > 10:  # Same threshold as above
                            self.finalize_track(track_id)
                            self.active_tracks.remove(track_id)
                        else:
                            if track_id in self.track_death_points_2d:
                                del self.track_death_points_2d[track_id]
                    else:
                        if track_id in self.track_death_points_2d:
                            del self.track_death_points_2d[track_id]
    
    def is_at_boundary_2d(self, point_2d, margin=20):
        """Check if a 2D point is at or near the boundary of the transform area"""
        if not self.perspective_transform:
            return False
        
        frame_w_2d, frame_h_2d = self.perspective_transform.output_size
        x, y = point_2d
        
        # Check if point is at or beyond the boundaries (with margin padding inward)
        return (x <= margin or x >= frame_w_2d - margin or 
                y <= margin or y >= frame_h_2d - margin)
    
    def is_outside_boundary_2d(self, point_2d):
        """Check if a 2D point is completely outside the transform area"""
        if not self.perspective_transform:
            return False
        
        frame_w_2d, frame_h_2d = self.perspective_transform.output_size
        x, y = point_2d
        
        # Check if point is outside the boundaries
        return (x < 0 or x >= frame_w_2d or y < 0 or y >= frame_h_2d)
    
    def clamp_point_inside_boundary_2d(self, point_2d, margin=20):
        """Clamp a 2D point to be inside the transform area with a small margin from edges"""
        if not self.perspective_transform:
            return point_2d
            
        frame_w_2d, frame_h_2d = self.perspective_transform.output_size
        x, y = point_2d
        
        # Clamp x coordinate
        x = max(margin, min(x, frame_w_2d - margin))
        
        # Clamp y coordinate  
        y = max(margin, min(y, frame_h_2d - margin))
        
        return (x, y)
    
    def create_failsafe_spawn_point(self, track_id, current_death_2d):
        """Failsafe: Create spawn point from furthest track history point when missing"""
        if track_id not in self.track_histories or len(self.track_histories[track_id]) == 0:
            logger.warning(f"Track {track_id}: No track history available for failsafe spawn point")
            return False
        
        # Get all track history points and convert to 2D
        history_points_2d = []
        for point in self.track_histories[track_id]:
            point_2d = self.perspective_transform.transform_point(point)
            history_points_2d.append(point_2d)
        
        # Find the point furthest from current death position
        max_distance = 0
        furthest_point_2d = None
        
        for point_2d in history_points_2d:
            distance = np.sqrt((point_2d[0] - current_death_2d[0])**2 + 
                             (point_2d[1] - current_death_2d[1])**2)
            if distance > max_distance:
                max_distance = distance
                furthest_point_2d = point_2d
        
        if furthest_point_2d and max_distance > 20:  # Minimum distance threshold
            # Check if the furthest point is not at boundary
            if not self.is_at_boundary_2d(furthest_point_2d, margin=40):
                self.track_spawn_points_2d[track_id] = furthest_point_2d
                return True
            else:
                return False
        else:
            return False
    
    def finalize_track(self, track_id, skip_edge_counting=False):
        """Finalize a track when it disappears and determine traffic direction (2D only)"""
        if not self.perspective_transform:
            return
        
        # Check if track has already been finalized
        if track_id in self.track_lifecycles_2d:
            return
        # Check if we have death point but missing spawn point - use failsafe
        if track_id not in self.track_spawn_points_2d and track_id in self.track_death_points_2d:
            death_point_2d = self.track_death_points_2d[track_id]
            if not self.create_failsafe_spawn_point(track_id, death_point_2d):
                return
        
        # Check if still missing either point after failsafe attempt
        if track_id not in self.track_spawn_points_2d or track_id not in self.track_death_points_2d:
            return
        
        # Get 2D points
        spawn_point_2d = self.track_spawn_points_2d[track_id]
        death_point_2d = self.track_death_points_2d[track_id]
        
        # Calculate distance in 2D space (more accurate)
        spawn_death_distance_2d = np.sqrt((death_point_2d[0] - spawn_point_2d[0])**2 + 
                                         (death_point_2d[1] - spawn_point_2d[1])**2)
        
        # Safeguard: If spawn and death are too close (stationary/barely moving vehicle)
        if spawn_death_distance_2d < 100:  # Close proximity threshold
            if track_id in self.track_death_points_2d:
                del self.track_death_points_2d[track_id]
            return
        
        # Apply distance filter in 2D space
        min_distance_2d = self.min_spawn_death_distance * 0.5  # Adjust for 2D scale
        if spawn_death_distance_2d < min_distance_2d:
            self.rejected_tracks_2d += 1
            
            # Remove spawn and death points for rejected tracks too
            if track_id in self.track_spawn_points_2d:
                del self.track_spawn_points_2d[track_id]
            if track_id in self.track_death_points_2d:
                del self.track_death_points_2d[track_id]
            return
        
        # Determine direction in 2D space
        direction_2d = self.determine_traffic_direction_2d(spawn_point_2d, death_point_2d)
        
        # Calculate extended line coordinates first
        frame_w_2d, frame_h_2d = self.perspective_transform.output_size
        spawn_extended, death_extended = self.extend_line_to_edges_2d(spawn_point_2d, death_point_2d, frame_w_2d, frame_h_2d)
        
        # Determine spawn and death edges based on where the extended line touches the boundaries
        spawn_edge = self.get_edge_from_extended_point(spawn_extended, frame_w_2d, frame_h_2d)
        death_edge = self.get_edge_from_extended_point(death_extended, frame_w_2d, frame_h_2d)
        
        # Store 2D lifecycle with extended coordinates
        self.track_lifecycles_2d[track_id] = {
            "spawn": spawn_point_2d,
            "death": death_point_2d,
            "spawn_extended": spawn_extended,
            "death_extended": death_extended,
            "direction": direction_2d,
            "death_frame": self.frame_count,
            "distance": spawn_death_distance_2d,
            "spawn_edge": spawn_edge,
            "death_edge": death_edge
        }
        
        # Get vehicle class for detailed counting
        vehicle_class_id = self.get_track_display_class(track_id)
        vehicle_class = self.get_vehicle_class_name(vehicle_class_id) if vehicle_class_id else "unknown"
        
        # Map edge names to compass directions
        edge_to_compass = {
            "top": "North",
            "bottom": "South",
            "left": "West",
            "right": "East"
        }
        
        # Update edge-based counters (only if not skipping due to boundary finalization)
        if not skip_edge_counting:
            # Spawn edge = OUT (entering from boundary), Death edge = IN (leaving to boundary)
            if spawn_edge in self.edge_counts:
                self.edge_counts[spawn_edge]["OUT"] += 1
            if death_edge in self.edge_counts:
                self.edge_counts[death_edge]["IN"] += 1
                
            # Update detailed counters by vehicle type and direction
            if vehicle_class != "unknown":
                # Spawn edge counting (OUT direction)
                if spawn_edge in ["top", "left", "right", "bottom"]:
                    counter_key = f"out{spawn_edge.title()}_{vehicle_class}"
                    if counter_key in self.detailed_traffic_counts:
                        self.detailed_traffic_counts[counter_key] += 1
                
                # Death edge counting (IN direction)
                if death_edge in ["top", "left", "right", "bottom"]:
                    counter_key = f"in{death_edge.title()}_{vehicle_class}"
                    if counter_key in self.detailed_traffic_counts:
                        self.detailed_traffic_counts[counter_key] += 1
                
                # NEW: Track turning movements (origin to destination)
                origin = edge_to_compass.get(spawn_edge, spawn_edge)
                destination = edge_to_compass.get(death_edge, death_edge)
                if origin and destination and origin != destination:  # No U-turns
                    movement_key = f"{origin}_to_{destination}_{vehicle_class}"
                    self.turning_movements[movement_key] += 1
                    logger.info(f"Track {track_id}: Movement {movement_key}")
        
        # Update 2D counters
        if direction_2d in self.traffic_counts_2d:
            self.traffic_counts_2d[direction_2d] += 1
        
        
        # Remove spawn and death points after counting
        if track_id in self.track_spawn_points_2d:
            del self.track_spawn_points_2d[track_id]
        if track_id in self.track_death_points_2d:
            del self.track_death_points_2d[track_id]
    
    
    def determine_traffic_direction_2d(self, spawn_point_2d, death_point_2d):
        """Determine traffic direction in 2D transformed space with proper edge extension"""
        if not self.perspective_transform:
            return "UNKNOWN"
        
        # Get 2D frame dimensions
        frame_w_2d, frame_h_2d = self.perspective_transform.output_size
        
        # Extend the spawn-death line to the edges of the 2D transform box
        spawn_extended, death_extended = self.extend_line_to_edges_2d(spawn_point_2d, death_point_2d, frame_w_2d, frame_h_2d)
        
        # Determine which edges the extended line intersects
        spawn_edge_2d = self.get_closest_edge_2d(spawn_extended, frame_w_2d, frame_h_2d)
        death_edge_2d = self.get_closest_edge_2d(death_extended, frame_w_2d, frame_h_2d)
        
        # Direction logic in 2D space (more accurate than 3D camera perspective)
        if spawn_edge_2d == "bottom" and death_edge_2d == "top":
            return "IN"  # Vehicle entered from bottom, exited at top
        elif spawn_edge_2d == "top" and death_edge_2d == "bottom":
            return "OUT"  # Vehicle entered from top, exited at bottom
        elif spawn_edge_2d == "left" and death_edge_2d == "right":
            return "IN"   # Vehicle entered from left, exited at right
        elif spawn_edge_2d == "right" and death_edge_2d == "left":
            return "OUT"  # Vehicle entered from right, exited at left
        else:
            # For other combinations, use movement vector in 2D space
            movement_vector = (death_point_2d[0] - spawn_point_2d[0], death_point_2d[1] - spawn_point_2d[1])
            
            # Determine primary movement direction
            if abs(movement_vector[0]) > abs(movement_vector[1]):
                # Horizontal movement
                if movement_vector[0] > 0:
                    return "IN"  # Moving right
                else:
                    return "OUT" # Moving left
            else:
                # Vertical movement
                if movement_vector[1] > 0:
                    return "OUT" # Moving down (away from camera typically)
                else:
                    return "IN"  # Moving up (toward camera typically)
    
    def get_edge_from_extended_point(self, extended_point, frame_w_2d, frame_h_2d, tolerance=1):
        """Determine which edge an extended point touches based on its coordinates"""
        x, y = extended_point
        
        # Check if point is valid (not outside frame boundaries by too much)
        if x < -tolerance or x > frame_w_2d + tolerance or y < -tolerance or y > frame_h_2d + tolerance:
            # Point is too far outside boundaries, can't determine edge reliably
            return None
        
        # Check which edge the extended point is on (with small tolerance for floating point)
        if abs(x) <= tolerance:
            return "left"
        elif abs(x - frame_w_2d) <= tolerance:
            return "right"
        elif abs(y) <= tolerance:
            return "top"
        elif abs(y - frame_h_2d) <= tolerance:
            return "bottom"
        else:
            # If not exactly on boundary, check if it's at least close to a boundary
            close_to_boundary = (x <= tolerance or x >= frame_w_2d - tolerance or 
                               y <= tolerance or y >= frame_h_2d - tolerance)
            
            if close_to_boundary:
                # Use proximity method for close points
                dist_to_top = y
                dist_to_bottom = frame_h_2d - y
                dist_to_left = x
                dist_to_right = frame_w_2d - x
                
                min_dist = min(dist_to_top, dist_to_bottom, dist_to_left, dist_to_right)
                
                if min_dist == dist_to_top:
                    return "top"
                elif min_dist == dist_to_bottom:
                    return "bottom"
                elif min_dist == dist_to_left:
                    return "left"
                else:
                    return "right"
            else:
                # Point is not close to any boundary, can't determine edge
                return None
    
    def get_closest_edge(self, point, frame_w, frame_h):
        """Determine which frame edge a point is closest to"""
        x, y = point
        
        # Calculate distances to each edge
        dist_to_top = y
        dist_to_bottom = frame_h - y
        dist_to_left = x
        dist_to_right = frame_w - x
        
        # Find minimum distance
        min_dist = min(dist_to_top, dist_to_bottom, dist_to_left, dist_to_right)
        
        if min_dist == dist_to_top:
            return "top"
        elif min_dist == dist_to_bottom:
            return "bottom"
        elif min_dist == dist_to_left:
            return "left"
        else:
            return "right"
    
    def extend_line_to_edges_2d(self, spawn_point_2d, death_point_2d, frame_w_2d, frame_h_2d):
        """Extend spawn-death line to edges of 2D transform box"""
        x1, y1 = spawn_point_2d
        x2, y2 = death_point_2d
        
        # Calculate line direction vector
        dx = x2 - x1
        dy = y2 - y1
        
        # Handle edge cases
        if abs(dx) < 1e-6 and abs(dy) < 1e-6:
            return spawn_point_2d, death_point_2d
        
        # Extend spawn point backwards to edge
        spawn_extended = spawn_point_2d
        if abs(dx) > 1e-6:
            # Calculate intersections with left/right edges
            t_left = -x1 / dx
            t_right = (frame_w_2d - x1) / dx
            
            # Choose the one that extends backwards (negative t)
            if t_left < 0 and abs(t_left) > 1e-6:
                y_intersect = y1 + t_left * dy
                if 0 <= y_intersect <= frame_h_2d:
                    spawn_extended = (0, y_intersect)
            elif t_right < 0 and abs(t_right) > 1e-6:
                y_intersect = y1 + t_right * dy
                if 0 <= y_intersect <= frame_h_2d:
                    spawn_extended = (frame_w_2d, y_intersect)
        
        if abs(dy) > 1e-6:
            # Calculate intersections with top/bottom edges
            t_top = -y1 / dy
            t_bottom = (frame_h_2d - y1) / dy
            
            # Choose the one that extends backwards (negative t)
            if t_top < 0 and abs(t_top) > 1e-6:
                x_intersect = x1 + t_top * dx
                if 0 <= x_intersect <= frame_w_2d:
                    spawn_extended = (x_intersect, 0)
            elif t_bottom < 0 and abs(t_bottom) > 1e-6:
                x_intersect = x1 + t_bottom * dx
                if 0 <= x_intersect <= frame_w_2d:
                    spawn_extended = (x_intersect, frame_h_2d)
        
        # Extend death point forwards to edge
        death_extended = death_point_2d
        if abs(dx) > 1e-6:
            # Calculate intersections with left/right edges
            if dx > 0 and x2 < frame_w_2d:  # Moving right
                t = (frame_w_2d - x2) / dx
                y_intersect = y2 + t * dy
                if 0 <= y_intersect <= frame_h_2d:
                    death_extended = (frame_w_2d, y_intersect)
            elif dx < 0 and x2 > 0:  # Moving left
                t = -x2 / dx
                y_intersect = y2 + t * dy
                if 0 <= y_intersect <= frame_h_2d:
                    death_extended = (0, y_intersect)
        
        if abs(dy) > 1e-6:
            # Calculate intersections with top/bottom edges
            if dy > 0 and y2 < frame_h_2d:  # Moving down
                t = (frame_h_2d - y2) / dy
                x_intersect = x2 + t * dx
                if 0 <= x_intersect <= frame_w_2d:
                    death_extended = (x_intersect, frame_h_2d)
            elif dy < 0 and y2 > 0:  # Moving up
                t = -y2 / dy
                x_intersect = x2 + t * dx
                if 0 <= x_intersect <= frame_w_2d:
                    death_extended = (x_intersect, 0)
        
        return spawn_extended, death_extended
    
    def get_closest_edge_2d(self, point_2d, frame_w_2d, frame_h_2d):
        """Determine which edge a point is closest to in 2D space"""
        x, y = point_2d
        
        # Calculate distances to each edge in 2D space
        dist_to_top = y
        dist_to_bottom = frame_h_2d - y
        dist_to_left = x
        dist_to_right = frame_w_2d - x
        
        # Find minimum distance
        min_dist = min(dist_to_top, dist_to_bottom, dist_to_left, dist_to_right)
        
        if min_dist == dist_to_top:
            return "top"
        elif min_dist == dist_to_bottom:
            return "bottom"
        elif min_dist == dist_to_left:
            return "left"
        else:
            return "right"
    
    
    def cleanup_faded_lifecycles(self):
        """Remove completely faded lifecycles to save memory (2D only)"""
        current_frame = self.frame_count
        to_remove = []
        
        for track_id, lifecycle in self.track_lifecycles_2d.items():
            death_frame = lifecycle["death_frame"]
            frames_since_death = current_frame - death_frame
            
            if frames_since_death > self.fade_duration:
                to_remove.append(track_id)
        
        for track_id in to_remove:
            # Clean up 2D data
            if track_id in self.track_lifecycles_2d:
                del self.track_lifecycles_2d[track_id]
            # Keep spawn and death points - don't delete them
            # if track_id in self.track_spawn_points_2d:
            #     del self.track_spawn_points_2d[track_id]
            # if track_id in self.track_death_points_2d:
            #     del self.track_death_points_2d[track_id]
            # Clean up movement tracking data
            if track_id in self.track_positions:
                del self.track_positions[track_id]
            if track_id in self.track_movement_status:
                del self.track_movement_status[track_id]
            if track_id in self.track_last_seen:
                del self.track_last_seen[track_id]

    def cleanup_orphaned_spawn_points(self):
        """Remove spawn points for tracks that are no longer active and haven't been seen for a while"""
        to_remove = []
        orphan_threshold = 120  # 4 seconds at 30fps
        
        for track_id in list(self.track_spawn_points_2d.keys()):
            # Check if track is still active
            if track_id not in self.active_tracks:
                # Check if track has been unseen for too long
                if track_id in self.track_last_seen:
                    frames_unseen = self.frame_count - self.track_last_seen[track_id]
                    if frames_unseen > orphan_threshold:
                        to_remove.append(track_id)
                else:
                    # Track has no last_seen record, consider it orphaned
                    to_remove.append(track_id)
        
        # Remove orphaned spawn points
        for track_id in to_remove:
            if track_id in self.track_spawn_points_2d:
                del self.track_spawn_points_2d[track_id]
            
            # Also clean up any orphaned death points
            if track_id in self.track_death_points_2d:
                del self.track_death_points_2d[track_id]
            
            # Clean up stability tracking for orphaned tracks
            if track_id in self.track_stability:
                del self.track_stability[track_id]

    def draw_visualization(self, frame, tracks):
        """Draw track visualization with 2D transform overlay (2D only)"""
        # Store frame for transform overlay
        self._last_frame = frame.copy()
        
        # Clean up old lifecycles that have fully faded
        self.cleanup_faded_lifecycles()
        
        # Clean up orphaned spawn points every 30 frames
        if self.frame_count % 30 == 0:
            self.cleanup_orphaned_spawn_points()
        
        # Draw track trajectories
        for track in tracks:
            track_id = track.track_id
            if track_id in self.track_histories and len(self.track_histories[track_id]) > 1:
                # Convert to numpy array for drawing
                points = np.array(list(self.track_histories[track_id]), dtype=np.int32)
                
                # Draw trajectory
                for i in range(1, len(points)):
                    cv2.line(frame, tuple(points[i-1]), tuple(points[i]), (0, 255, 0), 2)
                
                # Draw current position
                cv2.circle(frame, tuple(points[-1]), 5, (0, 255, 0), -1)
        
        # Skip drawing spawn points in 3D view - only show in 2D transformed view
        # for track_id, spawn_2d in self.track_spawn_points_2d.items():
        #     if self.perspective_transform:
        #         # Transform 2D spawn point back to camera view for visualization
    
        # Draw the 2D transform overlay on top of everything else
        if self.perspective_transform and self.perspective_transform.transform_matrix is not None:
            self._draw_transform_overlay(frame, tracks)
        
        # Add edge-based statistics box in bottom-left corner
        self._draw_edge_statistics_box(frame)
        
        # Show basic traffic statistics (simplified)
        stats_y = 60 if self.perspective_transform else 30
        cv2.putText(frame, f"Total Tracks: {len(self.track_histories)}", 
                   (10, stats_y), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
        
        cv2.putText(frame, f"Active: {len(self.active_tracks)}", 
                   (10, stats_y + 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 0), 2)
        cv2.putText(frame, f"2D Rejected: {self.rejected_tracks_2d}", 
                   (10, stats_y + 60), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (200, 200, 200), 1)
        
        if self.perspective_transform:
            cv2.putText(frame, "2D Transform: Active", 
                       (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)
        
        return frame
    
    def _draw_transform_overlay(self, frame, tracks):
        """Draw 2D transform overlay in top-right corner with spawn/death lines"""
        overlay_size = (400, 300)  # Made much bigger for better visibility
        frame_h, frame_w = frame.shape[:2]
        
        # Position overlay in top-right corner
        overlay_x = frame_w - overlay_size[0] - 10
        overlay_y = 10
        
        # Create mini transform view with black background
        mini_transform = np.zeros((overlay_size[1], overlay_size[0], 3), dtype=np.uint8)
        
        # Draw grid on mini view
        for i in range(0, overlay_size[0], 40):
            cv2.line(mini_transform, (i, 0), (i, overlay_size[1]), (60, 60, 60), 1)
        for i in range(0, overlay_size[1], 40):
            cv2.line(mini_transform, (0, i), (overlay_size[0], i), (60, 60, 60), 1)
        
        # Draw spawn-death lines for completed tracks in 2D
        self._draw_spawn_death_lines_mini(mini_transform, overlay_size)
        
        # Draw spawn points for active tracks
        self._draw_active_spawn_points_mini(mini_transform, overlay_size)
        
        # Draw current track dots on mini view
        for track in tracks:
            tlwh = track.tlwh
            center = (int(tlwh[0] + tlwh[2]/2), int(tlwh[1] + tlwh[3]/2))
            
            # Transform to perspective view coordinates
            transformed = self.perspective_transform.transform_point(center)
            
            # Scale to mini view
            mini_x = int((transformed[0] / self.perspective_transform.output_size[0]) * overlay_size[0])
            mini_y = int((transformed[1] / self.perspective_transform.output_size[1]) * overlay_size[1])
            
            if 0 <= mini_x < overlay_size[0] and 0 <= mini_y < overlay_size[1]:
                color = self._get_track_color(track.track_id)
                cv2.circle(mini_transform, (mini_x, mini_y), 3, color, -1)
                cv2.circle(mini_transform, (mini_x, mini_y), 4, (255, 255, 255), 1)
        
        # Add border only (no text)
        cv2.rectangle(mini_transform, (0, 0), (overlay_size[0]-1, overlay_size[1]-1), (255, 255, 255), 2)
        
        # Overlay on main frame
        frame[overlay_y:overlay_y+overlay_size[1], overlay_x:overlay_x+overlay_size[0]] = mini_transform
    
    def _draw_edge_statistics_box(self, frame):
        """Draw compact vehicle counter display with legend"""
        frame_h, frame_w = frame.shape[:2]
        
        # Create compact stats box (wider to accommodate vehicle types)
        box_width = 320
        box_height = 180
        box_x = 10
        box_y = frame_h - box_height - 10
        
        # Create black background with border
        cv2.rectangle(frame, (box_x, box_y), (box_x + box_width, box_y + box_height), (0, 0, 0), -1)
        cv2.rectangle(frame, (box_x, box_y), (box_x + box_width, box_y + box_height), (255, 255, 255), 2)
        
        # Title
        cv2.putText(frame, "Vehicle Traffic Counter", (box_x + 5, box_y + 15),
                   cv2.FONT_HERSHEY_SIMPLEX, 0.45, (255, 255, 255), 1)
        
        # Legend (compact single line)
        cv2.putText(frame, "C=Car  M=Motorcycle  B=Bus  T=Truck", (box_x + 5, box_y + 30),
                   cv2.FONT_HERSHEY_SIMPLEX, 0.35, (200, 200, 200), 1)
        
        # Compact grid display - 4 rows (directions) x 8 columns (4 IN + 4 OUT vehicle types)
        directions = ["Top", "Left", "Right", "Bottom"]
        vehicle_short = {"car": "C", "motorcycle": "M", "bus": "B", "truck": "T"}
        
        y_start = 50
        line_height = 25
        
        # Header row
        cv2.putText(frame, "Dir ", (box_x + 5, box_y + y_start), cv2.FONT_HERSHEY_SIMPLEX, 0.35, (150, 150, 150), 1)
        cv2.putText(frame, "IN  C M B T", (box_x + 40, box_y + y_start), cv2.FONT_HERSHEY_SIMPLEX, 0.35, (0, 255, 0), 1)
        cv2.putText(frame, "OUT C M B T", (box_x + 120, box_y + y_start), cv2.FONT_HERSHEY_SIMPLEX, 0.35, (0, 0, 255), 1)
        
        for i, direction in enumerate(directions):
            y_pos = box_y + y_start + (i + 1) * line_height
            
            # Direction label
            cv2.putText(frame, f"{direction[:3]}", (box_x + 5, y_pos), 
                       cv2.FONT_HERSHEY_SIMPLEX, 0.35, (255, 255, 255), 1)
            
            # IN counts (green)
            x_offset = 45
            for vehicle in ["car", "motorcycle", "bus", "truck"]:
                counter_key = f"in{direction}_{vehicle}"
                count = self.detailed_traffic_counts.get(counter_key, 0)
                cv2.putText(frame, f"{count:2d}", (box_x + x_offset, y_pos), 
                           cv2.FONT_HERSHEY_SIMPLEX, 0.35, (0, 255, 0), 1)
                x_offset += 15
            
            # OUT counts (red)
            x_offset = 125
            for vehicle in ["car", "motorcycle", "bus", "truck"]:
                counter_key = f"out{direction}_{vehicle}"
                count = self.detailed_traffic_counts.get(counter_key, 0)
                cv2.putText(frame, f"{count:2d}", (box_x + x_offset, y_pos), 
                           cv2.FONT_HERSHEY_SIMPLEX, 0.35, (0, 0, 255), 1)
                x_offset += 15
        
        # Summary totals at bottom
        total_in = sum(v for k, v in self.detailed_traffic_counts.items() if k.startswith("in"))
        total_out = sum(v for k, v in self.detailed_traffic_counts.items() if k.startswith("out"))
        
        cv2.putText(frame, f"Total IN: {total_in}  OUT: {total_out}", 
                   (box_x + 200, box_y + 15), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 0), 1)
        
    
    def _draw_spawn_death_lines_mini(self, img, overlay_size):
        """Draw spawn-death lines with edge extensions in the mini 2D overlay"""
        if not hasattr(self, 'track_lifecycles_2d') or not self.track_lifecycles_2d:
            return
        
        for track_id, lifecycle_2d in self.track_lifecycles_2d.items():
            spawn_2d = lifecycle_2d["spawn"]
            death_2d = lifecycle_2d["death"] 
            direction = lifecycle_2d["direction"]
            death_frame = lifecycle_2d["death_frame"]
            
            # Calculate fade factor
            frames_since_death = self.frame_count - death_frame
            if frames_since_death > self.fade_duration:
                continue
            
            alpha = max(0.0, 1.0 - (frames_since_death / self.fade_duration))
            
            # Use pre-calculated extended coordinates (fixed at finalization time)
            spawn_extended = lifecycle_2d.get("spawn_extended", spawn_2d)
            death_extended = lifecycle_2d.get("death_extended", death_2d)
            frame_w_2d, frame_h_2d = self.perspective_transform.output_size
            
            # Scale all 2D coordinates to mini view
            mini_spawn_x = int((spawn_2d[0] / frame_w_2d) * overlay_size[0])
            mini_spawn_y = int((spawn_2d[1] / frame_h_2d) * overlay_size[1])
            mini_death_x = int((death_2d[0] / frame_w_2d) * overlay_size[0])
            mini_death_y = int((death_2d[1] / frame_h_2d) * overlay_size[1])
            
            mini_spawn_ext_x = int((spawn_extended[0] / frame_w_2d) * overlay_size[0])
            mini_spawn_ext_y = int((spawn_extended[1] / frame_h_2d) * overlay_size[1])
            mini_death_ext_x = int((death_extended[0] / frame_w_2d) * overlay_size[0])
            mini_death_ext_y = int((death_extended[1] / frame_h_2d) * overlay_size[1])
            
            # All lines are green
            base_color = (0, 255, 0)  # Green
            
            # Apply alpha
            color = tuple(int(c * alpha) for c in base_color)
            line_thickness = max(1, int(2 * alpha))
            dash_thickness = max(1, int(1 * alpha))
            
            # Draw main spawn-death line
            cv2.line(img, (mini_spawn_x, mini_spawn_y), (mini_death_x, mini_death_y), color, line_thickness)
            
            # Draw extended lines (dashed style)
            self._draw_dashed_line_mini(img, (mini_spawn_x, mini_spawn_y), (mini_spawn_ext_x, mini_spawn_ext_y), color, dash_thickness)
            self._draw_dashed_line_mini(img, (mini_death_x, mini_death_y), (mini_death_ext_x, mini_death_ext_y), color, dash_thickness)
            
            # Draw spawn point
            if 0 <= mini_spawn_x < overlay_size[0] and 0 <= mini_spawn_y < overlay_size[1]:
                cv2.circle(img, (mini_spawn_x, mini_spawn_y), max(1, int(3 * alpha)), color, -1)
            
            # Draw death point
            if 0 <= mini_death_x < overlay_size[0] and 0 <= mini_death_y < overlay_size[1]:
                cv2.circle(img, (mini_death_x, mini_death_y), max(1, int(2 * alpha)), color, 2)
    
    def _draw_active_spawn_points_mini(self, img, overlay_size):
            """Draw spawn points for currently active tracks only in mini view"""
            if not hasattr(self, 'track_spawn_points_2d') or not hasattr(self, 'active_tracks'):
                return
            
            # Only draw spawn points for truly active tracks
            for track_id in self.active_tracks:
                if track_id in self.track_spawn_points_2d:
                    spawn_2d = self.track_spawn_points_2d[track_id]
                    
                    # Scale to mini view
                    scale_x = overlay_size[0] / self.perspective_transform.output_size[0]
                    scale_y = overlay_size[1] / self.perspective_transform.output_size[1]
                    mini_spawn_x = int(spawn_2d[0] * scale_x)
                    mini_spawn_y = int(spawn_2d[1] * scale_y)
                    
                    # Check bounds
                    if 0 <= mini_spawn_x < overlay_size[0] and 0 <= mini_spawn_y < overlay_size[1]:
                        # Draw cyan circle for active spawn points
                        cv2.circle(img, (mini_spawn_x, mini_spawn_y), 3, (255, 255, 0), -1)  # Yellow for active spawns

    
    def _draw_dashed_line_mini(self, img, pt1, pt2, color, thickness):
        """Draw a dashed line between two points in mini view"""
        x1, y1 = pt1
        x2, y2 = pt2
        
        # Calculate line length and direction
        dx = x2 - x1
        dy = y2 - y1
        length = np.sqrt(dx**2 + dy**2)
        
        if length == 0:
            return
        
        # Normalize direction
        dx_norm = dx / length
        dy_norm = dy / length
        
        # Draw dashed line with smaller segments for mini view
        dash_length = 3
        gap_length = 2
        current_length = 0
        
        while current_length < length:
            # Start of dash
            start_x = int(x1 + current_length * dx_norm)
            start_y = int(y1 + current_length * dy_norm)
            
            # End of dash
            end_length = min(current_length + dash_length, length)
            end_x = int(x1 + end_length * dx_norm)
            end_y = int(y1 + end_length * dy_norm)
            
            # Draw dash
            cv2.line(img, (start_x, start_y), (end_x, end_y), color, thickness)
            
            # Move to next dash
            current_length += dash_length + gap_length
    
    def _get_track_color(self, track_id):
        """Get consistent color for track ID (shared with TransformVisualizer)"""
        if not hasattr(self, '_track_colors'):
            self._track_colors = {}
        
        if track_id not in self._track_colors:
            # Generate color based on track ID
            hue = (track_id * 137) % 360  # Golden angle for good distribution
            color = cv2.cvtColor(np.uint8([[[hue, 255, 255]]]), cv2.COLOR_HSV2BGR)[0][0]
            self._track_colors[track_id] = tuple(map(int, color))
        return self._track_colors[track_id]
    
    
    def cleanup(self):
        """Clean up resources"""
        if self.transform_visualizer:
            self.transform_visualizer.stop()


class TrafficByteTracker:
    """Simple wrapper for ByteTracker with basic tracking"""
    
    def __init__(self, args, frame_rate=30):
        self.tracker = BYTETracker(args, frame_rate)
        self.traffic_analyzer = TrafficAnalyzer(args, fps=frame_rate)
    
    def update(self, output_results, img_info, img_size):
        """Update tracking"""
        # Run standard ByteTrack update
        online_targets = self.tracker.update(output_results, img_info, img_size)
        
        # Update basic tracking statistics  
        if online_targets:
            self.traffic_analyzer.update(online_targets, None)
        
        return online_targets


def make_parser():
    """Create argument parser"""
    parser = argparse.ArgumentParser("ByteTrack with traffic analysis demo")
    parser.add_argument("demo", default="video", help="demo type")
    parser.add_argument("-expn", "--experiment-name", type=str, default=None)
    parser.add_argument("-n", "--name", type=str, default=None, help="model name")

    parser.add_argument(
        "--path", default="./videos/palace.mp4", help="path to video"
    )
    parser.add_argument("--camid", type=int, default=0, help="webcam demo camera id")
    parser.add_argument(
        "--save_result",
        action="store_true",
        help="whether to save the inference result of video",
    )

    # exp file
    parser.add_argument(
        "-f",
        "--exp_file",
        default=None,
        type=str,
        help="please input your experiment description file",
    )
    parser.add_argument("-c", "--ckpt", default=None, type=str, help="ckpt for eval")
    parser.add_argument(
        "--device",
        default="gpu",
        type=str,
        help="device to run our model, can either be cpu or gpu",
    )
    parser.add_argument("--conf", default=0.25, type=float, help="test conf")
    parser.add_argument("--nms", default=0.25, type=float, help="test nms threshold")
    parser.add_argument("--tsize", default=640, type=int, help="test img size")
    parser.add_argument(
        "--fps", default=30, type=int, help="frame rate (fps)"
    )
    parser.add_argument(
        "--fp16",
        dest="fp16",
        default=False,
        action="store_true",
        help="Adopting mix precision evaluating.",
    )
    parser.add_argument(
        "--legacy",
        dest="legacy",
        default=False,
        action="store_true",
        help="To be compatible with older versions",
    )
    parser.add_argument(
        "--fuse",
        dest="fuse",
        default=False,
        action="store_true",
        help="Fuse conv and bn for testing.",
    )
    parser.add_argument(
        "--trt",
        dest="trt",
        default=True,
        action="store_true",
        help="Using TensorRT model for testing.",
    )
    
    # Tracking arguments
    parser.add_argument("--track_thresh", type=float, default=0.5, help="tracking confidence threshold")
    parser.add_argument("--track_buffer", type=int, default=30, help="the frames for keep lost tracks")
    parser.add_argument("--match_thresh", type=float, default=0.8, help="matching threshold for tracking")
    parser.add_argument("--min_box_area", type=float, default=1, help="filter out tiny boxes")
    parser.add_argument("--mot20", dest="mot20", default=False, action="store_true", help="test mot20.")
    
    # Traffic analysis arguments
    parser.add_argument("--perspective_config", type=str, default=None,
                       help="Path to perspective transformation config file")
    parser.add_argument("--fade_duration", type=int, default=180,
                       help="Number of frames to fade visualization over (default: 180)")
    parser.add_argument("--movement_threshold", type=int, default=15,
                       help="Minimum distance to consider movement (default: 15 pixels)")
    parser.add_argument("--stability_frames", type=int, default=30,
                       help="Frames to check for stability (default: 30)")
    parser.add_argument("--min_spawn_death_distance", type=int, default=100,
                       help="Minimum pixel distance between spawn and death points to count (default: 100 pixels)")
    
    # Track stability and recovery arguments
    parser.add_argument("--max_absence_frames", type=int, default=10,
                       help="Maximum frames a track can be absent before being considered unstable (default: 10)")
    parser.add_argument("--grace_period_frames", type=int, default=30,
                       help="Grace period frames for track recovery within boundaries (default: 30)")
    parser.add_argument("--zoom_out", type=float, default=1.375,
                       help="Zoom out factor for perspective transform (default: 1.0, >1.0 zooms out)")
    parser.add_argument("--no-save", dest="no_save", default=False, action="store_true",
                       help="Skip video saving for maximum processing throughput")
    
    return parser


class Predictor(object):
    def __init__(
        self,
        model,
        exp,
        cls_names=COCO_CLASSES,
        trt_file=None,
        decoder=None,
        device="cpu",
        fp16=False,
        legacy=False,
    ):
        self.model = model
        self.cls_names = cls_names
        self.decoder = decoder
        self.num_classes = exp.num_classes
        self.confthre = exp.test_conf
        self.nmsthre = exp.nmsthre
        self.test_size = exp.test_size
        self.device = device
        self.fp16 = fp16
        self.preproc = ValTransform(legacy=legacy)
        
        if trt_file is not None:
            from torch2trt import TRTModule

            model_trt = TRTModule()
            model_trt.load_state_dict(torch.load(trt_file))

            # Validate TensorRT model with warmup
            x = torch.ones(1, 3, exp.test_size[0], exp.test_size[1]).cuda()
            
            # Test original model first
            with torch.no_grad():
                self.model(x)
            
            self.model = model_trt

    def inference(self, img):
        img_info = {"id": 0}
        if isinstance(img, str):
            img_info["file_name"] = os.path.basename(img)
            img = cv2.imread(img)
        else:
            img_info["file_name"] = None

        height, width = img.shape[:2]
        img_info["height"] = height
        img_info["width"] = width
        img_info["raw_img"] = img

        ratio = min(self.test_size[0] / img.shape[0], self.test_size[1] / img.shape[1])
        img_info["ratio"] = ratio

        img, _ = self.preproc(img, None, self.test_size)
        img = torch.from_numpy(img).unsqueeze(0)
        img = img.float()
        if self.device == "gpu":
            img = img.cuda()
            if self.fp16:
                img = img.half()  # to FP16

        with torch.no_grad():
            t0 = time.time()
            outputs = self.model(img)
            if self.decoder is not None:
                outputs = self.decoder(outputs, dtype=outputs.type())
            
            
            outputs = postprocess(
                outputs, self.num_classes, self.confthre,
                self.nmsthre, class_agnostic=True
            )
            
                
        return outputs, img_info


def process_video(predictor, vis_folder, current_time, args):
    """Process video with traffic analysis"""
    cap = cv2.VideoCapture(args.path if args.demo == "video" else args.camid)
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = cap.get(cv2.CAP_PROP_FPS)
    
    # Initialize tracker
    tracker = TrafficByteTracker(args, frame_rate=fps)
    
    frame_id = 0
    results = []
    
    
    if args.save_result and not args.no_save:
        save_folder = os.path.join(
            vis_folder, time.strftime("%Y_%m_%d_%H_%M_%S", current_time)
        )
        os.makedirs(save_folder, exist_ok=True)
        save_path = os.path.join(save_folder, "traffic_analysis_" + os.path.basename(args.path))
        logger.info(f"Video save path: {save_path}")
        
        vid_writer = cv2.VideoWriter(
            save_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (width, height)
        )
    
    while True:
        ret_val, frame = cap.read()
        if not ret_val:
            break
        
        frame_id += 1
        
        # Get detections
        outputs, img_info = predictor.inference(frame)
        
        # Update tracker
        online_targets = []
        if outputs[0] is not None:
            img_h, img_w = frame.shape[:2]
            img_info_track = [img_h, img_w]
            img_size = predictor.test_size
            
            # Update tracker
            online_targets = tracker.update(outputs[0], img_info_track, img_size)
        
        # Update tracker with frame for transform visualization
        tracker.traffic_analyzer.update(online_targets, frame)
        
        # Store frame dimensions in traffic analyzer if not set
        if tracker.traffic_analyzer.frame_dimensions is None:
            h, w = frame.shape[:2]
            tracker.traffic_analyzer.frame_dimensions = (w, h)
        
        # Visualize results
        result_frame = tracker.traffic_analyzer.draw_visualization(frame, online_targets)
        
        # Add ByteTrack visualization (only for target vehicles)
        for track in online_targets:
            tlwh = track.tlwh
            track_id = track.track_id
            
            # Get class information and filter for vehicles only
            cls_id = int(track.cls) if hasattr(track, 'cls') else None
            
            # If no cls attribute, assume class 0 (since that's what we're detecting)
            if cls_id is None:
                cls_id = 0
            
            # Skip non-vehicle tracks
            if cls_id is None or not tracker.traffic_analyzer.is_target_vehicle(cls_id):
                continue
            
            x1, y1, w, h = tlwh
            x2, y2 = x1 + w, y1 + h
            
            score = track.score if hasattr(track, 'score') else 0.0
            
            # Get stable class name (locked-in class if available)
            stable_cls_id = tracker.traffic_analyzer.get_track_display_class(track_id)
            if stable_cls_id:
                cls_name = tracker.traffic_analyzer.get_vehicle_class_name(stable_cls_id)
                if not cls_name:  # Handle non-vehicle classes
                    cls_name = COCO_CLASSES[stable_cls_id] if stable_cls_id < len(COCO_CLASSES) else f"class_{stable_cls_id}"
                stability_indicator = "*" if track_id in tracker.traffic_analyzer.track_stable_classes else ""
            else:
                cls_name = tracker.traffic_analyzer.get_vehicle_class_name(cls_id)
                if not cls_name:  # Handle non-vehicle classes
                    cls_name = COCO_CLASSES[cls_id] if cls_id < len(COCO_CLASSES) else f"class_{cls_id}"
                stability_indicator = ""
            
            # Color coding by vehicle type
            vehicle_colors = {
                "car": (0, 255, 0),         # Green
                "motorcycle": (255, 0, 255), # Magenta  
                "bus": (0, 255, 255),       # Cyan
                "truck": (0, 165, 255)      # Orange
            }
            color = vehicle_colors.get(cls_name, (255, 255, 255))
            
            # Draw bounding box with vehicle-specific color
            cv2.rectangle(result_frame, (int(x1), int(y1)), (int(x2), int(y2)), color, 2)
            
            # Draw track ID and stable class with stability indicator
            label = f"ID:{track_id} {cls_name.upper()}{stability_indicator}"
            cv2.putText(result_frame, label, (int(x1), int(y1) - 10),
                       cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2)
        
        # Draw 2D overlay on top of everything else to ensure it's visible
        if (tracker.traffic_analyzer.perspective_transform and 
            tracker.traffic_analyzer.perspective_transform.transform_matrix is not None):
            tracker.traffic_analyzer._draw_transform_overlay(result_frame, online_targets)
        
        # Show frame info above traffic counter
        cv2.putText(result_frame, f"Frame: {frame_id}", (10, height - 200),
                   cv2.FONT_HERSHEY_SIMPLEX, 0.6, (200, 200, 200), 2)
        
        if args.save_result and not args.no_save:
            vid_writer.write(result_frame)
    
    
    # Cleanup
    tracker.traffic_analyzer.cleanup()
    
    # Report traffic analysis results to SQL database
    try:
        from sql_reporter import SQLReporter
        
        # Prepare traffic data for database
        traffic_data = {
            'file_name': args.mobius_config.get('file_name', ''),
            'file_instance': args.mobius_config.get('file_instance', 0),
            'time_calc': args.mobius_config.get('time_calc', ''),
        }
        
        # Add all detailed traffic counts (32 categories) for backwards compatibility
        try:
            detailed_counts = tracker.traffic_analyzer.detailed_traffic_counts
            traffic_data.update(detailed_counts)
        except AttributeError as e:
            logger.error(f"AttributeError accessing detailed_traffic_counts: {e}")
            detailed_counts = {}
        
        # NEW: Add turning movement data for ReportTable2
        try:
            turning_movements = dict(tracker.traffic_analyzer.turning_movements)
            traffic_data['turning_movements'] = turning_movements
            
            # Calculate summary totals
            total_vehicles = sum(turning_movements.values())
            traffic_data['Total_Vehicles'] = total_vehicles
            
            # Calculate totals by vehicle type
            totals_by_type = defaultdict(int)
            for movement_key, count in turning_movements.items():
                # Extract vehicle type from movement key (e.g., "North_to_East_car" -> "car")
                parts = movement_key.split('_')
                if len(parts) >= 4:
                    vehicle_type = parts[-1]
                    totals_by_type[vehicle_type] += count
            
            traffic_data['Total_Cars'] = totals_by_type.get('car', 0)
            traffic_data['Total_Motorcycles'] = totals_by_type.get('motorcycle', 0)
            traffic_data['Total_Buses'] = totals_by_type.get('bus', 0)
            traffic_data['Total_Trucks'] = totals_by_type.get('truck', 0)
            
        except AttributeError as e:
            logger.error(f"AttributeError accessing turning_movements: {e}")
            traffic_data['turning_movements'] = {}
        
        # Debug: Log the traffic data being sent
        logger.info(f"Traffic data being sent to SQL: {traffic_data}")
        logger.info(f"Detailed counts: {detailed_counts}")
        
        # Report to database (will use ReportTable2 if turning_movements present)
        reporter = SQLReporter()
        success = reporter.report_traffic_analysis(traffic_data)
        
        if success:
            logger.info("Successfully reported traffic analysis to database")
        else:
            logger.warning("Failed to report traffic analysis to database")
            
    except Exception as e:
        logger.error(f"Error reporting to database: {e}")
    
    cap.release()
    if args.save_result:
        vid_writer.release()


def main(exp, args):
    if not args.experiment_name:
        args.experiment_name = exp.exp_name

    file_name = os.path.join(exp.output_dir, args.experiment_name)
    os.makedirs(file_name, exist_ok=True)

    vis_folder = None
    if args.save_result:
        vis_folder = os.path.join(file_name, "traffic_analysis")
        os.makedirs(vis_folder, exist_ok=True)

    if args.trt:
        args.device = "gpu"

    logger.info("Args: {}".format(args))

    if args.conf is not None:
        exp.test_conf = args.conf
    if args.nms is not None:
        exp.nmsthre = args.nms
    if args.tsize is not None:
        exp.test_size = (args.tsize, args.tsize)

    model = exp.get_model()
    logger.info("Model Summary: {}".format(get_model_info(model, exp.test_size)))

    if args.device == "gpu":
        model.cuda()
        if args.fp16:
            model.half()
    model.eval()

    if not args.trt:
        if args.ckpt is None:
            ckpt_file = os.path.join(file_name, "best_ckpt.pth")
        else:
            ckpt_file = args.ckpt
        logger.info("loading checkpoint")
        ckpt = torch.load(ckpt_file, map_location="cpu")
        model.load_state_dict(ckpt["model"])
        logger.info("loaded checkpoint done.")

    if args.fuse:
        logger.info("\tFusing model...")
        model = fuse_model(model)

    if args.trt:
        assert not args.fuse, "TensorRT model is not support model fusing!"
        trt_file = os.path.join(file_name, "model_trt.pth")
        assert os.path.exists(
            trt_file
        ), "TensorRT model is not found!\n Run python3 tools/trt.py first!"
        model.head.decode_in_inference = False
        decoder = model.head.decode_outputs
        logger.info("Using TensorRT to inference")
    else:
        trt_file = None
        decoder = None

    predictor = Predictor(
        model, exp, COCO_CLASSES, trt_file, decoder,
        args.device, args.fp16, args.legacy,
    )
    current_time = time.localtime()
    
    if args.demo == "video" or args.demo == "webcam":
        process_video(predictor, vis_folder, current_time, args)
    else:
        logger.error("This demo only supports video input")


if __name__ == "__main__":
    # Use MobiusInput instead of command-line arguments
    try:
        # Scan MobiusInput folder for video and config files
        video_path, json_path = scan_mobius_input()
        mobius_config = parse_mobius_config(json_path)
        
        logger.info(f"Processing MobiusInput data:")
        logger.info(f"  File_Name: {mobius_config['file_name']}")
        logger.info(f"  File_Instance: {mobius_config['file_instance']}")
        logger.info(f"  TimeCalc: {mobius_config['time_calc']}")
        logger.info(f"  Direction: {mobius_config['direction']}")
        
        # Create default args with MobiusInput overrides
        parser = make_parser()
        args = parser.parse_args(['video'])  # Parse with default demo type
        
        # Override args with MobiusInput data
        args.path = video_path
        args.demo = "video"
        args.save_result = True
        
        # Set default model configuration if not provided
        if not args.name:
            args.name = "yolox-l"
        if not args.ckpt:
            args.ckpt = "yolox_l.pth"
        
        # Store MobiusInput config for later use
        args.mobius_config = mobius_config
        
        # If perspective corners are available, use them
        if mobius_config['perspective_corners']:
            args.perspective_corners = mobius_config['perspective_corners']
        
        exp = get_exp(args.exp_file, args.name)
        main(exp, args)
        
    except Exception as e:
        logger.error(f"Error processing MobiusInput: {e}")
        raise SystemExit(f"Failed to process MobiusInput: {e}")